From de3317e26bb67a2a7ea015a183bbd1d369880ebd Mon Sep 17 00:00:00 2001 From: Timothy Jaeryang Baek Date: Tue, 17 Mar 2026 17:58:01 -0500 Subject: [PATCH] refac --- backend/open_webui/__init__.py | 64 +- backend/open_webui/config.py | 3281 ++++++++-------- backend/open_webui/constants.py | 138 +- backend/open_webui/env.py | 707 ++-- backend/open_webui/functions.py | 180 +- backend/open_webui/internal/db.py | 50 +- .../internal/migrations/001_initial_schema.py | 48 +- .../migrations/002_add_local_sharing.py | 6 +- .../migrations/003_add_auth_api_key.py | 6 +- .../internal/migrations/004_add_archived.py | 4 +- .../internal/migrations/005_add_updated_at.py | 36 +- .../006_migrate_timestamps_and_charfields.py | 40 +- .../migrations/007_add_user_last_active_at.py | 12 +- .../internal/migrations/008_add_memory.py | 4 +- .../internal/migrations/009_add_models.py | 4 +- .../010_migrate_modelfiles_to_models.py | 48 +- .../migrations/011_add_user_settings.py | 4 +- .../internal/migrations/012_add_tools.py | 4 +- .../internal/migrations/013_add_user_info.py | 4 +- .../internal/migrations/014_add_files.py | 4 +- .../internal/migrations/015_add_functions.py | 4 +- .../016_add_valves_and_is_active.py | 12 +- .../migrations/017_add_user_oauth_sub.py | 4 +- .../migrations/018_add_function_is_global.py | 4 +- backend/open_webui/internal/wrappers.py | 32 +- backend/open_webui/main.py | 1057 +++--- backend/open_webui/migrations/env.py | 24 +- backend/open_webui/migrations/util.py | 2 +- .../versions/018012973d35_add_indexes.py | 32 +- .../versions/1af9b942657b_migrate_tags.py | 85 +- .../242a2047eae0_update_chat_table.py | 58 +- ...update_message_and_channel_member_table.py | 64 +- .../374d2f66af06_add_prompt_history_table.py | 176 +- .../3781e22d8b01_update_message_table.py | 54 +- .../37f288994c47_add_group_member_table.py | 90 +- .../38d63c18f30f_add_oauth_session_table.py | 64 +- .../versions/3ab32c4b8f59_update_tags.py | 57 +- .../3af16a1c9fb6_update_user_table.py | 20 +- .../3e0e00844bb0_add_knowledge_file_table.py | 108 +- ...ce53fd72c8_update_folder_table_datetime.py | 24 +- .../57c599a3cb57_add_channel_table.py | 46 +- .../6283dc0e4d8d_add_channel_file_table.py | 34 +- .../6a39f3d8e55c_add_knowledge_table.py | 48 +- .../7826ab40b532_update_file_table.py | 10 +- .../migrations/versions/7e5b5dc7342b_init.py | 272 +- ...update_channel_file_and_knowledge_table.py | 24 +- .../8452d01d26d7_add_chat_message_table.py | 164 +- ...pdate_channel_and_channel_members_table.py | 66 +- .../versions/922e7a387820_add_group_table.py | 58 +- .../versions/9f0c9cd09105_add_note_table.py | 24 +- .../versions/a1b2c3d4e5f6_add_skill_table.py | 36 +- ...13937_add_reply_to_id_column_to_message.py | 10 +- .../af906e964978_add_feedback_table.py | 36 +- .../b10670c03dd5_update_user_table.py | 150 +- ...3d4e5f6a7_add_scim_column_to_user_table.py | 8 +- .../c0fbf31ca0db_update_file_table.py | 16 +- .../c29facfe716b_update_file_table_path.py | 34 +- .../c440947495f3_add_chat_file_table.py | 38 +- .../versions/c69f45358db4_add_folder_table.py | 36 +- .../versions/ca81bd47c050_add_config_table.py | 20 +- .../d31026856c01_update_folder_table_data.py | 8 +- .../f1e2d3c4b5a6_add_access_grant_table.py | 227 +- backend/open_webui/models/access_grants.py | 231 +- backend/open_webui/models/auths.py | 52 +- backend/open_webui/models/channels.py | 331 +- backend/open_webui/models/chat_messages.py | 287 +- backend/open_webui/models/chats.py | 673 ++-- backend/open_webui/models/feedbacks.py | 147 +- backend/open_webui/models/files.py | 118 +- backend/open_webui/models/folders.py | 72 +- backend/open_webui/models/functions.py | 149 +- backend/open_webui/models/groups.py | 189 +- backend/open_webui/models/knowledge.py | 259 +- backend/open_webui/models/memories.py | 28 +- backend/open_webui/models/messages.py | 231 +- backend/open_webui/models/models.py | 174 +- backend/open_webui/models/notes.py | 152 +- backend/open_webui/models/oauth_sessions.py | 86 +- backend/open_webui/models/prompt_history.py | 48 +- backend/open_webui/models/prompts.py | 194 +- backend/open_webui/models/skills.py | 108 +- backend/open_webui/models/tags.py | 82 +- backend/open_webui/models/tools.py | 118 +- backend/open_webui/models/users.py | 227 +- .../retrieval/loaders/datalab_marker.py | 201 +- .../retrieval/loaders/external_document.py | 33 +- .../retrieval/loaders/external_web.py | 12 +- backend/open_webui/retrieval/loaders/main.py | 411 +- .../open_webui/retrieval/loaders/mineru.py | 252 +- .../open_webui/retrieval/loaders/mistral.py | 330 +- .../open_webui/retrieval/loaders/tavily.py | 32 +- .../open_webui/retrieval/loaders/youtube.py | 60 +- .../open_webui/retrieval/models/colbert.py | 28 +- .../open_webui/retrieval/models/external.py | 36 +- backend/open_webui/retrieval/utils.py | 636 ++-- .../open_webui/retrieval/vector/dbs/chroma.py | 70 +- .../retrieval/vector/dbs/elasticsearch.py | 177 +- .../retrieval/vector/dbs/mariadb_vector.py | 132 +- .../open_webui/retrieval/vector/dbs/milvus.py | 217 +- .../vector/dbs/milvus_multitenancy.py | 128 +- .../retrieval/vector/dbs/opengauss.py | 161 +- .../retrieval/vector/dbs/opensearch.py | 164 +- .../retrieval/vector/dbs/oracle23ai.py | 232 +- .../retrieval/vector/dbs/pgvector.py | 283 +- .../retrieval/vector/dbs/pinecone.py | 252 +- .../open_webui/retrieval/vector/dbs/qdrant.py | 84 +- .../vector/dbs/qdrant_multitenancy.py | 80 +- .../retrieval/vector/dbs/s3vector.py | 349 +- .../retrieval/vector/dbs/weaviate.py | 124 +- .../open_webui/retrieval/vector/factory.py | 3 +- backend/open_webui/retrieval/vector/main.py | 4 +- backend/open_webui/retrieval/vector/type.py | 24 +- backend/open_webui/retrieval/vector/utils.py | 6 +- backend/open_webui/retrieval/web/azure.py | 55 +- backend/open_webui/retrieval/web/bing.py | 36 +- backend/open_webui/retrieval/web/bocha.py | 45 +- backend/open_webui/retrieval/web/brave.py | 24 +- .../open_webui/retrieval/web/duckduckgo.py | 14 +- backend/open_webui/retrieval/web/exa.py | 32 +- backend/open_webui/retrieval/web/external.py | 20 +- backend/open_webui/retrieval/web/firecrawl.py | 8 +- .../open_webui/retrieval/web/google_pse.py | 30 +- .../open_webui/retrieval/web/jina_search.py | 24 +- backend/open_webui/retrieval/web/kagi.py | 18 +- backend/open_webui/retrieval/web/main.py | 2 +- backend/open_webui/retrieval/web/mojeek.py | 17 +- backend/open_webui/retrieval/web/ollama.py | 20 +- .../open_webui/retrieval/web/perplexity.py | 61 +- .../retrieval/web/perplexity_search.py | 25 +- backend/open_webui/retrieval/web/searchapi.py | 22 +- backend/open_webui/retrieval/web/searxng.py | 50 +- backend/open_webui/retrieval/web/serpapi.py | 22 +- backend/open_webui/retrieval/web/serper.py | 22 +- backend/open_webui/retrieval/web/serply.py | 44 +- backend/open_webui/retrieval/web/serpstack.py | 18 +- backend/open_webui/retrieval/web/sougou.py | 25 +- backend/open_webui/retrieval/web/tavily.py | 16 +- backend/open_webui/retrieval/web/utils.py | 217 +- backend/open_webui/retrieval/web/yacy.py | 36 +- backend/open_webui/retrieval/web/yandex.py | 108 +- backend/open_webui/retrieval/web/ydc.py | 26 +- backend/open_webui/routers/analytics.py | 132 +- backend/open_webui/routers/audio.py | 824 ++--- backend/open_webui/routers/auths.py | 666 ++-- backend/open_webui/routers/channels.py | 1150 +++--- backend/open_webui/routers/chats.py | 582 ++- backend/open_webui/routers/configs.py | 335 +- backend/open_webui/routers/evaluations.py | 191 +- backend/open_webui/routers/files.py | 342 +- backend/open_webui/routers/folders.py | 120 +- backend/open_webui/routers/functions.py | 197 +- backend/open_webui/routers/groups.py | 69 +- backend/open_webui/routers/images.py | 756 ++-- backend/open_webui/routers/knowledge.py | 311 +- backend/open_webui/routers/memories.py | 103 +- backend/open_webui/routers/models.py | 192 +- backend/open_webui/routers/notes.py | 152 +- backend/open_webui/routers/ollama.py | 841 ++--- backend/open_webui/routers/openai.py | 756 ++-- backend/open_webui/routers/pipelines.py | 239 +- backend/open_webui/routers/prompts.py | 200 +- backend/open_webui/routers/retrieval.py | 1481 ++++---- backend/open_webui/routers/scim.py | 382 +- backend/open_webui/routers/skills.py | 116 +- backend/open_webui/routers/tasks.py | 501 ++- backend/open_webui/routers/terminals.py | 145 +- backend/open_webui/routers/tools.py | 312 +- backend/open_webui/routers/users.py | 192 +- backend/open_webui/routers/utils.py | 50 +- backend/open_webui/socket/main.py | 541 ++- backend/open_webui/socket/utils.py | 53 +- backend/open_webui/storage/provider.py | 143 +- backend/open_webui/tasks.py | 56 +- .../test/apps/webui/routers/test_auths.py | 190 +- .../test/apps/webui/routers/test_models.py | 52 +- .../test/apps/webui/routers/test_users.py | 134 +- .../test/apps/webui/storage/test_provider.py | 163 +- backend/open_webui/test/util/test_redis.py | 444 ++- backend/open_webui/tools/builtin.py | 897 +++-- .../utils/access_control/__init__.py | 88 +- .../open_webui/utils/access_control/files.py | 24 +- backend/open_webui/utils/actions.py | 70 +- backend/open_webui/utils/anthropic.py | 416 +-- backend/open_webui/utils/audit.py | 96 +- backend/open_webui/utils/auth.py | 178 +- backend/open_webui/utils/channels.py | 10 +- backend/open_webui/utils/chat.py | 172 +- backend/open_webui/utils/code_interpreter.py | 141 +- backend/open_webui/utils/embeddings.py | 24 +- backend/open_webui/utils/files.py | 54 +- backend/open_webui/utils/filter.py | 69 +- backend/open_webui/utils/groups.py | 4 +- backend/open_webui/utils/headers.py | 2 +- backend/open_webui/utils/images/comfyui.py | 234 +- backend/open_webui/utils/logger.py | 100 +- backend/open_webui/utils/mcp/client.py | 48 +- backend/open_webui/utils/middleware.py | 3291 ++++++++--------- backend/open_webui/utils/misc.py | 411 +- backend/open_webui/utils/models.py | 311 +- backend/open_webui/utils/oauth.py | 766 ++-- backend/open_webui/utils/payload.py | 246 +- backend/open_webui/utils/pdf_generator.py | 48 +- backend/open_webui/utils/plugin.py | 196 +- backend/open_webui/utils/rate_limit.py | 10 +- backend/open_webui/utils/redis.py | 85 +- backend/open_webui/utils/response.py | 174 +- backend/open_webui/utils/sanitize.py | 10 +- backend/open_webui/utils/security_headers.py | 72 +- backend/open_webui/utils/task.py | 219 +- .../open_webui/utils/telemetry/constants.py | 32 +- .../utils/telemetry/instrumentors.py | 48 +- backend/open_webui/utils/telemetry/logs.py | 6 +- backend/open_webui/utils/telemetry/metrics.py | 70 +- backend/open_webui/utils/telemetry/setup.py | 6 +- backend/open_webui/utils/tools.py | 691 ++-- backend/open_webui/utils/validate.py | 16 +- backend/open_webui/utils/webhook.py | 48 +- contribution_stats.py | 26 +- hatch_build.py | 18 +- package.json | 2 +- 220 files changed, 17200 insertions(+), 22836 deletions(-) diff --git a/backend/open_webui/__init__.py b/backend/open_webui/__init__.py index 967a49de8f..acb70e17e2 100644 --- a/backend/open_webui/__init__.py +++ b/backend/open_webui/__init__.py @@ -10,94 +10,88 @@ from typing_extensions import Annotated app = typer.Typer() -KEY_FILE = Path.cwd() / ".webui_secret_key" +KEY_FILE = Path.cwd() / '.webui_secret_key' def version_callback(value: bool): if value: from open_webui.env import VERSION - typer.echo(f"Open WebUI version: {VERSION}") + typer.echo(f'Open WebUI version: {VERSION}') raise typer.Exit() @app.command() def main( - version: Annotated[ - Optional[bool], typer.Option("--version", callback=version_callback) - ] = None, + version: Annotated[Optional[bool], typer.Option('--version', callback=version_callback)] = None, ): pass @app.command() def serve( - host: str = "0.0.0.0", + host: str = '0.0.0.0', port: int = 8080, ): - os.environ["FROM_INIT_PY"] = "true" - if os.getenv("WEBUI_SECRET_KEY") is None: - typer.echo( - "Loading WEBUI_SECRET_KEY from file, not provided as an environment variable." - ) + os.environ['FROM_INIT_PY'] = 'true' + if os.getenv('WEBUI_SECRET_KEY') is None: + typer.echo('Loading WEBUI_SECRET_KEY from file, not provided as an environment variable.') if not KEY_FILE.exists(): - typer.echo(f"Generating a new secret key and saving it to {KEY_FILE}") + typer.echo(f'Generating a new secret key and saving it to {KEY_FILE}') KEY_FILE.write_bytes(base64.b64encode(random.randbytes(12))) - typer.echo(f"Loading WEBUI_SECRET_KEY from {KEY_FILE}") - os.environ["WEBUI_SECRET_KEY"] = KEY_FILE.read_text() + typer.echo(f'Loading WEBUI_SECRET_KEY from {KEY_FILE}') + os.environ['WEBUI_SECRET_KEY'] = KEY_FILE.read_text() - if os.getenv("USE_CUDA_DOCKER", "false") == "true": - typer.echo( - "CUDA is enabled, appending LD_LIBRARY_PATH to include torch/cudnn & cublas libraries." - ) - LD_LIBRARY_PATH = os.getenv("LD_LIBRARY_PATH", "").split(":") - os.environ["LD_LIBRARY_PATH"] = ":".join( + if os.getenv('USE_CUDA_DOCKER', 'false') == 'true': + typer.echo('CUDA is enabled, appending LD_LIBRARY_PATH to include torch/cudnn & cublas libraries.') + LD_LIBRARY_PATH = os.getenv('LD_LIBRARY_PATH', '').split(':') + os.environ['LD_LIBRARY_PATH'] = ':'.join( LD_LIBRARY_PATH + [ - "/usr/local/lib/python3.11/site-packages/torch/lib", - "/usr/local/lib/python3.11/site-packages/nvidia/cudnn/lib", + '/usr/local/lib/python3.11/site-packages/torch/lib', + '/usr/local/lib/python3.11/site-packages/nvidia/cudnn/lib', ] ) try: import torch - assert torch.cuda.is_available(), "CUDA not available" - typer.echo("CUDA seems to be working") + assert torch.cuda.is_available(), 'CUDA not available' + typer.echo('CUDA seems to be working') except Exception as e: typer.echo( - "Error when testing CUDA but USE_CUDA_DOCKER is true. " - "Resetting USE_CUDA_DOCKER to false and removing " - f"LD_LIBRARY_PATH modifications: {e}" + 'Error when testing CUDA but USE_CUDA_DOCKER is true. ' + 'Resetting USE_CUDA_DOCKER to false and removing ' + f'LD_LIBRARY_PATH modifications: {e}' ) - os.environ["USE_CUDA_DOCKER"] = "false" - os.environ["LD_LIBRARY_PATH"] = ":".join(LD_LIBRARY_PATH) + os.environ['USE_CUDA_DOCKER'] = 'false' + os.environ['LD_LIBRARY_PATH'] = ':'.join(LD_LIBRARY_PATH) import open_webui.main # we need set environment variables before importing main from open_webui.env import UVICORN_WORKERS # Import the workers setting uvicorn.run( - "open_webui.main:app", + 'open_webui.main:app', host=host, port=port, - forwarded_allow_ips="*", + forwarded_allow_ips='*', workers=UVICORN_WORKERS, ) @app.command() def dev( - host: str = "0.0.0.0", + host: str = '0.0.0.0', port: int = 8080, reload: bool = True, ): uvicorn.run( - "open_webui.main:app", + 'open_webui.main:app', host=host, port=port, reload=reload, - forwarded_allow_ips="*", + forwarded_allow_ips='*', ) -if __name__ == "__main__": +if __name__ == '__main__': app() diff --git a/backend/open_webui/config.py b/backend/open_webui/config.py index 13de48555a..3cbeb36644 100644 --- a/backend/open_webui/config.py +++ b/backend/open_webui/config.py @@ -41,11 +41,11 @@ from open_webui.utils.redis import get_redis_connection class EndpointFilter(logging.Filter): def filter(self, record: logging.LogRecord) -> bool: - return record.getMessage().find("/health") == -1 + return record.getMessage().find('/health') == -1 # Filter out /endpoint -logging.getLogger("uvicorn.access").addFilter(EndpointFilter()) +logging.getLogger('uvicorn.access').addFilter(EndpointFilter()) #################################### # Config helpers @@ -54,20 +54,20 @@ logging.getLogger("uvicorn.access").addFilter(EndpointFilter()) # Function to run the alembic migrations def run_migrations(): - log.info("Running migrations") + log.info('Running migrations') try: from alembic import command from alembic.config import Config - alembic_cfg = Config(OPEN_WEBUI_DIR / "alembic.ini") + alembic_cfg = Config(OPEN_WEBUI_DIR / 'alembic.ini') # Set the script location dynamically - migrations_path = OPEN_WEBUI_DIR / "migrations" - alembic_cfg.set_main_option("script_location", str(migrations_path)) + migrations_path = OPEN_WEBUI_DIR / 'migrations' + alembic_cfg.set_main_option('script_location', str(migrations_path)) - command.upgrade(alembic_cfg, "head") + command.upgrade(alembic_cfg, 'head') except Exception as e: - log.exception(f"Error running migrations: {e}") + log.exception(f'Error running migrations: {e}') if ENABLE_DB_MIGRATIONS: @@ -75,7 +75,7 @@ if ENABLE_DB_MIGRATIONS: class Config(Base): - __tablename__ = "config" + __tablename__ = 'config' id = Column(Integer, primary_key=True) data = Column(JSON, nullable=False) @@ -85,7 +85,7 @@ class Config(Base): def load_json_config(): - with open(f"{DATA_DIR}/config.json", "r") as file: + with open(f'{DATA_DIR}/config.json', 'r') as file: return json.load(file) @@ -109,14 +109,14 @@ def reset_config(): # When initializing, check if config.json exists and migrate it to the database -if os.path.exists(f"{DATA_DIR}/config.json"): +if os.path.exists(f'{DATA_DIR}/config.json'): data = load_json_config() save_to_db(data) - os.rename(f"{DATA_DIR}/config.json", f"{DATA_DIR}/old_config.json") + os.rename(f'{DATA_DIR}/config.json', f'{DATA_DIR}/old_config.json') DEFAULT_CONFIG = { - "version": 0, - "ui": {}, + 'version': 0, + 'ui': {}, } @@ -130,7 +130,7 @@ CONFIG_DATA = get_config() def get_config_value(config_path: str): - path_parts = config_path.split(".") + path_parts = config_path.split('.') cur_config = CONFIG_DATA for key in path_parts: if key in cur_config: @@ -159,11 +159,9 @@ def save_config(config): return True -T = TypeVar("T") +T = TypeVar('T') -ENABLE_PERSISTENT_CONFIG = ( - os.environ.get("ENABLE_PERSISTENT_CONFIG", "True").lower() == "true" -) +ENABLE_PERSISTENT_CONFIG = os.environ.get('ENABLE_PERSISTENT_CONFIG', 'True').lower() == 'true' class PersistentConfig(Generic[T]): @@ -174,13 +172,8 @@ class PersistentConfig(Generic[T]): self.config_value = get_config_value(config_path) if self.config_value is not None and ENABLE_PERSISTENT_CONFIG: - if ( - self.config_path.startswith("oauth.") - and not ENABLE_OAUTH_PERSISTENT_CONFIG - ): - log.info( - f"Skipping loading of '{env_name}' as OAuth persistent config is disabled" - ) + if self.config_path.startswith('oauth.') and not ENABLE_OAUTH_PERSISTENT_CONFIG: + log.info(f"Skipping loading of '{env_name}' as OAuth persistent config is disabled") self.value = env_value else: log.info(f"'{env_name}' loaded from the latest database entry") @@ -195,26 +188,22 @@ class PersistentConfig(Generic[T]): @property def __dict__(self): - raise TypeError( - "PersistentConfig object cannot be converted to dict, use config_get or .value instead." - ) + raise TypeError('PersistentConfig object cannot be converted to dict, use config_get or .value instead.') def __getattribute__(self, item): - if item == "__dict__": - raise TypeError( - "PersistentConfig object cannot be converted to dict, use config_get or .value instead." - ) + if item == '__dict__': + raise TypeError('PersistentConfig object cannot be converted to dict, use config_get or .value instead.') return super().__getattribute__(item) def update(self): new_value = get_config_value(self.config_path) if new_value is not None: self.value = new_value - log.info(f"Updated {self.env_name} to new value {self.value}") + log.info(f'Updated {self.env_name} to new value {self.value}') def save(self): log.info(f"Saving '{self.env_name}' to the database") - path_parts = self.config_path.split(".") + path_parts = self.config_path.split('.') sub_config = CONFIG_DATA for key in path_parts[:-1]: if key not in sub_config: @@ -236,12 +225,12 @@ class AppConfig: redis_url: Optional[str] = None, redis_sentinels: Optional[list] = [], redis_cluster: Optional[bool] = False, - redis_key_prefix: str = "open-webui", + redis_key_prefix: str = 'open-webui', ): if redis_url: - super().__setattr__("_redis_key_prefix", redis_key_prefix) + super().__setattr__('_redis_key_prefix', redis_key_prefix) super().__setattr__( - "_redis", + '_redis', get_redis_connection( redis_url, redis_sentinels, @@ -250,7 +239,7 @@ class AppConfig: ), ) - super().__setattr__("_state", {}) + super().__setattr__('_state', {}) def __setattr__(self, key, value): if isinstance(value, PersistentConfig): @@ -260,7 +249,7 @@ class AppConfig: self._state[key].save() if self._redis and ENABLE_PERSISTENT_CONFIG: - redis_key = f"{self._redis_key_prefix}:config:{key}" + redis_key = f'{self._redis_key_prefix}:config:{key}' self._redis.set(redis_key, json.dumps(self._state[key].value)) def __getattr__(self, key): @@ -269,7 +258,7 @@ class AppConfig: # If Redis is available and persistent config is enabled, check for an updated value if self._redis and ENABLE_PERSISTENT_CONFIG: - redis_key = f"{self._redis_key_prefix}:config:{key}" + redis_key = f'{self._redis_key_prefix}:config:{key}' redis_value = self._redis.get(redis_key) if redis_value is not None: @@ -279,10 +268,10 @@ class AppConfig: # Update the in-memory value if different if self._state[key].value != decoded_value: self._state[key].value = decoded_value - log.info(f"Updated {key} from Redis: {decoded_value}") + log.info(f'Updated {key} from Redis: {decoded_value}') except json.JSONDecodeError: - log.error(f"Invalid JSON format in Redis for {key}: {redis_value}") + log.error(f'Invalid JSON format in Redis for {key}: {redis_value}') return self._state[key].value @@ -292,389 +281,365 @@ class AppConfig: #################################### ENABLE_API_KEYS = PersistentConfig( - "ENABLE_API_KEYS", - "auth.enable_api_keys", - os.environ.get("ENABLE_API_KEYS", "False").lower() == "true", + 'ENABLE_API_KEYS', + 'auth.enable_api_keys', + os.environ.get('ENABLE_API_KEYS', 'False').lower() == 'true', ) ENABLE_API_KEYS_ENDPOINT_RESTRICTIONS = PersistentConfig( - "ENABLE_API_KEYS_ENDPOINT_RESTRICTIONS", - "auth.api_key.endpoint_restrictions", + 'ENABLE_API_KEYS_ENDPOINT_RESTRICTIONS', + 'auth.api_key.endpoint_restrictions', os.environ.get( - "ENABLE_API_KEYS_ENDPOINT_RESTRICTIONS", - os.environ.get("ENABLE_API_KEY_ENDPOINT_RESTRICTIONS", "False"), + 'ENABLE_API_KEYS_ENDPOINT_RESTRICTIONS', + os.environ.get('ENABLE_API_KEY_ENDPOINT_RESTRICTIONS', 'False'), ).lower() - == "true", + == 'true', ) API_KEYS_ALLOWED_ENDPOINTS = PersistentConfig( - "API_KEYS_ALLOWED_ENDPOINTS", - "auth.api_key.allowed_endpoints", - os.environ.get( - "API_KEYS_ALLOWED_ENDPOINTS", os.environ.get("API_KEY_ALLOWED_ENDPOINTS", "") - ), + 'API_KEYS_ALLOWED_ENDPOINTS', + 'auth.api_key.allowed_endpoints', + os.environ.get('API_KEYS_ALLOWED_ENDPOINTS', os.environ.get('API_KEY_ALLOWED_ENDPOINTS', '')), ) -JWT_EXPIRES_IN = PersistentConfig( - "JWT_EXPIRES_IN", "auth.jwt_expiry", os.environ.get("JWT_EXPIRES_IN", "4w") -) +JWT_EXPIRES_IN = PersistentConfig('JWT_EXPIRES_IN', 'auth.jwt_expiry', os.environ.get('JWT_EXPIRES_IN', '4w')) -if JWT_EXPIRES_IN.value == "-1": +if JWT_EXPIRES_IN.value == '-1': log.warning( "⚠️ SECURITY WARNING: JWT_EXPIRES_IN is set to '-1'\n" - " See: https://docs.openwebui.com/reference/env-configuration\n" + ' See: https://docs.openwebui.com/reference/env-configuration\n' ) #################################### # OAuth config #################################### -ENABLE_OAUTH_PERSISTENT_CONFIG = ( - os.environ.get("ENABLE_OAUTH_PERSISTENT_CONFIG", "False").lower() == "true" -) +ENABLE_OAUTH_PERSISTENT_CONFIG = os.environ.get('ENABLE_OAUTH_PERSISTENT_CONFIG', 'False').lower() == 'true' ENABLE_OAUTH_SIGNUP = PersistentConfig( - "ENABLE_OAUTH_SIGNUP", - "oauth.enable_signup", - os.environ.get("ENABLE_OAUTH_SIGNUP", "False").lower() == "true", + 'ENABLE_OAUTH_SIGNUP', + 'oauth.enable_signup', + os.environ.get('ENABLE_OAUTH_SIGNUP', 'False').lower() == 'true', ) OAUTH_REFRESH_TOKEN_INCLUDE_SCOPE = PersistentConfig( - "OAUTH_REFRESH_TOKEN_INCLUDE_SCOPE", - "oauth.refresh_token_include_scope", - os.environ.get("OAUTH_REFRESH_TOKEN_INCLUDE_SCOPE", "False").lower() == "true", + 'OAUTH_REFRESH_TOKEN_INCLUDE_SCOPE', + 'oauth.refresh_token_include_scope', + os.environ.get('OAUTH_REFRESH_TOKEN_INCLUDE_SCOPE', 'False').lower() == 'true', ) OAUTH_MERGE_ACCOUNTS_BY_EMAIL = PersistentConfig( - "OAUTH_MERGE_ACCOUNTS_BY_EMAIL", - "oauth.merge_accounts_by_email", - os.environ.get("OAUTH_MERGE_ACCOUNTS_BY_EMAIL", "False").lower() == "true", + 'OAUTH_MERGE_ACCOUNTS_BY_EMAIL', + 'oauth.merge_accounts_by_email', + os.environ.get('OAUTH_MERGE_ACCOUNTS_BY_EMAIL', 'False').lower() == 'true', ) OAUTH_PROVIDERS = {} GOOGLE_CLIENT_ID = PersistentConfig( - "GOOGLE_CLIENT_ID", - "oauth.google.client_id", - os.environ.get("GOOGLE_CLIENT_ID", ""), + 'GOOGLE_CLIENT_ID', + 'oauth.google.client_id', + os.environ.get('GOOGLE_CLIENT_ID', ''), ) GOOGLE_CLIENT_SECRET = PersistentConfig( - "GOOGLE_CLIENT_SECRET", - "oauth.google.client_secret", - os.environ.get("GOOGLE_CLIENT_SECRET", ""), + 'GOOGLE_CLIENT_SECRET', + 'oauth.google.client_secret', + os.environ.get('GOOGLE_CLIENT_SECRET', ''), ) GOOGLE_OAUTH_SCOPE = PersistentConfig( - "GOOGLE_OAUTH_SCOPE", - "oauth.google.scope", - os.environ.get("GOOGLE_OAUTH_SCOPE", "openid email profile"), + 'GOOGLE_OAUTH_SCOPE', + 'oauth.google.scope', + os.environ.get('GOOGLE_OAUTH_SCOPE', 'openid email profile'), ) GOOGLE_REDIRECT_URI = PersistentConfig( - "GOOGLE_REDIRECT_URI", - "oauth.google.redirect_uri", - os.environ.get("GOOGLE_REDIRECT_URI", ""), + 'GOOGLE_REDIRECT_URI', + 'oauth.google.redirect_uri', + os.environ.get('GOOGLE_REDIRECT_URI', ''), ) MICROSOFT_CLIENT_ID = PersistentConfig( - "MICROSOFT_CLIENT_ID", - "oauth.microsoft.client_id", - os.environ.get("MICROSOFT_CLIENT_ID", ""), + 'MICROSOFT_CLIENT_ID', + 'oauth.microsoft.client_id', + os.environ.get('MICROSOFT_CLIENT_ID', ''), ) MICROSOFT_CLIENT_SECRET = PersistentConfig( - "MICROSOFT_CLIENT_SECRET", - "oauth.microsoft.client_secret", - os.environ.get("MICROSOFT_CLIENT_SECRET", ""), + 'MICROSOFT_CLIENT_SECRET', + 'oauth.microsoft.client_secret', + os.environ.get('MICROSOFT_CLIENT_SECRET', ''), ) MICROSOFT_CLIENT_TENANT_ID = PersistentConfig( - "MICROSOFT_CLIENT_TENANT_ID", - "oauth.microsoft.tenant_id", - os.environ.get("MICROSOFT_CLIENT_TENANT_ID", ""), + 'MICROSOFT_CLIENT_TENANT_ID', + 'oauth.microsoft.tenant_id', + os.environ.get('MICROSOFT_CLIENT_TENANT_ID', ''), ) MICROSOFT_CLIENT_LOGIN_BASE_URL = PersistentConfig( - "MICROSOFT_CLIENT_LOGIN_BASE_URL", - "oauth.microsoft.login_base_url", - os.environ.get( - "MICROSOFT_CLIENT_LOGIN_BASE_URL", "https://login.microsoftonline.com" - ), + 'MICROSOFT_CLIENT_LOGIN_BASE_URL', + 'oauth.microsoft.login_base_url', + os.environ.get('MICROSOFT_CLIENT_LOGIN_BASE_URL', 'https://login.microsoftonline.com'), ) MICROSOFT_CLIENT_PICTURE_URL = PersistentConfig( - "MICROSOFT_CLIENT_PICTURE_URL", - "oauth.microsoft.picture_url", + 'MICROSOFT_CLIENT_PICTURE_URL', + 'oauth.microsoft.picture_url', os.environ.get( - "MICROSOFT_CLIENT_PICTURE_URL", - "https://graph.microsoft.com/v1.0/me/photo/$value", + 'MICROSOFT_CLIENT_PICTURE_URL', + 'https://graph.microsoft.com/v1.0/me/photo/$value', ), ) MICROSOFT_OAUTH_SCOPE = PersistentConfig( - "MICROSOFT_OAUTH_SCOPE", - "oauth.microsoft.scope", - os.environ.get("MICROSOFT_OAUTH_SCOPE", "openid email profile"), + 'MICROSOFT_OAUTH_SCOPE', + 'oauth.microsoft.scope', + os.environ.get('MICROSOFT_OAUTH_SCOPE', 'openid email profile'), ) MICROSOFT_REDIRECT_URI = PersistentConfig( - "MICROSOFT_REDIRECT_URI", - "oauth.microsoft.redirect_uri", - os.environ.get("MICROSOFT_REDIRECT_URI", ""), + 'MICROSOFT_REDIRECT_URI', + 'oauth.microsoft.redirect_uri', + os.environ.get('MICROSOFT_REDIRECT_URI', ''), ) GITHUB_CLIENT_ID = PersistentConfig( - "GITHUB_CLIENT_ID", - "oauth.github.client_id", - os.environ.get("GITHUB_CLIENT_ID", ""), + 'GITHUB_CLIENT_ID', + 'oauth.github.client_id', + os.environ.get('GITHUB_CLIENT_ID', ''), ) GITHUB_CLIENT_SECRET = PersistentConfig( - "GITHUB_CLIENT_SECRET", - "oauth.github.client_secret", - os.environ.get("GITHUB_CLIENT_SECRET", ""), + 'GITHUB_CLIENT_SECRET', + 'oauth.github.client_secret', + os.environ.get('GITHUB_CLIENT_SECRET', ''), ) GITHUB_CLIENT_SCOPE = PersistentConfig( - "GITHUB_CLIENT_SCOPE", - "oauth.github.scope", - os.environ.get("GITHUB_CLIENT_SCOPE", "user:email"), + 'GITHUB_CLIENT_SCOPE', + 'oauth.github.scope', + os.environ.get('GITHUB_CLIENT_SCOPE', 'user:email'), ) GITHUB_CLIENT_REDIRECT_URI = PersistentConfig( - "GITHUB_CLIENT_REDIRECT_URI", - "oauth.github.redirect_uri", - os.environ.get("GITHUB_CLIENT_REDIRECT_URI", ""), + 'GITHUB_CLIENT_REDIRECT_URI', + 'oauth.github.redirect_uri', + os.environ.get('GITHUB_CLIENT_REDIRECT_URI', ''), ) OAUTH_CLIENT_ID = PersistentConfig( - "OAUTH_CLIENT_ID", - "oauth.oidc.client_id", - os.environ.get("OAUTH_CLIENT_ID", ""), + 'OAUTH_CLIENT_ID', + 'oauth.oidc.client_id', + os.environ.get('OAUTH_CLIENT_ID', ''), ) OAUTH_CLIENT_SECRET = PersistentConfig( - "OAUTH_CLIENT_SECRET", - "oauth.oidc.client_secret", - os.environ.get("OAUTH_CLIENT_SECRET", ""), + 'OAUTH_CLIENT_SECRET', + 'oauth.oidc.client_secret', + os.environ.get('OAUTH_CLIENT_SECRET', ''), ) OPENID_PROVIDER_URL = PersistentConfig( - "OPENID_PROVIDER_URL", - "oauth.oidc.provider_url", - os.environ.get("OPENID_PROVIDER_URL", ""), + 'OPENID_PROVIDER_URL', + 'oauth.oidc.provider_url', + os.environ.get('OPENID_PROVIDER_URL', ''), ) OPENID_END_SESSION_ENDPOINT = PersistentConfig( - "OPENID_END_SESSION_ENDPOINT", - "oauth.oidc.end_session_endpoint", - os.environ.get("OPENID_END_SESSION_ENDPOINT", ""), + 'OPENID_END_SESSION_ENDPOINT', + 'oauth.oidc.end_session_endpoint', + os.environ.get('OPENID_END_SESSION_ENDPOINT', ''), ) OPENID_REDIRECT_URI = PersistentConfig( - "OPENID_REDIRECT_URI", - "oauth.oidc.redirect_uri", - os.environ.get("OPENID_REDIRECT_URI", ""), + 'OPENID_REDIRECT_URI', + 'oauth.oidc.redirect_uri', + os.environ.get('OPENID_REDIRECT_URI', ''), ) OAUTH_SCOPES = PersistentConfig( - "OAUTH_SCOPES", - "oauth.oidc.scopes", - os.environ.get("OAUTH_SCOPES", "openid email profile"), + 'OAUTH_SCOPES', + 'oauth.oidc.scopes', + os.environ.get('OAUTH_SCOPES', 'openid email profile'), ) OAUTH_TIMEOUT = PersistentConfig( - "OAUTH_TIMEOUT", - "oauth.oidc.oauth_timeout", - os.environ.get("OAUTH_TIMEOUT", ""), + 'OAUTH_TIMEOUT', + 'oauth.oidc.oauth_timeout', + os.environ.get('OAUTH_TIMEOUT', ''), ) OAUTH_TOKEN_ENDPOINT_AUTH_METHOD = PersistentConfig( - "OAUTH_TOKEN_ENDPOINT_AUTH_METHOD", - "oauth.oidc.token_endpoint_auth_method", - os.environ.get("OAUTH_TOKEN_ENDPOINT_AUTH_METHOD", None), + 'OAUTH_TOKEN_ENDPOINT_AUTH_METHOD', + 'oauth.oidc.token_endpoint_auth_method', + os.environ.get('OAUTH_TOKEN_ENDPOINT_AUTH_METHOD', None), ) OAUTH_CODE_CHALLENGE_METHOD = PersistentConfig( - "OAUTH_CODE_CHALLENGE_METHOD", - "oauth.oidc.code_challenge_method", - os.environ.get("OAUTH_CODE_CHALLENGE_METHOD", None), + 'OAUTH_CODE_CHALLENGE_METHOD', + 'oauth.oidc.code_challenge_method', + os.environ.get('OAUTH_CODE_CHALLENGE_METHOD', None), ) OAUTH_PROVIDER_NAME = PersistentConfig( - "OAUTH_PROVIDER_NAME", - "oauth.oidc.provider_name", - os.environ.get("OAUTH_PROVIDER_NAME", "SSO"), + 'OAUTH_PROVIDER_NAME', + 'oauth.oidc.provider_name', + os.environ.get('OAUTH_PROVIDER_NAME', 'SSO'), ) OAUTH_SUB_CLAIM = PersistentConfig( - "OAUTH_SUB_CLAIM", - "oauth.oidc.sub_claim", - os.environ.get("OAUTH_SUB_CLAIM", None), + 'OAUTH_SUB_CLAIM', + 'oauth.oidc.sub_claim', + os.environ.get('OAUTH_SUB_CLAIM', None), ) OAUTH_USERNAME_CLAIM = PersistentConfig( - "OAUTH_USERNAME_CLAIM", - "oauth.oidc.username_claim", - os.environ.get("OAUTH_USERNAME_CLAIM", "name"), + 'OAUTH_USERNAME_CLAIM', + 'oauth.oidc.username_claim', + os.environ.get('OAUTH_USERNAME_CLAIM', 'name'), ) OAUTH_PICTURE_CLAIM = PersistentConfig( - "OAUTH_PICTURE_CLAIM", - "oauth.oidc.avatar_claim", - os.environ.get("OAUTH_PICTURE_CLAIM", "picture"), + 'OAUTH_PICTURE_CLAIM', + 'oauth.oidc.avatar_claim', + os.environ.get('OAUTH_PICTURE_CLAIM', 'picture'), ) OAUTH_EMAIL_CLAIM = PersistentConfig( - "OAUTH_EMAIL_CLAIM", - "oauth.oidc.email_claim", - os.environ.get("OAUTH_EMAIL_CLAIM", "email"), + 'OAUTH_EMAIL_CLAIM', + 'oauth.oidc.email_claim', + os.environ.get('OAUTH_EMAIL_CLAIM', 'email'), ) OAUTH_GROUPS_CLAIM = PersistentConfig( - "OAUTH_GROUPS_CLAIM", - "oauth.oidc.group_claim", - os.environ.get("OAUTH_GROUPS_CLAIM", os.environ.get("OAUTH_GROUP_CLAIM", "groups")), + 'OAUTH_GROUPS_CLAIM', + 'oauth.oidc.group_claim', + os.environ.get('OAUTH_GROUPS_CLAIM', os.environ.get('OAUTH_GROUP_CLAIM', 'groups')), ) FEISHU_CLIENT_ID = PersistentConfig( - "FEISHU_CLIENT_ID", - "oauth.feishu.client_id", - os.environ.get("FEISHU_CLIENT_ID", ""), + 'FEISHU_CLIENT_ID', + 'oauth.feishu.client_id', + os.environ.get('FEISHU_CLIENT_ID', ''), ) FEISHU_CLIENT_SECRET = PersistentConfig( - "FEISHU_CLIENT_SECRET", - "oauth.feishu.client_secret", - os.environ.get("FEISHU_CLIENT_SECRET", ""), + 'FEISHU_CLIENT_SECRET', + 'oauth.feishu.client_secret', + os.environ.get('FEISHU_CLIENT_SECRET', ''), ) FEISHU_OAUTH_SCOPE = PersistentConfig( - "FEISHU_OAUTH_SCOPE", - "oauth.feishu.scope", - os.environ.get("FEISHU_OAUTH_SCOPE", "contact:user.base:readonly"), + 'FEISHU_OAUTH_SCOPE', + 'oauth.feishu.scope', + os.environ.get('FEISHU_OAUTH_SCOPE', 'contact:user.base:readonly'), ) FEISHU_REDIRECT_URI = PersistentConfig( - "FEISHU_REDIRECT_URI", - "oauth.feishu.redirect_uri", - os.environ.get("FEISHU_REDIRECT_URI", ""), + 'FEISHU_REDIRECT_URI', + 'oauth.feishu.redirect_uri', + os.environ.get('FEISHU_REDIRECT_URI', ''), ) ENABLE_OAUTH_ROLE_MANAGEMENT = PersistentConfig( - "ENABLE_OAUTH_ROLE_MANAGEMENT", - "oauth.enable_role_mapping", - os.environ.get("ENABLE_OAUTH_ROLE_MANAGEMENT", "False").lower() == "true", + 'ENABLE_OAUTH_ROLE_MANAGEMENT', + 'oauth.enable_role_mapping', + os.environ.get('ENABLE_OAUTH_ROLE_MANAGEMENT', 'False').lower() == 'true', ) ENABLE_OAUTH_GROUP_MANAGEMENT = PersistentConfig( - "ENABLE_OAUTH_GROUP_MANAGEMENT", - "oauth.enable_group_mapping", - os.environ.get("ENABLE_OAUTH_GROUP_MANAGEMENT", "False").lower() == "true", + 'ENABLE_OAUTH_GROUP_MANAGEMENT', + 'oauth.enable_group_mapping', + os.environ.get('ENABLE_OAUTH_GROUP_MANAGEMENT', 'False').lower() == 'true', ) ENABLE_OAUTH_GROUP_CREATION = PersistentConfig( - "ENABLE_OAUTH_GROUP_CREATION", - "oauth.enable_group_creation", - os.environ.get("ENABLE_OAUTH_GROUP_CREATION", "False").lower() == "true", + 'ENABLE_OAUTH_GROUP_CREATION', + 'oauth.enable_group_creation', + os.environ.get('ENABLE_OAUTH_GROUP_CREATION', 'False').lower() == 'true', ) -oauth_group_default_share = ( - os.environ.get("OAUTH_GROUP_DEFAULT_SHARE", "true").strip().lower() -) +oauth_group_default_share = os.environ.get('OAUTH_GROUP_DEFAULT_SHARE', 'true').strip().lower() OAUTH_GROUP_DEFAULT_SHARE = PersistentConfig( - "OAUTH_GROUP_DEFAULT_SHARE", - "oauth.group_default_share", - ( - "members" - if oauth_group_default_share == "members" - else oauth_group_default_share == "true" - ), + 'OAUTH_GROUP_DEFAULT_SHARE', + 'oauth.group_default_share', + ('members' if oauth_group_default_share == 'members' else oauth_group_default_share == 'true'), ) OAUTH_BLOCKED_GROUPS = PersistentConfig( - "OAUTH_BLOCKED_GROUPS", - "oauth.blocked_groups", - os.environ.get("OAUTH_BLOCKED_GROUPS", "[]"), + 'OAUTH_BLOCKED_GROUPS', + 'oauth.blocked_groups', + os.environ.get('OAUTH_BLOCKED_GROUPS', '[]'), ) -OAUTH_GROUPS_SEPARATOR = os.environ.get("OAUTH_GROUPS_SEPARATOR", ";") +OAUTH_GROUPS_SEPARATOR = os.environ.get('OAUTH_GROUPS_SEPARATOR', ';') OAUTH_ROLES_CLAIM = PersistentConfig( - "OAUTH_ROLES_CLAIM", - "oauth.roles_claim", - os.environ.get("OAUTH_ROLES_CLAIM", "roles"), + 'OAUTH_ROLES_CLAIM', + 'oauth.roles_claim', + os.environ.get('OAUTH_ROLES_CLAIM', 'roles'), ) -OAUTH_ROLES_SEPARATOR = os.environ.get("OAUTH_ROLES_SEPARATOR", ",") +OAUTH_ROLES_SEPARATOR = os.environ.get('OAUTH_ROLES_SEPARATOR', ',') OAUTH_ALLOWED_ROLES = PersistentConfig( - "OAUTH_ALLOWED_ROLES", - "oauth.allowed_roles", + 'OAUTH_ALLOWED_ROLES', + 'oauth.allowed_roles', [ role.strip() - for role in os.environ.get( - "OAUTH_ALLOWED_ROLES", f"user{OAUTH_ROLES_SEPARATOR}admin" - ).split(OAUTH_ROLES_SEPARATOR) - if role - ], -) - -OAUTH_ADMIN_ROLES = PersistentConfig( - "OAUTH_ADMIN_ROLES", - "oauth.admin_roles", - [ - role.strip() - for role in os.environ.get("OAUTH_ADMIN_ROLES", "admin").split( + for role in os.environ.get('OAUTH_ALLOWED_ROLES', f'user{OAUTH_ROLES_SEPARATOR}admin').split( OAUTH_ROLES_SEPARATOR ) if role ], ) +OAUTH_ADMIN_ROLES = PersistentConfig( + 'OAUTH_ADMIN_ROLES', + 'oauth.admin_roles', + [role.strip() for role in os.environ.get('OAUTH_ADMIN_ROLES', 'admin').split(OAUTH_ROLES_SEPARATOR) if role], +) + OAUTH_ALLOWED_DOMAINS = PersistentConfig( - "OAUTH_ALLOWED_DOMAINS", - "oauth.allowed_domains", - [ - domain.strip() - for domain in os.environ.get("OAUTH_ALLOWED_DOMAINS", "*").split(",") - ], + 'OAUTH_ALLOWED_DOMAINS', + 'oauth.allowed_domains', + [domain.strip() for domain in os.environ.get('OAUTH_ALLOWED_DOMAINS', '*').split(',')], ) OAUTH_UPDATE_PICTURE_ON_LOGIN = PersistentConfig( - "OAUTH_UPDATE_PICTURE_ON_LOGIN", - "oauth.update_picture_on_login", - os.environ.get("OAUTH_UPDATE_PICTURE_ON_LOGIN", "False").lower() == "true", + 'OAUTH_UPDATE_PICTURE_ON_LOGIN', + 'oauth.update_picture_on_login', + os.environ.get('OAUTH_UPDATE_PICTURE_ON_LOGIN', 'False').lower() == 'true', ) OAUTH_UPDATE_NAME_ON_LOGIN = PersistentConfig( - "OAUTH_UPDATE_NAME_ON_LOGIN", - "oauth.update_name_on_login", - os.environ.get("OAUTH_UPDATE_NAME_ON_LOGIN", "False").lower() == "true", + 'OAUTH_UPDATE_NAME_ON_LOGIN', + 'oauth.update_name_on_login', + os.environ.get('OAUTH_UPDATE_NAME_ON_LOGIN', 'False').lower() == 'true', ) OAUTH_UPDATE_EMAIL_ON_LOGIN = PersistentConfig( - "OAUTH_UPDATE_EMAIL_ON_LOGIN", - "oauth.update_email_on_login", - os.environ.get("OAUTH_UPDATE_EMAIL_ON_LOGIN", "False").lower() == "true", + 'OAUTH_UPDATE_EMAIL_ON_LOGIN', + 'oauth.update_email_on_login', + os.environ.get('OAUTH_UPDATE_EMAIL_ON_LOGIN', 'False').lower() == 'true', ) OAUTH_ACCESS_TOKEN_REQUEST_INCLUDE_CLIENT_ID = ( - os.environ.get("OAUTH_ACCESS_TOKEN_REQUEST_INCLUDE_CLIENT_ID", "False").lower() - == "true" + os.environ.get('OAUTH_ACCESS_TOKEN_REQUEST_INCLUDE_CLIENT_ID', 'False').lower() == 'true' ) OAUTH_AUDIENCE = PersistentConfig( - "OAUTH_AUDIENCE", - "oauth.audience", - os.environ.get("OAUTH_AUDIENCE", ""), + 'OAUTH_AUDIENCE', + 'oauth.audience', + os.environ.get('OAUTH_AUDIENCE', ''), ) @@ -684,84 +649,68 @@ def load_oauth_providers(): def google_oauth_register(oauth: OAuth): client = oauth.register( - name="google", + name='google', client_id=GOOGLE_CLIENT_ID.value, client_secret=GOOGLE_CLIENT_SECRET.value, - server_metadata_url="https://accounts.google.com/.well-known/openid-configuration", + server_metadata_url='https://accounts.google.com/.well-known/openid-configuration', client_kwargs={ - "scope": GOOGLE_OAUTH_SCOPE.value, - **( - {"timeout": int(OAUTH_TIMEOUT.value)} - if OAUTH_TIMEOUT.value - else {} - ), + 'scope': GOOGLE_OAUTH_SCOPE.value, + **({'timeout': int(OAUTH_TIMEOUT.value)} if OAUTH_TIMEOUT.value else {}), }, redirect_uri=GOOGLE_REDIRECT_URI.value, ) return client - OAUTH_PROVIDERS["google"] = { - "redirect_uri": GOOGLE_REDIRECT_URI.value, - "register": google_oauth_register, + OAUTH_PROVIDERS['google'] = { + 'redirect_uri': GOOGLE_REDIRECT_URI.value, + 'register': google_oauth_register, } - if ( - MICROSOFT_CLIENT_ID.value - and MICROSOFT_CLIENT_SECRET.value - and MICROSOFT_CLIENT_TENANT_ID.value - ): + if MICROSOFT_CLIENT_ID.value and MICROSOFT_CLIENT_SECRET.value and MICROSOFT_CLIENT_TENANT_ID.value: def microsoft_oauth_register(oauth: OAuth): client = oauth.register( - name="microsoft", + name='microsoft', client_id=MICROSOFT_CLIENT_ID.value, client_secret=MICROSOFT_CLIENT_SECRET.value, - server_metadata_url=f"{MICROSOFT_CLIENT_LOGIN_BASE_URL.value}/{MICROSOFT_CLIENT_TENANT_ID.value}/v2.0/.well-known/openid-configuration?appid={MICROSOFT_CLIENT_ID.value}", + server_metadata_url=f'{MICROSOFT_CLIENT_LOGIN_BASE_URL.value}/{MICROSOFT_CLIENT_TENANT_ID.value}/v2.0/.well-known/openid-configuration?appid={MICROSOFT_CLIENT_ID.value}', client_kwargs={ - "scope": MICROSOFT_OAUTH_SCOPE.value, - **( - {"timeout": int(OAUTH_TIMEOUT.value)} - if OAUTH_TIMEOUT.value - else {} - ), + 'scope': MICROSOFT_OAUTH_SCOPE.value, + **({'timeout': int(OAUTH_TIMEOUT.value)} if OAUTH_TIMEOUT.value else {}), }, redirect_uri=MICROSOFT_REDIRECT_URI.value, ) return client - OAUTH_PROVIDERS["microsoft"] = { - "redirect_uri": MICROSOFT_REDIRECT_URI.value, - "picture_url": MICROSOFT_CLIENT_PICTURE_URL.value, - "register": microsoft_oauth_register, + OAUTH_PROVIDERS['microsoft'] = { + 'redirect_uri': MICROSOFT_REDIRECT_URI.value, + 'picture_url': MICROSOFT_CLIENT_PICTURE_URL.value, + 'register': microsoft_oauth_register, } if GITHUB_CLIENT_ID.value and GITHUB_CLIENT_SECRET.value: def github_oauth_register(oauth: OAuth): client = oauth.register( - name="github", + name='github', client_id=GITHUB_CLIENT_ID.value, client_secret=GITHUB_CLIENT_SECRET.value, - access_token_url="https://github.com/login/oauth/access_token", - authorize_url="https://github.com/login/oauth/authorize", - api_base_url="https://api.github.com", - userinfo_endpoint="https://api.github.com/user", + access_token_url='https://github.com/login/oauth/access_token', + authorize_url='https://github.com/login/oauth/authorize', + api_base_url='https://api.github.com', + userinfo_endpoint='https://api.github.com/user', client_kwargs={ - "scope": GITHUB_CLIENT_SCOPE.value, - **( - {"timeout": int(OAUTH_TIMEOUT.value)} - if OAUTH_TIMEOUT.value - else {} - ), + 'scope': GITHUB_CLIENT_SCOPE.value, + **({'timeout': int(OAUTH_TIMEOUT.value)} if OAUTH_TIMEOUT.value else {}), }, redirect_uri=GITHUB_CLIENT_REDIRECT_URI.value, ) return client - OAUTH_PROVIDERS["github"] = { - "redirect_uri": GITHUB_CLIENT_REDIRECT_URI.value, - "register": github_oauth_register, - "sub_claim": "id", + OAUTH_PROVIDERS['github'] = { + 'redirect_uri': GITHUB_CLIENT_REDIRECT_URI.value, + 'register': github_oauth_register, + 'sub_claim': 'id', } if ( @@ -772,32 +721,25 @@ def load_oauth_providers(): def oidc_oauth_register(oauth: OAuth): client_kwargs = { - "scope": OAUTH_SCOPES.value, + 'scope': OAUTH_SCOPES.value, **( - { - "token_endpoint_auth_method": OAUTH_TOKEN_ENDPOINT_AUTH_METHOD.value - } + {'token_endpoint_auth_method': OAUTH_TOKEN_ENDPOINT_AUTH_METHOD.value} if OAUTH_TOKEN_ENDPOINT_AUTH_METHOD.value else {} ), - **( - {"timeout": int(OAUTH_TIMEOUT.value)} if OAUTH_TIMEOUT.value else {} - ), + **({'timeout': int(OAUTH_TIMEOUT.value)} if OAUTH_TIMEOUT.value else {}), } - if ( - OAUTH_CODE_CHALLENGE_METHOD.value - and OAUTH_CODE_CHALLENGE_METHOD.value == "S256" - ): - client_kwargs["code_challenge_method"] = "S256" + if OAUTH_CODE_CHALLENGE_METHOD.value and OAUTH_CODE_CHALLENGE_METHOD.value == 'S256': + client_kwargs['code_challenge_method'] = 'S256' elif OAUTH_CODE_CHALLENGE_METHOD.value: raise Exception( 'Code challenge methods other than "%s" not supported. Given: "%s"' - % ("S256", OAUTH_CODE_CHALLENGE_METHOD.value) + % ('S256', OAUTH_CODE_CHALLENGE_METHOD.value) ) client = oauth.register( - name="oidc", + name='oidc', client_id=OAUTH_CLIENT_ID.value, client_secret=OAUTH_CLIENT_SECRET.value, server_metadata_url=OPENID_PROVIDER_URL.value, @@ -806,62 +748,54 @@ def load_oauth_providers(): ) return client - OAUTH_PROVIDERS["oidc"] = { - "name": OAUTH_PROVIDER_NAME.value, - "redirect_uri": OPENID_REDIRECT_URI.value, - "register": oidc_oauth_register, + OAUTH_PROVIDERS['oidc'] = { + 'name': OAUTH_PROVIDER_NAME.value, + 'redirect_uri': OPENID_REDIRECT_URI.value, + 'register': oidc_oauth_register, } if FEISHU_CLIENT_ID.value and FEISHU_CLIENT_SECRET.value: def feishu_oauth_register(oauth: OAuth): client = oauth.register( - name="feishu", + name='feishu', client_id=FEISHU_CLIENT_ID.value, client_secret=FEISHU_CLIENT_SECRET.value, - access_token_url="https://open.feishu.cn/open-apis/authen/v2/oauth/token", - authorize_url="https://accounts.feishu.cn/open-apis/authen/v1/authorize", - api_base_url="https://open.feishu.cn/open-apis", - userinfo_endpoint="https://open.feishu.cn/open-apis/authen/v1/user_info", + access_token_url='https://open.feishu.cn/open-apis/authen/v2/oauth/token', + authorize_url='https://accounts.feishu.cn/open-apis/authen/v1/authorize', + api_base_url='https://open.feishu.cn/open-apis', + userinfo_endpoint='https://open.feishu.cn/open-apis/authen/v1/user_info', client_kwargs={ - "scope": FEISHU_OAUTH_SCOPE.value, - **( - {"timeout": int(OAUTH_TIMEOUT.value)} - if OAUTH_TIMEOUT.value - else {} - ), + 'scope': FEISHU_OAUTH_SCOPE.value, + **({'timeout': int(OAUTH_TIMEOUT.value)} if OAUTH_TIMEOUT.value else {}), }, redirect_uri=FEISHU_REDIRECT_URI.value, ) return client - OAUTH_PROVIDERS["feishu"] = { - "register": feishu_oauth_register, - "sub_claim": "user_id", + OAUTH_PROVIDERS['feishu'] = { + 'register': feishu_oauth_register, + 'sub_claim': 'user_id', } configured_providers = [] if GOOGLE_CLIENT_ID.value: - configured_providers.append("Google") + configured_providers.append('Google') if MICROSOFT_CLIENT_ID.value: - configured_providers.append("Microsoft") + configured_providers.append('Microsoft') if GITHUB_CLIENT_ID.value: - configured_providers.append("GitHub") + configured_providers.append('GitHub') if FEISHU_CLIENT_ID.value: - configured_providers.append("Feishu") + configured_providers.append('Feishu') - if ( - configured_providers - and not OPENID_PROVIDER_URL.value - and not OPENID_END_SESSION_ENDPOINT.value - ): - provider_list = ", ".join(configured_providers) + if configured_providers and not OPENID_PROVIDER_URL.value and not OPENID_END_SESSION_ENDPOINT.value: + provider_list = ', '.join(configured_providers) log.warning( - f"⚠️ OAuth providers configured ({provider_list}) but OPENID_PROVIDER_URL not set - logout will not work!" + f'⚠️ OAuth providers configured ({provider_list}) but OPENID_PROVIDER_URL not set - logout will not work!' ) log.warning( f"Set OPENID_PROVIDER_URL to your OAuth provider's OpenID Connect discovery endpoint," - f" or set OPENID_END_SESSION_ENDPOINT to a custom logout URL to fix logout functionality." + f' or set OPENID_END_SESSION_ENDPOINT to a custom logout URL to fix logout functionality.' ) @@ -871,7 +805,7 @@ load_oauth_providers() # Static DIR #################################### -STATIC_DIR = Path(os.getenv("STATIC_DIR", OPEN_WEBUI_DIR / "static")).resolve() +STATIC_DIR = Path(os.getenv('STATIC_DIR', OPEN_WEBUI_DIR / 'static')).resolve() try: if STATIC_DIR.exists(): @@ -884,80 +818,72 @@ try: except Exception as e: pass -for file_path in (FRONTEND_BUILD_DIR / "static").glob("**/*"): +for file_path in (FRONTEND_BUILD_DIR / 'static').glob('**/*'): if file_path.is_file(): - target_path = STATIC_DIR / file_path.relative_to( - (FRONTEND_BUILD_DIR / "static") - ) + target_path = STATIC_DIR / file_path.relative_to((FRONTEND_BUILD_DIR / 'static')) target_path.parent.mkdir(parents=True, exist_ok=True) try: shutil.copyfile(file_path, target_path) except Exception as e: - logging.error(f"An error occurred: {e}") + logging.error(f'An error occurred: {e}') -frontend_favicon = FRONTEND_BUILD_DIR / "static" / "favicon.png" +frontend_favicon = FRONTEND_BUILD_DIR / 'static' / 'favicon.png' if frontend_favicon.exists(): try: - shutil.copyfile(frontend_favicon, STATIC_DIR / "favicon.png") + shutil.copyfile(frontend_favicon, STATIC_DIR / 'favicon.png') except Exception as e: - logging.error(f"An error occurred: {e}") + logging.error(f'An error occurred: {e}') -frontend_splash = FRONTEND_BUILD_DIR / "static" / "splash.png" +frontend_splash = FRONTEND_BUILD_DIR / 'static' / 'splash.png' if frontend_splash.exists(): try: - shutil.copyfile(frontend_splash, STATIC_DIR / "splash.png") + shutil.copyfile(frontend_splash, STATIC_DIR / 'splash.png') except Exception as e: - logging.error(f"An error occurred: {e}") + logging.error(f'An error occurred: {e}') -frontend_loader = FRONTEND_BUILD_DIR / "static" / "loader.js" +frontend_loader = FRONTEND_BUILD_DIR / 'static' / 'loader.js' if frontend_loader.exists(): try: - shutil.copyfile(frontend_loader, STATIC_DIR / "loader.js") + shutil.copyfile(frontend_loader, STATIC_DIR / 'loader.js') except Exception as e: - logging.error(f"An error occurred: {e}") + logging.error(f'An error occurred: {e}') #################################### # CUSTOM_NAME (Legacy) #################################### -CUSTOM_NAME = os.environ.get("CUSTOM_NAME", "") +CUSTOM_NAME = os.environ.get('CUSTOM_NAME', '') if CUSTOM_NAME: try: - r = requests.get(f"https://api.openwebui.com/api/v1/custom/{CUSTOM_NAME}") + r = requests.get(f'https://api.openwebui.com/api/v1/custom/{CUSTOM_NAME}') data = r.json() if r.ok: - if "logo" in data: + if 'logo' in data: WEBUI_FAVICON_URL = url = ( - f"https://api.openwebui.com{data['logo']}" - if data["logo"][0] == "/" - else data["logo"] + f'https://api.openwebui.com{data["logo"]}' if data['logo'][0] == '/' else data['logo'] ) r = requests.get(url, stream=True) if r.status_code == 200: - with open(f"{STATIC_DIR}/favicon.png", "wb") as f: + with open(f'{STATIC_DIR}/favicon.png', 'wb') as f: r.raw.decode_content = True shutil.copyfileobj(r.raw, f) - if "splash" in data: - url = ( - f"https://api.openwebui.com{data['splash']}" - if data["splash"][0] == "/" - else data["splash"] - ) + if 'splash' in data: + url = f'https://api.openwebui.com{data["splash"]}' if data['splash'][0] == '/' else data['splash'] r = requests.get(url, stream=True) if r.status_code == 200: - with open(f"{STATIC_DIR}/splash.png", "wb") as f: + with open(f'{STATIC_DIR}/splash.png', 'wb') as f: r.raw.decode_content = True shutil.copyfileobj(r.raw, f) - WEBUI_NAME = data["name"] + WEBUI_NAME = data['name'] except Exception as e: log.exception(e) pass @@ -967,34 +893,30 @@ if CUSTOM_NAME: # STORAGE PROVIDER #################################### -STORAGE_PROVIDER = os.environ.get("STORAGE_PROVIDER", "local") # defaults to local, s3 +STORAGE_PROVIDER = os.environ.get('STORAGE_PROVIDER', 'local') # defaults to local, s3 -S3_ACCESS_KEY_ID = os.environ.get("S3_ACCESS_KEY_ID", None) -S3_SECRET_ACCESS_KEY = os.environ.get("S3_SECRET_ACCESS_KEY", None) -S3_REGION_NAME = os.environ.get("S3_REGION_NAME", None) -S3_BUCKET_NAME = os.environ.get("S3_BUCKET_NAME", None) -S3_KEY_PREFIX = os.environ.get("S3_KEY_PREFIX", None) -S3_ENDPOINT_URL = os.environ.get("S3_ENDPOINT_URL", None) -S3_USE_ACCELERATE_ENDPOINT = ( - os.environ.get("S3_USE_ACCELERATE_ENDPOINT", "false").lower() == "true" -) -S3_ADDRESSING_STYLE = os.environ.get("S3_ADDRESSING_STYLE", None) -S3_ENABLE_TAGGING = os.getenv("S3_ENABLE_TAGGING", "false").lower() == "true" +S3_ACCESS_KEY_ID = os.environ.get('S3_ACCESS_KEY_ID', None) +S3_SECRET_ACCESS_KEY = os.environ.get('S3_SECRET_ACCESS_KEY', None) +S3_REGION_NAME = os.environ.get('S3_REGION_NAME', None) +S3_BUCKET_NAME = os.environ.get('S3_BUCKET_NAME', None) +S3_KEY_PREFIX = os.environ.get('S3_KEY_PREFIX', None) +S3_ENDPOINT_URL = os.environ.get('S3_ENDPOINT_URL', None) +S3_USE_ACCELERATE_ENDPOINT = os.environ.get('S3_USE_ACCELERATE_ENDPOINT', 'false').lower() == 'true' +S3_ADDRESSING_STYLE = os.environ.get('S3_ADDRESSING_STYLE', None) +S3_ENABLE_TAGGING = os.getenv('S3_ENABLE_TAGGING', 'false').lower() == 'true' -GCS_BUCKET_NAME = os.environ.get("GCS_BUCKET_NAME", None) -GOOGLE_APPLICATION_CREDENTIALS_JSON = os.environ.get( - "GOOGLE_APPLICATION_CREDENTIALS_JSON", None -) +GCS_BUCKET_NAME = os.environ.get('GCS_BUCKET_NAME', None) +GOOGLE_APPLICATION_CREDENTIALS_JSON = os.environ.get('GOOGLE_APPLICATION_CREDENTIALS_JSON', None) -AZURE_STORAGE_ENDPOINT = os.environ.get("AZURE_STORAGE_ENDPOINT", None) -AZURE_STORAGE_CONTAINER_NAME = os.environ.get("AZURE_STORAGE_CONTAINER_NAME", None) -AZURE_STORAGE_KEY = os.environ.get("AZURE_STORAGE_KEY", None) +AZURE_STORAGE_ENDPOINT = os.environ.get('AZURE_STORAGE_ENDPOINT', None) +AZURE_STORAGE_CONTAINER_NAME = os.environ.get('AZURE_STORAGE_CONTAINER_NAME', None) +AZURE_STORAGE_KEY = os.environ.get('AZURE_STORAGE_KEY', None) #################################### # File Upload DIR #################################### -UPLOAD_DIR = DATA_DIR / "uploads" +UPLOAD_DIR = DATA_DIR / 'uploads' UPLOAD_DIR.mkdir(parents=True, exist_ok=True) @@ -1002,7 +924,7 @@ UPLOAD_DIR.mkdir(parents=True, exist_ok=True) # Cache DIR #################################### -CACHE_DIR = DATA_DIR / "cache" +CACHE_DIR = DATA_DIR / 'cache' CACHE_DIR.mkdir(parents=True, exist_ok=True) @@ -1011,9 +933,9 @@ CACHE_DIR.mkdir(parents=True, exist_ok=True) #################################### ENABLE_DIRECT_CONNECTIONS = PersistentConfig( - "ENABLE_DIRECT_CONNECTIONS", - "direct.enable", - os.environ.get("ENABLE_DIRECT_CONNECTIONS", "False").lower() == "true", + 'ENABLE_DIRECT_CONNECTIONS', + 'direct.enable', + os.environ.get('ENABLE_DIRECT_CONNECTIONS', 'False').lower() == 'true', ) #################################### @@ -1021,43 +943,35 @@ ENABLE_DIRECT_CONNECTIONS = PersistentConfig( #################################### ENABLE_OLLAMA_API = PersistentConfig( - "ENABLE_OLLAMA_API", - "ollama.enable", - os.environ.get("ENABLE_OLLAMA_API", "True").lower() == "true", + 'ENABLE_OLLAMA_API', + 'ollama.enable', + os.environ.get('ENABLE_OLLAMA_API', 'True').lower() == 'true', ) -OLLAMA_API_BASE_URL = os.environ.get( - "OLLAMA_API_BASE_URL", "http://localhost:11434/api" -) +OLLAMA_API_BASE_URL = os.environ.get('OLLAMA_API_BASE_URL', 'http://localhost:11434/api') -OLLAMA_BASE_URL = os.environ.get("OLLAMA_BASE_URL", "") +OLLAMA_BASE_URL = os.environ.get('OLLAMA_BASE_URL', '') if OLLAMA_BASE_URL: # Remove trailing slash - OLLAMA_BASE_URL = ( - OLLAMA_BASE_URL[:-1] if OLLAMA_BASE_URL.endswith("/") else OLLAMA_BASE_URL - ) + OLLAMA_BASE_URL = OLLAMA_BASE_URL[:-1] if OLLAMA_BASE_URL.endswith('/') else OLLAMA_BASE_URL -K8S_FLAG = os.environ.get("K8S_FLAG", "") -USE_OLLAMA_DOCKER = os.environ.get("USE_OLLAMA_DOCKER", "false") +K8S_FLAG = os.environ.get('K8S_FLAG', '') +USE_OLLAMA_DOCKER = os.environ.get('USE_OLLAMA_DOCKER', 'false') -if OLLAMA_BASE_URL == "" and OLLAMA_API_BASE_URL != "": - OLLAMA_BASE_URL = ( - OLLAMA_API_BASE_URL[:-4] - if OLLAMA_API_BASE_URL.endswith("/api") - else OLLAMA_API_BASE_URL - ) +if OLLAMA_BASE_URL == '' and OLLAMA_API_BASE_URL != '': + OLLAMA_BASE_URL = OLLAMA_API_BASE_URL[:-4] if OLLAMA_API_BASE_URL.endswith('/api') else OLLAMA_API_BASE_URL -if ENV == "prod": - if OLLAMA_BASE_URL == "/ollama" and not K8S_FLAG: - if USE_OLLAMA_DOCKER.lower() == "true": +if ENV == 'prod': + if OLLAMA_BASE_URL == '/ollama' and not K8S_FLAG: + if USE_OLLAMA_DOCKER.lower() == 'true': # if you use all-in-one docker container (Open WebUI + Ollama) # with the docker build arg USE_OLLAMA=true (--build-arg="USE_OLLAMA=true") this only works with http://localhost:11434 - OLLAMA_BASE_URL = "http://localhost:11434" + OLLAMA_BASE_URL = 'http://localhost:11434' else: - OLLAMA_BASE_URL = "http://host.docker.internal:11434" + OLLAMA_BASE_URL = 'http://host.docker.internal:11434' elif K8S_FLAG: - OLLAMA_BASE_URL = "http://ollama-service.open-webui.svc.cluster.local:11434" + OLLAMA_BASE_URL = 'http://ollama-service.open-webui.svc.cluster.local:11434' def _resolve_ollama_base_url(url: str) -> str: @@ -1070,40 +984,36 @@ def _resolve_ollama_base_url(url: str) -> str: except (OSError, TimeoutError): return False - host = urlparse(url).hostname or "localhost" + host = urlparse(url).hostname or 'localhost' with ThreadPoolExecutor(max_workers=2) as pool: default = pool.submit(reachable, host, 11434) fallback = pool.submit(reachable, host, 12434) if not default.result() and fallback.result(): - url = url.replace(":11434", ":12434") - log.info(f"Ollama port 11434 unreachable on {host}, falling back to 12434") + url = url.replace(':11434', ':12434') + log.info(f'Ollama port 11434 unreachable on {host}, falling back to 12434') elif not default.result(): - log.info(f"Ollama ports 11434 and 12434 both unreachable on {host}") + log.info(f'Ollama ports 11434 and 12434 both unreachable on {host}') return url # Auto-resolve Ollama port when no explicit URL was provided by the user. # The Dockerfile default is "/ollama" which the block above rewrites to :11434. -if os.environ.get("OLLAMA_BASE_URL", "") in ("", "/ollama") and not os.environ.get( - "OLLAMA_BASE_URLS", "" -): +if os.environ.get('OLLAMA_BASE_URL', '') in ('', '/ollama') and not os.environ.get('OLLAMA_BASE_URLS', ''): OLLAMA_BASE_URL = _resolve_ollama_base_url(OLLAMA_BASE_URL) -OLLAMA_BASE_URLS = os.environ.get("OLLAMA_BASE_URLS", "") -OLLAMA_BASE_URLS = OLLAMA_BASE_URLS if OLLAMA_BASE_URLS != "" else OLLAMA_BASE_URL +OLLAMA_BASE_URLS = os.environ.get('OLLAMA_BASE_URLS', '') +OLLAMA_BASE_URLS = OLLAMA_BASE_URLS if OLLAMA_BASE_URLS != '' else OLLAMA_BASE_URL -OLLAMA_BASE_URLS = [url.strip() for url in OLLAMA_BASE_URLS.split(";")] -OLLAMA_BASE_URLS = PersistentConfig( - "OLLAMA_BASE_URLS", "ollama.base_urls", OLLAMA_BASE_URLS -) +OLLAMA_BASE_URLS = [url.strip() for url in OLLAMA_BASE_URLS.split(';')] +OLLAMA_BASE_URLS = PersistentConfig('OLLAMA_BASE_URLS', 'ollama.base_urls', OLLAMA_BASE_URLS) OLLAMA_API_CONFIGS = PersistentConfig( - "OLLAMA_API_CONFIGS", - "ollama.api_configs", + 'OLLAMA_API_CONFIGS', + 'ollama.api_configs', {}, ) @@ -1113,61 +1023,52 @@ OLLAMA_API_CONFIGS = PersistentConfig( ENABLE_OPENAI_API = PersistentConfig( - "ENABLE_OPENAI_API", - "openai.enable", - os.environ.get("ENABLE_OPENAI_API", "True").lower() == "true", + 'ENABLE_OPENAI_API', + 'openai.enable', + os.environ.get('ENABLE_OPENAI_API', 'True').lower() == 'true', ) -OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY", "") -OPENAI_API_BASE_URL = os.environ.get("OPENAI_API_BASE_URL", "") +OPENAI_API_KEY = os.environ.get('OPENAI_API_KEY', '') +OPENAI_API_BASE_URL = os.environ.get('OPENAI_API_BASE_URL', '') -GEMINI_API_KEY = os.environ.get("GEMINI_API_KEY", "") -GEMINI_API_BASE_URL = os.environ.get("GEMINI_API_BASE_URL", "") +GEMINI_API_KEY = os.environ.get('GEMINI_API_KEY', '') +GEMINI_API_BASE_URL = os.environ.get('GEMINI_API_BASE_URL', '') -if OPENAI_API_BASE_URL == "": - OPENAI_API_BASE_URL = "https://api.openai.com/v1" +if OPENAI_API_BASE_URL == '': + OPENAI_API_BASE_URL = 'https://api.openai.com/v1' else: - if OPENAI_API_BASE_URL.endswith("/"): + if OPENAI_API_BASE_URL.endswith('/'): OPENAI_API_BASE_URL = OPENAI_API_BASE_URL[:-1] -OPENAI_API_KEYS = os.environ.get("OPENAI_API_KEYS", "") -OPENAI_API_KEYS = OPENAI_API_KEYS if OPENAI_API_KEYS != "" else OPENAI_API_KEY +OPENAI_API_KEYS = os.environ.get('OPENAI_API_KEYS', '') +OPENAI_API_KEYS = OPENAI_API_KEYS if OPENAI_API_KEYS != '' else OPENAI_API_KEY -OPENAI_API_KEYS = [url.strip() for url in OPENAI_API_KEYS.split(";")] -OPENAI_API_KEYS = PersistentConfig( - "OPENAI_API_KEYS", "openai.api_keys", OPENAI_API_KEYS -) +OPENAI_API_KEYS = [url.strip() for url in OPENAI_API_KEYS.split(';')] +OPENAI_API_KEYS = PersistentConfig('OPENAI_API_KEYS', 'openai.api_keys', OPENAI_API_KEYS) -OPENAI_API_BASE_URLS = os.environ.get("OPENAI_API_BASE_URLS", "") -OPENAI_API_BASE_URLS = ( - OPENAI_API_BASE_URLS if OPENAI_API_BASE_URLS != "" else OPENAI_API_BASE_URL -) +OPENAI_API_BASE_URLS = os.environ.get('OPENAI_API_BASE_URLS', '') +OPENAI_API_BASE_URLS = OPENAI_API_BASE_URLS if OPENAI_API_BASE_URLS != '' else OPENAI_API_BASE_URL OPENAI_API_BASE_URLS = [ - url.strip() if url != "" else "https://api.openai.com/v1" - for url in OPENAI_API_BASE_URLS.split(";") + url.strip() if url != '' else 'https://api.openai.com/v1' for url in OPENAI_API_BASE_URLS.split(';') ] -OPENAI_API_BASE_URLS = PersistentConfig( - "OPENAI_API_BASE_URLS", "openai.api_base_urls", OPENAI_API_BASE_URLS -) +OPENAI_API_BASE_URLS = PersistentConfig('OPENAI_API_BASE_URLS', 'openai.api_base_urls', OPENAI_API_BASE_URLS) OPENAI_API_CONFIGS = PersistentConfig( - "OPENAI_API_CONFIGS", - "openai.api_configs", + 'OPENAI_API_CONFIGS', + 'openai.api_configs', {}, ) # Get the actual OpenAI API key based on the base URL -OPENAI_API_KEY = "" +OPENAI_API_KEY = '' try: - OPENAI_API_KEY = OPENAI_API_KEYS.value[ - OPENAI_API_BASE_URLS.value.index("https://api.openai.com/v1") - ] + OPENAI_API_KEY = OPENAI_API_KEYS.value[OPENAI_API_BASE_URLS.value.index('https://api.openai.com/v1')] except Exception: pass -OPENAI_API_BASE_URL = "https://api.openai.com/v1" +OPENAI_API_BASE_URL = 'https://api.openai.com/v1' #################################### @@ -1175,9 +1076,9 @@ OPENAI_API_BASE_URL = "https://api.openai.com/v1" #################################### ENABLE_BASE_MODELS_CACHE = PersistentConfig( - "ENABLE_BASE_MODELS_CACHE", - "models.base_models_cache", - os.environ.get("ENABLE_BASE_MODELS_CACHE", "False").lower() == "true", + 'ENABLE_BASE_MODELS_CACHE', + 'models.base_models_cache', + os.environ.get('ENABLE_BASE_MODELS_CACHE', 'False').lower() == 'true', ) @@ -1186,17 +1087,15 @@ ENABLE_BASE_MODELS_CACHE = PersistentConfig( #################################### try: - tool_server_connections = json.loads( - os.environ.get("TOOL_SERVER_CONNECTIONS", "[]") - ) + tool_server_connections = json.loads(os.environ.get('TOOL_SERVER_CONNECTIONS', '[]')) except Exception as e: - log.exception(f"Error loading TOOL_SERVER_CONNECTIONS: {e}") + log.exception(f'Error loading TOOL_SERVER_CONNECTIONS: {e}') tool_server_connections = [] TOOL_SERVER_CONNECTIONS = PersistentConfig( - "TOOL_SERVER_CONNECTIONS", - "tool_server.connections", + 'TOOL_SERVER_CONNECTIONS', + 'tool_server.connections', tool_server_connections, ) @@ -1204,13 +1103,11 @@ TOOL_SERVER_CONNECTIONS = PersistentConfig( # TERMINAL_SERVER #################################### -terminal_server_connections = json.loads( - os.environ.get("TERMINAL_SERVER_CONNECTIONS", "[]") -) +terminal_server_connections = json.loads(os.environ.get('TERMINAL_SERVER_CONNECTIONS', '[]')) TERMINAL_SERVER_CONNECTIONS = PersistentConfig( - "TERMINAL_SERVER_CONNECTIONS", - "terminal_server.connections", + 'TERMINAL_SERVER_CONNECTIONS', + 'terminal_server.connections', terminal_server_connections, ) @@ -1219,579 +1116,478 @@ TERMINAL_SERVER_CONNECTIONS = PersistentConfig( #################################### -WEBUI_URL = PersistentConfig("WEBUI_URL", "webui.url", os.environ.get("WEBUI_URL", "")) +WEBUI_URL = PersistentConfig('WEBUI_URL', 'webui.url', os.environ.get('WEBUI_URL', '')) ENABLE_SIGNUP = PersistentConfig( - "ENABLE_SIGNUP", - "ui.enable_signup", - ( - False - if not WEBUI_AUTH - else os.environ.get("ENABLE_SIGNUP", "True").lower() == "true" - ), + 'ENABLE_SIGNUP', + 'ui.enable_signup', + (False if not WEBUI_AUTH else os.environ.get('ENABLE_SIGNUP', 'True').lower() == 'true'), ) ENABLE_LOGIN_FORM = PersistentConfig( - "ENABLE_LOGIN_FORM", - "ui.ENABLE_LOGIN_FORM", - os.environ.get("ENABLE_LOGIN_FORM", "True").lower() == "true", + 'ENABLE_LOGIN_FORM', + 'ui.ENABLE_LOGIN_FORM', + os.environ.get('ENABLE_LOGIN_FORM', 'True').lower() == 'true', ) -ENABLE_PASSWORD_AUTH = os.environ.get("ENABLE_PASSWORD_AUTH", "True").lower() == "true" +ENABLE_PASSWORD_AUTH = os.environ.get('ENABLE_PASSWORD_AUTH', 'True').lower() == 'true' DEFAULT_LOCALE = PersistentConfig( - "DEFAULT_LOCALE", - "ui.default_locale", - os.environ.get("DEFAULT_LOCALE", ""), + 'DEFAULT_LOCALE', + 'ui.default_locale', + os.environ.get('DEFAULT_LOCALE', ''), ) -DEFAULT_MODELS = PersistentConfig( - "DEFAULT_MODELS", "ui.default_models", os.environ.get("DEFAULT_MODELS", None) -) +DEFAULT_MODELS = PersistentConfig('DEFAULT_MODELS', 'ui.default_models', os.environ.get('DEFAULT_MODELS', None)) DEFAULT_PINNED_MODELS = PersistentConfig( - "DEFAULT_PINNED_MODELS", - "ui.default_pinned_models", - os.environ.get("DEFAULT_PINNED_MODELS", None), + 'DEFAULT_PINNED_MODELS', + 'ui.default_pinned_models', + os.environ.get('DEFAULT_PINNED_MODELS', None), ) try: - default_prompt_suggestions = json.loads( - os.environ.get("DEFAULT_PROMPT_SUGGESTIONS", "[]") - ) + default_prompt_suggestions = json.loads(os.environ.get('DEFAULT_PROMPT_SUGGESTIONS', '[]')) except Exception as e: - log.exception(f"Error loading DEFAULT_PROMPT_SUGGESTIONS: {e}") + log.exception(f'Error loading DEFAULT_PROMPT_SUGGESTIONS: {e}') default_prompt_suggestions = [] if default_prompt_suggestions == []: default_prompt_suggestions = [ { - "title": ["Help me study", "vocabulary for a college entrance exam"], - "content": "Help me study vocabulary: write a sentence for me to fill in the blank, and I'll try to pick the correct option.", + 'title': ['Help me study', 'vocabulary for a college entrance exam'], + 'content': "Help me study vocabulary: write a sentence for me to fill in the blank, and I'll try to pick the correct option.", }, { - "title": ["Give me ideas", "for what to do with my kids' art"], - "content": "What are 5 creative things I could do with my kids' art? I don't want to throw them away, but it's also so much clutter.", + 'title': ['Give me ideas', "for what to do with my kids' art"], + 'content': "What are 5 creative things I could do with my kids' art? I don't want to throw them away, but it's also so much clutter.", }, { - "title": ["Tell me a fun fact", "about the Roman Empire"], - "content": "Tell me a random fun fact about the Roman Empire", + 'title': ['Tell me a fun fact', 'about the Roman Empire'], + 'content': 'Tell me a random fun fact about the Roman Empire', }, { - "title": ["Show me a code snippet", "of a website's sticky header"], - "content": "Show me a code snippet of a website's sticky header in CSS and JavaScript.", + 'title': ['Show me a code snippet', "of a website's sticky header"], + 'content': "Show me a code snippet of a website's sticky header in CSS and JavaScript.", }, { - "title": [ - "Explain options trading", + 'title': [ + 'Explain options trading', "if I'm familiar with buying and selling stocks", ], - "content": "Explain options trading in simple terms if I'm familiar with buying and selling stocks.", + 'content': "Explain options trading in simple terms if I'm familiar with buying and selling stocks.", }, { - "title": ["Overcome procrastination", "give me tips"], - "content": "Could you start by asking me about instances when I procrastinate the most and then give me some suggestions to overcome it?", + 'title': ['Overcome procrastination', 'give me tips'], + 'content': 'Could you start by asking me about instances when I procrastinate the most and then give me some suggestions to overcome it?', }, ] DEFAULT_PROMPT_SUGGESTIONS = PersistentConfig( - "DEFAULT_PROMPT_SUGGESTIONS", - "ui.prompt_suggestions", + 'DEFAULT_PROMPT_SUGGESTIONS', + 'ui.prompt_suggestions', default_prompt_suggestions, ) MODEL_ORDER_LIST = PersistentConfig( - "MODEL_ORDER_LIST", - "ui.model_order_list", + 'MODEL_ORDER_LIST', + 'ui.model_order_list', [], ) DEFAULT_MODEL_METADATA = PersistentConfig( - "DEFAULT_MODEL_METADATA", - "models.default_metadata", + 'DEFAULT_MODEL_METADATA', + 'models.default_metadata', {}, ) DEFAULT_MODEL_PARAMS = PersistentConfig( - "DEFAULT_MODEL_PARAMS", - "models.default_params", + 'DEFAULT_MODEL_PARAMS', + 'models.default_params', {}, ) DEFAULT_USER_ROLE = PersistentConfig( - "DEFAULT_USER_ROLE", - "ui.default_user_role", - os.getenv("DEFAULT_USER_ROLE", "pending"), + 'DEFAULT_USER_ROLE', + 'ui.default_user_role', + os.getenv('DEFAULT_USER_ROLE', 'pending'), ) DEFAULT_GROUP_ID = PersistentConfig( - "DEFAULT_GROUP_ID", - "ui.default_group_id", - os.environ.get("DEFAULT_GROUP_ID", ""), + 'DEFAULT_GROUP_ID', + 'ui.default_group_id', + os.environ.get('DEFAULT_GROUP_ID', ''), ) PENDING_USER_OVERLAY_TITLE = PersistentConfig( - "PENDING_USER_OVERLAY_TITLE", - "ui.pending_user_overlay_title", - os.environ.get("PENDING_USER_OVERLAY_TITLE", ""), + 'PENDING_USER_OVERLAY_TITLE', + 'ui.pending_user_overlay_title', + os.environ.get('PENDING_USER_OVERLAY_TITLE', ''), ) PENDING_USER_OVERLAY_CONTENT = PersistentConfig( - "PENDING_USER_OVERLAY_CONTENT", - "ui.pending_user_overlay_content", - os.environ.get("PENDING_USER_OVERLAY_CONTENT", ""), + 'PENDING_USER_OVERLAY_CONTENT', + 'ui.pending_user_overlay_content', + os.environ.get('PENDING_USER_OVERLAY_CONTENT', ''), ) RESPONSE_WATERMARK = PersistentConfig( - "RESPONSE_WATERMARK", - "ui.watermark", - os.environ.get("RESPONSE_WATERMARK", ""), + 'RESPONSE_WATERMARK', + 'ui.watermark', + os.environ.get('RESPONSE_WATERMARK', ''), ) USER_PERMISSIONS_WORKSPACE_MODELS_ACCESS = ( - os.environ.get("USER_PERMISSIONS_WORKSPACE_MODELS_ACCESS", "False").lower() - == "true" + os.environ.get('USER_PERMISSIONS_WORKSPACE_MODELS_ACCESS', 'False').lower() == 'true' ) USER_PERMISSIONS_WORKSPACE_KNOWLEDGE_ACCESS = ( - os.environ.get("USER_PERMISSIONS_WORKSPACE_KNOWLEDGE_ACCESS", "False").lower() - == "true" + os.environ.get('USER_PERMISSIONS_WORKSPACE_KNOWLEDGE_ACCESS', 'False').lower() == 'true' ) USER_PERMISSIONS_WORKSPACE_PROMPTS_ACCESS = ( - os.environ.get("USER_PERMISSIONS_WORKSPACE_PROMPTS_ACCESS", "False").lower() - == "true" + os.environ.get('USER_PERMISSIONS_WORKSPACE_PROMPTS_ACCESS', 'False').lower() == 'true' ) USER_PERMISSIONS_WORKSPACE_TOOLS_ACCESS = ( - os.environ.get("USER_PERMISSIONS_WORKSPACE_TOOLS_ACCESS", "False").lower() == "true" + os.environ.get('USER_PERMISSIONS_WORKSPACE_TOOLS_ACCESS', 'False').lower() == 'true' ) USER_PERMISSIONS_WORKSPACE_SKILLS_ACCESS = ( - os.environ.get("USER_PERMISSIONS_WORKSPACE_SKILLS_ACCESS", "False").lower() - == "true" + os.environ.get('USER_PERMISSIONS_WORKSPACE_SKILLS_ACCESS', 'False').lower() == 'true' ) USER_PERMISSIONS_WORKSPACE_MODELS_IMPORT = ( - os.environ.get("USER_PERMISSIONS_WORKSPACE_MODELS_IMPORT", "False").lower() - == "true" + os.environ.get('USER_PERMISSIONS_WORKSPACE_MODELS_IMPORT', 'False').lower() == 'true' ) USER_PERMISSIONS_WORKSPACE_MODELS_EXPORT = ( - os.environ.get("USER_PERMISSIONS_WORKSPACE_MODELS_EXPORT", "False").lower() - == "true" + os.environ.get('USER_PERMISSIONS_WORKSPACE_MODELS_EXPORT', 'False').lower() == 'true' ) USER_PERMISSIONS_WORKSPACE_PROMPTS_IMPORT = ( - os.environ.get("USER_PERMISSIONS_WORKSPACE_PROMPTS_IMPORT", "False").lower() - == "true" + os.environ.get('USER_PERMISSIONS_WORKSPACE_PROMPTS_IMPORT', 'False').lower() == 'true' ) USER_PERMISSIONS_WORKSPACE_PROMPTS_EXPORT = ( - os.environ.get("USER_PERMISSIONS_WORKSPACE_PROMPTS_EXPORT", "False").lower() - == "true" + os.environ.get('USER_PERMISSIONS_WORKSPACE_PROMPTS_EXPORT', 'False').lower() == 'true' ) USER_PERMISSIONS_WORKSPACE_TOOLS_IMPORT = ( - os.environ.get("USER_PERMISSIONS_WORKSPACE_TOOLS_IMPORT", "False").lower() == "true" + os.environ.get('USER_PERMISSIONS_WORKSPACE_TOOLS_IMPORT', 'False').lower() == 'true' ) USER_PERMISSIONS_WORKSPACE_TOOLS_EXPORT = ( - os.environ.get("USER_PERMISSIONS_WORKSPACE_TOOLS_EXPORT", "False").lower() == "true" + os.environ.get('USER_PERMISSIONS_WORKSPACE_TOOLS_EXPORT', 'False').lower() == 'true' ) USER_PERMISSIONS_WORKSPACE_MODELS_ALLOW_SHARING = ( - os.environ.get("USER_PERMISSIONS_WORKSPACE_MODELS_ALLOW_SHARING", "False").lower() - == "true" + os.environ.get('USER_PERMISSIONS_WORKSPACE_MODELS_ALLOW_SHARING', 'False').lower() == 'true' ) USER_PERMISSIONS_WORKSPACE_MODELS_ALLOW_PUBLIC_SHARING = ( - os.environ.get( - "USER_PERMISSIONS_WORKSPACE_MODELS_ALLOW_PUBLIC_SHARING", "False" - ).lower() - == "true" + os.environ.get('USER_PERMISSIONS_WORKSPACE_MODELS_ALLOW_PUBLIC_SHARING', 'False').lower() == 'true' ) USER_PERMISSIONS_WORKSPACE_KNOWLEDGE_ALLOW_SHARING = ( - os.environ.get( - "USER_PERMISSIONS_WORKSPACE_KNOWLEDGE_ALLOW_SHARING", "False" - ).lower() - == "true" + os.environ.get('USER_PERMISSIONS_WORKSPACE_KNOWLEDGE_ALLOW_SHARING', 'False').lower() == 'true' ) USER_PERMISSIONS_WORKSPACE_KNOWLEDGE_ALLOW_PUBLIC_SHARING = ( - os.environ.get( - "USER_PERMISSIONS_WORKSPACE_KNOWLEDGE_ALLOW_PUBLIC_SHARING", "False" - ).lower() - == "true" + os.environ.get('USER_PERMISSIONS_WORKSPACE_KNOWLEDGE_ALLOW_PUBLIC_SHARING', 'False').lower() == 'true' ) USER_PERMISSIONS_WORKSPACE_PROMPTS_ALLOW_SHARING = ( - os.environ.get("USER_PERMISSIONS_WORKSPACE_PROMPTS_ALLOW_SHARING", "False").lower() - == "true" + os.environ.get('USER_PERMISSIONS_WORKSPACE_PROMPTS_ALLOW_SHARING', 'False').lower() == 'true' ) USER_PERMISSIONS_WORKSPACE_PROMPTS_ALLOW_PUBLIC_SHARING = ( - os.environ.get( - "USER_PERMISSIONS_WORKSPACE_PROMPTS_ALLOW_PUBLIC_SHARING", "False" - ).lower() - == "true" + os.environ.get('USER_PERMISSIONS_WORKSPACE_PROMPTS_ALLOW_PUBLIC_SHARING', 'False').lower() == 'true' ) USER_PERMISSIONS_WORKSPACE_TOOLS_ALLOW_SHARING = ( - os.environ.get("USER_PERMISSIONS_WORKSPACE_TOOLS_ALLOW_SHARING", "False").lower() - == "true" + os.environ.get('USER_PERMISSIONS_WORKSPACE_TOOLS_ALLOW_SHARING', 'False').lower() == 'true' ) USER_PERMISSIONS_WORKSPACE_TOOLS_ALLOW_PUBLIC_SHARING = ( - os.environ.get( - "USER_PERMISSIONS_WORKSPACE_TOOLS_ALLOW_PUBLIC_SHARING", "False" - ).lower() - == "true" + os.environ.get('USER_PERMISSIONS_WORKSPACE_TOOLS_ALLOW_PUBLIC_SHARING', 'False').lower() == 'true' ) USER_PERMISSIONS_WORKSPACE_SKILLS_ALLOW_SHARING = ( - os.environ.get("USER_PERMISSIONS_WORKSPACE_SKILLS_ALLOW_SHARING", "False").lower() - == "true" + os.environ.get('USER_PERMISSIONS_WORKSPACE_SKILLS_ALLOW_SHARING', 'False').lower() == 'true' ) USER_PERMISSIONS_WORKSPACE_SKILLS_ALLOW_PUBLIC_SHARING = ( - os.environ.get( - "USER_PERMISSIONS_WORKSPACE_SKILLS_ALLOW_PUBLIC_SHARING", "False" - ).lower() - == "true" + os.environ.get('USER_PERMISSIONS_WORKSPACE_SKILLS_ALLOW_PUBLIC_SHARING', 'False').lower() == 'true' ) -USER_PERMISSIONS_NOTES_ALLOW_SHARING = ( - os.environ.get("USER_PERMISSIONS_NOTES_ALLOW_SHARING", "False").lower() == "true" -) +USER_PERMISSIONS_NOTES_ALLOW_SHARING = os.environ.get('USER_PERMISSIONS_NOTES_ALLOW_SHARING', 'False').lower() == 'true' USER_PERMISSIONS_NOTES_ALLOW_PUBLIC_SHARING = ( - os.environ.get("USER_PERMISSIONS_NOTES_ALLOW_PUBLIC_SHARING", "False").lower() - == "true" + os.environ.get('USER_PERMISSIONS_NOTES_ALLOW_PUBLIC_SHARING', 'False').lower() == 'true' ) USER_PERMISSIONS_ACCESS_GRANTS_ALLOW_USERS = ( - os.environ.get("USER_PERMISSIONS_ACCESS_GRANTS_ALLOW_USERS", "True").lower() - == "true" + os.environ.get('USER_PERMISSIONS_ACCESS_GRANTS_ALLOW_USERS', 'True').lower() == 'true' ) -USER_PERMISSIONS_CHAT_CONTROLS = ( - os.environ.get("USER_PERMISSIONS_CHAT_CONTROLS", "True").lower() == "true" -) +USER_PERMISSIONS_CHAT_CONTROLS = os.environ.get('USER_PERMISSIONS_CHAT_CONTROLS', 'True').lower() == 'true' -USER_PERMISSIONS_CHAT_VALVES = ( - os.environ.get("USER_PERMISSIONS_CHAT_VALVES", "True").lower() == "true" -) +USER_PERMISSIONS_CHAT_VALVES = os.environ.get('USER_PERMISSIONS_CHAT_VALVES', 'True').lower() == 'true' -USER_PERMISSIONS_CHAT_SYSTEM_PROMPT = ( - os.environ.get("USER_PERMISSIONS_CHAT_SYSTEM_PROMPT", "True").lower() == "true" -) +USER_PERMISSIONS_CHAT_SYSTEM_PROMPT = os.environ.get('USER_PERMISSIONS_CHAT_SYSTEM_PROMPT', 'True').lower() == 'true' -USER_PERMISSIONS_CHAT_PARAMS = ( - os.environ.get("USER_PERMISSIONS_CHAT_PARAMS", "True").lower() == "true" -) +USER_PERMISSIONS_CHAT_PARAMS = os.environ.get('USER_PERMISSIONS_CHAT_PARAMS', 'True').lower() == 'true' -USER_PERMISSIONS_CHAT_FILE_UPLOAD = ( - os.environ.get("USER_PERMISSIONS_CHAT_FILE_UPLOAD", "True").lower() == "true" -) +USER_PERMISSIONS_CHAT_FILE_UPLOAD = os.environ.get('USER_PERMISSIONS_CHAT_FILE_UPLOAD', 'True').lower() == 'true' -USER_PERMISSIONS_CHAT_WEB_UPLOAD = ( - os.environ.get("USER_PERMISSIONS_CHAT_WEB_UPLOAD", "True").lower() == "true" -) +USER_PERMISSIONS_CHAT_WEB_UPLOAD = os.environ.get('USER_PERMISSIONS_CHAT_WEB_UPLOAD', 'True').lower() == 'true' -USER_PERMISSIONS_CHAT_DELETE = ( - os.environ.get("USER_PERMISSIONS_CHAT_DELETE", "True").lower() == "true" -) +USER_PERMISSIONS_CHAT_DELETE = os.environ.get('USER_PERMISSIONS_CHAT_DELETE', 'True').lower() == 'true' -USER_PERMISSIONS_CHAT_DELETE_MESSAGE = ( - os.environ.get("USER_PERMISSIONS_CHAT_DELETE_MESSAGE", "True").lower() == "true" -) +USER_PERMISSIONS_CHAT_DELETE_MESSAGE = os.environ.get('USER_PERMISSIONS_CHAT_DELETE_MESSAGE', 'True').lower() == 'true' USER_PERMISSIONS_CHAT_CONTINUE_RESPONSE = ( - os.environ.get("USER_PERMISSIONS_CHAT_CONTINUE_RESPONSE", "True").lower() == "true" + os.environ.get('USER_PERMISSIONS_CHAT_CONTINUE_RESPONSE', 'True').lower() == 'true' ) USER_PERMISSIONS_CHAT_REGENERATE_RESPONSE = ( - os.environ.get("USER_PERMISSIONS_CHAT_REGENERATE_RESPONSE", "True").lower() - == "true" + os.environ.get('USER_PERMISSIONS_CHAT_REGENERATE_RESPONSE', 'True').lower() == 'true' ) -USER_PERMISSIONS_CHAT_RATE_RESPONSE = ( - os.environ.get("USER_PERMISSIONS_CHAT_RATE_RESPONSE", "True").lower() == "true" -) +USER_PERMISSIONS_CHAT_RATE_RESPONSE = os.environ.get('USER_PERMISSIONS_CHAT_RATE_RESPONSE', 'True').lower() == 'true' -USER_PERMISSIONS_CHAT_EDIT = ( - os.environ.get("USER_PERMISSIONS_CHAT_EDIT", "True").lower() == "true" -) +USER_PERMISSIONS_CHAT_EDIT = os.environ.get('USER_PERMISSIONS_CHAT_EDIT', 'True').lower() == 'true' -USER_PERMISSIONS_CHAT_SHARE = ( - os.environ.get("USER_PERMISSIONS_CHAT_SHARE", "True").lower() == "true" -) +USER_PERMISSIONS_CHAT_SHARE = os.environ.get('USER_PERMISSIONS_CHAT_SHARE', 'True').lower() == 'true' -USER_PERMISSIONS_CHAT_EXPORT = ( - os.environ.get("USER_PERMISSIONS_CHAT_EXPORT", "True").lower() == "true" -) +USER_PERMISSIONS_CHAT_EXPORT = os.environ.get('USER_PERMISSIONS_CHAT_EXPORT', 'True').lower() == 'true' -USER_PERMISSIONS_CHAT_STT = ( - os.environ.get("USER_PERMISSIONS_CHAT_STT", "True").lower() == "true" -) +USER_PERMISSIONS_CHAT_STT = os.environ.get('USER_PERMISSIONS_CHAT_STT', 'True').lower() == 'true' -USER_PERMISSIONS_CHAT_TTS = ( - os.environ.get("USER_PERMISSIONS_CHAT_TTS", "True").lower() == "true" -) +USER_PERMISSIONS_CHAT_TTS = os.environ.get('USER_PERMISSIONS_CHAT_TTS', 'True').lower() == 'true' -USER_PERMISSIONS_CHAT_CALL = ( - os.environ.get("USER_PERMISSIONS_CHAT_CALL", "True").lower() == "true" -) +USER_PERMISSIONS_CHAT_CALL = os.environ.get('USER_PERMISSIONS_CHAT_CALL', 'True').lower() == 'true' USER_PERMISSIONS_CHAT_MULTIPLE_MODELS = ( - os.environ.get("USER_PERMISSIONS_CHAT_MULTIPLE_MODELS", "True").lower() == "true" + os.environ.get('USER_PERMISSIONS_CHAT_MULTIPLE_MODELS', 'True').lower() == 'true' ) -USER_PERMISSIONS_CHAT_TEMPORARY = ( - os.environ.get("USER_PERMISSIONS_CHAT_TEMPORARY", "True").lower() == "true" -) +USER_PERMISSIONS_CHAT_TEMPORARY = os.environ.get('USER_PERMISSIONS_CHAT_TEMPORARY', 'True').lower() == 'true' USER_PERMISSIONS_CHAT_TEMPORARY_ENFORCED = ( - os.environ.get("USER_PERMISSIONS_CHAT_TEMPORARY_ENFORCED", "False").lower() - == "true" + os.environ.get('USER_PERMISSIONS_CHAT_TEMPORARY_ENFORCED', 'False').lower() == 'true' ) USER_PERMISSIONS_FEATURES_DIRECT_TOOL_SERVERS = ( - os.environ.get("USER_PERMISSIONS_FEATURES_DIRECT_TOOL_SERVERS", "False").lower() - == "true" + os.environ.get('USER_PERMISSIONS_FEATURES_DIRECT_TOOL_SERVERS', 'False').lower() == 'true' ) -USER_PERMISSIONS_FEATURES_WEB_SEARCH = ( - os.environ.get("USER_PERMISSIONS_FEATURES_WEB_SEARCH", "True").lower() == "true" -) +USER_PERMISSIONS_FEATURES_WEB_SEARCH = os.environ.get('USER_PERMISSIONS_FEATURES_WEB_SEARCH', 'True').lower() == 'true' USER_PERMISSIONS_FEATURES_IMAGE_GENERATION = ( - os.environ.get("USER_PERMISSIONS_FEATURES_IMAGE_GENERATION", "True").lower() - == "true" + os.environ.get('USER_PERMISSIONS_FEATURES_IMAGE_GENERATION', 'True').lower() == 'true' ) USER_PERMISSIONS_FEATURES_CODE_INTERPRETER = ( - os.environ.get("USER_PERMISSIONS_FEATURES_CODE_INTERPRETER", "True").lower() - == "true" + os.environ.get('USER_PERMISSIONS_FEATURES_CODE_INTERPRETER', 'True').lower() == 'true' ) -USER_PERMISSIONS_FEATURES_FOLDERS = ( - os.environ.get("USER_PERMISSIONS_FEATURES_FOLDERS", "True").lower() == "true" -) +USER_PERMISSIONS_FEATURES_FOLDERS = os.environ.get('USER_PERMISSIONS_FEATURES_FOLDERS', 'True').lower() == 'true' -USER_PERMISSIONS_FEATURES_NOTES = ( - os.environ.get("USER_PERMISSIONS_FEATURES_NOTES", "True").lower() == "true" -) +USER_PERMISSIONS_FEATURES_NOTES = os.environ.get('USER_PERMISSIONS_FEATURES_NOTES', 'True').lower() == 'true' -USER_PERMISSIONS_FEATURES_CHANNELS = ( - os.environ.get("USER_PERMISSIONS_FEATURES_CHANNELS", "True").lower() == "true" -) +USER_PERMISSIONS_FEATURES_CHANNELS = os.environ.get('USER_PERMISSIONS_FEATURES_CHANNELS', 'True').lower() == 'true' -USER_PERMISSIONS_FEATURES_API_KEYS = ( - os.environ.get("USER_PERMISSIONS_FEATURES_API_KEYS", "False").lower() == "true" -) +USER_PERMISSIONS_FEATURES_API_KEYS = os.environ.get('USER_PERMISSIONS_FEATURES_API_KEYS', 'False').lower() == 'true' -USER_PERMISSIONS_FEATURES_MEMORIES = ( - os.environ.get("USER_PERMISSIONS_FEATURES_MEMORIES", "True").lower() == "true" -) +USER_PERMISSIONS_FEATURES_MEMORIES = os.environ.get('USER_PERMISSIONS_FEATURES_MEMORIES', 'True').lower() == 'true' -USER_PERMISSIONS_SETTINGS_INTERFACE = ( - os.environ.get("USER_PERMISSIONS_SETTINGS_INTERFACE", "True").lower() == "true" -) +USER_PERMISSIONS_SETTINGS_INTERFACE = os.environ.get('USER_PERMISSIONS_SETTINGS_INTERFACE', 'True').lower() == 'true' DEFAULT_USER_PERMISSIONS = { - "workspace": { - "models": USER_PERMISSIONS_WORKSPACE_MODELS_ACCESS, - "knowledge": USER_PERMISSIONS_WORKSPACE_KNOWLEDGE_ACCESS, - "prompts": USER_PERMISSIONS_WORKSPACE_PROMPTS_ACCESS, - "tools": USER_PERMISSIONS_WORKSPACE_TOOLS_ACCESS, - "skills": USER_PERMISSIONS_WORKSPACE_SKILLS_ACCESS, - "models_import": USER_PERMISSIONS_WORKSPACE_MODELS_IMPORT, - "models_export": USER_PERMISSIONS_WORKSPACE_MODELS_EXPORT, - "prompts_import": USER_PERMISSIONS_WORKSPACE_PROMPTS_IMPORT, - "prompts_export": USER_PERMISSIONS_WORKSPACE_PROMPTS_EXPORT, - "tools_import": USER_PERMISSIONS_WORKSPACE_TOOLS_IMPORT, - "tools_export": USER_PERMISSIONS_WORKSPACE_TOOLS_EXPORT, + 'workspace': { + 'models': USER_PERMISSIONS_WORKSPACE_MODELS_ACCESS, + 'knowledge': USER_PERMISSIONS_WORKSPACE_KNOWLEDGE_ACCESS, + 'prompts': USER_PERMISSIONS_WORKSPACE_PROMPTS_ACCESS, + 'tools': USER_PERMISSIONS_WORKSPACE_TOOLS_ACCESS, + 'skills': USER_PERMISSIONS_WORKSPACE_SKILLS_ACCESS, + 'models_import': USER_PERMISSIONS_WORKSPACE_MODELS_IMPORT, + 'models_export': USER_PERMISSIONS_WORKSPACE_MODELS_EXPORT, + 'prompts_import': USER_PERMISSIONS_WORKSPACE_PROMPTS_IMPORT, + 'prompts_export': USER_PERMISSIONS_WORKSPACE_PROMPTS_EXPORT, + 'tools_import': USER_PERMISSIONS_WORKSPACE_TOOLS_IMPORT, + 'tools_export': USER_PERMISSIONS_WORKSPACE_TOOLS_EXPORT, }, - "sharing": { - "models": USER_PERMISSIONS_WORKSPACE_MODELS_ALLOW_SHARING, - "public_models": USER_PERMISSIONS_WORKSPACE_MODELS_ALLOW_PUBLIC_SHARING, - "knowledge": USER_PERMISSIONS_WORKSPACE_KNOWLEDGE_ALLOW_SHARING, - "public_knowledge": USER_PERMISSIONS_WORKSPACE_KNOWLEDGE_ALLOW_PUBLIC_SHARING, - "prompts": USER_PERMISSIONS_WORKSPACE_PROMPTS_ALLOW_SHARING, - "public_prompts": USER_PERMISSIONS_WORKSPACE_PROMPTS_ALLOW_PUBLIC_SHARING, - "tools": USER_PERMISSIONS_WORKSPACE_TOOLS_ALLOW_SHARING, - "public_tools": USER_PERMISSIONS_WORKSPACE_TOOLS_ALLOW_PUBLIC_SHARING, - "skills": USER_PERMISSIONS_WORKSPACE_SKILLS_ALLOW_SHARING, - "public_skills": USER_PERMISSIONS_WORKSPACE_SKILLS_ALLOW_PUBLIC_SHARING, - "notes": USER_PERMISSIONS_NOTES_ALLOW_SHARING, - "public_notes": USER_PERMISSIONS_NOTES_ALLOW_PUBLIC_SHARING, + 'sharing': { + 'models': USER_PERMISSIONS_WORKSPACE_MODELS_ALLOW_SHARING, + 'public_models': USER_PERMISSIONS_WORKSPACE_MODELS_ALLOW_PUBLIC_SHARING, + 'knowledge': USER_PERMISSIONS_WORKSPACE_KNOWLEDGE_ALLOW_SHARING, + 'public_knowledge': USER_PERMISSIONS_WORKSPACE_KNOWLEDGE_ALLOW_PUBLIC_SHARING, + 'prompts': USER_PERMISSIONS_WORKSPACE_PROMPTS_ALLOW_SHARING, + 'public_prompts': USER_PERMISSIONS_WORKSPACE_PROMPTS_ALLOW_PUBLIC_SHARING, + 'tools': USER_PERMISSIONS_WORKSPACE_TOOLS_ALLOW_SHARING, + 'public_tools': USER_PERMISSIONS_WORKSPACE_TOOLS_ALLOW_PUBLIC_SHARING, + 'skills': USER_PERMISSIONS_WORKSPACE_SKILLS_ALLOW_SHARING, + 'public_skills': USER_PERMISSIONS_WORKSPACE_SKILLS_ALLOW_PUBLIC_SHARING, + 'notes': USER_PERMISSIONS_NOTES_ALLOW_SHARING, + 'public_notes': USER_PERMISSIONS_NOTES_ALLOW_PUBLIC_SHARING, }, - "access_grants": { - "allow_users": USER_PERMISSIONS_ACCESS_GRANTS_ALLOW_USERS, + 'access_grants': { + 'allow_users': USER_PERMISSIONS_ACCESS_GRANTS_ALLOW_USERS, }, - "chat": { - "controls": USER_PERMISSIONS_CHAT_CONTROLS, - "valves": USER_PERMISSIONS_CHAT_VALVES, - "system_prompt": USER_PERMISSIONS_CHAT_SYSTEM_PROMPT, - "params": USER_PERMISSIONS_CHAT_PARAMS, - "file_upload": USER_PERMISSIONS_CHAT_FILE_UPLOAD, - "web_upload": USER_PERMISSIONS_CHAT_WEB_UPLOAD, - "delete": USER_PERMISSIONS_CHAT_DELETE, - "delete_message": USER_PERMISSIONS_CHAT_DELETE_MESSAGE, - "continue_response": USER_PERMISSIONS_CHAT_CONTINUE_RESPONSE, - "regenerate_response": USER_PERMISSIONS_CHAT_REGENERATE_RESPONSE, - "rate_response": USER_PERMISSIONS_CHAT_RATE_RESPONSE, - "edit": USER_PERMISSIONS_CHAT_EDIT, - "share": USER_PERMISSIONS_CHAT_SHARE, - "export": USER_PERMISSIONS_CHAT_EXPORT, - "stt": USER_PERMISSIONS_CHAT_STT, - "tts": USER_PERMISSIONS_CHAT_TTS, - "call": USER_PERMISSIONS_CHAT_CALL, - "multiple_models": USER_PERMISSIONS_CHAT_MULTIPLE_MODELS, - "temporary": USER_PERMISSIONS_CHAT_TEMPORARY, - "temporary_enforced": USER_PERMISSIONS_CHAT_TEMPORARY_ENFORCED, + 'chat': { + 'controls': USER_PERMISSIONS_CHAT_CONTROLS, + 'valves': USER_PERMISSIONS_CHAT_VALVES, + 'system_prompt': USER_PERMISSIONS_CHAT_SYSTEM_PROMPT, + 'params': USER_PERMISSIONS_CHAT_PARAMS, + 'file_upload': USER_PERMISSIONS_CHAT_FILE_UPLOAD, + 'web_upload': USER_PERMISSIONS_CHAT_WEB_UPLOAD, + 'delete': USER_PERMISSIONS_CHAT_DELETE, + 'delete_message': USER_PERMISSIONS_CHAT_DELETE_MESSAGE, + 'continue_response': USER_PERMISSIONS_CHAT_CONTINUE_RESPONSE, + 'regenerate_response': USER_PERMISSIONS_CHAT_REGENERATE_RESPONSE, + 'rate_response': USER_PERMISSIONS_CHAT_RATE_RESPONSE, + 'edit': USER_PERMISSIONS_CHAT_EDIT, + 'share': USER_PERMISSIONS_CHAT_SHARE, + 'export': USER_PERMISSIONS_CHAT_EXPORT, + 'stt': USER_PERMISSIONS_CHAT_STT, + 'tts': USER_PERMISSIONS_CHAT_TTS, + 'call': USER_PERMISSIONS_CHAT_CALL, + 'multiple_models': USER_PERMISSIONS_CHAT_MULTIPLE_MODELS, + 'temporary': USER_PERMISSIONS_CHAT_TEMPORARY, + 'temporary_enforced': USER_PERMISSIONS_CHAT_TEMPORARY_ENFORCED, }, - "features": { + 'features': { # General features - "api_keys": USER_PERMISSIONS_FEATURES_API_KEYS, - "notes": USER_PERMISSIONS_FEATURES_NOTES, - "folders": USER_PERMISSIONS_FEATURES_FOLDERS, - "channels": USER_PERMISSIONS_FEATURES_CHANNELS, - "direct_tool_servers": USER_PERMISSIONS_FEATURES_DIRECT_TOOL_SERVERS, + 'api_keys': USER_PERMISSIONS_FEATURES_API_KEYS, + 'notes': USER_PERMISSIONS_FEATURES_NOTES, + 'folders': USER_PERMISSIONS_FEATURES_FOLDERS, + 'channels': USER_PERMISSIONS_FEATURES_CHANNELS, + 'direct_tool_servers': USER_PERMISSIONS_FEATURES_DIRECT_TOOL_SERVERS, # Chat features - "web_search": USER_PERMISSIONS_FEATURES_WEB_SEARCH, - "image_generation": USER_PERMISSIONS_FEATURES_IMAGE_GENERATION, - "code_interpreter": USER_PERMISSIONS_FEATURES_CODE_INTERPRETER, - "memories": USER_PERMISSIONS_FEATURES_MEMORIES, + 'web_search': USER_PERMISSIONS_FEATURES_WEB_SEARCH, + 'image_generation': USER_PERMISSIONS_FEATURES_IMAGE_GENERATION, + 'code_interpreter': USER_PERMISSIONS_FEATURES_CODE_INTERPRETER, + 'memories': USER_PERMISSIONS_FEATURES_MEMORIES, }, - "settings": { - "interface": USER_PERMISSIONS_SETTINGS_INTERFACE, + 'settings': { + 'interface': USER_PERMISSIONS_SETTINGS_INTERFACE, }, } USER_PERMISSIONS = PersistentConfig( - "USER_PERMISSIONS", - "user.permissions", + 'USER_PERMISSIONS', + 'user.permissions', DEFAULT_USER_PERMISSIONS, ) ENABLE_FOLDERS = PersistentConfig( - "ENABLE_FOLDERS", - "folders.enable", - os.environ.get("ENABLE_FOLDERS", "True").lower() == "true", + 'ENABLE_FOLDERS', + 'folders.enable', + os.environ.get('ENABLE_FOLDERS', 'True').lower() == 'true', ) FOLDER_MAX_FILE_COUNT = PersistentConfig( - "FOLDER_MAX_FILE_COUNT", - "folders.max_file_count", - os.environ.get("FOLDER_MAX_FILE_COUNT", ""), + 'FOLDER_MAX_FILE_COUNT', + 'folders.max_file_count', + os.environ.get('FOLDER_MAX_FILE_COUNT', ''), ) ENABLE_CHANNELS = PersistentConfig( - "ENABLE_CHANNELS", - "channels.enable", - os.environ.get("ENABLE_CHANNELS", "False").lower() == "true", + 'ENABLE_CHANNELS', + 'channels.enable', + os.environ.get('ENABLE_CHANNELS', 'False').lower() == 'true', ) ENABLE_NOTES = PersistentConfig( - "ENABLE_NOTES", - "notes.enable", - os.environ.get("ENABLE_NOTES", "True").lower() == "true", + 'ENABLE_NOTES', + 'notes.enable', + os.environ.get('ENABLE_NOTES', 'True').lower() == 'true', ) ENABLE_USER_STATUS = PersistentConfig( - "ENABLE_USER_STATUS", - "users.enable_status", - os.environ.get("ENABLE_USER_STATUS", "True").lower() == "true", + 'ENABLE_USER_STATUS', + 'users.enable_status', + os.environ.get('ENABLE_USER_STATUS', 'True').lower() == 'true', ) ENABLE_EVALUATION_ARENA_MODELS = PersistentConfig( - "ENABLE_EVALUATION_ARENA_MODELS", - "evaluation.arena.enable", - os.environ.get("ENABLE_EVALUATION_ARENA_MODELS", "True").lower() == "true", + 'ENABLE_EVALUATION_ARENA_MODELS', + 'evaluation.arena.enable', + os.environ.get('ENABLE_EVALUATION_ARENA_MODELS', 'True').lower() == 'true', ) EVALUATION_ARENA_MODELS = PersistentConfig( - "EVALUATION_ARENA_MODELS", - "evaluation.arena.models", + 'EVALUATION_ARENA_MODELS', + 'evaluation.arena.models', [], ) DEFAULT_ARENA_MODEL = { - "id": "arena-model", - "name": "Arena Model", - "meta": { - "profile_image_url": "/favicon.png", - "description": "Submit your questions to anonymous AI chatbots and vote on the best response.", - "model_ids": None, + 'id': 'arena-model', + 'name': 'Arena Model', + 'meta': { + 'profile_image_url': '/favicon.png', + 'description': 'Submit your questions to anonymous AI chatbots and vote on the best response.', + 'model_ids': None, }, } -WEBHOOK_URL = PersistentConfig( - "WEBHOOK_URL", "webhook_url", os.environ.get("WEBHOOK_URL", "") -) +WEBHOOK_URL = PersistentConfig('WEBHOOK_URL', 'webhook_url', os.environ.get('WEBHOOK_URL', '')) -ENABLE_ADMIN_EXPORT = os.environ.get("ENABLE_ADMIN_EXPORT", "True").lower() == "true" +ENABLE_ADMIN_EXPORT = os.environ.get('ENABLE_ADMIN_EXPORT', 'True').lower() == 'true' ENABLE_ADMIN_WORKSPACE_CONTENT_ACCESS = ( - os.environ.get("ENABLE_ADMIN_WORKSPACE_CONTENT_ACCESS", "True").lower() == "true" + os.environ.get('ENABLE_ADMIN_WORKSPACE_CONTENT_ACCESS', 'True').lower() == 'true' ) BYPASS_ADMIN_ACCESS_CONTROL = ( os.environ.get( - "BYPASS_ADMIN_ACCESS_CONTROL", - os.environ.get("ENABLE_ADMIN_WORKSPACE_CONTENT_ACCESS", "True"), + 'BYPASS_ADMIN_ACCESS_CONTROL', + os.environ.get('ENABLE_ADMIN_WORKSPACE_CONTENT_ACCESS', 'True'), ).lower() - == "true" + == 'true' ) -ENABLE_ADMIN_CHAT_ACCESS = ( - os.environ.get("ENABLE_ADMIN_CHAT_ACCESS", "True").lower() == "true" -) +ENABLE_ADMIN_CHAT_ACCESS = os.environ.get('ENABLE_ADMIN_CHAT_ACCESS', 'True').lower() == 'true' -ENABLE_ADMIN_ANALYTICS = ( - os.environ.get("ENABLE_ADMIN_ANALYTICS", "True").lower() == "true" -) +ENABLE_ADMIN_ANALYTICS = os.environ.get('ENABLE_ADMIN_ANALYTICS', 'True').lower() == 'true' ENABLE_COMMUNITY_SHARING = PersistentConfig( - "ENABLE_COMMUNITY_SHARING", - "ui.enable_community_sharing", - os.environ.get("ENABLE_COMMUNITY_SHARING", "True").lower() == "true", + 'ENABLE_COMMUNITY_SHARING', + 'ui.enable_community_sharing', + os.environ.get('ENABLE_COMMUNITY_SHARING', 'True').lower() == 'true', ) ENABLE_MESSAGE_RATING = PersistentConfig( - "ENABLE_MESSAGE_RATING", - "ui.enable_message_rating", - os.environ.get("ENABLE_MESSAGE_RATING", "True").lower() == "true", + 'ENABLE_MESSAGE_RATING', + 'ui.enable_message_rating', + os.environ.get('ENABLE_MESSAGE_RATING', 'True').lower() == 'true', ) ENABLE_USER_WEBHOOKS = PersistentConfig( - "ENABLE_USER_WEBHOOKS", - "ui.enable_user_webhooks", - os.environ.get("ENABLE_USER_WEBHOOKS", "True").lower() == "true", + 'ENABLE_USER_WEBHOOKS', + 'ui.enable_user_webhooks', + os.environ.get('ENABLE_USER_WEBHOOKS', 'True').lower() == 'true', ) # FastAPI / AnyIO settings -THREAD_POOL_SIZE = os.getenv("THREAD_POOL_SIZE", None) +THREAD_POOL_SIZE = os.getenv('THREAD_POOL_SIZE', None) if THREAD_POOL_SIZE is not None and isinstance(THREAD_POOL_SIZE, str): try: THREAD_POOL_SIZE = int(THREAD_POOL_SIZE) except ValueError: - log.warning( - f"THREAD_POOL_SIZE is not a valid integer: {THREAD_POOL_SIZE}. Defaulting to None." - ) + log.warning(f'THREAD_POOL_SIZE is not a valid integer: {THREAD_POOL_SIZE}. Defaulting to None.') THREAD_POOL_SIZE = None @@ -1799,7 +1595,7 @@ def validate_cors_origin(origin): parsed_url = urlparse(origin) # Check if the scheme is either http or https, or a custom scheme - schemes = ["http", "https"] + CORS_ALLOW_CUSTOM_SCHEME + schemes = ['http', 'https'] + CORS_ALLOW_CUSTOM_SCHEME if parsed_url.scheme not in schemes: raise ValueError( f"Invalid scheme in CORS_ALLOW_ORIGIN: '{origin}'. Only 'http' and 'https' and CORS_ALLOW_CUSTOM_SCHEME are allowed." @@ -1815,17 +1611,15 @@ def validate_cors_origin(origin): # To test CORS_ALLOW_ORIGIN locally, you can set something like # CORS_ALLOW_ORIGIN=http://localhost:5173;http://localhost:8080 # in your .env file depending on your frontend port, 5173 in this case. -CORS_ALLOW_ORIGIN = os.environ.get("CORS_ALLOW_ORIGIN", "*").split(";") +CORS_ALLOW_ORIGIN = os.environ.get('CORS_ALLOW_ORIGIN', '*').split(';') # Allows custom URL schemes (e.g., app://) to be used as origins for CORS. # Useful for local development or desktop clients with schemes like app:// or other custom protocols. # Provide a semicolon-separated list of allowed schemes in the environment variable CORS_ALLOW_CUSTOM_SCHEMES. -CORS_ALLOW_CUSTOM_SCHEME = os.environ.get("CORS_ALLOW_CUSTOM_SCHEME", "").split(";") +CORS_ALLOW_CUSTOM_SCHEME = os.environ.get('CORS_ALLOW_CUSTOM_SCHEME', '').split(';') -if CORS_ALLOW_ORIGIN == ["*"]: - log.warning( - "\n\nWARNING: CORS_ALLOW_ORIGIN IS SET TO '*' - NOT RECOMMENDED FOR PRODUCTION DEPLOYMENTS.\n" - ) +if CORS_ALLOW_ORIGIN == ['*']: + log.warning("\n\nWARNING: CORS_ALLOW_ORIGIN IS SET TO '*' - NOT RECOMMENDED FOR PRODUCTION DEPLOYMENTS.\n") else: # You have to pick between a single wildcard or a list of origins. # Doing both will result in CORS errors in the browser. @@ -1843,25 +1637,25 @@ class BannerModel(BaseModel): try: - banners = json.loads(os.environ.get("WEBUI_BANNERS", "[]")) + banners = json.loads(os.environ.get('WEBUI_BANNERS', '[]')) banners = [BannerModel(**banner) for banner in banners] except Exception as e: - log.exception(f"Error loading WEBUI_BANNERS: {e}") + log.exception(f'Error loading WEBUI_BANNERS: {e}') banners = [] -WEBUI_BANNERS = PersistentConfig("WEBUI_BANNERS", "ui.banners", banners) +WEBUI_BANNERS = PersistentConfig('WEBUI_BANNERS', 'ui.banners', banners) SHOW_ADMIN_DETAILS = PersistentConfig( - "SHOW_ADMIN_DETAILS", - "auth.admin.show", - os.environ.get("SHOW_ADMIN_DETAILS", "true").lower() == "true", + 'SHOW_ADMIN_DETAILS', + 'auth.admin.show', + os.environ.get('SHOW_ADMIN_DETAILS', 'true').lower() == 'true', ) ADMIN_EMAIL = PersistentConfig( - "ADMIN_EMAIL", - "auth.admin.email", - os.environ.get("ADMIN_EMAIL", None), + 'ADMIN_EMAIL', + 'auth.admin.email', + os.environ.get('ADMIN_EMAIL', None), ) @@ -1871,21 +1665,21 @@ ADMIN_EMAIL = PersistentConfig( TASK_MODEL = PersistentConfig( - "TASK_MODEL", - "task.model.default", - os.environ.get("TASK_MODEL", ""), + 'TASK_MODEL', + 'task.model.default', + os.environ.get('TASK_MODEL', ''), ) TASK_MODEL_EXTERNAL = PersistentConfig( - "TASK_MODEL_EXTERNAL", - "task.model.external", - os.environ.get("TASK_MODEL_EXTERNAL", ""), + 'TASK_MODEL_EXTERNAL', + 'task.model.external', + os.environ.get('TASK_MODEL_EXTERNAL', ''), ) TITLE_GENERATION_PROMPT_TEMPLATE = PersistentConfig( - "TITLE_GENERATION_PROMPT_TEMPLATE", - "task.title.prompt_template", - os.environ.get("TITLE_GENERATION_PROMPT_TEMPLATE", ""), + 'TITLE_GENERATION_PROMPT_TEMPLATE', + 'task.title.prompt_template', + os.environ.get('TITLE_GENERATION_PROMPT_TEMPLATE', ''), ) DEFAULT_TITLE_GENERATION_PROMPT_TEMPLATE = """### Task: @@ -1913,9 +1707,9 @@ JSON format: { "title": "your concise title here" } """ TAGS_GENERATION_PROMPT_TEMPLATE = PersistentConfig( - "TAGS_GENERATION_PROMPT_TEMPLATE", - "task.tags.prompt_template", - os.environ.get("TAGS_GENERATION_PROMPT_TEMPLATE", ""), + 'TAGS_GENERATION_PROMPT_TEMPLATE', + 'task.tags.prompt_template', + os.environ.get('TAGS_GENERATION_PROMPT_TEMPLATE', ''), ) DEFAULT_TAGS_GENERATION_PROMPT_TEMPLATE = """### Task: @@ -1937,9 +1731,9 @@ JSON format: { "tags": ["tag1", "tag2", "tag3"] } """ IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE = PersistentConfig( - "IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE", - "task.image.prompt_template", - os.environ.get("IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE", ""), + 'IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE', + 'task.image.prompt_template', + os.environ.get('IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE', ''), ) DEFAULT_IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE = """### Task: @@ -1964,9 +1758,9 @@ Strictly return in JSON format: FOLLOW_UP_GENERATION_PROMPT_TEMPLATE = PersistentConfig( - "FOLLOW_UP_GENERATION_PROMPT_TEMPLATE", - "task.follow_up.prompt_template", - os.environ.get("FOLLOW_UP_GENERATION_PROMPT_TEMPLATE", ""), + 'FOLLOW_UP_GENERATION_PROMPT_TEMPLATE', + 'task.follow_up.prompt_template', + os.environ.get('FOLLOW_UP_GENERATION_PROMPT_TEMPLATE', ''), ) DEFAULT_FOLLOW_UP_GENERATION_PROMPT_TEMPLATE = """### Task: @@ -1986,41 +1780,41 @@ JSON format: { "follow_ups": ["Question 1?", "Question 2?", "Question 3?"] } """ ENABLE_FOLLOW_UP_GENERATION = PersistentConfig( - "ENABLE_FOLLOW_UP_GENERATION", - "task.follow_up.enable", - os.environ.get("ENABLE_FOLLOW_UP_GENERATION", "True").lower() == "true", + 'ENABLE_FOLLOW_UP_GENERATION', + 'task.follow_up.enable', + os.environ.get('ENABLE_FOLLOW_UP_GENERATION', 'True').lower() == 'true', ) ENABLE_TAGS_GENERATION = PersistentConfig( - "ENABLE_TAGS_GENERATION", - "task.tags.enable", - os.environ.get("ENABLE_TAGS_GENERATION", "True").lower() == "true", + 'ENABLE_TAGS_GENERATION', + 'task.tags.enable', + os.environ.get('ENABLE_TAGS_GENERATION', 'True').lower() == 'true', ) ENABLE_TITLE_GENERATION = PersistentConfig( - "ENABLE_TITLE_GENERATION", - "task.title.enable", - os.environ.get("ENABLE_TITLE_GENERATION", "True").lower() == "true", + 'ENABLE_TITLE_GENERATION', + 'task.title.enable', + os.environ.get('ENABLE_TITLE_GENERATION', 'True').lower() == 'true', ) ENABLE_SEARCH_QUERY_GENERATION = PersistentConfig( - "ENABLE_SEARCH_QUERY_GENERATION", - "task.query.search.enable", - os.environ.get("ENABLE_SEARCH_QUERY_GENERATION", "True").lower() == "true", + 'ENABLE_SEARCH_QUERY_GENERATION', + 'task.query.search.enable', + os.environ.get('ENABLE_SEARCH_QUERY_GENERATION', 'True').lower() == 'true', ) ENABLE_RETRIEVAL_QUERY_GENERATION = PersistentConfig( - "ENABLE_RETRIEVAL_QUERY_GENERATION", - "task.query.retrieval.enable", - os.environ.get("ENABLE_RETRIEVAL_QUERY_GENERATION", "True").lower() == "true", + 'ENABLE_RETRIEVAL_QUERY_GENERATION', + 'task.query.retrieval.enable', + os.environ.get('ENABLE_RETRIEVAL_QUERY_GENERATION', 'True').lower() == 'true', ) QUERY_GENERATION_PROMPT_TEMPLATE = PersistentConfig( - "QUERY_GENERATION_PROMPT_TEMPLATE", - "task.query.prompt_template", - os.environ.get("QUERY_GENERATION_PROMPT_TEMPLATE", ""), + 'QUERY_GENERATION_PROMPT_TEMPLATE', + 'task.query.prompt_template', + os.environ.get('QUERY_GENERATION_PROMPT_TEMPLATE', ''), ) DEFAULT_QUERY_GENERATION_PROMPT_TEMPLATE = """### Task: @@ -2048,21 +1842,21 @@ Strictly return in JSON format: """ ENABLE_AUTOCOMPLETE_GENERATION = PersistentConfig( - "ENABLE_AUTOCOMPLETE_GENERATION", - "task.autocomplete.enable", - os.environ.get("ENABLE_AUTOCOMPLETE_GENERATION", "False").lower() == "true", + 'ENABLE_AUTOCOMPLETE_GENERATION', + 'task.autocomplete.enable', + os.environ.get('ENABLE_AUTOCOMPLETE_GENERATION', 'False').lower() == 'true', ) AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH = PersistentConfig( - "AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH", - "task.autocomplete.input_max_length", - int(os.environ.get("AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH", "-1")), + 'AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH', + 'task.autocomplete.input_max_length', + int(os.environ.get('AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH', '-1')), ) AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE = PersistentConfig( - "AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE", - "task.autocomplete.prompt_template", - os.environ.get("AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE", ""), + 'AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE', + 'task.autocomplete.prompt_template', + os.environ.get('AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE', ''), ) @@ -2110,9 +1904,9 @@ Output: VOICE_MODE_PROMPT_TEMPLATE = PersistentConfig( - "VOICE_MODE_PROMPT_TEMPLATE", - "task.voice.prompt_template", - os.environ.get("VOICE_MODE_PROMPT_TEMPLATE", ""), + 'VOICE_MODE_PROMPT_TEMPLATE', + 'task.voice.prompt_template', + os.environ.get('VOICE_MODE_PROMPT_TEMPLATE', ''), ) DEFAULT_VOICE_MODE_PROMPT_TEMPLATE = """You are a friendly, concise voice assistant. @@ -2141,9 +1935,9 @@ ERROR HANDLING: Stay consistent, helpful, and easy to listen to.""" TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = PersistentConfig( - "TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE", - "task.tools.prompt_template", - os.environ.get("TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE", ""), + 'TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE', + 'task.tools.prompt_template', + os.environ.get('TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE', ''), ) @@ -2187,123 +1981,119 @@ Responses from models: {{responses}}""" #################################### ENABLE_CODE_EXECUTION = PersistentConfig( - "ENABLE_CODE_EXECUTION", - "code_execution.enable", - os.environ.get("ENABLE_CODE_EXECUTION", "True").lower() == "true", + 'ENABLE_CODE_EXECUTION', + 'code_execution.enable', + os.environ.get('ENABLE_CODE_EXECUTION', 'True').lower() == 'true', ) CODE_EXECUTION_ENGINE = PersistentConfig( - "CODE_EXECUTION_ENGINE", - "code_execution.engine", - os.environ.get("CODE_EXECUTION_ENGINE", "pyodide"), + 'CODE_EXECUTION_ENGINE', + 'code_execution.engine', + os.environ.get('CODE_EXECUTION_ENGINE', 'pyodide'), ) CODE_EXECUTION_JUPYTER_URL = PersistentConfig( - "CODE_EXECUTION_JUPYTER_URL", - "code_execution.jupyter.url", - os.environ.get("CODE_EXECUTION_JUPYTER_URL", ""), + 'CODE_EXECUTION_JUPYTER_URL', + 'code_execution.jupyter.url', + os.environ.get('CODE_EXECUTION_JUPYTER_URL', ''), ) CODE_EXECUTION_JUPYTER_AUTH = PersistentConfig( - "CODE_EXECUTION_JUPYTER_AUTH", - "code_execution.jupyter.auth", - os.environ.get("CODE_EXECUTION_JUPYTER_AUTH", ""), + 'CODE_EXECUTION_JUPYTER_AUTH', + 'code_execution.jupyter.auth', + os.environ.get('CODE_EXECUTION_JUPYTER_AUTH', ''), ) CODE_EXECUTION_JUPYTER_AUTH_TOKEN = PersistentConfig( - "CODE_EXECUTION_JUPYTER_AUTH_TOKEN", - "code_execution.jupyter.auth_token", - os.environ.get("CODE_EXECUTION_JUPYTER_AUTH_TOKEN", ""), + 'CODE_EXECUTION_JUPYTER_AUTH_TOKEN', + 'code_execution.jupyter.auth_token', + os.environ.get('CODE_EXECUTION_JUPYTER_AUTH_TOKEN', ''), ) CODE_EXECUTION_JUPYTER_AUTH_PASSWORD = PersistentConfig( - "CODE_EXECUTION_JUPYTER_AUTH_PASSWORD", - "code_execution.jupyter.auth_password", - os.environ.get("CODE_EXECUTION_JUPYTER_AUTH_PASSWORD", ""), + 'CODE_EXECUTION_JUPYTER_AUTH_PASSWORD', + 'code_execution.jupyter.auth_password', + os.environ.get('CODE_EXECUTION_JUPYTER_AUTH_PASSWORD', ''), ) CODE_EXECUTION_JUPYTER_TIMEOUT = PersistentConfig( - "CODE_EXECUTION_JUPYTER_TIMEOUT", - "code_execution.jupyter.timeout", - int(os.environ.get("CODE_EXECUTION_JUPYTER_TIMEOUT", "60")), + 'CODE_EXECUTION_JUPYTER_TIMEOUT', + 'code_execution.jupyter.timeout', + int(os.environ.get('CODE_EXECUTION_JUPYTER_TIMEOUT', '60')), ) ENABLE_CODE_INTERPRETER = PersistentConfig( - "ENABLE_CODE_INTERPRETER", - "code_interpreter.enable", - os.environ.get("ENABLE_CODE_INTERPRETER", "True").lower() == "true", + 'ENABLE_CODE_INTERPRETER', + 'code_interpreter.enable', + os.environ.get('ENABLE_CODE_INTERPRETER', 'True').lower() == 'true', ) ENABLE_MEMORIES = PersistentConfig( - "ENABLE_MEMORIES", - "memories.enable", - os.environ.get("ENABLE_MEMORIES", "True").lower() == "true", + 'ENABLE_MEMORIES', + 'memories.enable', + os.environ.get('ENABLE_MEMORIES', 'True').lower() == 'true', ) CODE_INTERPRETER_ENGINE = PersistentConfig( - "CODE_INTERPRETER_ENGINE", - "code_interpreter.engine", - os.environ.get("CODE_INTERPRETER_ENGINE", "pyodide"), + 'CODE_INTERPRETER_ENGINE', + 'code_interpreter.engine', + os.environ.get('CODE_INTERPRETER_ENGINE', 'pyodide'), ) CODE_INTERPRETER_PROMPT_TEMPLATE = PersistentConfig( - "CODE_INTERPRETER_PROMPT_TEMPLATE", - "code_interpreter.prompt_template", - os.environ.get("CODE_INTERPRETER_PROMPT_TEMPLATE", ""), + 'CODE_INTERPRETER_PROMPT_TEMPLATE', + 'code_interpreter.prompt_template', + os.environ.get('CODE_INTERPRETER_PROMPT_TEMPLATE', ''), ) CODE_INTERPRETER_JUPYTER_URL = PersistentConfig( - "CODE_INTERPRETER_JUPYTER_URL", - "code_interpreter.jupyter.url", - os.environ.get( - "CODE_INTERPRETER_JUPYTER_URL", os.environ.get("CODE_EXECUTION_JUPYTER_URL", "") - ), + 'CODE_INTERPRETER_JUPYTER_URL', + 'code_interpreter.jupyter.url', + os.environ.get('CODE_INTERPRETER_JUPYTER_URL', os.environ.get('CODE_EXECUTION_JUPYTER_URL', '')), ) CODE_INTERPRETER_JUPYTER_AUTH = PersistentConfig( - "CODE_INTERPRETER_JUPYTER_AUTH", - "code_interpreter.jupyter.auth", + 'CODE_INTERPRETER_JUPYTER_AUTH', + 'code_interpreter.jupyter.auth', os.environ.get( - "CODE_INTERPRETER_JUPYTER_AUTH", - os.environ.get("CODE_EXECUTION_JUPYTER_AUTH", ""), + 'CODE_INTERPRETER_JUPYTER_AUTH', + os.environ.get('CODE_EXECUTION_JUPYTER_AUTH', ''), ), ) CODE_INTERPRETER_JUPYTER_AUTH_TOKEN = PersistentConfig( - "CODE_INTERPRETER_JUPYTER_AUTH_TOKEN", - "code_interpreter.jupyter.auth_token", + 'CODE_INTERPRETER_JUPYTER_AUTH_TOKEN', + 'code_interpreter.jupyter.auth_token', os.environ.get( - "CODE_INTERPRETER_JUPYTER_AUTH_TOKEN", - os.environ.get("CODE_EXECUTION_JUPYTER_AUTH_TOKEN", ""), + 'CODE_INTERPRETER_JUPYTER_AUTH_TOKEN', + os.environ.get('CODE_EXECUTION_JUPYTER_AUTH_TOKEN', ''), ), ) CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD = PersistentConfig( - "CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD", - "code_interpreter.jupyter.auth_password", + 'CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD', + 'code_interpreter.jupyter.auth_password', os.environ.get( - "CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD", - os.environ.get("CODE_EXECUTION_JUPYTER_AUTH_PASSWORD", ""), + 'CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD', + os.environ.get('CODE_EXECUTION_JUPYTER_AUTH_PASSWORD', ''), ), ) CODE_INTERPRETER_JUPYTER_TIMEOUT = PersistentConfig( - "CODE_INTERPRETER_JUPYTER_TIMEOUT", - "code_interpreter.jupyter.timeout", + 'CODE_INTERPRETER_JUPYTER_TIMEOUT', + 'code_interpreter.jupyter.timeout', int( os.environ.get( - "CODE_INTERPRETER_JUPYTER_TIMEOUT", - os.environ.get("CODE_EXECUTION_JUPYTER_TIMEOUT", "60"), + 'CODE_INTERPRETER_JUPYTER_TIMEOUT', + os.environ.get('CODE_EXECUTION_JUPYTER_TIMEOUT', '60'), ) ), ) CODE_INTERPRETER_BLOCKED_MODULES = [ - library.strip() - for library in os.environ.get("CODE_INTERPRETER_BLOCKED_MODULES", "").split(",") - if library.strip() + library.strip() for library in os.environ.get('CODE_INTERPRETER_BLOCKED_MODULES', '').split(',') if library.strip() ] DEFAULT_CODE_INTERPRETER_PROMPT = """ @@ -2343,56 +2133,47 @@ CODE_INTERPRETER_PYODIDE_PROMPT = """ # Vector Database #################################### -VECTOR_DB = os.environ.get("VECTOR_DB", "chroma") +VECTOR_DB = os.environ.get('VECTOR_DB', 'chroma') # Chroma -CHROMA_DATA_PATH = f"{DATA_DIR}/vector_db" +CHROMA_DATA_PATH = f'{DATA_DIR}/vector_db' -if VECTOR_DB == "chroma": +if VECTOR_DB == 'chroma': import chromadb - CHROMA_TENANT = os.environ.get("CHROMA_TENANT", chromadb.DEFAULT_TENANT) - CHROMA_DATABASE = os.environ.get("CHROMA_DATABASE", chromadb.DEFAULT_DATABASE) - CHROMA_HTTP_HOST = os.environ.get("CHROMA_HTTP_HOST", "") - CHROMA_HTTP_PORT = int(os.environ.get("CHROMA_HTTP_PORT", "8000")) - CHROMA_CLIENT_AUTH_PROVIDER = os.environ.get("CHROMA_CLIENT_AUTH_PROVIDER", "") - CHROMA_CLIENT_AUTH_CREDENTIALS = os.environ.get( - "CHROMA_CLIENT_AUTH_CREDENTIALS", "" - ) + CHROMA_TENANT = os.environ.get('CHROMA_TENANT', chromadb.DEFAULT_TENANT) + CHROMA_DATABASE = os.environ.get('CHROMA_DATABASE', chromadb.DEFAULT_DATABASE) + CHROMA_HTTP_HOST = os.environ.get('CHROMA_HTTP_HOST', '') + CHROMA_HTTP_PORT = int(os.environ.get('CHROMA_HTTP_PORT', '8000')) + CHROMA_CLIENT_AUTH_PROVIDER = os.environ.get('CHROMA_CLIENT_AUTH_PROVIDER', '') + CHROMA_CLIENT_AUTH_CREDENTIALS = os.environ.get('CHROMA_CLIENT_AUTH_CREDENTIALS', '') # Comma-separated list of header=value pairs - CHROMA_HTTP_HEADERS = os.environ.get("CHROMA_HTTP_HEADERS", "") + CHROMA_HTTP_HEADERS = os.environ.get('CHROMA_HTTP_HEADERS', '') if CHROMA_HTTP_HEADERS: - CHROMA_HTTP_HEADERS = dict( - [pair.split("=") for pair in CHROMA_HTTP_HEADERS.split(",")] - ) + CHROMA_HTTP_HEADERS = dict([pair.split('=') for pair in CHROMA_HTTP_HEADERS.split(',')]) else: CHROMA_HTTP_HEADERS = None - CHROMA_HTTP_SSL = os.environ.get("CHROMA_HTTP_SSL", "false").lower() == "true" + CHROMA_HTTP_SSL = os.environ.get('CHROMA_HTTP_SSL', 'false').lower() == 'true' # this uses the model defined in the Dockerfile ENV variable. If you dont use docker or docker based deployments such as k8s, the default embedding model will be used (sentence-transformers/all-MiniLM-L6-v2) # MariaDB Vector (mariadb-vector) -MARIADB_VECTOR_DB_URL = os.environ.get("MARIADB_VECTOR_DB_URL", "").strip() +MARIADB_VECTOR_DB_URL = os.environ.get('MARIADB_VECTOR_DB_URL', '').strip() MARIADB_VECTOR_INITIALIZE_MAX_VECTOR_LENGTH = int( - os.environ.get("MARIADB_VECTOR_INITIALIZE_MAX_VECTOR_LENGTH", "1536").strip() - or "1536" + os.environ.get('MARIADB_VECTOR_INITIALIZE_MAX_VECTOR_LENGTH', '1536').strip() or '1536' ) # Distance strategy: # - cosine => vec_distance_cosine(...) # - euclidean => vec_distance_euclidean(...) -MARIADB_VECTOR_DISTANCE_STRATEGY = ( - os.environ.get("MARIADB_VECTOR_DISTANCE_STRATEGY", "cosine").strip().lower() -) +MARIADB_VECTOR_DISTANCE_STRATEGY = os.environ.get('MARIADB_VECTOR_DISTANCE_STRATEGY', 'cosine').strip().lower() # HNSW M parameter (MariaDB VECTOR INDEX ... M=) -MARIADB_VECTOR_INDEX_M = int( - os.environ.get("MARIADB_VECTOR_INDEX_M", "8").strip() or "8" -) +MARIADB_VECTOR_INDEX_M = int(os.environ.get('MARIADB_VECTOR_INDEX_M', '8').strip() or '8') # Pooling (MariaDB-Vector) -MARIADB_VECTOR_POOL_SIZE = os.environ.get("MARIADB_VECTOR_POOL_SIZE", None) +MARIADB_VECTOR_POOL_SIZE = os.environ.get('MARIADB_VECTOR_POOL_SIZE', None) if MARIADB_VECTOR_POOL_SIZE != None: try: @@ -2400,9 +2181,9 @@ if MARIADB_VECTOR_POOL_SIZE != None: except Exception: MARIADB_VECTOR_POOL_SIZE = None -MARIADB_VECTOR_POOL_MAX_OVERFLOW = os.environ.get("MARIADB_VECTOR_POOL_MAX_OVERFLOW", 0) +MARIADB_VECTOR_POOL_MAX_OVERFLOW = os.environ.get('MARIADB_VECTOR_POOL_MAX_OVERFLOW', 0) -if MARIADB_VECTOR_POOL_MAX_OVERFLOW == "": +if MARIADB_VECTOR_POOL_MAX_OVERFLOW == '': MARIADB_VECTOR_POOL_MAX_OVERFLOW = 0 else: try: @@ -2410,9 +2191,9 @@ else: except Exception: MARIADB_VECTOR_POOL_MAX_OVERFLOW = 0 -MARIADB_VECTOR_POOL_TIMEOUT = os.environ.get("MARIADB_VECTOR_POOL_TIMEOUT", 30) +MARIADB_VECTOR_POOL_TIMEOUT = os.environ.get('MARIADB_VECTOR_POOL_TIMEOUT', 30) -if MARIADB_VECTOR_POOL_TIMEOUT == "": +if MARIADB_VECTOR_POOL_TIMEOUT == '': MARIADB_VECTOR_POOL_TIMEOUT = 30 else: try: @@ -2420,9 +2201,9 @@ else: except Exception: MARIADB_VECTOR_POOL_TIMEOUT = 30 -MARIADB_VECTOR_POOL_RECYCLE = os.environ.get("MARIADB_VECTOR_POOL_RECYCLE", 3600) +MARIADB_VECTOR_POOL_RECYCLE = os.environ.get('MARIADB_VECTOR_POOL_RECYCLE', 3600) -if MARIADB_VECTOR_POOL_RECYCLE == "": +if MARIADB_VECTOR_POOL_RECYCLE == '': MARIADB_VECTOR_POOL_RECYCLE = 3600 else: try: @@ -2431,115 +2212,97 @@ else: MARIADB_VECTOR_POOL_RECYCLE = 3600 ENABLE_MARIADB_VECTOR = True -if VECTOR_DB == "mariadb-vector": +if VECTOR_DB == 'mariadb-vector': if not MARIADB_VECTOR_DB_URL: ENABLE_MARIADB_VECTOR = False else: try: parsed = urlparse(MARIADB_VECTOR_DB_URL) - scheme = (parsed.scheme or "").lower() + scheme = (parsed.scheme or '').lower() # Require official driver so VECTOR binds as float32 bytes correctly - if scheme != "mariadb+mariadbconnector": + if scheme != 'mariadb+mariadbconnector': ENABLE_MARIADB_VECTOR = False except Exception: ENABLE_MARIADB_VECTOR = False # Milvus -MILVUS_URI = os.environ.get("MILVUS_URI", f"{DATA_DIR}/vector_db/milvus.db") -MILVUS_DB = os.environ.get("MILVUS_DB", "default") -MILVUS_TOKEN = os.environ.get("MILVUS_TOKEN", None) -MILVUS_INDEX_TYPE = os.environ.get("MILVUS_INDEX_TYPE", "HNSW") -MILVUS_METRIC_TYPE = os.environ.get("MILVUS_METRIC_TYPE", "COSINE") -MILVUS_HNSW_M = int(os.environ.get("MILVUS_HNSW_M", "16")) -MILVUS_HNSW_EFCONSTRUCTION = int(os.environ.get("MILVUS_HNSW_EFCONSTRUCTION", "100")) -MILVUS_IVF_FLAT_NLIST = int(os.environ.get("MILVUS_IVF_FLAT_NLIST", "128")) -MILVUS_DISKANN_MAX_DEGREE = int(os.environ.get("MILVUS_DISKANN_MAX_DEGREE", "56")) -MILVUS_DISKANN_SEARCH_LIST_SIZE = int( - os.environ.get("MILVUS_DISKANN_SEARCH_LIST_SIZE", "100") -) -ENABLE_MILVUS_MULTITENANCY_MODE = ( - os.environ.get("ENABLE_MILVUS_MULTITENANCY_MODE", "false").lower() == "true" -) +MILVUS_URI = os.environ.get('MILVUS_URI', f'{DATA_DIR}/vector_db/milvus.db') +MILVUS_DB = os.environ.get('MILVUS_DB', 'default') +MILVUS_TOKEN = os.environ.get('MILVUS_TOKEN', None) +MILVUS_INDEX_TYPE = os.environ.get('MILVUS_INDEX_TYPE', 'HNSW') +MILVUS_METRIC_TYPE = os.environ.get('MILVUS_METRIC_TYPE', 'COSINE') +MILVUS_HNSW_M = int(os.environ.get('MILVUS_HNSW_M', '16')) +MILVUS_HNSW_EFCONSTRUCTION = int(os.environ.get('MILVUS_HNSW_EFCONSTRUCTION', '100')) +MILVUS_IVF_FLAT_NLIST = int(os.environ.get('MILVUS_IVF_FLAT_NLIST', '128')) +MILVUS_DISKANN_MAX_DEGREE = int(os.environ.get('MILVUS_DISKANN_MAX_DEGREE', '56')) +MILVUS_DISKANN_SEARCH_LIST_SIZE = int(os.environ.get('MILVUS_DISKANN_SEARCH_LIST_SIZE', '100')) +ENABLE_MILVUS_MULTITENANCY_MODE = os.environ.get('ENABLE_MILVUS_MULTITENANCY_MODE', 'false').lower() == 'true' # Hyphens not allowed, need to use underscores in collection names -MILVUS_COLLECTION_PREFIX = os.environ.get("MILVUS_COLLECTION_PREFIX", "open_webui") +MILVUS_COLLECTION_PREFIX = os.environ.get('MILVUS_COLLECTION_PREFIX', 'open_webui') # Qdrant -QDRANT_URI = os.environ.get("QDRANT_URI", None) -QDRANT_API_KEY = os.environ.get("QDRANT_API_KEY", None) -QDRANT_ON_DISK = os.environ.get("QDRANT_ON_DISK", "false").lower() == "true" -QDRANT_PREFER_GRPC = os.environ.get("QDRANT_PREFER_GRPC", "false").lower() == "true" -QDRANT_GRPC_PORT = int(os.environ.get("QDRANT_GRPC_PORT", "6334")) -QDRANT_TIMEOUT = int(os.environ.get("QDRANT_TIMEOUT", "5")) -QDRANT_HNSW_M = int(os.environ.get("QDRANT_HNSW_M", "16")) -ENABLE_QDRANT_MULTITENANCY_MODE = ( - os.environ.get("ENABLE_QDRANT_MULTITENANCY_MODE", "true").lower() == "true" -) -QDRANT_COLLECTION_PREFIX = os.environ.get("QDRANT_COLLECTION_PREFIX", "open-webui") +QDRANT_URI = os.environ.get('QDRANT_URI', None) +QDRANT_API_KEY = os.environ.get('QDRANT_API_KEY', None) +QDRANT_ON_DISK = os.environ.get('QDRANT_ON_DISK', 'false').lower() == 'true' +QDRANT_PREFER_GRPC = os.environ.get('QDRANT_PREFER_GRPC', 'false').lower() == 'true' +QDRANT_GRPC_PORT = int(os.environ.get('QDRANT_GRPC_PORT', '6334')) +QDRANT_TIMEOUT = int(os.environ.get('QDRANT_TIMEOUT', '5')) +QDRANT_HNSW_M = int(os.environ.get('QDRANT_HNSW_M', '16')) +ENABLE_QDRANT_MULTITENANCY_MODE = os.environ.get('ENABLE_QDRANT_MULTITENANCY_MODE', 'true').lower() == 'true' +QDRANT_COLLECTION_PREFIX = os.environ.get('QDRANT_COLLECTION_PREFIX', 'open-webui') -WEAVIATE_HTTP_HOST = os.environ.get("WEAVIATE_HTTP_HOST", "") -WEAVIATE_GRPC_HOST = os.environ.get("WEAVIATE_GRPC_HOST", "") -WEAVIATE_HTTP_PORT = int(os.environ.get("WEAVIATE_HTTP_PORT", "8080")) -WEAVIATE_GRPC_PORT = int(os.environ.get("WEAVIATE_GRPC_PORT", "50051")) -WEAVIATE_API_KEY = os.environ.get("WEAVIATE_API_KEY") -WEAVIATE_HTTP_SECURE = os.environ.get("WEAVIATE_HTTP_SECURE", "false").lower() == "true" -WEAVIATE_GRPC_SECURE = os.environ.get("WEAVIATE_GRPC_SECURE", "false").lower() == "true" -WEAVIATE_SKIP_INIT_CHECKS = ( - os.environ.get("WEAVIATE_SKIP_INIT_CHECKS", "false").lower() == "true" -) +WEAVIATE_HTTP_HOST = os.environ.get('WEAVIATE_HTTP_HOST', '') +WEAVIATE_GRPC_HOST = os.environ.get('WEAVIATE_GRPC_HOST', '') +WEAVIATE_HTTP_PORT = int(os.environ.get('WEAVIATE_HTTP_PORT', '8080')) +WEAVIATE_GRPC_PORT = int(os.environ.get('WEAVIATE_GRPC_PORT', '50051')) +WEAVIATE_API_KEY = os.environ.get('WEAVIATE_API_KEY') +WEAVIATE_HTTP_SECURE = os.environ.get('WEAVIATE_HTTP_SECURE', 'false').lower() == 'true' +WEAVIATE_GRPC_SECURE = os.environ.get('WEAVIATE_GRPC_SECURE', 'false').lower() == 'true' +WEAVIATE_SKIP_INIT_CHECKS = os.environ.get('WEAVIATE_SKIP_INIT_CHECKS', 'false').lower() == 'true' # OpenSearch -OPENSEARCH_URI = os.environ.get("OPENSEARCH_URI", "https://localhost:9200") -OPENSEARCH_SSL = os.environ.get("OPENSEARCH_SSL", "true").lower() == "true" -OPENSEARCH_CERT_VERIFY = ( - os.environ.get("OPENSEARCH_CERT_VERIFY", "false").lower() == "true" -) -OPENSEARCH_USERNAME = os.environ.get("OPENSEARCH_USERNAME", None) -OPENSEARCH_PASSWORD = os.environ.get("OPENSEARCH_PASSWORD", None) +OPENSEARCH_URI = os.environ.get('OPENSEARCH_URI', 'https://localhost:9200') +OPENSEARCH_SSL = os.environ.get('OPENSEARCH_SSL', 'true').lower() == 'true' +OPENSEARCH_CERT_VERIFY = os.environ.get('OPENSEARCH_CERT_VERIFY', 'false').lower() == 'true' +OPENSEARCH_USERNAME = os.environ.get('OPENSEARCH_USERNAME', None) +OPENSEARCH_PASSWORD = os.environ.get('OPENSEARCH_PASSWORD', None) # ElasticSearch -ELASTICSEARCH_URL = os.environ.get("ELASTICSEARCH_URL", "https://localhost:9200") -ELASTICSEARCH_CA_CERTS = os.environ.get("ELASTICSEARCH_CA_CERTS", None) -ELASTICSEARCH_API_KEY = os.environ.get("ELASTICSEARCH_API_KEY", None) -ELASTICSEARCH_USERNAME = os.environ.get("ELASTICSEARCH_USERNAME", None) -ELASTICSEARCH_PASSWORD = os.environ.get("ELASTICSEARCH_PASSWORD", None) -ELASTICSEARCH_CLOUD_ID = os.environ.get("ELASTICSEARCH_CLOUD_ID", None) -SSL_ASSERT_FINGERPRINT = os.environ.get("SSL_ASSERT_FINGERPRINT", None) -ELASTICSEARCH_INDEX_PREFIX = os.environ.get( - "ELASTICSEARCH_INDEX_PREFIX", "open_webui_collections" -) +ELASTICSEARCH_URL = os.environ.get('ELASTICSEARCH_URL', 'https://localhost:9200') +ELASTICSEARCH_CA_CERTS = os.environ.get('ELASTICSEARCH_CA_CERTS', None) +ELASTICSEARCH_API_KEY = os.environ.get('ELASTICSEARCH_API_KEY', None) +ELASTICSEARCH_USERNAME = os.environ.get('ELASTICSEARCH_USERNAME', None) +ELASTICSEARCH_PASSWORD = os.environ.get('ELASTICSEARCH_PASSWORD', None) +ELASTICSEARCH_CLOUD_ID = os.environ.get('ELASTICSEARCH_CLOUD_ID', None) +SSL_ASSERT_FINGERPRINT = os.environ.get('SSL_ASSERT_FINGERPRINT', None) +ELASTICSEARCH_INDEX_PREFIX = os.environ.get('ELASTICSEARCH_INDEX_PREFIX', 'open_webui_collections') # Pgvector -PGVECTOR_DB_URL = os.environ.get("PGVECTOR_DB_URL", DATABASE_URL) -if VECTOR_DB == "pgvector" and not PGVECTOR_DB_URL.startswith("postgres"): +PGVECTOR_DB_URL = os.environ.get('PGVECTOR_DB_URL', DATABASE_URL) +if VECTOR_DB == 'pgvector' and not PGVECTOR_DB_URL.startswith('postgres'): raise ValueError( - "Pgvector requires setting PGVECTOR_DB_URL or using Postgres with vector extension as the primary database." + 'Pgvector requires setting PGVECTOR_DB_URL or using Postgres with vector extension as the primary database.' ) -PGVECTOR_INITIALIZE_MAX_VECTOR_LENGTH = int( - os.environ.get("PGVECTOR_INITIALIZE_MAX_VECTOR_LENGTH", "1536") -) +PGVECTOR_INITIALIZE_MAX_VECTOR_LENGTH = int(os.environ.get('PGVECTOR_INITIALIZE_MAX_VECTOR_LENGTH', '1536')) -PGVECTOR_USE_HALFVEC = os.getenv("PGVECTOR_USE_HALFVEC", "false").lower() == "true" +PGVECTOR_USE_HALFVEC = os.getenv('PGVECTOR_USE_HALFVEC', 'false').lower() == 'true' if PGVECTOR_INITIALIZE_MAX_VECTOR_LENGTH > 2000 and not PGVECTOR_USE_HALFVEC: raise ValueError( - "PGVECTOR_INITIALIZE_MAX_VECTOR_LENGTH is set to " - f"{PGVECTOR_INITIALIZE_MAX_VECTOR_LENGTH}, which exceeds the 2000 dimension limit of the " + 'PGVECTOR_INITIALIZE_MAX_VECTOR_LENGTH is set to ' + f'{PGVECTOR_INITIALIZE_MAX_VECTOR_LENGTH}, which exceeds the 2000 dimension limit of the ' "'vector' type. Set PGVECTOR_USE_HALFVEC=true to enable the 'halfvec' " - "type required for high-dimensional embeddings." + 'type required for high-dimensional embeddings.' ) -PGVECTOR_CREATE_EXTENSION = ( - os.getenv("PGVECTOR_CREATE_EXTENSION", "true").lower() == "true" -) -PGVECTOR_PGCRYPTO = os.getenv("PGVECTOR_PGCRYPTO", "false").lower() == "true" -PGVECTOR_PGCRYPTO_KEY = os.getenv("PGVECTOR_PGCRYPTO_KEY", None) +PGVECTOR_CREATE_EXTENSION = os.getenv('PGVECTOR_CREATE_EXTENSION', 'true').lower() == 'true' +PGVECTOR_PGCRYPTO = os.getenv('PGVECTOR_PGCRYPTO', 'false').lower() == 'true' +PGVECTOR_PGCRYPTO_KEY = os.getenv('PGVECTOR_PGCRYPTO_KEY', None) if PGVECTOR_PGCRYPTO and not PGVECTOR_PGCRYPTO_KEY: - raise ValueError( - "PGVECTOR_PGCRYPTO is enabled but PGVECTOR_PGCRYPTO_KEY is not set. Please provide a valid key." - ) + raise ValueError('PGVECTOR_PGCRYPTO is enabled but PGVECTOR_PGCRYPTO_KEY is not set. Please provide a valid key.') -PGVECTOR_POOL_SIZE = os.environ.get("PGVECTOR_POOL_SIZE", None) +PGVECTOR_POOL_SIZE = os.environ.get('PGVECTOR_POOL_SIZE', None) if PGVECTOR_POOL_SIZE != None: try: @@ -2547,9 +2310,9 @@ if PGVECTOR_POOL_SIZE != None: except Exception: PGVECTOR_POOL_SIZE = None -PGVECTOR_POOL_MAX_OVERFLOW = os.environ.get("PGVECTOR_POOL_MAX_OVERFLOW", 0) +PGVECTOR_POOL_MAX_OVERFLOW = os.environ.get('PGVECTOR_POOL_MAX_OVERFLOW', 0) -if PGVECTOR_POOL_MAX_OVERFLOW == "": +if PGVECTOR_POOL_MAX_OVERFLOW == '': PGVECTOR_POOL_MAX_OVERFLOW = 0 else: try: @@ -2557,9 +2320,9 @@ else: except Exception: PGVECTOR_POOL_MAX_OVERFLOW = 0 -PGVECTOR_POOL_TIMEOUT = os.environ.get("PGVECTOR_POOL_TIMEOUT", 30) +PGVECTOR_POOL_TIMEOUT = os.environ.get('PGVECTOR_POOL_TIMEOUT', 30) -if PGVECTOR_POOL_TIMEOUT == "": +if PGVECTOR_POOL_TIMEOUT == '': PGVECTOR_POOL_TIMEOUT = 30 else: try: @@ -2567,9 +2330,9 @@ else: except Exception: PGVECTOR_POOL_TIMEOUT = 30 -PGVECTOR_POOL_RECYCLE = os.environ.get("PGVECTOR_POOL_RECYCLE", 3600) +PGVECTOR_POOL_RECYCLE = os.environ.get('PGVECTOR_POOL_RECYCLE', 3600) -if PGVECTOR_POOL_RECYCLE == "": +if PGVECTOR_POOL_RECYCLE == '': PGVECTOR_POOL_RECYCLE = 3600 else: try: @@ -2577,13 +2340,13 @@ else: except Exception: PGVECTOR_POOL_RECYCLE = 3600 -PGVECTOR_INDEX_METHOD = os.getenv("PGVECTOR_INDEX_METHOD", "").strip().lower() -if PGVECTOR_INDEX_METHOD not in ("ivfflat", "hnsw", ""): - PGVECTOR_INDEX_METHOD = "" +PGVECTOR_INDEX_METHOD = os.getenv('PGVECTOR_INDEX_METHOD', '').strip().lower() +if PGVECTOR_INDEX_METHOD not in ('ivfflat', 'hnsw', ''): + PGVECTOR_INDEX_METHOD = '' -PGVECTOR_HNSW_M = os.environ.get("PGVECTOR_HNSW_M", 16) +PGVECTOR_HNSW_M = os.environ.get('PGVECTOR_HNSW_M', 16) -if PGVECTOR_HNSW_M == "": +if PGVECTOR_HNSW_M == '': PGVECTOR_HNSW_M = 16 else: try: @@ -2591,9 +2354,9 @@ else: except Exception: PGVECTOR_HNSW_M = 16 -PGVECTOR_HNSW_EF_CONSTRUCTION = os.environ.get("PGVECTOR_HNSW_EF_CONSTRUCTION", 64) +PGVECTOR_HNSW_EF_CONSTRUCTION = os.environ.get('PGVECTOR_HNSW_EF_CONSTRUCTION', 64) -if PGVECTOR_HNSW_EF_CONSTRUCTION == "": +if PGVECTOR_HNSW_EF_CONSTRUCTION == '': PGVECTOR_HNSW_EF_CONSTRUCTION = 64 else: try: @@ -2601,9 +2364,9 @@ else: except Exception: PGVECTOR_HNSW_EF_CONSTRUCTION = 64 -PGVECTOR_IVFFLAT_LISTS = os.environ.get("PGVECTOR_IVFFLAT_LISTS", 100) +PGVECTOR_IVFFLAT_LISTS = os.environ.get('PGVECTOR_IVFFLAT_LISTS', 100) -if PGVECTOR_IVFFLAT_LISTS == "": +if PGVECTOR_IVFFLAT_LISTS == '': PGVECTOR_IVFFLAT_LISTS = 100 else: try: @@ -2612,13 +2375,11 @@ else: PGVECTOR_IVFFLAT_LISTS = 100 # openGauss -OPENGAUSS_DB_URL = os.environ.get("OPENGAUSS_DB_URL", DATABASE_URL) +OPENGAUSS_DB_URL = os.environ.get('OPENGAUSS_DB_URL', DATABASE_URL) -OPENGAUSS_INITIALIZE_MAX_VECTOR_LENGTH = int( - os.environ.get("OPENGAUSS_INITIALIZE_MAX_VECTOR_LENGTH", "1536") -) +OPENGAUSS_INITIALIZE_MAX_VECTOR_LENGTH = int(os.environ.get('OPENGAUSS_INITIALIZE_MAX_VECTOR_LENGTH', '1536')) -OPENGAUSS_POOL_SIZE = os.environ.get("OPENGAUSS_POOL_SIZE", None) +OPENGAUSS_POOL_SIZE = os.environ.get('OPENGAUSS_POOL_SIZE', None) if OPENGAUSS_POOL_SIZE != None: try: @@ -2626,9 +2387,9 @@ if OPENGAUSS_POOL_SIZE != None: except Exception: OPENGAUSS_POOL_SIZE = None -OPENGAUSS_POOL_MAX_OVERFLOW = os.environ.get("OPENGAUSS_POOL_MAX_OVERFLOW", 0) +OPENGAUSS_POOL_MAX_OVERFLOW = os.environ.get('OPENGAUSS_POOL_MAX_OVERFLOW', 0) -if OPENGAUSS_POOL_MAX_OVERFLOW == "": +if OPENGAUSS_POOL_MAX_OVERFLOW == '': OPENGAUSS_POOL_MAX_OVERFLOW = 0 else: try: @@ -2636,9 +2397,9 @@ else: except Exception: OPENGAUSS_POOL_MAX_OVERFLOW = 0 -OPENGAUSS_POOL_TIMEOUT = os.environ.get("OPENGAUSS_POOL_TIMEOUT", 30) +OPENGAUSS_POOL_TIMEOUT = os.environ.get('OPENGAUSS_POOL_TIMEOUT', 30) -if OPENGAUSS_POOL_TIMEOUT == "": +if OPENGAUSS_POOL_TIMEOUT == '': OPENGAUSS_POOL_TIMEOUT = 30 else: try: @@ -2646,9 +2407,9 @@ else: except Exception: OPENGAUSS_POOL_TIMEOUT = 30 -OPENGAUSS_POOL_RECYCLE = os.environ.get("OPENGAUSS_POOL_RECYCLE", 3600) +OPENGAUSS_POOL_RECYCLE = os.environ.get('OPENGAUSS_POOL_RECYCLE', 3600) -if OPENGAUSS_POOL_RECYCLE == "": +if OPENGAUSS_POOL_RECYCLE == '': OPENGAUSS_POOL_RECYCLE = 3600 else: try: @@ -2657,43 +2418,41 @@ else: OPENGAUSS_POOL_RECYCLE = 3600 # Pinecone -PINECONE_API_KEY = os.environ.get("PINECONE_API_KEY", None) -PINECONE_ENVIRONMENT = os.environ.get("PINECONE_ENVIRONMENT", None) -PINECONE_INDEX_NAME = os.getenv("PINECONE_INDEX_NAME", "open-webui-index") -PINECONE_DIMENSION = int(os.getenv("PINECONE_DIMENSION", 1536)) # or 3072, 1024, 768 -PINECONE_METRIC = os.getenv("PINECONE_METRIC", "cosine") -PINECONE_CLOUD = os.getenv("PINECONE_CLOUD", "aws") # or "gcp" or "azure" +PINECONE_API_KEY = os.environ.get('PINECONE_API_KEY', None) +PINECONE_ENVIRONMENT = os.environ.get('PINECONE_ENVIRONMENT', None) +PINECONE_INDEX_NAME = os.getenv('PINECONE_INDEX_NAME', 'open-webui-index') +PINECONE_DIMENSION = int(os.getenv('PINECONE_DIMENSION', 1536)) # or 3072, 1024, 768 +PINECONE_METRIC = os.getenv('PINECONE_METRIC', 'cosine') +PINECONE_CLOUD = os.getenv('PINECONE_CLOUD', 'aws') # or "gcp" or "azure" # ORACLE23AI (Oracle23ai Vector Search) -ORACLE_DB_USE_WALLET = os.environ.get("ORACLE_DB_USE_WALLET", "false").lower() == "true" -ORACLE_DB_USER = os.environ.get("ORACLE_DB_USER", None) # -ORACLE_DB_PASSWORD = os.environ.get("ORACLE_DB_PASSWORD", None) # -ORACLE_DB_DSN = os.environ.get("ORACLE_DB_DSN", None) # -ORACLE_WALLET_DIR = os.environ.get("ORACLE_WALLET_DIR", None) -ORACLE_WALLET_PASSWORD = os.environ.get("ORACLE_WALLET_PASSWORD", None) -ORACLE_VECTOR_LENGTH = os.environ.get("ORACLE_VECTOR_LENGTH", 768) +ORACLE_DB_USE_WALLET = os.environ.get('ORACLE_DB_USE_WALLET', 'false').lower() == 'true' +ORACLE_DB_USER = os.environ.get('ORACLE_DB_USER', None) # +ORACLE_DB_PASSWORD = os.environ.get('ORACLE_DB_PASSWORD', None) # +ORACLE_DB_DSN = os.environ.get('ORACLE_DB_DSN', None) # +ORACLE_WALLET_DIR = os.environ.get('ORACLE_WALLET_DIR', None) +ORACLE_WALLET_PASSWORD = os.environ.get('ORACLE_WALLET_PASSWORD', None) +ORACLE_VECTOR_LENGTH = os.environ.get('ORACLE_VECTOR_LENGTH', 768) -ORACLE_DB_POOL_MIN = int(os.environ.get("ORACLE_DB_POOL_MIN", 2)) -ORACLE_DB_POOL_MAX = int(os.environ.get("ORACLE_DB_POOL_MAX", 10)) -ORACLE_DB_POOL_INCREMENT = int(os.environ.get("ORACLE_DB_POOL_INCREMENT", 1)) +ORACLE_DB_POOL_MIN = int(os.environ.get('ORACLE_DB_POOL_MIN', 2)) +ORACLE_DB_POOL_MAX = int(os.environ.get('ORACLE_DB_POOL_MAX', 10)) +ORACLE_DB_POOL_INCREMENT = int(os.environ.get('ORACLE_DB_POOL_INCREMENT', 1)) -if VECTOR_DB == "oracle23ai": +if VECTOR_DB == 'oracle23ai': if not ORACLE_DB_USER or not ORACLE_DB_PASSWORD or not ORACLE_DB_DSN: - raise ValueError( - "Oracle23ai requires setting ORACLE_DB_USER, ORACLE_DB_PASSWORD, and ORACLE_DB_DSN." - ) + raise ValueError('Oracle23ai requires setting ORACLE_DB_USER, ORACLE_DB_PASSWORD, and ORACLE_DB_DSN.') if ORACLE_DB_USE_WALLET and (not ORACLE_WALLET_DIR or not ORACLE_WALLET_PASSWORD): raise ValueError( - "Oracle23ai requires setting ORACLE_WALLET_DIR and ORACLE_WALLET_PASSWORD when using wallet authentication." + 'Oracle23ai requires setting ORACLE_WALLET_DIR and ORACLE_WALLET_PASSWORD when using wallet authentication.' ) -log.info(f"VECTOR_DB: {VECTOR_DB}") +log.info(f'VECTOR_DB: {VECTOR_DB}') # S3 Vector -S3_VECTOR_BUCKET_NAME = os.environ.get("S3_VECTOR_BUCKET_NAME", None) -S3_VECTOR_REGION = os.environ.get("S3_VECTOR_REGION", None) +S3_VECTOR_BUCKET_NAME = os.environ.get('S3_VECTOR_BUCKET_NAME', None) +S3_VECTOR_REGION = os.environ.get('S3_VECTOR_REGION', None) #################################### # Information Retrieval (RAG) @@ -2702,476 +2461,435 @@ S3_VECTOR_REGION = os.environ.get("S3_VECTOR_REGION", None) # If configured, Google Drive will be available as an upload option. ENABLE_GOOGLE_DRIVE_INTEGRATION = PersistentConfig( - "ENABLE_GOOGLE_DRIVE_INTEGRATION", - "google_drive.enable", - os.getenv("ENABLE_GOOGLE_DRIVE_INTEGRATION", "False").lower() == "true", + 'ENABLE_GOOGLE_DRIVE_INTEGRATION', + 'google_drive.enable', + os.getenv('ENABLE_GOOGLE_DRIVE_INTEGRATION', 'False').lower() == 'true', ) GOOGLE_DRIVE_CLIENT_ID = PersistentConfig( - "GOOGLE_DRIVE_CLIENT_ID", - "google_drive.client_id", - os.environ.get("GOOGLE_DRIVE_CLIENT_ID", ""), + 'GOOGLE_DRIVE_CLIENT_ID', + 'google_drive.client_id', + os.environ.get('GOOGLE_DRIVE_CLIENT_ID', ''), ) GOOGLE_DRIVE_API_KEY = PersistentConfig( - "GOOGLE_DRIVE_API_KEY", - "google_drive.api_key", - os.environ.get("GOOGLE_DRIVE_API_KEY", ""), + 'GOOGLE_DRIVE_API_KEY', + 'google_drive.api_key', + os.environ.get('GOOGLE_DRIVE_API_KEY', ''), ) ENABLE_ONEDRIVE_INTEGRATION = PersistentConfig( - "ENABLE_ONEDRIVE_INTEGRATION", - "onedrive.enable", - os.getenv("ENABLE_ONEDRIVE_INTEGRATION", "False").lower() == "true", + 'ENABLE_ONEDRIVE_INTEGRATION', + 'onedrive.enable', + os.getenv('ENABLE_ONEDRIVE_INTEGRATION', 'False').lower() == 'true', ) -ENABLE_ONEDRIVE_PERSONAL = ( - os.environ.get("ENABLE_ONEDRIVE_PERSONAL", "True").lower() == "true" -) -ENABLE_ONEDRIVE_BUSINESS = ( - os.environ.get("ENABLE_ONEDRIVE_BUSINESS", "True").lower() == "true" -) +ENABLE_ONEDRIVE_PERSONAL = os.environ.get('ENABLE_ONEDRIVE_PERSONAL', 'True').lower() == 'true' +ENABLE_ONEDRIVE_BUSINESS = os.environ.get('ENABLE_ONEDRIVE_BUSINESS', 'True').lower() == 'true' -ONEDRIVE_CLIENT_ID = os.environ.get("ONEDRIVE_CLIENT_ID", "") -ONEDRIVE_CLIENT_ID_PERSONAL = os.environ.get( - "ONEDRIVE_CLIENT_ID_PERSONAL", ONEDRIVE_CLIENT_ID -) -ONEDRIVE_CLIENT_ID_BUSINESS = os.environ.get( - "ONEDRIVE_CLIENT_ID_BUSINESS", ONEDRIVE_CLIENT_ID -) +ONEDRIVE_CLIENT_ID = os.environ.get('ONEDRIVE_CLIENT_ID', '') +ONEDRIVE_CLIENT_ID_PERSONAL = os.environ.get('ONEDRIVE_CLIENT_ID_PERSONAL', ONEDRIVE_CLIENT_ID) +ONEDRIVE_CLIENT_ID_BUSINESS = os.environ.get('ONEDRIVE_CLIENT_ID_BUSINESS', ONEDRIVE_CLIENT_ID) ONEDRIVE_SHAREPOINT_URL = PersistentConfig( - "ONEDRIVE_SHAREPOINT_URL", - "onedrive.sharepoint_url", - os.environ.get("ONEDRIVE_SHAREPOINT_URL", ""), + 'ONEDRIVE_SHAREPOINT_URL', + 'onedrive.sharepoint_url', + os.environ.get('ONEDRIVE_SHAREPOINT_URL', ''), ) ONEDRIVE_SHAREPOINT_TENANT_ID = PersistentConfig( - "ONEDRIVE_SHAREPOINT_TENANT_ID", - "onedrive.sharepoint_tenant_id", - os.environ.get("ONEDRIVE_SHAREPOINT_TENANT_ID", ""), + 'ONEDRIVE_SHAREPOINT_TENANT_ID', + 'onedrive.sharepoint_tenant_id', + os.environ.get('ONEDRIVE_SHAREPOINT_TENANT_ID', ''), ) # RAG Content Extraction CONTENT_EXTRACTION_ENGINE = PersistentConfig( - "CONTENT_EXTRACTION_ENGINE", - "rag.CONTENT_EXTRACTION_ENGINE", - os.environ.get("CONTENT_EXTRACTION_ENGINE", "").lower(), + 'CONTENT_EXTRACTION_ENGINE', + 'rag.CONTENT_EXTRACTION_ENGINE', + os.environ.get('CONTENT_EXTRACTION_ENGINE', '').lower(), ) DATALAB_MARKER_API_KEY = PersistentConfig( - "DATALAB_MARKER_API_KEY", - "rag.datalab_marker_api_key", - os.environ.get("DATALAB_MARKER_API_KEY", ""), + 'DATALAB_MARKER_API_KEY', + 'rag.datalab_marker_api_key', + os.environ.get('DATALAB_MARKER_API_KEY', ''), ) DATALAB_MARKER_API_BASE_URL = PersistentConfig( - "DATALAB_MARKER_API_BASE_URL", - "rag.datalab_marker_api_base_url", - os.environ.get("DATALAB_MARKER_API_BASE_URL", ""), + 'DATALAB_MARKER_API_BASE_URL', + 'rag.datalab_marker_api_base_url', + os.environ.get('DATALAB_MARKER_API_BASE_URL', ''), ) DATALAB_MARKER_ADDITIONAL_CONFIG = PersistentConfig( - "DATALAB_MARKER_ADDITIONAL_CONFIG", - "rag.datalab_marker_additional_config", - os.environ.get("DATALAB_MARKER_ADDITIONAL_CONFIG", ""), + 'DATALAB_MARKER_ADDITIONAL_CONFIG', + 'rag.datalab_marker_additional_config', + os.environ.get('DATALAB_MARKER_ADDITIONAL_CONFIG', ''), ) DATALAB_MARKER_USE_LLM = PersistentConfig( - "DATALAB_MARKER_USE_LLM", - "rag.DATALAB_MARKER_USE_LLM", - os.environ.get("DATALAB_MARKER_USE_LLM", "false").lower() == "true", + 'DATALAB_MARKER_USE_LLM', + 'rag.DATALAB_MARKER_USE_LLM', + os.environ.get('DATALAB_MARKER_USE_LLM', 'false').lower() == 'true', ) DATALAB_MARKER_SKIP_CACHE = PersistentConfig( - "DATALAB_MARKER_SKIP_CACHE", - "rag.datalab_marker_skip_cache", - os.environ.get("DATALAB_MARKER_SKIP_CACHE", "false").lower() == "true", + 'DATALAB_MARKER_SKIP_CACHE', + 'rag.datalab_marker_skip_cache', + os.environ.get('DATALAB_MARKER_SKIP_CACHE', 'false').lower() == 'true', ) DATALAB_MARKER_FORCE_OCR = PersistentConfig( - "DATALAB_MARKER_FORCE_OCR", - "rag.datalab_marker_force_ocr", - os.environ.get("DATALAB_MARKER_FORCE_OCR", "false").lower() == "true", + 'DATALAB_MARKER_FORCE_OCR', + 'rag.datalab_marker_force_ocr', + os.environ.get('DATALAB_MARKER_FORCE_OCR', 'false').lower() == 'true', ) DATALAB_MARKER_PAGINATE = PersistentConfig( - "DATALAB_MARKER_PAGINATE", - "rag.datalab_marker_paginate", - os.environ.get("DATALAB_MARKER_PAGINATE", "false").lower() == "true", + 'DATALAB_MARKER_PAGINATE', + 'rag.datalab_marker_paginate', + os.environ.get('DATALAB_MARKER_PAGINATE', 'false').lower() == 'true', ) DATALAB_MARKER_STRIP_EXISTING_OCR = PersistentConfig( - "DATALAB_MARKER_STRIP_EXISTING_OCR", - "rag.datalab_marker_strip_existing_ocr", - os.environ.get("DATALAB_MARKER_STRIP_EXISTING_OCR", "false").lower() == "true", + 'DATALAB_MARKER_STRIP_EXISTING_OCR', + 'rag.datalab_marker_strip_existing_ocr', + os.environ.get('DATALAB_MARKER_STRIP_EXISTING_OCR', 'false').lower() == 'true', ) DATALAB_MARKER_DISABLE_IMAGE_EXTRACTION = PersistentConfig( - "DATALAB_MARKER_DISABLE_IMAGE_EXTRACTION", - "rag.datalab_marker_disable_image_extraction", - os.environ.get("DATALAB_MARKER_DISABLE_IMAGE_EXTRACTION", "false").lower() - == "true", + 'DATALAB_MARKER_DISABLE_IMAGE_EXTRACTION', + 'rag.datalab_marker_disable_image_extraction', + os.environ.get('DATALAB_MARKER_DISABLE_IMAGE_EXTRACTION', 'false').lower() == 'true', ) DATALAB_MARKER_FORMAT_LINES = PersistentConfig( - "DATALAB_MARKER_FORMAT_LINES", - "rag.datalab_marker_format_lines", - os.environ.get("DATALAB_MARKER_FORMAT_LINES", "false").lower() == "true", + 'DATALAB_MARKER_FORMAT_LINES', + 'rag.datalab_marker_format_lines', + os.environ.get('DATALAB_MARKER_FORMAT_LINES', 'false').lower() == 'true', ) DATALAB_MARKER_OUTPUT_FORMAT = PersistentConfig( - "DATALAB_MARKER_OUTPUT_FORMAT", - "rag.datalab_marker_output_format", - os.environ.get("DATALAB_MARKER_OUTPUT_FORMAT", "markdown"), + 'DATALAB_MARKER_OUTPUT_FORMAT', + 'rag.datalab_marker_output_format', + os.environ.get('DATALAB_MARKER_OUTPUT_FORMAT', 'markdown'), ) MINERU_API_MODE = PersistentConfig( - "MINERU_API_MODE", - "rag.mineru_api_mode", - os.environ.get("MINERU_API_MODE", "local"), # "local" or "cloud" + 'MINERU_API_MODE', + 'rag.mineru_api_mode', + os.environ.get('MINERU_API_MODE', 'local'), # "local" or "cloud" ) MINERU_API_URL = PersistentConfig( - "MINERU_API_URL", - "rag.mineru_api_url", - os.environ.get("MINERU_API_URL", "http://localhost:8000"), + 'MINERU_API_URL', + 'rag.mineru_api_url', + os.environ.get('MINERU_API_URL', 'http://localhost:8000'), ) MINERU_API_TIMEOUT = PersistentConfig( - "MINERU_API_TIMEOUT", - "rag.mineru_api_timeout", - os.environ.get("MINERU_API_TIMEOUT", "300"), + 'MINERU_API_TIMEOUT', + 'rag.mineru_api_timeout', + os.environ.get('MINERU_API_TIMEOUT', '300'), ) MINERU_API_KEY = PersistentConfig( - "MINERU_API_KEY", - "rag.mineru_api_key", - os.environ.get("MINERU_API_KEY", ""), + 'MINERU_API_KEY', + 'rag.mineru_api_key', + os.environ.get('MINERU_API_KEY', ''), ) -mineru_params = os.getenv("MINERU_PARAMS", "") +mineru_params = os.getenv('MINERU_PARAMS', '') try: mineru_params = json.loads(mineru_params) except json.JSONDecodeError: mineru_params = {} MINERU_PARAMS = PersistentConfig( - "MINERU_PARAMS", - "rag.mineru_params", + 'MINERU_PARAMS', + 'rag.mineru_params', mineru_params, ) EXTERNAL_DOCUMENT_LOADER_URL = PersistentConfig( - "EXTERNAL_DOCUMENT_LOADER_URL", - "rag.external_document_loader_url", - os.environ.get("EXTERNAL_DOCUMENT_LOADER_URL", ""), + 'EXTERNAL_DOCUMENT_LOADER_URL', + 'rag.external_document_loader_url', + os.environ.get('EXTERNAL_DOCUMENT_LOADER_URL', ''), ) EXTERNAL_DOCUMENT_LOADER_API_KEY = PersistentConfig( - "EXTERNAL_DOCUMENT_LOADER_API_KEY", - "rag.external_document_loader_api_key", - os.environ.get("EXTERNAL_DOCUMENT_LOADER_API_KEY", ""), + 'EXTERNAL_DOCUMENT_LOADER_API_KEY', + 'rag.external_document_loader_api_key', + os.environ.get('EXTERNAL_DOCUMENT_LOADER_API_KEY', ''), ) TIKA_SERVER_URL = PersistentConfig( - "TIKA_SERVER_URL", - "rag.tika_server_url", - os.getenv("TIKA_SERVER_URL", "http://tika:9998"), # Default for sidecar deployment + 'TIKA_SERVER_URL', + 'rag.tika_server_url', + os.getenv('TIKA_SERVER_URL', 'http://tika:9998'), # Default for sidecar deployment ) DOCLING_SERVER_URL = PersistentConfig( - "DOCLING_SERVER_URL", - "rag.docling_server_url", - os.getenv("DOCLING_SERVER_URL", "http://docling:5001"), + 'DOCLING_SERVER_URL', + 'rag.docling_server_url', + os.getenv('DOCLING_SERVER_URL', 'http://docling:5001'), ) DOCLING_API_KEY = PersistentConfig( - "DOCLING_API_KEY", - "rag.docling_api_key", - os.getenv("DOCLING_API_KEY", ""), + 'DOCLING_API_KEY', + 'rag.docling_api_key', + os.getenv('DOCLING_API_KEY', ''), ) -docling_params = os.getenv("DOCLING_PARAMS", "") +docling_params = os.getenv('DOCLING_PARAMS', '') try: docling_params = json.loads(docling_params) except json.JSONDecodeError: docling_params = {} DOCLING_PARAMS = PersistentConfig( - "DOCLING_PARAMS", - "rag.docling_params", + 'DOCLING_PARAMS', + 'rag.docling_params', docling_params, ) DOCUMENT_INTELLIGENCE_ENDPOINT = PersistentConfig( - "DOCUMENT_INTELLIGENCE_ENDPOINT", - "rag.document_intelligence_endpoint", - os.getenv("DOCUMENT_INTELLIGENCE_ENDPOINT", ""), + 'DOCUMENT_INTELLIGENCE_ENDPOINT', + 'rag.document_intelligence_endpoint', + os.getenv('DOCUMENT_INTELLIGENCE_ENDPOINT', ''), ) DOCUMENT_INTELLIGENCE_KEY = PersistentConfig( - "DOCUMENT_INTELLIGENCE_KEY", - "rag.document_intelligence_key", - os.getenv("DOCUMENT_INTELLIGENCE_KEY", ""), + 'DOCUMENT_INTELLIGENCE_KEY', + 'rag.document_intelligence_key', + os.getenv('DOCUMENT_INTELLIGENCE_KEY', ''), ) DOCUMENT_INTELLIGENCE_MODEL = PersistentConfig( - "DOCUMENT_INTELLIGENCE_MODEL", - "rag.document_intelligence_model", - os.getenv("DOCUMENT_INTELLIGENCE_MODEL", "prebuilt-layout"), + 'DOCUMENT_INTELLIGENCE_MODEL', + 'rag.document_intelligence_model', + os.getenv('DOCUMENT_INTELLIGENCE_MODEL', 'prebuilt-layout'), ) MISTRAL_OCR_API_BASE_URL = PersistentConfig( - "MISTRAL_OCR_API_BASE_URL", - "rag.MISTRAL_OCR_API_BASE_URL", - os.getenv("MISTRAL_OCR_API_BASE_URL", "https://api.mistral.ai/v1"), + 'MISTRAL_OCR_API_BASE_URL', + 'rag.MISTRAL_OCR_API_BASE_URL', + os.getenv('MISTRAL_OCR_API_BASE_URL', 'https://api.mistral.ai/v1'), ) MISTRAL_OCR_API_KEY = PersistentConfig( - "MISTRAL_OCR_API_KEY", - "rag.mistral_ocr_api_key", - os.getenv("MISTRAL_OCR_API_KEY", ""), + 'MISTRAL_OCR_API_KEY', + 'rag.mistral_ocr_api_key', + os.getenv('MISTRAL_OCR_API_KEY', ''), ) BYPASS_EMBEDDING_AND_RETRIEVAL = PersistentConfig( - "BYPASS_EMBEDDING_AND_RETRIEVAL", - "rag.bypass_embedding_and_retrieval", - os.environ.get("BYPASS_EMBEDDING_AND_RETRIEVAL", "False").lower() == "true", + 'BYPASS_EMBEDDING_AND_RETRIEVAL', + 'rag.bypass_embedding_and_retrieval', + os.environ.get('BYPASS_EMBEDDING_AND_RETRIEVAL', 'False').lower() == 'true', ) -RAG_TOP_K = PersistentConfig( - "RAG_TOP_K", "rag.top_k", int(os.environ.get("RAG_TOP_K", "3")) -) +RAG_TOP_K = PersistentConfig('RAG_TOP_K', 'rag.top_k', int(os.environ.get('RAG_TOP_K', '3'))) RAG_TOP_K_RERANKER = PersistentConfig( - "RAG_TOP_K_RERANKER", - "rag.top_k_reranker", - int(os.environ.get("RAG_TOP_K_RERANKER", "3")), + 'RAG_TOP_K_RERANKER', + 'rag.top_k_reranker', + int(os.environ.get('RAG_TOP_K_RERANKER', '3')), ) RAG_RELEVANCE_THRESHOLD = PersistentConfig( - "RAG_RELEVANCE_THRESHOLD", - "rag.relevance_threshold", - float(os.environ.get("RAG_RELEVANCE_THRESHOLD", "0.0")), + 'RAG_RELEVANCE_THRESHOLD', + 'rag.relevance_threshold', + float(os.environ.get('RAG_RELEVANCE_THRESHOLD', '0.0')), ) RAG_HYBRID_BM25_WEIGHT = PersistentConfig( - "RAG_HYBRID_BM25_WEIGHT", - "rag.hybrid_bm25_weight", - float(os.environ.get("RAG_HYBRID_BM25_WEIGHT", "0.5")), + 'RAG_HYBRID_BM25_WEIGHT', + 'rag.hybrid_bm25_weight', + float(os.environ.get('RAG_HYBRID_BM25_WEIGHT', '0.5')), ) ENABLE_RAG_HYBRID_SEARCH = PersistentConfig( - "ENABLE_RAG_HYBRID_SEARCH", - "rag.enable_hybrid_search", - os.environ.get("ENABLE_RAG_HYBRID_SEARCH", "").lower() == "true", + 'ENABLE_RAG_HYBRID_SEARCH', + 'rag.enable_hybrid_search', + os.environ.get('ENABLE_RAG_HYBRID_SEARCH', '').lower() == 'true', ) ENABLE_RAG_HYBRID_SEARCH_ENRICHED_TEXTS = PersistentConfig( - "ENABLE_RAG_HYBRID_SEARCH_ENRICHED_TEXTS", - "rag.enable_hybrid_search_enriched_texts", - os.environ.get("ENABLE_RAG_HYBRID_SEARCH_ENRICHED_TEXTS", "False").lower() - == "true", + 'ENABLE_RAG_HYBRID_SEARCH_ENRICHED_TEXTS', + 'rag.enable_hybrid_search_enriched_texts', + os.environ.get('ENABLE_RAG_HYBRID_SEARCH_ENRICHED_TEXTS', 'False').lower() == 'true', ) RAG_FULL_CONTEXT = PersistentConfig( - "RAG_FULL_CONTEXT", - "rag.full_context", - os.getenv("RAG_FULL_CONTEXT", "False").lower() == "true", + 'RAG_FULL_CONTEXT', + 'rag.full_context', + os.getenv('RAG_FULL_CONTEXT', 'False').lower() == 'true', ) RAG_FILE_MAX_COUNT = PersistentConfig( - "RAG_FILE_MAX_COUNT", - "rag.file.max_count", - ( - int(os.environ.get("RAG_FILE_MAX_COUNT")) - if os.environ.get("RAG_FILE_MAX_COUNT") - else None - ), + 'RAG_FILE_MAX_COUNT', + 'rag.file.max_count', + (int(os.environ.get('RAG_FILE_MAX_COUNT')) if os.environ.get('RAG_FILE_MAX_COUNT') else None), ) RAG_FILE_MAX_SIZE = PersistentConfig( - "RAG_FILE_MAX_SIZE", - "rag.file.max_size", - ( - int(os.environ.get("RAG_FILE_MAX_SIZE")) - if os.environ.get("RAG_FILE_MAX_SIZE") - else None - ), + 'RAG_FILE_MAX_SIZE', + 'rag.file.max_size', + (int(os.environ.get('RAG_FILE_MAX_SIZE')) if os.environ.get('RAG_FILE_MAX_SIZE') else None), ) FILE_IMAGE_COMPRESSION_WIDTH = PersistentConfig( - "FILE_IMAGE_COMPRESSION_WIDTH", - "file.image_compression_width", - ( - int(os.environ.get("FILE_IMAGE_COMPRESSION_WIDTH")) - if os.environ.get("FILE_IMAGE_COMPRESSION_WIDTH") - else None - ), + 'FILE_IMAGE_COMPRESSION_WIDTH', + 'file.image_compression_width', + (int(os.environ.get('FILE_IMAGE_COMPRESSION_WIDTH')) if os.environ.get('FILE_IMAGE_COMPRESSION_WIDTH') else None), ) FILE_IMAGE_COMPRESSION_HEIGHT = PersistentConfig( - "FILE_IMAGE_COMPRESSION_HEIGHT", - "file.image_compression_height", - ( - int(os.environ.get("FILE_IMAGE_COMPRESSION_HEIGHT")) - if os.environ.get("FILE_IMAGE_COMPRESSION_HEIGHT") - else None - ), + 'FILE_IMAGE_COMPRESSION_HEIGHT', + 'file.image_compression_height', + (int(os.environ.get('FILE_IMAGE_COMPRESSION_HEIGHT')) if os.environ.get('FILE_IMAGE_COMPRESSION_HEIGHT') else None), ) RAG_ALLOWED_FILE_EXTENSIONS = PersistentConfig( - "RAG_ALLOWED_FILE_EXTENSIONS", - "rag.file.allowed_extensions", - [ - ext.strip() - for ext in os.environ.get("RAG_ALLOWED_FILE_EXTENSIONS", "").split(",") - if ext.strip() - ], + 'RAG_ALLOWED_FILE_EXTENSIONS', + 'rag.file.allowed_extensions', + [ext.strip() for ext in os.environ.get('RAG_ALLOWED_FILE_EXTENSIONS', '').split(',') if ext.strip()], ) RAG_EMBEDDING_ENGINE = PersistentConfig( - "RAG_EMBEDDING_ENGINE", - "rag.embedding_engine", - os.environ.get("RAG_EMBEDDING_ENGINE", ""), + 'RAG_EMBEDDING_ENGINE', + 'rag.embedding_engine', + os.environ.get('RAG_EMBEDDING_ENGINE', ''), ) PDF_EXTRACT_IMAGES = PersistentConfig( - "PDF_EXTRACT_IMAGES", - "rag.pdf_extract_images", - os.environ.get("PDF_EXTRACT_IMAGES", "False").lower() == "true", + 'PDF_EXTRACT_IMAGES', + 'rag.pdf_extract_images', + os.environ.get('PDF_EXTRACT_IMAGES', 'False').lower() == 'true', ) PDF_LOADER_MODE = PersistentConfig( - "PDF_LOADER_MODE", - "rag.pdf_loader_mode", - os.environ.get("PDF_LOADER_MODE", "page"), + 'PDF_LOADER_MODE', + 'rag.pdf_loader_mode', + os.environ.get('PDF_LOADER_MODE', 'page'), ) RAG_EMBEDDING_MODEL = PersistentConfig( - "RAG_EMBEDDING_MODEL", - "rag.embedding_model", - os.environ.get("RAG_EMBEDDING_MODEL", "sentence-transformers/all-MiniLM-L6-v2"), + 'RAG_EMBEDDING_MODEL', + 'rag.embedding_model', + os.environ.get('RAG_EMBEDDING_MODEL', 'sentence-transformers/all-MiniLM-L6-v2'), ) -log.info(f"Embedding model set: {RAG_EMBEDDING_MODEL.value}") +log.info(f'Embedding model set: {RAG_EMBEDDING_MODEL.value}') RAG_EMBEDDING_MODEL_AUTO_UPDATE = ( - not OFFLINE_MODE - and os.environ.get("RAG_EMBEDDING_MODEL_AUTO_UPDATE", "True").lower() == "true" + not OFFLINE_MODE and os.environ.get('RAG_EMBEDDING_MODEL_AUTO_UPDATE', 'True').lower() == 'true' ) RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE = ( - os.environ.get("RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE", "True").lower() == "true" + os.environ.get('RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE', 'True').lower() == 'true' ) RAG_EMBEDDING_BATCH_SIZE = PersistentConfig( - "RAG_EMBEDDING_BATCH_SIZE", - "rag.embedding_batch_size", - int( - os.environ.get("RAG_EMBEDDING_BATCH_SIZE") - or os.environ.get("RAG_EMBEDDING_OPENAI_BATCH_SIZE", "1") - ), + 'RAG_EMBEDDING_BATCH_SIZE', + 'rag.embedding_batch_size', + int(os.environ.get('RAG_EMBEDDING_BATCH_SIZE') or os.environ.get('RAG_EMBEDDING_OPENAI_BATCH_SIZE', '1')), ) ENABLE_ASYNC_EMBEDDING = PersistentConfig( - "ENABLE_ASYNC_EMBEDDING", - "rag.enable_async_embedding", - os.environ.get("ENABLE_ASYNC_EMBEDDING", "True").lower() == "true", + 'ENABLE_ASYNC_EMBEDDING', + 'rag.enable_async_embedding', + os.environ.get('ENABLE_ASYNC_EMBEDDING', 'True').lower() == 'true', ) RAG_EMBEDDING_CONCURRENT_REQUESTS = PersistentConfig( - "RAG_EMBEDDING_CONCURRENT_REQUESTS", - "rag.embedding_concurrent_requests", - int(os.getenv("RAG_EMBEDDING_CONCURRENT_REQUESTS", "0")), + 'RAG_EMBEDDING_CONCURRENT_REQUESTS', + 'rag.embedding_concurrent_requests', + int(os.getenv('RAG_EMBEDDING_CONCURRENT_REQUESTS', '0')), ) -RAG_EMBEDDING_QUERY_PREFIX = os.environ.get("RAG_EMBEDDING_QUERY_PREFIX", None) +RAG_EMBEDDING_QUERY_PREFIX = os.environ.get('RAG_EMBEDDING_QUERY_PREFIX', None) -RAG_EMBEDDING_CONTENT_PREFIX = os.environ.get("RAG_EMBEDDING_CONTENT_PREFIX", None) +RAG_EMBEDDING_CONTENT_PREFIX = os.environ.get('RAG_EMBEDDING_CONTENT_PREFIX', None) -RAG_EMBEDDING_PREFIX_FIELD_NAME = os.environ.get( - "RAG_EMBEDDING_PREFIX_FIELD_NAME", None -) +RAG_EMBEDDING_PREFIX_FIELD_NAME = os.environ.get('RAG_EMBEDDING_PREFIX_FIELD_NAME', None) RAG_RERANKING_ENGINE = PersistentConfig( - "RAG_RERANKING_ENGINE", - "rag.reranking_engine", - os.environ.get("RAG_RERANKING_ENGINE", ""), + 'RAG_RERANKING_ENGINE', + 'rag.reranking_engine', + os.environ.get('RAG_RERANKING_ENGINE', ''), ) RAG_RERANKING_MODEL = PersistentConfig( - "RAG_RERANKING_MODEL", - "rag.reranking_model", - os.environ.get("RAG_RERANKING_MODEL", ""), + 'RAG_RERANKING_MODEL', + 'rag.reranking_model', + os.environ.get('RAG_RERANKING_MODEL', ''), ) -if RAG_RERANKING_MODEL.value != "": - log.info(f"Reranking model set: {RAG_RERANKING_MODEL.value}") +if RAG_RERANKING_MODEL.value != '': + log.info(f'Reranking model set: {RAG_RERANKING_MODEL.value}') RAG_RERANKING_MODEL_AUTO_UPDATE = ( - not OFFLINE_MODE - and os.environ.get("RAG_RERANKING_MODEL_AUTO_UPDATE", "True").lower() == "true" + not OFFLINE_MODE and os.environ.get('RAG_RERANKING_MODEL_AUTO_UPDATE', 'True').lower() == 'true' ) RAG_RERANKING_MODEL_TRUST_REMOTE_CODE = ( - os.environ.get("RAG_RERANKING_MODEL_TRUST_REMOTE_CODE", "True").lower() == "true" + os.environ.get('RAG_RERANKING_MODEL_TRUST_REMOTE_CODE', 'True').lower() == 'true' ) RAG_EXTERNAL_RERANKER_URL = PersistentConfig( - "RAG_EXTERNAL_RERANKER_URL", - "rag.external_reranker_url", - os.environ.get("RAG_EXTERNAL_RERANKER_URL", ""), + 'RAG_EXTERNAL_RERANKER_URL', + 'rag.external_reranker_url', + os.environ.get('RAG_EXTERNAL_RERANKER_URL', ''), ) RAG_EXTERNAL_RERANKER_API_KEY = PersistentConfig( - "RAG_EXTERNAL_RERANKER_API_KEY", - "rag.external_reranker_api_key", - os.environ.get("RAG_EXTERNAL_RERANKER_API_KEY", ""), + 'RAG_EXTERNAL_RERANKER_API_KEY', + 'rag.external_reranker_api_key', + os.environ.get('RAG_EXTERNAL_RERANKER_API_KEY', ''), ) RAG_EXTERNAL_RERANKER_TIMEOUT = PersistentConfig( - "RAG_EXTERNAL_RERANKER_TIMEOUT", - "rag.external_reranker_timeout", - os.environ.get("RAG_EXTERNAL_RERANKER_TIMEOUT", ""), + 'RAG_EXTERNAL_RERANKER_TIMEOUT', + 'rag.external_reranker_timeout', + os.environ.get('RAG_EXTERNAL_RERANKER_TIMEOUT', ''), ) RAG_TEXT_SPLITTER = PersistentConfig( - "RAG_TEXT_SPLITTER", - "rag.text_splitter", - os.environ.get("RAG_TEXT_SPLITTER", ""), + 'RAG_TEXT_SPLITTER', + 'rag.text_splitter', + os.environ.get('RAG_TEXT_SPLITTER', ''), ) ENABLE_MARKDOWN_HEADER_TEXT_SPLITTER = PersistentConfig( - "ENABLE_MARKDOWN_HEADER_TEXT_SPLITTER", - "rag.enable_markdown_header_text_splitter", - os.environ.get("ENABLE_MARKDOWN_HEADER_TEXT_SPLITTER", "True").lower() == "true", + 'ENABLE_MARKDOWN_HEADER_TEXT_SPLITTER', + 'rag.enable_markdown_header_text_splitter', + os.environ.get('ENABLE_MARKDOWN_HEADER_TEXT_SPLITTER', 'True').lower() == 'true', ) -TIKTOKEN_CACHE_DIR = os.environ.get("TIKTOKEN_CACHE_DIR", f"{CACHE_DIR}/tiktoken") +TIKTOKEN_CACHE_DIR = os.environ.get('TIKTOKEN_CACHE_DIR', f'{CACHE_DIR}/tiktoken') TIKTOKEN_ENCODING_NAME = PersistentConfig( - "TIKTOKEN_ENCODING_NAME", - "rag.tiktoken_encoding_name", - os.environ.get("TIKTOKEN_ENCODING_NAME", "cl100k_base"), + 'TIKTOKEN_ENCODING_NAME', + 'rag.tiktoken_encoding_name', + os.environ.get('TIKTOKEN_ENCODING_NAME', 'cl100k_base'), ) -CHUNK_SIZE = PersistentConfig( - "CHUNK_SIZE", "rag.chunk_size", int(os.environ.get("CHUNK_SIZE", "1000")) -) +CHUNK_SIZE = PersistentConfig('CHUNK_SIZE', 'rag.chunk_size', int(os.environ.get('CHUNK_SIZE', '1000'))) CHUNK_MIN_SIZE_TARGET = PersistentConfig( - "CHUNK_MIN_SIZE_TARGET", - "rag.chunk_min_size_target", - int(os.environ.get("CHUNK_MIN_SIZE_TARGET", "0")), + 'CHUNK_MIN_SIZE_TARGET', + 'rag.chunk_min_size_target', + int(os.environ.get('CHUNK_MIN_SIZE_TARGET', '0')), ) CHUNK_OVERLAP = PersistentConfig( - "CHUNK_OVERLAP", - "rag.chunk_overlap", - int(os.environ.get("CHUNK_OVERLAP", "100")), + 'CHUNK_OVERLAP', + 'rag.chunk_overlap', + int(os.environ.get('CHUNK_OVERLAP', '100')), ) DEFAULT_RAG_TEMPLATE = """### Task: @@ -3201,85 +2919,81 @@ Provide a clear and direct response to the user's query, including inline citati """ RAG_TEMPLATE = PersistentConfig( - "RAG_TEMPLATE", - "rag.template", - os.environ.get("RAG_TEMPLATE", DEFAULT_RAG_TEMPLATE), + 'RAG_TEMPLATE', + 'rag.template', + os.environ.get('RAG_TEMPLATE', DEFAULT_RAG_TEMPLATE), ) RAG_OPENAI_API_BASE_URL = PersistentConfig( - "RAG_OPENAI_API_BASE_URL", - "rag.openai_api_base_url", - os.getenv("RAG_OPENAI_API_BASE_URL", OPENAI_API_BASE_URL), + 'RAG_OPENAI_API_BASE_URL', + 'rag.openai_api_base_url', + os.getenv('RAG_OPENAI_API_BASE_URL', OPENAI_API_BASE_URL), ) RAG_OPENAI_API_KEY = PersistentConfig( - "RAG_OPENAI_API_KEY", - "rag.openai_api_key", - os.getenv("RAG_OPENAI_API_KEY", OPENAI_API_KEY), + 'RAG_OPENAI_API_KEY', + 'rag.openai_api_key', + os.getenv('RAG_OPENAI_API_KEY', OPENAI_API_KEY), ) RAG_AZURE_OPENAI_BASE_URL = PersistentConfig( - "RAG_AZURE_OPENAI_BASE_URL", - "rag.azure_openai.base_url", - os.getenv("RAG_AZURE_OPENAI_BASE_URL", ""), + 'RAG_AZURE_OPENAI_BASE_URL', + 'rag.azure_openai.base_url', + os.getenv('RAG_AZURE_OPENAI_BASE_URL', ''), ) RAG_AZURE_OPENAI_API_KEY = PersistentConfig( - "RAG_AZURE_OPENAI_API_KEY", - "rag.azure_openai.api_key", - os.getenv("RAG_AZURE_OPENAI_API_KEY", ""), + 'RAG_AZURE_OPENAI_API_KEY', + 'rag.azure_openai.api_key', + os.getenv('RAG_AZURE_OPENAI_API_KEY', ''), ) RAG_AZURE_OPENAI_API_VERSION = PersistentConfig( - "RAG_AZURE_OPENAI_API_VERSION", - "rag.azure_openai.api_version", - os.getenv("RAG_AZURE_OPENAI_API_VERSION", ""), + 'RAG_AZURE_OPENAI_API_VERSION', + 'rag.azure_openai.api_version', + os.getenv('RAG_AZURE_OPENAI_API_VERSION', ''), ) RAG_OLLAMA_BASE_URL = PersistentConfig( - "RAG_OLLAMA_BASE_URL", - "rag.ollama.url", - os.getenv("RAG_OLLAMA_BASE_URL", OLLAMA_BASE_URL), + 'RAG_OLLAMA_BASE_URL', + 'rag.ollama.url', + os.getenv('RAG_OLLAMA_BASE_URL', OLLAMA_BASE_URL), ) RAG_OLLAMA_API_KEY = PersistentConfig( - "RAG_OLLAMA_API_KEY", - "rag.ollama.key", - os.getenv("RAG_OLLAMA_API_KEY", ""), + 'RAG_OLLAMA_API_KEY', + 'rag.ollama.key', + os.getenv('RAG_OLLAMA_API_KEY', ''), ) -ENABLE_RAG_LOCAL_WEB_FETCH = ( - os.getenv("ENABLE_RAG_LOCAL_WEB_FETCH", "False").lower() == "true" -) +ENABLE_RAG_LOCAL_WEB_FETCH = os.getenv('ENABLE_RAG_LOCAL_WEB_FETCH', 'False').lower() == 'true' DEFAULT_WEB_FETCH_FILTER_LIST = [ - "!169.254.169.254", - "!fd00:ec2::254", - "!metadata.google.internal", - "!metadata.azure.com", - "!100.100.100.200", + '!169.254.169.254', + '!fd00:ec2::254', + '!metadata.google.internal', + '!metadata.azure.com', + '!100.100.100.200', ] -web_fetch_filter_list = os.getenv("WEB_FETCH_FILTER_LIST", "") -if web_fetch_filter_list == "": +web_fetch_filter_list = os.getenv('WEB_FETCH_FILTER_LIST', '') +if web_fetch_filter_list == '': web_fetch_filter_list = [] else: - web_fetch_filter_list = [ - item.strip() for item in web_fetch_filter_list.split(",") if item.strip() - ] + web_fetch_filter_list = [item.strip() for item in web_fetch_filter_list.split(',') if item.strip()] WEB_FETCH_FILTER_LIST = list(set(DEFAULT_WEB_FETCH_FILTER_LIST + web_fetch_filter_list)) YOUTUBE_LOADER_LANGUAGE = PersistentConfig( - "YOUTUBE_LOADER_LANGUAGE", - "rag.youtube_loader_language", - os.getenv("YOUTUBE_LOADER_LANGUAGE", "en").split(","), + 'YOUTUBE_LOADER_LANGUAGE', + 'rag.youtube_loader_language', + os.getenv('YOUTUBE_LOADER_LANGUAGE', 'en').split(','), ) YOUTUBE_LOADER_PROXY_URL = PersistentConfig( - "YOUTUBE_LOADER_PROXY_URL", - "rag.youtube_loader_proxy_url", - os.getenv("YOUTUBE_LOADER_PROXY_URL", ""), + 'YOUTUBE_LOADER_PROXY_URL', + 'rag.youtube_loader_proxy_url', + os.getenv('YOUTUBE_LOADER_PROXY_URL', ''), ) @@ -3288,41 +3002,39 @@ YOUTUBE_LOADER_PROXY_URL = PersistentConfig( #################################### ENABLE_WEB_SEARCH = PersistentConfig( - "ENABLE_WEB_SEARCH", - "rag.web.search.enable", - os.getenv("ENABLE_WEB_SEARCH", "False").lower() == "true", + 'ENABLE_WEB_SEARCH', + 'rag.web.search.enable', + os.getenv('ENABLE_WEB_SEARCH', 'False').lower() == 'true', ) WEB_SEARCH_ENGINE = PersistentConfig( - "WEB_SEARCH_ENGINE", - "rag.web.search.engine", - os.getenv("WEB_SEARCH_ENGINE", ""), + 'WEB_SEARCH_ENGINE', + 'rag.web.search.engine', + os.getenv('WEB_SEARCH_ENGINE', ''), ) BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL = PersistentConfig( - "BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL", - "rag.web.search.bypass_embedding_and_retrieval", - os.getenv("BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL", "False").lower() == "true", + 'BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL', + 'rag.web.search.bypass_embedding_and_retrieval', + os.getenv('BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL', 'False').lower() == 'true', ) BYPASS_WEB_SEARCH_WEB_LOADER = PersistentConfig( - "BYPASS_WEB_SEARCH_WEB_LOADER", - "rag.web.search.bypass_web_loader", - os.getenv("BYPASS_WEB_SEARCH_WEB_LOADER", "False").lower() == "true", + 'BYPASS_WEB_SEARCH_WEB_LOADER', + 'rag.web.search.bypass_web_loader', + os.getenv('BYPASS_WEB_SEARCH_WEB_LOADER', 'False').lower() == 'true', ) WEB_SEARCH_RESULT_COUNT = PersistentConfig( - "WEB_SEARCH_RESULT_COUNT", - "rag.web.search.result_count", - int(os.getenv("WEB_SEARCH_RESULT_COUNT", "3")), + 'WEB_SEARCH_RESULT_COUNT', + 'rag.web.search.result_count', + int(os.getenv('WEB_SEARCH_RESULT_COUNT', '3')), ) try: - web_search_domain_filter_list = json.loads( - os.getenv("WEB_SEARCH_DOMAIN_FILTER_LIST", "[]") - ) + web_search_domain_filter_list = json.loads(os.getenv('WEB_SEARCH_DOMAIN_FILTER_LIST', '[]')) except Exception as e: web_search_domain_filter_list = [ # "wikipedia.com", @@ -3334,360 +3046,354 @@ except Exception as e: # You can provide a list of your own websites to filter after performing a web search. # This ensures the highest level of safety and reliability of the information sources. WEB_SEARCH_DOMAIN_FILTER_LIST = PersistentConfig( - "WEB_SEARCH_DOMAIN_FILTER_LIST", - "rag.web.search.domain.filter_list", + 'WEB_SEARCH_DOMAIN_FILTER_LIST', + 'rag.web.search.domain.filter_list', web_search_domain_filter_list, ) WEB_SEARCH_CONCURRENT_REQUESTS = PersistentConfig( - "WEB_SEARCH_CONCURRENT_REQUESTS", - "rag.web.search.concurrent_requests", - int(os.getenv("WEB_SEARCH_CONCURRENT_REQUESTS", "0")), + 'WEB_SEARCH_CONCURRENT_REQUESTS', + 'rag.web.search.concurrent_requests', + int(os.getenv('WEB_SEARCH_CONCURRENT_REQUESTS', '0')), ) WEB_FETCH_MAX_CONTENT_LENGTH = PersistentConfig( - "WEB_FETCH_MAX_CONTENT_LENGTH", - "rag.web.search.fetch_url_max_content_length", - ( - int(os.environ.get("WEB_FETCH_MAX_CONTENT_LENGTH")) - if os.environ.get("WEB_FETCH_MAX_CONTENT_LENGTH") - else None - ), + 'WEB_FETCH_MAX_CONTENT_LENGTH', + 'rag.web.search.fetch_url_max_content_length', + (int(os.environ.get('WEB_FETCH_MAX_CONTENT_LENGTH')) if os.environ.get('WEB_FETCH_MAX_CONTENT_LENGTH') else None), ) WEB_LOADER_ENGINE = PersistentConfig( - "WEB_LOADER_ENGINE", - "rag.web.loader.engine", - os.environ.get("WEB_LOADER_ENGINE", ""), + 'WEB_LOADER_ENGINE', + 'rag.web.loader.engine', + os.environ.get('WEB_LOADER_ENGINE', ''), ) WEB_LOADER_CONCURRENT_REQUESTS = PersistentConfig( - "WEB_LOADER_CONCURRENT_REQUESTS", - "rag.web.loader.concurrent_requests", - int(os.getenv("WEB_LOADER_CONCURRENT_REQUESTS", "10")), + 'WEB_LOADER_CONCURRENT_REQUESTS', + 'rag.web.loader.concurrent_requests', + int(os.getenv('WEB_LOADER_CONCURRENT_REQUESTS', '10')), ) WEB_LOADER_TIMEOUT = PersistentConfig( - "WEB_LOADER_TIMEOUT", - "rag.web.loader.timeout", - os.getenv("WEB_LOADER_TIMEOUT", ""), + 'WEB_LOADER_TIMEOUT', + 'rag.web.loader.timeout', + os.getenv('WEB_LOADER_TIMEOUT', ''), ) ENABLE_WEB_LOADER_SSL_VERIFICATION = PersistentConfig( - "ENABLE_WEB_LOADER_SSL_VERIFICATION", - "rag.web.loader.ssl_verification", - os.environ.get("ENABLE_WEB_LOADER_SSL_VERIFICATION", "True").lower() == "true", + 'ENABLE_WEB_LOADER_SSL_VERIFICATION', + 'rag.web.loader.ssl_verification', + os.environ.get('ENABLE_WEB_LOADER_SSL_VERIFICATION', 'True').lower() == 'true', ) WEB_SEARCH_TRUST_ENV = PersistentConfig( - "WEB_SEARCH_TRUST_ENV", - "rag.web.search.trust_env", - os.getenv("WEB_SEARCH_TRUST_ENV", "False").lower() == "true", + 'WEB_SEARCH_TRUST_ENV', + 'rag.web.search.trust_env', + os.getenv('WEB_SEARCH_TRUST_ENV', 'False').lower() == 'true', ) OLLAMA_CLOUD_WEB_SEARCH_API_KEY = PersistentConfig( - "OLLAMA_CLOUD_WEB_SEARCH_API_KEY", - "rag.web.search.ollama_cloud_api_key", - os.getenv("OLLAMA_CLOUD_API_KEY", ""), + 'OLLAMA_CLOUD_WEB_SEARCH_API_KEY', + 'rag.web.search.ollama_cloud_api_key', + os.getenv('OLLAMA_CLOUD_API_KEY', ''), ) SEARXNG_QUERY_URL = PersistentConfig( - "SEARXNG_QUERY_URL", - "rag.web.search.searxng_query_url", - os.getenv("SEARXNG_QUERY_URL", ""), + 'SEARXNG_QUERY_URL', + 'rag.web.search.searxng_query_url', + os.getenv('SEARXNG_QUERY_URL', ''), ) SEARXNG_LANGUAGE = PersistentConfig( - "SEARXNG_LANGUAGE", - "rag.web.search.searxng_language", - os.getenv("SEARXNG_LANGUAGE", "all"), + 'SEARXNG_LANGUAGE', + 'rag.web.search.searxng_language', + os.getenv('SEARXNG_LANGUAGE', 'all'), ) YACY_QUERY_URL = PersistentConfig( - "YACY_QUERY_URL", - "rag.web.search.yacy_query_url", - os.getenv("YACY_QUERY_URL", ""), + 'YACY_QUERY_URL', + 'rag.web.search.yacy_query_url', + os.getenv('YACY_QUERY_URL', ''), ) YACY_USERNAME = PersistentConfig( - "YACY_USERNAME", - "rag.web.search.yacy_username", - os.getenv("YACY_USERNAME", ""), + 'YACY_USERNAME', + 'rag.web.search.yacy_username', + os.getenv('YACY_USERNAME', ''), ) YACY_PASSWORD = PersistentConfig( - "YACY_PASSWORD", - "rag.web.search.yacy_password", - os.getenv("YACY_PASSWORD", ""), + 'YACY_PASSWORD', + 'rag.web.search.yacy_password', + os.getenv('YACY_PASSWORD', ''), ) GOOGLE_PSE_API_KEY = PersistentConfig( - "GOOGLE_PSE_API_KEY", - "rag.web.search.google_pse_api_key", - os.getenv("GOOGLE_PSE_API_KEY", ""), + 'GOOGLE_PSE_API_KEY', + 'rag.web.search.google_pse_api_key', + os.getenv('GOOGLE_PSE_API_KEY', ''), ) GOOGLE_PSE_ENGINE_ID = PersistentConfig( - "GOOGLE_PSE_ENGINE_ID", - "rag.web.search.google_pse_engine_id", - os.getenv("GOOGLE_PSE_ENGINE_ID", ""), + 'GOOGLE_PSE_ENGINE_ID', + 'rag.web.search.google_pse_engine_id', + os.getenv('GOOGLE_PSE_ENGINE_ID', ''), ) BRAVE_SEARCH_API_KEY = PersistentConfig( - "BRAVE_SEARCH_API_KEY", - "rag.web.search.brave_search_api_key", - os.getenv("BRAVE_SEARCH_API_KEY", ""), + 'BRAVE_SEARCH_API_KEY', + 'rag.web.search.brave_search_api_key', + os.getenv('BRAVE_SEARCH_API_KEY', ''), ) KAGI_SEARCH_API_KEY = PersistentConfig( - "KAGI_SEARCH_API_KEY", - "rag.web.search.kagi_search_api_key", - os.getenv("KAGI_SEARCH_API_KEY", ""), + 'KAGI_SEARCH_API_KEY', + 'rag.web.search.kagi_search_api_key', + os.getenv('KAGI_SEARCH_API_KEY', ''), ) MOJEEK_SEARCH_API_KEY = PersistentConfig( - "MOJEEK_SEARCH_API_KEY", - "rag.web.search.mojeek_search_api_key", - os.getenv("MOJEEK_SEARCH_API_KEY", ""), + 'MOJEEK_SEARCH_API_KEY', + 'rag.web.search.mojeek_search_api_key', + os.getenv('MOJEEK_SEARCH_API_KEY', ''), ) BOCHA_SEARCH_API_KEY = PersistentConfig( - "BOCHA_SEARCH_API_KEY", - "rag.web.search.bocha_search_api_key", - os.getenv("BOCHA_SEARCH_API_KEY", ""), + 'BOCHA_SEARCH_API_KEY', + 'rag.web.search.bocha_search_api_key', + os.getenv('BOCHA_SEARCH_API_KEY', ''), ) SERPSTACK_API_KEY = PersistentConfig( - "SERPSTACK_API_KEY", - "rag.web.search.serpstack_api_key", - os.getenv("SERPSTACK_API_KEY", ""), + 'SERPSTACK_API_KEY', + 'rag.web.search.serpstack_api_key', + os.getenv('SERPSTACK_API_KEY', ''), ) SERPSTACK_HTTPS = PersistentConfig( - "SERPSTACK_HTTPS", - "rag.web.search.serpstack_https", - os.getenv("SERPSTACK_HTTPS", "True").lower() == "true", + 'SERPSTACK_HTTPS', + 'rag.web.search.serpstack_https', + os.getenv('SERPSTACK_HTTPS', 'True').lower() == 'true', ) SERPER_API_KEY = PersistentConfig( - "SERPER_API_KEY", - "rag.web.search.serper_api_key", - os.getenv("SERPER_API_KEY", ""), + 'SERPER_API_KEY', + 'rag.web.search.serper_api_key', + os.getenv('SERPER_API_KEY', ''), ) SERPLY_API_KEY = PersistentConfig( - "SERPLY_API_KEY", - "rag.web.search.serply_api_key", - os.getenv("SERPLY_API_KEY", ""), + 'SERPLY_API_KEY', + 'rag.web.search.serply_api_key', + os.getenv('SERPLY_API_KEY', ''), ) DDGS_BACKEND = PersistentConfig( - "DDGS_BACKEND", - "rag.web.search.ddgs_backend", - os.getenv("DDGS_BACKEND", "auto"), + 'DDGS_BACKEND', + 'rag.web.search.ddgs_backend', + os.getenv('DDGS_BACKEND', 'auto'), ) JINA_API_KEY = PersistentConfig( - "JINA_API_KEY", - "rag.web.search.jina_api_key", - os.getenv("JINA_API_KEY", ""), + 'JINA_API_KEY', + 'rag.web.search.jina_api_key', + os.getenv('JINA_API_KEY', ''), ) JINA_API_BASE_URL = PersistentConfig( - "JINA_API_BASE_URL", - "rag.web.search.jina_api_base_url", - os.getenv("JINA_API_BASE_URL", ""), + 'JINA_API_BASE_URL', + 'rag.web.search.jina_api_base_url', + os.getenv('JINA_API_BASE_URL', ''), ) SEARCHAPI_API_KEY = PersistentConfig( - "SEARCHAPI_API_KEY", - "rag.web.search.searchapi_api_key", - os.getenv("SEARCHAPI_API_KEY", ""), + 'SEARCHAPI_API_KEY', + 'rag.web.search.searchapi_api_key', + os.getenv('SEARCHAPI_API_KEY', ''), ) SEARCHAPI_ENGINE = PersistentConfig( - "SEARCHAPI_ENGINE", - "rag.web.search.searchapi_engine", - os.getenv("SEARCHAPI_ENGINE", ""), + 'SEARCHAPI_ENGINE', + 'rag.web.search.searchapi_engine', + os.getenv('SEARCHAPI_ENGINE', ''), ) SERPAPI_API_KEY = PersistentConfig( - "SERPAPI_API_KEY", - "rag.web.search.serpapi_api_key", - os.getenv("SERPAPI_API_KEY", ""), + 'SERPAPI_API_KEY', + 'rag.web.search.serpapi_api_key', + os.getenv('SERPAPI_API_KEY', ''), ) SERPAPI_ENGINE = PersistentConfig( - "SERPAPI_ENGINE", - "rag.web.search.serpapi_engine", - os.getenv("SERPAPI_ENGINE", ""), + 'SERPAPI_ENGINE', + 'rag.web.search.serpapi_engine', + os.getenv('SERPAPI_ENGINE', ''), ) BING_SEARCH_V7_ENDPOINT = PersistentConfig( - "BING_SEARCH_V7_ENDPOINT", - "rag.web.search.bing_search_v7_endpoint", - os.environ.get( - "BING_SEARCH_V7_ENDPOINT", "https://api.bing.microsoft.com/v7.0/search" - ), + 'BING_SEARCH_V7_ENDPOINT', + 'rag.web.search.bing_search_v7_endpoint', + os.environ.get('BING_SEARCH_V7_ENDPOINT', 'https://api.bing.microsoft.com/v7.0/search'), ) BING_SEARCH_V7_SUBSCRIPTION_KEY = PersistentConfig( - "BING_SEARCH_V7_SUBSCRIPTION_KEY", - "rag.web.search.bing_search_v7_subscription_key", - os.environ.get("BING_SEARCH_V7_SUBSCRIPTION_KEY", ""), + 'BING_SEARCH_V7_SUBSCRIPTION_KEY', + 'rag.web.search.bing_search_v7_subscription_key', + os.environ.get('BING_SEARCH_V7_SUBSCRIPTION_KEY', ''), ) AZURE_AI_SEARCH_API_KEY = PersistentConfig( - "AZURE_AI_SEARCH_API_KEY", - "rag.web.search.azure_ai_search_api_key", - os.environ.get("AZURE_AI_SEARCH_API_KEY", ""), + 'AZURE_AI_SEARCH_API_KEY', + 'rag.web.search.azure_ai_search_api_key', + os.environ.get('AZURE_AI_SEARCH_API_KEY', ''), ) AZURE_AI_SEARCH_ENDPOINT = PersistentConfig( - "AZURE_AI_SEARCH_ENDPOINT", - "rag.web.search.azure_ai_search_endpoint", - os.environ.get("AZURE_AI_SEARCH_ENDPOINT", ""), + 'AZURE_AI_SEARCH_ENDPOINT', + 'rag.web.search.azure_ai_search_endpoint', + os.environ.get('AZURE_AI_SEARCH_ENDPOINT', ''), ) AZURE_AI_SEARCH_INDEX_NAME = PersistentConfig( - "AZURE_AI_SEARCH_INDEX_NAME", - "rag.web.search.azure_ai_search_index_name", - os.environ.get("AZURE_AI_SEARCH_INDEX_NAME", ""), + 'AZURE_AI_SEARCH_INDEX_NAME', + 'rag.web.search.azure_ai_search_index_name', + os.environ.get('AZURE_AI_SEARCH_INDEX_NAME', ''), ) EXA_API_KEY = PersistentConfig( - "EXA_API_KEY", - "rag.web.search.exa_api_key", - os.getenv("EXA_API_KEY", ""), + 'EXA_API_KEY', + 'rag.web.search.exa_api_key', + os.getenv('EXA_API_KEY', ''), ) PERPLEXITY_API_KEY = PersistentConfig( - "PERPLEXITY_API_KEY", - "rag.web.search.perplexity_api_key", - os.getenv("PERPLEXITY_API_KEY", ""), + 'PERPLEXITY_API_KEY', + 'rag.web.search.perplexity_api_key', + os.getenv('PERPLEXITY_API_KEY', ''), ) PERPLEXITY_MODEL = PersistentConfig( - "PERPLEXITY_MODEL", - "rag.web.search.perplexity_model", - os.getenv("PERPLEXITY_MODEL", "sonar"), + 'PERPLEXITY_MODEL', + 'rag.web.search.perplexity_model', + os.getenv('PERPLEXITY_MODEL', 'sonar'), ) PERPLEXITY_SEARCH_CONTEXT_USAGE = PersistentConfig( - "PERPLEXITY_SEARCH_CONTEXT_USAGE", - "rag.web.search.perplexity_search_context_usage", - os.getenv("PERPLEXITY_SEARCH_CONTEXT_USAGE", "medium"), + 'PERPLEXITY_SEARCH_CONTEXT_USAGE', + 'rag.web.search.perplexity_search_context_usage', + os.getenv('PERPLEXITY_SEARCH_CONTEXT_USAGE', 'medium'), ) PERPLEXITY_SEARCH_API_URL = PersistentConfig( - "PERPLEXITY_SEARCH_API_URL", - "rag.web.search.perplexity_search_api_url", - os.getenv("PERPLEXITY_SEARCH_API_URL", "https://api.perplexity.ai/search"), + 'PERPLEXITY_SEARCH_API_URL', + 'rag.web.search.perplexity_search_api_url', + os.getenv('PERPLEXITY_SEARCH_API_URL', 'https://api.perplexity.ai/search'), ) SOUGOU_API_SID = PersistentConfig( - "SOUGOU_API_SID", - "rag.web.search.sougou_api_sid", - os.getenv("SOUGOU_API_SID", ""), + 'SOUGOU_API_SID', + 'rag.web.search.sougou_api_sid', + os.getenv('SOUGOU_API_SID', ''), ) SOUGOU_API_SK = PersistentConfig( - "SOUGOU_API_SK", - "rag.web.search.sougou_api_sk", - os.getenv("SOUGOU_API_SK", ""), + 'SOUGOU_API_SK', + 'rag.web.search.sougou_api_sk', + os.getenv('SOUGOU_API_SK', ''), ) TAVILY_API_KEY = PersistentConfig( - "TAVILY_API_KEY", - "rag.web.search.tavily_api_key", - os.getenv("TAVILY_API_KEY", ""), + 'TAVILY_API_KEY', + 'rag.web.search.tavily_api_key', + os.getenv('TAVILY_API_KEY', ''), ) TAVILY_EXTRACT_DEPTH = PersistentConfig( - "TAVILY_EXTRACT_DEPTH", - "rag.web.search.tavily_extract_depth", - os.getenv("TAVILY_EXTRACT_DEPTH", "basic"), + 'TAVILY_EXTRACT_DEPTH', + 'rag.web.search.tavily_extract_depth', + os.getenv('TAVILY_EXTRACT_DEPTH', 'basic'), ) PLAYWRIGHT_WS_URL = PersistentConfig( - "PLAYWRIGHT_WS_URL", - "rag.web.loader.playwright_ws_url", - os.environ.get("PLAYWRIGHT_WS_URL", ""), + 'PLAYWRIGHT_WS_URL', + 'rag.web.loader.playwright_ws_url', + os.environ.get('PLAYWRIGHT_WS_URL', ''), ) PLAYWRIGHT_TIMEOUT = PersistentConfig( - "PLAYWRIGHT_TIMEOUT", - "rag.web.loader.playwright_timeout", - int(os.environ.get("PLAYWRIGHT_TIMEOUT", "10000")), + 'PLAYWRIGHT_TIMEOUT', + 'rag.web.loader.playwright_timeout', + int(os.environ.get('PLAYWRIGHT_TIMEOUT', '10000')), ) FIRECRAWL_API_KEY = PersistentConfig( - "FIRECRAWL_API_KEY", - "rag.web.loader.firecrawl_api_key", - os.environ.get("FIRECRAWL_API_KEY", ""), + 'FIRECRAWL_API_KEY', + 'rag.web.loader.firecrawl_api_key', + os.environ.get('FIRECRAWL_API_KEY', ''), ) FIRECRAWL_API_BASE_URL = PersistentConfig( - "FIRECRAWL_API_BASE_URL", - "rag.web.loader.firecrawl_api_url", - os.environ.get("FIRECRAWL_API_BASE_URL", "https://api.firecrawl.dev"), + 'FIRECRAWL_API_BASE_URL', + 'rag.web.loader.firecrawl_api_url', + os.environ.get('FIRECRAWL_API_BASE_URL', 'https://api.firecrawl.dev'), ) FIRECRAWL_TIMEOUT = PersistentConfig( - "FIRECRAWL_TIMEOUT", - "rag.web.loader.firecrawl_timeout", - os.environ.get("FIRECRAWL_TIMEOUT", ""), + 'FIRECRAWL_TIMEOUT', + 'rag.web.loader.firecrawl_timeout', + os.environ.get('FIRECRAWL_TIMEOUT', ''), ) EXTERNAL_WEB_SEARCH_URL = PersistentConfig( - "EXTERNAL_WEB_SEARCH_URL", - "rag.web.search.external_web_search_url", - os.environ.get("EXTERNAL_WEB_SEARCH_URL", ""), + 'EXTERNAL_WEB_SEARCH_URL', + 'rag.web.search.external_web_search_url', + os.environ.get('EXTERNAL_WEB_SEARCH_URL', ''), ) EXTERNAL_WEB_SEARCH_API_KEY = PersistentConfig( - "EXTERNAL_WEB_SEARCH_API_KEY", - "rag.web.search.external_web_search_api_key", - os.environ.get("EXTERNAL_WEB_SEARCH_API_KEY", ""), + 'EXTERNAL_WEB_SEARCH_API_KEY', + 'rag.web.search.external_web_search_api_key', + os.environ.get('EXTERNAL_WEB_SEARCH_API_KEY', ''), ) EXTERNAL_WEB_LOADER_URL = PersistentConfig( - "EXTERNAL_WEB_LOADER_URL", - "rag.web.loader.external_web_loader_url", - os.environ.get("EXTERNAL_WEB_LOADER_URL", ""), + 'EXTERNAL_WEB_LOADER_URL', + 'rag.web.loader.external_web_loader_url', + os.environ.get('EXTERNAL_WEB_LOADER_URL', ''), ) EXTERNAL_WEB_LOADER_API_KEY = PersistentConfig( - "EXTERNAL_WEB_LOADER_API_KEY", - "rag.web.loader.external_web_loader_api_key", - os.environ.get("EXTERNAL_WEB_LOADER_API_KEY", ""), + 'EXTERNAL_WEB_LOADER_API_KEY', + 'rag.web.loader.external_web_loader_api_key', + os.environ.get('EXTERNAL_WEB_LOADER_API_KEY', ''), ) YANDEX_WEB_SEARCH_URL = PersistentConfig( - "YANDEX_WEB_SEARCH_URL", - "rag.web.search.yandex_web_search_url", - os.environ.get("YANDEX_WEB_SEARCH_URL", ""), + 'YANDEX_WEB_SEARCH_URL', + 'rag.web.search.yandex_web_search_url', + os.environ.get('YANDEX_WEB_SEARCH_URL', ''), ) YANDEX_WEB_SEARCH_API_KEY = PersistentConfig( - "YANDEX_WEB_SEARCH_API_KEY", - "rag.web.search.yandex_web_search_api_key", - os.environ.get("YANDEX_WEB_SEARCH_API_KEY", ""), + 'YANDEX_WEB_SEARCH_API_KEY', + 'rag.web.search.yandex_web_search_api_key', + os.environ.get('YANDEX_WEB_SEARCH_API_KEY', ''), ) YANDEX_WEB_SEARCH_CONFIG = PersistentConfig( - "YANDEX_WEB_SEARCH_CONFIG", - "rag.web.search.yandex_web_search_config", - os.environ.get("YANDEX_WEB_SEARCH_CONFIG", ""), + 'YANDEX_WEB_SEARCH_CONFIG', + 'rag.web.search.yandex_web_search_config', + os.environ.get('YANDEX_WEB_SEARCH_CONFIG', ''), ) YOUCOM_API_KEY = PersistentConfig( - "YOUCOM_API_KEY", - "rag.web.search.youcom_api_key", - os.environ.get("YOUCOM_API_KEY", ""), + 'YOUCOM_API_KEY', + 'rag.web.search.youcom_api_key', + os.environ.get('YOUCOM_API_KEY', ''), ) #################################### @@ -3695,80 +3401,72 @@ YOUCOM_API_KEY = PersistentConfig( #################################### ENABLE_IMAGE_GENERATION = PersistentConfig( - "ENABLE_IMAGE_GENERATION", - "image_generation.enable", - os.environ.get("ENABLE_IMAGE_GENERATION", "").lower() == "true", + 'ENABLE_IMAGE_GENERATION', + 'image_generation.enable', + os.environ.get('ENABLE_IMAGE_GENERATION', '').lower() == 'true', ) IMAGE_GENERATION_ENGINE = PersistentConfig( - "IMAGE_GENERATION_ENGINE", - "image_generation.engine", - os.getenv("IMAGE_GENERATION_ENGINE", "openai"), + 'IMAGE_GENERATION_ENGINE', + 'image_generation.engine', + os.getenv('IMAGE_GENERATION_ENGINE', 'openai'), ) IMAGE_GENERATION_MODEL = PersistentConfig( - "IMAGE_GENERATION_MODEL", - "image_generation.model", - os.getenv("IMAGE_GENERATION_MODEL", ""), + 'IMAGE_GENERATION_MODEL', + 'image_generation.model', + os.getenv('IMAGE_GENERATION_MODEL', ''), ) # Regex pattern for models that support IMAGE_SIZE = "auto". -IMAGE_AUTO_SIZE_MODELS_REGEX_PATTERN = os.getenv( - "IMAGE_AUTO_SIZE_MODELS_REGEX_PATTERN", "^gpt-image" -) +IMAGE_AUTO_SIZE_MODELS_REGEX_PATTERN = os.getenv('IMAGE_AUTO_SIZE_MODELS_REGEX_PATTERN', '^gpt-image') # Regex pattern for models that return URLs instead of base64 data. -IMAGE_URL_RESPONSE_MODELS_REGEX_PATTERN = os.getenv( - "IMAGE_URL_RESPONSE_MODELS_REGEX_PATTERN", "^gpt-image" -) +IMAGE_URL_RESPONSE_MODELS_REGEX_PATTERN = os.getenv('IMAGE_URL_RESPONSE_MODELS_REGEX_PATTERN', '^gpt-image') -IMAGE_SIZE = PersistentConfig( - "IMAGE_SIZE", "image_generation.size", os.getenv("IMAGE_SIZE", "512x512") -) +IMAGE_SIZE = PersistentConfig('IMAGE_SIZE', 'image_generation.size', os.getenv('IMAGE_SIZE', '512x512')) -IMAGE_STEPS = PersistentConfig( - "IMAGE_STEPS", "image_generation.steps", int(os.getenv("IMAGE_STEPS", 50)) -) +IMAGE_STEPS = PersistentConfig('IMAGE_STEPS', 'image_generation.steps', int(os.getenv('IMAGE_STEPS', 50))) ENABLE_IMAGE_PROMPT_GENERATION = PersistentConfig( - "ENABLE_IMAGE_PROMPT_GENERATION", - "image_generation.prompt.enable", - os.environ.get("ENABLE_IMAGE_PROMPT_GENERATION", "true").lower() == "true", + 'ENABLE_IMAGE_PROMPT_GENERATION', + 'image_generation.prompt.enable', + os.environ.get('ENABLE_IMAGE_PROMPT_GENERATION', 'true').lower() == 'true', ) AUTOMATIC1111_BASE_URL = PersistentConfig( - "AUTOMATIC1111_BASE_URL", - "image_generation.automatic1111.base_url", - os.getenv("AUTOMATIC1111_BASE_URL", ""), + 'AUTOMATIC1111_BASE_URL', + 'image_generation.automatic1111.base_url', + os.getenv('AUTOMATIC1111_BASE_URL', ''), ) AUTOMATIC1111_API_AUTH = PersistentConfig( - "AUTOMATIC1111_API_AUTH", - "image_generation.automatic1111.api_auth", - os.getenv("AUTOMATIC1111_API_AUTH", ""), + 'AUTOMATIC1111_API_AUTH', + 'image_generation.automatic1111.api_auth', + os.getenv('AUTOMATIC1111_API_AUTH', ''), ) -automatic1111_params = os.getenv("AUTOMATIC1111_PARAMS", "") +automatic1111_params = os.getenv('AUTOMATIC1111_PARAMS', '') try: automatic1111_params = json.loads(automatic1111_params) except json.JSONDecodeError: automatic1111_params = {} AUTOMATIC1111_PARAMS = PersistentConfig( - "AUTOMATIC1111_PARAMS", - "image_generation.automatic1111.api_params", + 'AUTOMATIC1111_PARAMS', + 'image_generation.automatic1111.api_params', automatic1111_params, ) COMFYUI_BASE_URL = PersistentConfig( - "COMFYUI_BASE_URL", - "image_generation.comfyui.base_url", - os.getenv("COMFYUI_BASE_URL", ""), + 'COMFYUI_BASE_URL', + 'image_generation.comfyui.base_url', + os.getenv('COMFYUI_BASE_URL', ''), ) COMFYUI_API_KEY = PersistentConfig( - "COMFYUI_API_KEY", - "image_generation.comfyui.api_key", - os.getenv("COMFYUI_API_KEY", ""), + 'COMFYUI_API_KEY', + 'image_generation.comfyui.api_key', + os.getenv('COMFYUI_API_KEY', ''), ) COMFYUI_DEFAULT_WORKFLOW = """ @@ -3883,41 +3581,41 @@ COMFYUI_DEFAULT_WORKFLOW = """ COMFYUI_WORKFLOW = PersistentConfig( - "COMFYUI_WORKFLOW", - "image_generation.comfyui.workflow", - os.getenv("COMFYUI_WORKFLOW", COMFYUI_DEFAULT_WORKFLOW), + 'COMFYUI_WORKFLOW', + 'image_generation.comfyui.workflow', + os.getenv('COMFYUI_WORKFLOW', COMFYUI_DEFAULT_WORKFLOW), ) -comfyui_workflow_nodes = os.getenv("COMFYUI_WORKFLOW_NODES", "") +comfyui_workflow_nodes = os.getenv('COMFYUI_WORKFLOW_NODES', '') try: comfyui_workflow_nodes = json.loads(comfyui_workflow_nodes) except json.JSONDecodeError: comfyui_workflow_nodes = [] COMFYUI_WORKFLOW_NODES = PersistentConfig( - "COMFYUI_WORKFLOW_NODES", - "image_generation.comfyui.nodes", + 'COMFYUI_WORKFLOW_NODES', + 'image_generation.comfyui.nodes', comfyui_workflow_nodes, ) IMAGES_OPENAI_API_BASE_URL = PersistentConfig( - "IMAGES_OPENAI_API_BASE_URL", - "image_generation.openai.api_base_url", - os.getenv("IMAGES_OPENAI_API_BASE_URL", OPENAI_API_BASE_URL), + 'IMAGES_OPENAI_API_BASE_URL', + 'image_generation.openai.api_base_url', + os.getenv('IMAGES_OPENAI_API_BASE_URL', OPENAI_API_BASE_URL), ) IMAGES_OPENAI_API_VERSION = PersistentConfig( - "IMAGES_OPENAI_API_VERSION", - "image_generation.openai.api_version", - os.getenv("IMAGES_OPENAI_API_VERSION", ""), + 'IMAGES_OPENAI_API_VERSION', + 'image_generation.openai.api_version', + os.getenv('IMAGES_OPENAI_API_VERSION', ''), ) IMAGES_OPENAI_API_KEY = PersistentConfig( - "IMAGES_OPENAI_API_KEY", - "image_generation.openai.api_key", - os.getenv("IMAGES_OPENAI_API_KEY", OPENAI_API_KEY), + 'IMAGES_OPENAI_API_KEY', + 'image_generation.openai.api_key', + os.getenv('IMAGES_OPENAI_API_KEY', OPENAI_API_KEY), ) -images_openai_params = os.getenv("IMAGES_OPENAI_PARAMS", "") +images_openai_params = os.getenv('IMAGES_OPENAI_PARAMS', '') try: images_openai_params = json.loads(images_openai_params) except json.JSONDecodeError: @@ -3925,104 +3623,102 @@ except json.JSONDecodeError: IMAGES_OPENAI_API_PARAMS = PersistentConfig( - "IMAGES_OPENAI_API_PARAMS", "image_generation.openai.params", images_openai_params + 'IMAGES_OPENAI_API_PARAMS', 'image_generation.openai.params', images_openai_params ) IMAGES_GEMINI_API_BASE_URL = PersistentConfig( - "IMAGES_GEMINI_API_BASE_URL", - "image_generation.gemini.api_base_url", - os.getenv("IMAGES_GEMINI_API_BASE_URL", GEMINI_API_BASE_URL), + 'IMAGES_GEMINI_API_BASE_URL', + 'image_generation.gemini.api_base_url', + os.getenv('IMAGES_GEMINI_API_BASE_URL', GEMINI_API_BASE_URL), ) IMAGES_GEMINI_API_KEY = PersistentConfig( - "IMAGES_GEMINI_API_KEY", - "image_generation.gemini.api_key", - os.getenv("IMAGES_GEMINI_API_KEY", GEMINI_API_KEY), + 'IMAGES_GEMINI_API_KEY', + 'image_generation.gemini.api_key', + os.getenv('IMAGES_GEMINI_API_KEY', GEMINI_API_KEY), ) IMAGES_GEMINI_ENDPOINT_METHOD = PersistentConfig( - "IMAGES_GEMINI_ENDPOINT_METHOD", - "image_generation.gemini.endpoint_method", - os.getenv("IMAGES_GEMINI_ENDPOINT_METHOD", ""), + 'IMAGES_GEMINI_ENDPOINT_METHOD', + 'image_generation.gemini.endpoint_method', + os.getenv('IMAGES_GEMINI_ENDPOINT_METHOD', ''), ) ENABLE_IMAGE_EDIT = PersistentConfig( - "ENABLE_IMAGE_EDIT", - "images.edit.enable", - os.environ.get("ENABLE_IMAGE_EDIT", "").lower() == "true", + 'ENABLE_IMAGE_EDIT', + 'images.edit.enable', + os.environ.get('ENABLE_IMAGE_EDIT', '').lower() == 'true', ) IMAGE_EDIT_ENGINE = PersistentConfig( - "IMAGE_EDIT_ENGINE", - "images.edit.engine", - os.getenv("IMAGE_EDIT_ENGINE", "openai"), + 'IMAGE_EDIT_ENGINE', + 'images.edit.engine', + os.getenv('IMAGE_EDIT_ENGINE', 'openai'), ) IMAGE_EDIT_MODEL = PersistentConfig( - "IMAGE_EDIT_MODEL", - "images.edit.model", - os.getenv("IMAGE_EDIT_MODEL", ""), + 'IMAGE_EDIT_MODEL', + 'images.edit.model', + os.getenv('IMAGE_EDIT_MODEL', ''), ) -IMAGE_EDIT_SIZE = PersistentConfig( - "IMAGE_EDIT_SIZE", "images.edit.size", os.getenv("IMAGE_EDIT_SIZE", "") -) +IMAGE_EDIT_SIZE = PersistentConfig('IMAGE_EDIT_SIZE', 'images.edit.size', os.getenv('IMAGE_EDIT_SIZE', '')) IMAGES_EDIT_OPENAI_API_BASE_URL = PersistentConfig( - "IMAGES_EDIT_OPENAI_API_BASE_URL", - "images.edit.openai.api_base_url", - os.getenv("IMAGES_EDIT_OPENAI_API_BASE_URL", OPENAI_API_BASE_URL), + 'IMAGES_EDIT_OPENAI_API_BASE_URL', + 'images.edit.openai.api_base_url', + os.getenv('IMAGES_EDIT_OPENAI_API_BASE_URL', OPENAI_API_BASE_URL), ) IMAGES_EDIT_OPENAI_API_VERSION = PersistentConfig( - "IMAGES_EDIT_OPENAI_API_VERSION", - "images.edit.openai.api_version", - os.getenv("IMAGES_EDIT_OPENAI_API_VERSION", ""), + 'IMAGES_EDIT_OPENAI_API_VERSION', + 'images.edit.openai.api_version', + os.getenv('IMAGES_EDIT_OPENAI_API_VERSION', ''), ) IMAGES_EDIT_OPENAI_API_KEY = PersistentConfig( - "IMAGES_EDIT_OPENAI_API_KEY", - "images.edit.openai.api_key", - os.getenv("IMAGES_EDIT_OPENAI_API_KEY", OPENAI_API_KEY), + 'IMAGES_EDIT_OPENAI_API_KEY', + 'images.edit.openai.api_key', + os.getenv('IMAGES_EDIT_OPENAI_API_KEY', OPENAI_API_KEY), ) IMAGES_EDIT_GEMINI_API_BASE_URL = PersistentConfig( - "IMAGES_EDIT_GEMINI_API_BASE_URL", - "images.edit.gemini.api_base_url", - os.getenv("IMAGES_EDIT_GEMINI_API_BASE_URL", GEMINI_API_BASE_URL), + 'IMAGES_EDIT_GEMINI_API_BASE_URL', + 'images.edit.gemini.api_base_url', + os.getenv('IMAGES_EDIT_GEMINI_API_BASE_URL', GEMINI_API_BASE_URL), ) IMAGES_EDIT_GEMINI_API_KEY = PersistentConfig( - "IMAGES_EDIT_GEMINI_API_KEY", - "images.edit.gemini.api_key", - os.getenv("IMAGES_EDIT_GEMINI_API_KEY", GEMINI_API_KEY), + 'IMAGES_EDIT_GEMINI_API_KEY', + 'images.edit.gemini.api_key', + os.getenv('IMAGES_EDIT_GEMINI_API_KEY', GEMINI_API_KEY), ) IMAGES_EDIT_COMFYUI_BASE_URL = PersistentConfig( - "IMAGES_EDIT_COMFYUI_BASE_URL", - "images.edit.comfyui.base_url", - os.getenv("IMAGES_EDIT_COMFYUI_BASE_URL", ""), + 'IMAGES_EDIT_COMFYUI_BASE_URL', + 'images.edit.comfyui.base_url', + os.getenv('IMAGES_EDIT_COMFYUI_BASE_URL', ''), ) IMAGES_EDIT_COMFYUI_API_KEY = PersistentConfig( - "IMAGES_EDIT_COMFYUI_API_KEY", - "images.edit.comfyui.api_key", - os.getenv("IMAGES_EDIT_COMFYUI_API_KEY", ""), + 'IMAGES_EDIT_COMFYUI_API_KEY', + 'images.edit.comfyui.api_key', + os.getenv('IMAGES_EDIT_COMFYUI_API_KEY', ''), ) IMAGES_EDIT_COMFYUI_WORKFLOW = PersistentConfig( - "IMAGES_EDIT_COMFYUI_WORKFLOW", - "images.edit.comfyui.workflow", - os.getenv("IMAGES_EDIT_COMFYUI_WORKFLOW", ""), + 'IMAGES_EDIT_COMFYUI_WORKFLOW', + 'images.edit.comfyui.workflow', + os.getenv('IMAGES_EDIT_COMFYUI_WORKFLOW', ''), ) -images_edit_comfyui_workflow_nodes = os.getenv("IMAGES_EDIT_COMFYUI_WORKFLOW_NODES", "") +images_edit_comfyui_workflow_nodes = os.getenv('IMAGES_EDIT_COMFYUI_WORKFLOW_NODES', '') try: images_edit_comfyui_workflow_nodes = json.loads(images_edit_comfyui_workflow_nodes) except json.JSONDecodeError: images_edit_comfyui_workflow_nodes = [] IMAGES_EDIT_COMFYUI_WORKFLOW_NODES = PersistentConfig( - "IMAGES_EDIT_COMFYUI_WORKFLOW_NODES", - "images.edit.comfyui.nodes", + 'IMAGES_EDIT_COMFYUI_WORKFLOW_NODES', + 'images.edit.comfyui.nodes', images_edit_comfyui_workflow_nodes, ) @@ -4032,193 +3728,184 @@ IMAGES_EDIT_COMFYUI_WORKFLOW_NODES = PersistentConfig( # Transcription WHISPER_MODEL = PersistentConfig( - "WHISPER_MODEL", - "audio.stt.whisper_model", - os.getenv("WHISPER_MODEL", "base"), + 'WHISPER_MODEL', + 'audio.stt.whisper_model', + os.getenv('WHISPER_MODEL', 'base'), ) -WHISPER_COMPUTE_TYPE = os.getenv("WHISPER_COMPUTE_TYPE", "int8") -WHISPER_MODEL_DIR = os.getenv("WHISPER_MODEL_DIR", f"{CACHE_DIR}/whisper/models") -WHISPER_MODEL_AUTO_UPDATE = ( - not OFFLINE_MODE - and os.environ.get("WHISPER_MODEL_AUTO_UPDATE", "").lower() == "true" -) +WHISPER_COMPUTE_TYPE = os.getenv('WHISPER_COMPUTE_TYPE', 'int8') +WHISPER_MODEL_DIR = os.getenv('WHISPER_MODEL_DIR', f'{CACHE_DIR}/whisper/models') +WHISPER_MODEL_AUTO_UPDATE = not OFFLINE_MODE and os.environ.get('WHISPER_MODEL_AUTO_UPDATE', '').lower() == 'true' -WHISPER_VAD_FILTER = os.getenv("WHISPER_VAD_FILTER", "False").lower() == "true" +WHISPER_VAD_FILTER = os.getenv('WHISPER_VAD_FILTER', 'False').lower() == 'true' -WHISPER_MULTILINGUAL = os.getenv("WHISPER_MULTILINGUAL", "False").lower() == "true" +WHISPER_MULTILINGUAL = os.getenv('WHISPER_MULTILINGUAL', 'False').lower() == 'true' -WHISPER_LANGUAGE = os.getenv("WHISPER_LANGUAGE", "").lower() or None +WHISPER_LANGUAGE = os.getenv('WHISPER_LANGUAGE', '').lower() or None # Add Deepgram configuration DEEPGRAM_API_KEY = PersistentConfig( - "DEEPGRAM_API_KEY", - "audio.stt.deepgram.api_key", - os.getenv("DEEPGRAM_API_KEY", ""), + 'DEEPGRAM_API_KEY', + 'audio.stt.deepgram.api_key', + os.getenv('DEEPGRAM_API_KEY', ''), ) # ElevenLabs configuration -ELEVENLABS_API_BASE_URL = os.getenv( - "ELEVENLABS_API_BASE_URL", "https://api.elevenlabs.io" -) +ELEVENLABS_API_BASE_URL = os.getenv('ELEVENLABS_API_BASE_URL', 'https://api.elevenlabs.io') AUDIO_STT_OPENAI_API_BASE_URL = PersistentConfig( - "AUDIO_STT_OPENAI_API_BASE_URL", - "audio.stt.openai.api_base_url", - os.getenv("AUDIO_STT_OPENAI_API_BASE_URL", OPENAI_API_BASE_URL), + 'AUDIO_STT_OPENAI_API_BASE_URL', + 'audio.stt.openai.api_base_url', + os.getenv('AUDIO_STT_OPENAI_API_BASE_URL', OPENAI_API_BASE_URL), ) AUDIO_STT_OPENAI_API_KEY = PersistentConfig( - "AUDIO_STT_OPENAI_API_KEY", - "audio.stt.openai.api_key", - os.getenv("AUDIO_STT_OPENAI_API_KEY", OPENAI_API_KEY), + 'AUDIO_STT_OPENAI_API_KEY', + 'audio.stt.openai.api_key', + os.getenv('AUDIO_STT_OPENAI_API_KEY', OPENAI_API_KEY), ) AUDIO_STT_ENGINE = PersistentConfig( - "AUDIO_STT_ENGINE", - "audio.stt.engine", - os.getenv("AUDIO_STT_ENGINE", ""), + 'AUDIO_STT_ENGINE', + 'audio.stt.engine', + os.getenv('AUDIO_STT_ENGINE', ''), ) AUDIO_STT_MODEL = PersistentConfig( - "AUDIO_STT_MODEL", - "audio.stt.model", - os.getenv("AUDIO_STT_MODEL", ""), + 'AUDIO_STT_MODEL', + 'audio.stt.model', + os.getenv('AUDIO_STT_MODEL', ''), ) AUDIO_STT_SUPPORTED_CONTENT_TYPES = PersistentConfig( - "AUDIO_STT_SUPPORTED_CONTENT_TYPES", - "audio.stt.supported_content_types", + 'AUDIO_STT_SUPPORTED_CONTENT_TYPES', + 'audio.stt.supported_content_types', [ content_type.strip() - for content_type in os.environ.get( - "AUDIO_STT_SUPPORTED_CONTENT_TYPES", "" - ).split(",") + for content_type in os.environ.get('AUDIO_STT_SUPPORTED_CONTENT_TYPES', '').split(',') if content_type.strip() ], ) AUDIO_STT_AZURE_API_KEY = PersistentConfig( - "AUDIO_STT_AZURE_API_KEY", - "audio.stt.azure.api_key", - os.getenv("AUDIO_STT_AZURE_API_KEY", ""), + 'AUDIO_STT_AZURE_API_KEY', + 'audio.stt.azure.api_key', + os.getenv('AUDIO_STT_AZURE_API_KEY', ''), ) AUDIO_STT_AZURE_REGION = PersistentConfig( - "AUDIO_STT_AZURE_REGION", - "audio.stt.azure.region", - os.getenv("AUDIO_STT_AZURE_REGION", ""), + 'AUDIO_STT_AZURE_REGION', + 'audio.stt.azure.region', + os.getenv('AUDIO_STT_AZURE_REGION', ''), ) AUDIO_STT_AZURE_LOCALES = PersistentConfig( - "AUDIO_STT_AZURE_LOCALES", - "audio.stt.azure.locales", - os.getenv("AUDIO_STT_AZURE_LOCALES", ""), + 'AUDIO_STT_AZURE_LOCALES', + 'audio.stt.azure.locales', + os.getenv('AUDIO_STT_AZURE_LOCALES', ''), ) AUDIO_STT_AZURE_BASE_URL = PersistentConfig( - "AUDIO_STT_AZURE_BASE_URL", - "audio.stt.azure.base_url", - os.getenv("AUDIO_STT_AZURE_BASE_URL", ""), + 'AUDIO_STT_AZURE_BASE_URL', + 'audio.stt.azure.base_url', + os.getenv('AUDIO_STT_AZURE_BASE_URL', ''), ) AUDIO_STT_AZURE_MAX_SPEAKERS = PersistentConfig( - "AUDIO_STT_AZURE_MAX_SPEAKERS", - "audio.stt.azure.max_speakers", - os.getenv("AUDIO_STT_AZURE_MAX_SPEAKERS", ""), + 'AUDIO_STT_AZURE_MAX_SPEAKERS', + 'audio.stt.azure.max_speakers', + os.getenv('AUDIO_STT_AZURE_MAX_SPEAKERS', ''), ) AUDIO_STT_MISTRAL_API_KEY = PersistentConfig( - "AUDIO_STT_MISTRAL_API_KEY", - "audio.stt.mistral.api_key", - os.getenv("AUDIO_STT_MISTRAL_API_KEY", ""), + 'AUDIO_STT_MISTRAL_API_KEY', + 'audio.stt.mistral.api_key', + os.getenv('AUDIO_STT_MISTRAL_API_KEY', ''), ) AUDIO_STT_MISTRAL_API_BASE_URL = PersistentConfig( - "AUDIO_STT_MISTRAL_API_BASE_URL", - "audio.stt.mistral.api_base_url", - os.getenv("AUDIO_STT_MISTRAL_API_BASE_URL", "https://api.mistral.ai/v1"), + 'AUDIO_STT_MISTRAL_API_BASE_URL', + 'audio.stt.mistral.api_base_url', + os.getenv('AUDIO_STT_MISTRAL_API_BASE_URL', 'https://api.mistral.ai/v1'), ) AUDIO_STT_MISTRAL_USE_CHAT_COMPLETIONS = PersistentConfig( - "AUDIO_STT_MISTRAL_USE_CHAT_COMPLETIONS", - "audio.stt.mistral.use_chat_completions", - os.getenv("AUDIO_STT_MISTRAL_USE_CHAT_COMPLETIONS", "false").lower() == "true", + 'AUDIO_STT_MISTRAL_USE_CHAT_COMPLETIONS', + 'audio.stt.mistral.use_chat_completions', + os.getenv('AUDIO_STT_MISTRAL_USE_CHAT_COMPLETIONS', 'false').lower() == 'true', ) AUDIO_TTS_OPENAI_API_BASE_URL = PersistentConfig( - "AUDIO_TTS_OPENAI_API_BASE_URL", - "audio.tts.openai.api_base_url", - os.getenv("AUDIO_TTS_OPENAI_API_BASE_URL", OPENAI_API_BASE_URL), + 'AUDIO_TTS_OPENAI_API_BASE_URL', + 'audio.tts.openai.api_base_url', + os.getenv('AUDIO_TTS_OPENAI_API_BASE_URL', OPENAI_API_BASE_URL), ) AUDIO_TTS_OPENAI_API_KEY = PersistentConfig( - "AUDIO_TTS_OPENAI_API_KEY", - "audio.tts.openai.api_key", - os.getenv("AUDIO_TTS_OPENAI_API_KEY", OPENAI_API_KEY), + 'AUDIO_TTS_OPENAI_API_KEY', + 'audio.tts.openai.api_key', + os.getenv('AUDIO_TTS_OPENAI_API_KEY', OPENAI_API_KEY), ) -audio_tts_openai_params = os.getenv("AUDIO_TTS_OPENAI_PARAMS", "") +audio_tts_openai_params = os.getenv('AUDIO_TTS_OPENAI_PARAMS', '') try: audio_tts_openai_params = json.loads(audio_tts_openai_params) except json.JSONDecodeError: audio_tts_openai_params = {} AUDIO_TTS_OPENAI_PARAMS = PersistentConfig( - "AUDIO_TTS_OPENAI_PARAMS", - "audio.tts.openai.params", + 'AUDIO_TTS_OPENAI_PARAMS', + 'audio.tts.openai.params', audio_tts_openai_params, ) AUDIO_TTS_API_KEY = PersistentConfig( - "AUDIO_TTS_API_KEY", - "audio.tts.api_key", - os.getenv("AUDIO_TTS_API_KEY", ""), + 'AUDIO_TTS_API_KEY', + 'audio.tts.api_key', + os.getenv('AUDIO_TTS_API_KEY', ''), ) AUDIO_TTS_ENGINE = PersistentConfig( - "AUDIO_TTS_ENGINE", - "audio.tts.engine", - os.getenv("AUDIO_TTS_ENGINE", ""), + 'AUDIO_TTS_ENGINE', + 'audio.tts.engine', + os.getenv('AUDIO_TTS_ENGINE', ''), ) AUDIO_TTS_MODEL = PersistentConfig( - "AUDIO_TTS_MODEL", - "audio.tts.model", - os.getenv("AUDIO_TTS_MODEL", "tts-1"), # OpenAI default model + 'AUDIO_TTS_MODEL', + 'audio.tts.model', + os.getenv('AUDIO_TTS_MODEL', 'tts-1'), # OpenAI default model ) AUDIO_TTS_VOICE = PersistentConfig( - "AUDIO_TTS_VOICE", - "audio.tts.voice", - os.getenv("AUDIO_TTS_VOICE", "alloy"), # OpenAI default voice + 'AUDIO_TTS_VOICE', + 'audio.tts.voice', + os.getenv('AUDIO_TTS_VOICE', 'alloy'), # OpenAI default voice ) AUDIO_TTS_SPLIT_ON = PersistentConfig( - "AUDIO_TTS_SPLIT_ON", - "audio.tts.split_on", - os.getenv("AUDIO_TTS_SPLIT_ON", "punctuation"), + 'AUDIO_TTS_SPLIT_ON', + 'audio.tts.split_on', + os.getenv('AUDIO_TTS_SPLIT_ON', 'punctuation'), ) AUDIO_TTS_AZURE_SPEECH_REGION = PersistentConfig( - "AUDIO_TTS_AZURE_SPEECH_REGION", - "audio.tts.azure.speech_region", - os.getenv("AUDIO_TTS_AZURE_SPEECH_REGION", ""), + 'AUDIO_TTS_AZURE_SPEECH_REGION', + 'audio.tts.azure.speech_region', + os.getenv('AUDIO_TTS_AZURE_SPEECH_REGION', ''), ) AUDIO_TTS_AZURE_SPEECH_BASE_URL = PersistentConfig( - "AUDIO_TTS_AZURE_SPEECH_BASE_URL", - "audio.tts.azure.speech_base_url", - os.getenv("AUDIO_TTS_AZURE_SPEECH_BASE_URL", ""), + 'AUDIO_TTS_AZURE_SPEECH_BASE_URL', + 'audio.tts.azure.speech_base_url', + os.getenv('AUDIO_TTS_AZURE_SPEECH_BASE_URL', ''), ) AUDIO_TTS_AZURE_SPEECH_OUTPUT_FORMAT = PersistentConfig( - "AUDIO_TTS_AZURE_SPEECH_OUTPUT_FORMAT", - "audio.tts.azure.speech_output_format", - os.getenv( - "AUDIO_TTS_AZURE_SPEECH_OUTPUT_FORMAT", "audio-24khz-160kbitrate-mono-mp3" - ), + 'AUDIO_TTS_AZURE_SPEECH_OUTPUT_FORMAT', + 'audio.tts.azure.speech_output_format', + os.getenv('AUDIO_TTS_AZURE_SPEECH_OUTPUT_FORMAT', 'audio-24khz-160kbitrate-mono-mp3'), ) @@ -4227,98 +3914,92 @@ AUDIO_TTS_AZURE_SPEECH_OUTPUT_FORMAT = PersistentConfig( #################################### ENABLE_LDAP = PersistentConfig( - "ENABLE_LDAP", - "ldap.enable", - os.environ.get("ENABLE_LDAP", "false").lower() == "true", + 'ENABLE_LDAP', + 'ldap.enable', + os.environ.get('ENABLE_LDAP', 'false').lower() == 'true', ) LDAP_SERVER_LABEL = PersistentConfig( - "LDAP_SERVER_LABEL", - "ldap.server.label", - os.environ.get("LDAP_SERVER_LABEL", "LDAP Server"), + 'LDAP_SERVER_LABEL', + 'ldap.server.label', + os.environ.get('LDAP_SERVER_LABEL', 'LDAP Server'), ) LDAP_SERVER_HOST = PersistentConfig( - "LDAP_SERVER_HOST", - "ldap.server.host", - os.environ.get("LDAP_SERVER_HOST", "localhost"), + 'LDAP_SERVER_HOST', + 'ldap.server.host', + os.environ.get('LDAP_SERVER_HOST', 'localhost'), ) LDAP_SERVER_PORT = PersistentConfig( - "LDAP_SERVER_PORT", - "ldap.server.port", - int(os.environ.get("LDAP_SERVER_PORT", "389")), + 'LDAP_SERVER_PORT', + 'ldap.server.port', + int(os.environ.get('LDAP_SERVER_PORT', '389')), ) LDAP_ATTRIBUTE_FOR_MAIL = PersistentConfig( - "LDAP_ATTRIBUTE_FOR_MAIL", - "ldap.server.attribute_for_mail", - os.environ.get("LDAP_ATTRIBUTE_FOR_MAIL", "mail"), + 'LDAP_ATTRIBUTE_FOR_MAIL', + 'ldap.server.attribute_for_mail', + os.environ.get('LDAP_ATTRIBUTE_FOR_MAIL', 'mail'), ) LDAP_ATTRIBUTE_FOR_USERNAME = PersistentConfig( - "LDAP_ATTRIBUTE_FOR_USERNAME", - "ldap.server.attribute_for_username", - os.environ.get("LDAP_ATTRIBUTE_FOR_USERNAME", "uid"), + 'LDAP_ATTRIBUTE_FOR_USERNAME', + 'ldap.server.attribute_for_username', + os.environ.get('LDAP_ATTRIBUTE_FOR_USERNAME', 'uid'), ) -LDAP_APP_DN = PersistentConfig( - "LDAP_APP_DN", "ldap.server.app_dn", os.environ.get("LDAP_APP_DN", "") -) +LDAP_APP_DN = PersistentConfig('LDAP_APP_DN', 'ldap.server.app_dn', os.environ.get('LDAP_APP_DN', '')) LDAP_APP_PASSWORD = PersistentConfig( - "LDAP_APP_PASSWORD", - "ldap.server.app_password", - os.environ.get("LDAP_APP_PASSWORD", ""), + 'LDAP_APP_PASSWORD', + 'ldap.server.app_password', + os.environ.get('LDAP_APP_PASSWORD', ''), ) -LDAP_SEARCH_BASE = PersistentConfig( - "LDAP_SEARCH_BASE", "ldap.server.users_dn", os.environ.get("LDAP_SEARCH_BASE", "") -) +LDAP_SEARCH_BASE = PersistentConfig('LDAP_SEARCH_BASE', 'ldap.server.users_dn', os.environ.get('LDAP_SEARCH_BASE', '')) LDAP_SEARCH_FILTERS = PersistentConfig( - "LDAP_SEARCH_FILTER", - "ldap.server.search_filter", - os.environ.get("LDAP_SEARCH_FILTER", os.environ.get("LDAP_SEARCH_FILTERS", "")), + 'LDAP_SEARCH_FILTER', + 'ldap.server.search_filter', + os.environ.get('LDAP_SEARCH_FILTER', os.environ.get('LDAP_SEARCH_FILTERS', '')), ) LDAP_USE_TLS = PersistentConfig( - "LDAP_USE_TLS", - "ldap.server.use_tls", - os.environ.get("LDAP_USE_TLS", "True").lower() == "true", + 'LDAP_USE_TLS', + 'ldap.server.use_tls', + os.environ.get('LDAP_USE_TLS', 'True').lower() == 'true', ) LDAP_CA_CERT_FILE = PersistentConfig( - "LDAP_CA_CERT_FILE", - "ldap.server.ca_cert_file", - os.environ.get("LDAP_CA_CERT_FILE", ""), + 'LDAP_CA_CERT_FILE', + 'ldap.server.ca_cert_file', + os.environ.get('LDAP_CA_CERT_FILE', ''), ) LDAP_VALIDATE_CERT = PersistentConfig( - "LDAP_VALIDATE_CERT", - "ldap.server.validate_cert", - os.environ.get("LDAP_VALIDATE_CERT", "True").lower() == "true", + 'LDAP_VALIDATE_CERT', + 'ldap.server.validate_cert', + os.environ.get('LDAP_VALIDATE_CERT', 'True').lower() == 'true', ) -LDAP_CIPHERS = PersistentConfig( - "LDAP_CIPHERS", "ldap.server.ciphers", os.environ.get("LDAP_CIPHERS", "ALL") -) +LDAP_CIPHERS = PersistentConfig('LDAP_CIPHERS', 'ldap.server.ciphers', os.environ.get('LDAP_CIPHERS', 'ALL')) # For LDAP Group Management ENABLE_LDAP_GROUP_MANAGEMENT = PersistentConfig( - "ENABLE_LDAP_GROUP_MANAGEMENT", - "ldap.group.enable_management", - os.environ.get("ENABLE_LDAP_GROUP_MANAGEMENT", "False").lower() == "true", + 'ENABLE_LDAP_GROUP_MANAGEMENT', + 'ldap.group.enable_management', + os.environ.get('ENABLE_LDAP_GROUP_MANAGEMENT', 'False').lower() == 'true', ) ENABLE_LDAP_GROUP_CREATION = PersistentConfig( - "ENABLE_LDAP_GROUP_CREATION", - "ldap.group.enable_creation", - os.environ.get("ENABLE_LDAP_GROUP_CREATION", "False").lower() == "true", + 'ENABLE_LDAP_GROUP_CREATION', + 'ldap.group.enable_creation', + os.environ.get('ENABLE_LDAP_GROUP_CREATION', 'False').lower() == 'true', ) LDAP_ATTRIBUTE_FOR_GROUPS = PersistentConfig( - "LDAP_ATTRIBUTE_FOR_GROUPS", - "ldap.server.attribute_for_groups", - os.environ.get("LDAP_ATTRIBUTE_FOR_GROUPS", "memberOf"), + 'LDAP_ATTRIBUTE_FOR_GROUPS', + 'ldap.server.attribute_for_groups', + os.environ.get('LDAP_ATTRIBUTE_FOR_GROUPS', 'memberOf'), ) diff --git a/backend/open_webui/constants.py b/backend/open_webui/constants.py index 4d39d16cdb..ec1d0c6047 100644 --- a/backend/open_webui/constants.py +++ b/backend/open_webui/constants.py @@ -2,125 +2,107 @@ from enum import Enum class MESSAGES(str, Enum): - DEFAULT = lambda msg="": f"{msg if msg else ''}" - MODEL_ADDED = lambda model="": f"The model '{model}' has been added successfully." - MODEL_DELETED = ( - lambda model="": f"The model '{model}' has been deleted successfully." - ) + DEFAULT = lambda msg='': f'{msg if msg else ""}' + MODEL_ADDED = lambda model='': f"The model '{model}' has been added successfully." + MODEL_DELETED = lambda model='': f"The model '{model}' has been deleted successfully." class WEBHOOK_MESSAGES(str, Enum): - DEFAULT = lambda msg="": f"{msg if msg else ''}" - USER_SIGNUP = lambda username="": ( - f"New user signed up: {username}" if username else "New user signed up" - ) + DEFAULT = lambda msg='': f'{msg if msg else ""}' + USER_SIGNUP = lambda username='': (f'New user signed up: {username}' if username else 'New user signed up') class ERROR_MESSAGES(str, Enum): def __str__(self) -> str: return super().__str__() - DEFAULT = ( - lambda err="": f'{"Something went wrong :/" if err == "" else "[ERROR: " + str(err) + "]"}' + DEFAULT = lambda err='': f'{"Something went wrong :/" if err == "" else "[ERROR: " + str(err) + "]"}' + ENV_VAR_NOT_FOUND = 'Required environment variable not found. Terminating now.' + CREATE_USER_ERROR = 'Oops! Something went wrong while creating your account. Please try again later. If the issue persists, contact support for assistance.' + DELETE_USER_ERROR = 'Oops! Something went wrong. We encountered an issue while trying to delete the user. Please give it another shot.' + EMAIL_MISMATCH = 'Uh-oh! This email does not match the email your provider is registered with. Please check your email and try again.' + EMAIL_TAKEN = 'Uh-oh! This email is already registered. Sign in with your existing account or choose another email to start anew.' + USERNAME_TAKEN = 'Uh-oh! This username is already registered. Please choose another username.' + PASSWORD_TOO_LONG = ( + 'Uh-oh! The password you entered is too long. Please make sure your password is less than 72 bytes long.' ) - ENV_VAR_NOT_FOUND = "Required environment variable not found. Terminating now." - CREATE_USER_ERROR = "Oops! Something went wrong while creating your account. Please try again later. If the issue persists, contact support for assistance." - DELETE_USER_ERROR = "Oops! Something went wrong. We encountered an issue while trying to delete the user. Please give it another shot." - EMAIL_MISMATCH = "Uh-oh! This email does not match the email your provider is registered with. Please check your email and try again." - EMAIL_TAKEN = "Uh-oh! This email is already registered. Sign in with your existing account or choose another email to start anew." - USERNAME_TAKEN = ( - "Uh-oh! This username is already registered. Please choose another username." - ) - PASSWORD_TOO_LONG = "Uh-oh! The password you entered is too long. Please make sure your password is less than 72 bytes long." - COMMAND_TAKEN = "Uh-oh! This command is already registered. Please choose another command string." - FILE_EXISTS = "Uh-oh! This file is already registered. Please choose another file." + COMMAND_TAKEN = 'Uh-oh! This command is already registered. Please choose another command string.' + FILE_EXISTS = 'Uh-oh! This file is already registered. Please choose another file.' - ID_TAKEN = "Uh-oh! This id is already registered. Please choose another id string." - MODEL_ID_TAKEN = "Uh-oh! This model id is already registered. Please choose another model id string." - NAME_TAG_TAKEN = "Uh-oh! This name tag is already registered. Please choose another name tag string." - MODEL_ID_TOO_LONG = "The model id is too long. Please make sure your model id is less than 256 characters long." + ID_TAKEN = 'Uh-oh! This id is already registered. Please choose another id string.' + MODEL_ID_TAKEN = 'Uh-oh! This model id is already registered. Please choose another model id string.' + NAME_TAG_TAKEN = 'Uh-oh! This name tag is already registered. Please choose another name tag string.' + MODEL_ID_TOO_LONG = 'The model id is too long. Please make sure your model id is less than 256 characters long.' - INVALID_TOKEN = ( - "Your session has expired or the token is invalid. Please sign in again." - ) - INVALID_CRED = "The email or password provided is incorrect. Please check for typos and try logging in again." + INVALID_TOKEN = 'Your session has expired or the token is invalid. Please sign in again.' + INVALID_CRED = 'The email or password provided is incorrect. Please check for typos and try logging in again.' INVALID_EMAIL_FORMAT = "The email format you entered is invalid. Please double-check and make sure you're using a valid email address (e.g., yourname@example.com)." - INCORRECT_PASSWORD = ( - "The password provided is incorrect. Please check for typos and try again." + INCORRECT_PASSWORD = 'The password provided is incorrect. Please check for typos and try again.' + INVALID_TRUSTED_HEADER = ( + 'Your provider has not provided a trusted header. Please contact your administrator for assistance.' ) - INVALID_TRUSTED_HEADER = "Your provider has not provided a trusted header. Please contact your administrator for assistance." EXISTING_USERS = "You can't turn off authentication because there are existing users. If you want to disable WEBUI_AUTH, make sure your web interface doesn't have any existing users and is a fresh installation." - UNAUTHORIZED = "401 Unauthorized" - ACCESS_PROHIBITED = "You do not have permission to access this resource. Please contact your administrator for assistance." - ACTION_PROHIBITED = ( - "The requested action has been restricted as a security measure." + UNAUTHORIZED = '401 Unauthorized' + ACCESS_PROHIBITED = ( + 'You do not have permission to access this resource. Please contact your administrator for assistance.' ) + ACTION_PROHIBITED = 'The requested action has been restricted as a security measure.' - FILE_NOT_SENT = "FILE_NOT_SENT" + FILE_NOT_SENT = 'FILE_NOT_SENT' FILE_NOT_SUPPORTED = "Oops! It seems like the file format you're trying to upload is not supported. Please upload a file with a supported format and try again." NOT_FOUND = "We could not find what you're looking for :/" USER_NOT_FOUND = "We could not find what you're looking for :/" API_KEY_NOT_FOUND = "Oops! It looks like there's a hiccup. The API key is missing. Please make sure to provide a valid API key to access this feature." - API_KEY_NOT_ALLOWED = "Use of API key is not enabled in the environment." + API_KEY_NOT_ALLOWED = 'Use of API key is not enabled in the environment.' - MALICIOUS = "Unusual activities detected, please try again in a few minutes." + MALICIOUS = 'Unusual activities detected, please try again in a few minutes.' - PANDOC_NOT_INSTALLED = "Pandoc is not installed on the server. Please contact your administrator for assistance." - INCORRECT_FORMAT = ( - lambda err="": f"Invalid format. Please use the correct format{err}" - ) - RATE_LIMIT_EXCEEDED = "API rate limit exceeded" + PANDOC_NOT_INSTALLED = 'Pandoc is not installed on the server. Please contact your administrator for assistance.' + INCORRECT_FORMAT = lambda err='': f'Invalid format. Please use the correct format{err}' + RATE_LIMIT_EXCEEDED = 'API rate limit exceeded' - MODEL_NOT_FOUND = lambda name="": f"Model '{name}' was not found" - OPENAI_NOT_FOUND = lambda name="": "OpenAI API was not found" - OLLAMA_NOT_FOUND = "WebUI could not connect to Ollama" - CREATE_API_KEY_ERROR = "Oops! Something went wrong while creating your API key. Please try again later. If the issue persists, contact support for assistance." - API_KEY_CREATION_NOT_ALLOWED = "API key creation is not allowed in the environment." + MODEL_NOT_FOUND = lambda name='': f"Model '{name}' was not found" + OPENAI_NOT_FOUND = lambda name='': 'OpenAI API was not found' + OLLAMA_NOT_FOUND = 'WebUI could not connect to Ollama' + CREATE_API_KEY_ERROR = 'Oops! Something went wrong while creating your API key. Please try again later. If the issue persists, contact support for assistance.' + API_KEY_CREATION_NOT_ALLOWED = 'API key creation is not allowed in the environment.' - EMPTY_CONTENT = "The content provided is empty. Please ensure that there is text or data present before proceeding." + EMPTY_CONTENT = 'The content provided is empty. Please ensure that there is text or data present before proceeding.' - DB_NOT_SQLITE = "This feature is only available when running with SQLite databases." + DB_NOT_SQLITE = 'This feature is only available when running with SQLite databases.' - INVALID_URL = ( - "Oops! The URL you provided is invalid. Please double-check and try again." - ) + INVALID_URL = 'Oops! The URL you provided is invalid. Please double-check and try again.' - WEB_SEARCH_ERROR = ( - lambda err="": f"{err if err else 'Oops! Something went wrong while searching the web.'}" - ) + WEB_SEARCH_ERROR = lambda err='': f'{err if err else "Oops! Something went wrong while searching the web."}' - OLLAMA_API_DISABLED = ( - "The Ollama API is disabled. Please enable it to use this feature." - ) + OLLAMA_API_DISABLED = 'The Ollama API is disabled. Please enable it to use this feature.' FILE_TOO_LARGE = ( - lambda size="": f"Oops! The file you're trying to upload is too large. Please upload a file that is less than {size}." + lambda size='': f"Oops! The file you're trying to upload is too large. Please upload a file that is less than {size}." ) - DUPLICATE_CONTENT = ( - "Duplicate content detected. Please provide unique content to proceed." + DUPLICATE_CONTENT = 'Duplicate content detected. Please provide unique content to proceed.' + FILE_NOT_PROCESSED = ( + 'Extracted content is not available for this file. Please ensure that the file is processed before proceeding.' ) - FILE_NOT_PROCESSED = "Extracted content is not available for this file. Please ensure that the file is processed before proceeding." - INVALID_PASSWORD = lambda err="": ( - err if err else "The password does not meet the required validation criteria." - ) + INVALID_PASSWORD = lambda err='': (err if err else 'The password does not meet the required validation criteria.') class TASKS(str, Enum): def __str__(self) -> str: return super().__str__() - DEFAULT = lambda task="": f"{task if task else 'generation'}" - TITLE_GENERATION = "title_generation" - FOLLOW_UP_GENERATION = "follow_up_generation" - TAGS_GENERATION = "tags_generation" - EMOJI_GENERATION = "emoji_generation" - QUERY_GENERATION = "query_generation" - IMAGE_PROMPT_GENERATION = "image_prompt_generation" - AUTOCOMPLETE_GENERATION = "autocomplete_generation" - FUNCTION_CALLING = "function_calling" - MOA_RESPONSE_GENERATION = "moa_response_generation" + DEFAULT = lambda task='': f'{task if task else "generation"}' + TITLE_GENERATION = 'title_generation' + FOLLOW_UP_GENERATION = 'follow_up_generation' + TAGS_GENERATION = 'tags_generation' + EMOJI_GENERATION = 'emoji_generation' + QUERY_GENERATION = 'query_generation' + IMAGE_PROMPT_GENERATION = 'image_prompt_generation' + AUTOCOMPLETE_GENERATION = 'autocomplete_generation' + FUNCTION_CALLING = 'function_calling' + MOA_RESPONSE_GENERATION = 'moa_response_generation' diff --git a/backend/open_webui/env.py b/backend/open_webui/env.py index f05f16591e..043fa3c6df 100644 --- a/backend/open_webui/env.py +++ b/backend/open_webui/env.py @@ -37,37 +37,34 @@ BASE_DIR = BACKEND_DIR.parent try: from dotenv import find_dotenv, load_dotenv - load_dotenv(find_dotenv(str(BASE_DIR / ".env"))) + load_dotenv(find_dotenv(str(BASE_DIR / '.env'))) except ImportError: - print("dotenv not installed, skipping...") + print('dotenv not installed, skipping...') -DOCKER = os.environ.get("DOCKER", "False").lower() == "true" +DOCKER = os.environ.get('DOCKER', 'False').lower() == 'true' # device type embedding models - "cpu" (default), "cuda" (nvidia gpu required) or "mps" (apple silicon) - choosing this right can lead to better performance -USE_CUDA = os.environ.get("USE_CUDA_DOCKER", "false") +USE_CUDA = os.environ.get('USE_CUDA_DOCKER', 'false') -if USE_CUDA.lower() == "true": +if USE_CUDA.lower() == 'true': try: import torch - assert torch.cuda.is_available(), "CUDA not available" - DEVICE_TYPE = "cuda" + assert torch.cuda.is_available(), 'CUDA not available' + DEVICE_TYPE = 'cuda' except Exception as e: - cuda_error = ( - "Error when testing CUDA but USE_CUDA_DOCKER is true. " - f"Resetting USE_CUDA_DOCKER to false: {e}" - ) - os.environ["USE_CUDA_DOCKER"] = "false" - USE_CUDA = "false" - DEVICE_TYPE = "cpu" + cuda_error = f'Error when testing CUDA but USE_CUDA_DOCKER is true. Resetting USE_CUDA_DOCKER to false: {e}' + os.environ['USE_CUDA_DOCKER'] = 'false' + USE_CUDA = 'false' + DEVICE_TYPE = 'cpu' else: - DEVICE_TYPE = "cpu" + DEVICE_TYPE = 'cpu' try: import torch if torch.backends.mps.is_available() and torch.backends.mps.is_built(): - DEVICE_TYPE = "mps" + DEVICE_TYPE = 'mps' except Exception: pass @@ -76,11 +73,11 @@ except Exception: #################################### _LEVEL_MAP = { - "DEBUG": "debug", - "INFO": "info", - "WARNING": "warn", - "ERROR": "error", - "CRITICAL": "fatal", + 'DEBUG': 'debug', + 'INFO': 'info', + 'WARNING': 'warn', + 'ERROR': 'error', + 'CRITICAL': 'fatal', } @@ -89,132 +86,128 @@ class JSONFormatter(logging.Formatter): def format(self, record: logging.LogRecord) -> str: log_entry: dict[str, Any] = { - "ts": datetime.fromtimestamp(record.created, tz=timezone.utc).isoformat( - timespec="milliseconds" - ), - "level": _LEVEL_MAP.get(record.levelname, record.levelname.lower()), - "msg": record.getMessage(), - "caller": record.name, + 'ts': datetime.fromtimestamp(record.created, tz=timezone.utc).isoformat(timespec='milliseconds'), + 'level': _LEVEL_MAP.get(record.levelname, record.levelname.lower()), + 'msg': record.getMessage(), + 'caller': record.name, } if record.exc_info and record.exc_info[0] is not None: - log_entry["error"] = "".join( - traceback.format_exception(*record.exc_info) - ).rstrip() + log_entry['error'] = ''.join(traceback.format_exception(*record.exc_info)).rstrip() elif record.exc_text: - log_entry["error"] = record.exc_text + log_entry['error'] = record.exc_text if record.stack_info: - log_entry["stacktrace"] = record.stack_info + log_entry['stacktrace'] = record.stack_info return json.dumps(log_entry, ensure_ascii=False, default=str) -LOG_FORMAT = os.environ.get("LOG_FORMAT", "").lower() +LOG_FORMAT = os.environ.get('LOG_FORMAT', '').lower() -GLOBAL_LOG_LEVEL = os.environ.get("GLOBAL_LOG_LEVEL", "").upper() +GLOBAL_LOG_LEVEL = os.environ.get('GLOBAL_LOG_LEVEL', '').upper() if GLOBAL_LOG_LEVEL in logging.getLevelNamesMapping(): - if LOG_FORMAT == "json": + if LOG_FORMAT == 'json': _handler = logging.StreamHandler(sys.stdout) _handler.setFormatter(JSONFormatter()) logging.basicConfig(handlers=[_handler], level=GLOBAL_LOG_LEVEL, force=True) else: logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL, force=True) else: - GLOBAL_LOG_LEVEL = "INFO" + GLOBAL_LOG_LEVEL = 'INFO' log = logging.getLogger(__name__) -log.info(f"GLOBAL_LOG_LEVEL: {GLOBAL_LOG_LEVEL}") +log.info(f'GLOBAL_LOG_LEVEL: {GLOBAL_LOG_LEVEL}') -if "cuda_error" in locals(): +if 'cuda_error' in locals(): log.exception(cuda_error) del cuda_error SRC_LOG_LEVELS = {} # Legacy variable, do not remove -WEBUI_NAME = os.environ.get("WEBUI_NAME", "Open WebUI") -if WEBUI_NAME != "Open WebUI": - WEBUI_NAME += " (Open WebUI)" +WEBUI_NAME = os.environ.get('WEBUI_NAME', 'Open WebUI') +if WEBUI_NAME != 'Open WebUI': + WEBUI_NAME += ' (Open WebUI)' -WEBUI_FAVICON_URL = "https://openwebui.com/favicon.png" +WEBUI_FAVICON_URL = 'https://openwebui.com/favicon.png' -TRUSTED_SIGNATURE_KEY = os.environ.get("TRUSTED_SIGNATURE_KEY", "") +TRUSTED_SIGNATURE_KEY = os.environ.get('TRUSTED_SIGNATURE_KEY', '') #################################### # ENV (dev,test,prod) #################################### -ENV = os.environ.get("ENV", "dev") +ENV = os.environ.get('ENV', 'dev') -FROM_INIT_PY = os.environ.get("FROM_INIT_PY", "False").lower() == "true" +FROM_INIT_PY = os.environ.get('FROM_INIT_PY', 'False').lower() == 'true' if FROM_INIT_PY: - PACKAGE_DATA = {"version": importlib.metadata.version("open-webui")} + PACKAGE_DATA = {'version': importlib.metadata.version('open-webui')} else: try: - PACKAGE_DATA = json.loads((BASE_DIR / "package.json").read_text()) + PACKAGE_DATA = json.loads((BASE_DIR / 'package.json').read_text()) except Exception: - PACKAGE_DATA = {"version": "0.0.0"} + PACKAGE_DATA = {'version': '0.0.0'} -VERSION = PACKAGE_DATA["version"] +VERSION = PACKAGE_DATA['version'] -DEPLOYMENT_ID = os.environ.get("DEPLOYMENT_ID", "") -INSTANCE_ID = os.environ.get("INSTANCE_ID", str(uuid4())) +DEPLOYMENT_ID = os.environ.get('DEPLOYMENT_ID', '') +INSTANCE_ID = os.environ.get('INSTANCE_ID', str(uuid4())) -ENABLE_DB_MIGRATIONS = os.environ.get("ENABLE_DB_MIGRATIONS", "True").lower() == "true" +ENABLE_DB_MIGRATIONS = os.environ.get('ENABLE_DB_MIGRATIONS', 'True').lower() == 'true' # Function to parse each section def parse_section(section): items = [] - for li in section.find_all("li"): + for li in section.find_all('li'): # Extract raw HTML string raw_html = str(li) # Extract text without HTML tags - text = li.get_text(separator=" ", strip=True) + text = li.get_text(separator=' ', strip=True) # Split into title and content - parts = text.split(": ", 1) - title = parts[0].strip() if len(parts) > 1 else "" + parts = text.split(': ', 1) + title = parts[0].strip() if len(parts) > 1 else '' content = parts[1].strip() if len(parts) > 1 else text - items.append({"title": title, "content": content, "raw": raw_html}) + items.append({'title': title, 'content': content, 'raw': raw_html}) return items try: - changelog_path = BASE_DIR / "CHANGELOG.md" - with open(str(changelog_path.absolute()), "r", encoding="utf8") as file: + changelog_path = BASE_DIR / 'CHANGELOG.md' + with open(str(changelog_path.absolute()), 'r', encoding='utf8') as file: changelog_content = file.read() except Exception: - changelog_content = (pkgutil.get_data("open_webui", "CHANGELOG.md") or b"").decode() + changelog_content = (pkgutil.get_data('open_webui', 'CHANGELOG.md') or b'').decode() # Convert markdown content to HTML html_content = markdown.markdown(changelog_content) # Parse the HTML content -soup = BeautifulSoup(html_content, "html.parser") +soup = BeautifulSoup(html_content, 'html.parser') # Initialize JSON structure changelog_json = {} # Iterate over each version -for version in soup.find_all("h2"): - version_number = version.get_text().strip().split(" - ")[0][1:-1] # Remove brackets - date = version.get_text().strip().split(" - ")[1] +for version in soup.find_all('h2'): + version_number = version.get_text().strip().split(' - ')[0][1:-1] # Remove brackets + date = version.get_text().strip().split(' - ')[1] - version_data = {"date": date} + version_data = {'date': date} # Find the next sibling that is a h3 tag (section title) current = version.find_next_sibling() - while current and current.name != "h2": - if current.name == "h3": + while current and current.name != 'h2': + if current.name == 'h3': section_title = current.get_text().lower() # e.g., "added", "fixed" - section_items = parse_section(current.find_next_sibling("ul")) + section_items = parse_section(current.find_next_sibling('ul')) version_data[section_title] = section_items # Move to the next element @@ -228,65 +221,51 @@ CHANGELOG = changelog_json # SAFE_MODE #################################### -SAFE_MODE = os.environ.get("SAFE_MODE", "false").lower() == "true" +SAFE_MODE = os.environ.get('SAFE_MODE', 'false').lower() == 'true' #################################### # ENABLE_FORWARD_USER_INFO_HEADERS #################################### -ENABLE_FORWARD_USER_INFO_HEADERS = ( - os.environ.get("ENABLE_FORWARD_USER_INFO_HEADERS", "False").lower() == "true" -) +ENABLE_FORWARD_USER_INFO_HEADERS = os.environ.get('ENABLE_FORWARD_USER_INFO_HEADERS', 'False').lower() == 'true' # Header names for user info forwarding (customizable via environment variables) -FORWARD_USER_INFO_HEADER_USER_NAME = os.environ.get( - "FORWARD_USER_INFO_HEADER_USER_NAME", "X-OpenWebUI-User-Name" -) -FORWARD_USER_INFO_HEADER_USER_ID = os.environ.get( - "FORWARD_USER_INFO_HEADER_USER_ID", "X-OpenWebUI-User-Id" -) -FORWARD_USER_INFO_HEADER_USER_EMAIL = os.environ.get( - "FORWARD_USER_INFO_HEADER_USER_EMAIL", "X-OpenWebUI-User-Email" -) -FORWARD_USER_INFO_HEADER_USER_ROLE = os.environ.get( - "FORWARD_USER_INFO_HEADER_USER_ROLE", "X-OpenWebUI-User-Role" -) +FORWARD_USER_INFO_HEADER_USER_NAME = os.environ.get('FORWARD_USER_INFO_HEADER_USER_NAME', 'X-OpenWebUI-User-Name') +FORWARD_USER_INFO_HEADER_USER_ID = os.environ.get('FORWARD_USER_INFO_HEADER_USER_ID', 'X-OpenWebUI-User-Id') +FORWARD_USER_INFO_HEADER_USER_EMAIL = os.environ.get('FORWARD_USER_INFO_HEADER_USER_EMAIL', 'X-OpenWebUI-User-Email') +FORWARD_USER_INFO_HEADER_USER_ROLE = os.environ.get('FORWARD_USER_INFO_HEADER_USER_ROLE', 'X-OpenWebUI-User-Role') # Header name for chat ID forwarding (customizable via environment variable) FORWARD_SESSION_INFO_HEADER_MESSAGE_ID = os.environ.get( - "FORWARD_SESSION_INFO_HEADER_MESSAGE_ID", "X-OpenWebUI-Message-Id" -) -FORWARD_SESSION_INFO_HEADER_CHAT_ID = os.environ.get( - "FORWARD_SESSION_INFO_HEADER_CHAT_ID", "X-OpenWebUI-Chat-Id" + 'FORWARD_SESSION_INFO_HEADER_MESSAGE_ID', 'X-OpenWebUI-Message-Id' ) +FORWARD_SESSION_INFO_HEADER_CHAT_ID = os.environ.get('FORWARD_SESSION_INFO_HEADER_CHAT_ID', 'X-OpenWebUI-Chat-Id') # Experimental feature, may be removed in future -ENABLE_STAR_SESSIONS_MIDDLEWARE = ( - os.environ.get("ENABLE_STAR_SESSIONS_MIDDLEWARE", "False").lower() == "true" -) +ENABLE_STAR_SESSIONS_MIDDLEWARE = os.environ.get('ENABLE_STAR_SESSIONS_MIDDLEWARE', 'False').lower() == 'true' -ENABLE_EASTER_EGGS = os.environ.get("ENABLE_EASTER_EGGS", "True").lower() == "true" +ENABLE_EASTER_EGGS = os.environ.get('ENABLE_EASTER_EGGS', 'True').lower() == 'true' #################################### # WEBUI_BUILD_HASH #################################### -WEBUI_BUILD_HASH = os.environ.get("WEBUI_BUILD_HASH", "dev-build") +WEBUI_BUILD_HASH = os.environ.get('WEBUI_BUILD_HASH', 'dev-build') #################################### # DATA/FRONTEND BUILD DIR #################################### -DATA_DIR = Path(os.getenv("DATA_DIR", BACKEND_DIR / "data")).resolve() +DATA_DIR = Path(os.getenv('DATA_DIR', BACKEND_DIR / 'data')).resolve() if FROM_INIT_PY: - NEW_DATA_DIR = Path(os.getenv("DATA_DIR", OPEN_WEBUI_DIR / "data")).resolve() + NEW_DATA_DIR = Path(os.getenv('DATA_DIR', OPEN_WEBUI_DIR / 'data')).resolve() NEW_DATA_DIR.mkdir(parents=True, exist_ok=True) # Check if the data directory exists in the package directory if DATA_DIR.exists() and DATA_DIR != NEW_DATA_DIR: - log.info(f"Moving {DATA_DIR} to {NEW_DATA_DIR}") + log.info(f'Moving {DATA_DIR} to {NEW_DATA_DIR}') for item in DATA_DIR.iterdir(): dest = NEW_DATA_DIR / item.name if item.is_dir(): @@ -295,69 +274,69 @@ if FROM_INIT_PY: shutil.copy2(item, dest) # Zip the data directory - shutil.make_archive(DATA_DIR.parent / "open_webui_data", "zip", DATA_DIR) + shutil.make_archive(DATA_DIR.parent / 'open_webui_data', 'zip', DATA_DIR) # Remove the old data directory shutil.rmtree(DATA_DIR) - DATA_DIR = Path(os.getenv("DATA_DIR", OPEN_WEBUI_DIR / "data")) + DATA_DIR = Path(os.getenv('DATA_DIR', OPEN_WEBUI_DIR / 'data')) -STATIC_DIR = Path(os.getenv("STATIC_DIR", OPEN_WEBUI_DIR / "static")) +STATIC_DIR = Path(os.getenv('STATIC_DIR', OPEN_WEBUI_DIR / 'static')) -FONTS_DIR = Path(os.getenv("FONTS_DIR", OPEN_WEBUI_DIR / "static" / "fonts")) +FONTS_DIR = Path(os.getenv('FONTS_DIR', OPEN_WEBUI_DIR / 'static' / 'fonts')) -FRONTEND_BUILD_DIR = Path(os.getenv("FRONTEND_BUILD_DIR", BASE_DIR / "build")).resolve() +FRONTEND_BUILD_DIR = Path(os.getenv('FRONTEND_BUILD_DIR', BASE_DIR / 'build')).resolve() if FROM_INIT_PY: - FRONTEND_BUILD_DIR = Path( - os.getenv("FRONTEND_BUILD_DIR", OPEN_WEBUI_DIR / "frontend") - ).resolve() + FRONTEND_BUILD_DIR = Path(os.getenv('FRONTEND_BUILD_DIR', OPEN_WEBUI_DIR / 'frontend')).resolve() #################################### # Database #################################### # Check if the file exists -if os.path.exists(f"{DATA_DIR}/ollama.db"): +if os.path.exists(f'{DATA_DIR}/ollama.db'): # Rename the file - os.rename(f"{DATA_DIR}/ollama.db", f"{DATA_DIR}/webui.db") - log.info("Database migrated from Ollama-WebUI successfully.") + os.rename(f'{DATA_DIR}/ollama.db', f'{DATA_DIR}/webui.db') + log.info('Database migrated from Ollama-WebUI successfully.') else: pass -DATABASE_URL = os.environ.get("DATABASE_URL", f"sqlite:///{DATA_DIR}/webui.db") +DATABASE_URL = os.environ.get('DATABASE_URL', f'sqlite:///{DATA_DIR}/webui.db') -DATABASE_TYPE = os.environ.get("DATABASE_TYPE") -DATABASE_USER = os.environ.get("DATABASE_USER") -DATABASE_PASSWORD = os.environ.get("DATABASE_PASSWORD") +DATABASE_TYPE = os.environ.get('DATABASE_TYPE') +DATABASE_USER = os.environ.get('DATABASE_USER') +DATABASE_PASSWORD = os.environ.get('DATABASE_PASSWORD') -DATABASE_CRED = "" +DATABASE_CRED = '' if DATABASE_USER: - DATABASE_CRED += f"{DATABASE_USER}" + DATABASE_CRED += f'{DATABASE_USER}' if DATABASE_PASSWORD: - DATABASE_CRED += f":{DATABASE_PASSWORD}" + DATABASE_CRED += f':{DATABASE_PASSWORD}' DB_VARS = { - "db_type": DATABASE_TYPE, - "db_cred": DATABASE_CRED, - "db_host": os.environ.get("DATABASE_HOST"), - "db_port": os.environ.get("DATABASE_PORT"), - "db_name": os.environ.get("DATABASE_NAME"), + 'db_type': DATABASE_TYPE, + 'db_cred': DATABASE_CRED, + 'db_host': os.environ.get('DATABASE_HOST'), + 'db_port': os.environ.get('DATABASE_PORT'), + 'db_name': os.environ.get('DATABASE_NAME'), } if all(DB_VARS.values()): - DATABASE_URL = f"{DB_VARS['db_type']}://{DB_VARS['db_cred']}@{DB_VARS['db_host']}:{DB_VARS['db_port']}/{DB_VARS['db_name']}" -elif DATABASE_TYPE == "sqlite+sqlcipher" and not os.environ.get("DATABASE_URL"): + DATABASE_URL = ( + f'{DB_VARS["db_type"]}://{DB_VARS["db_cred"]}@{DB_VARS["db_host"]}:{DB_VARS["db_port"]}/{DB_VARS["db_name"]}' + ) +elif DATABASE_TYPE == 'sqlite+sqlcipher' and not os.environ.get('DATABASE_URL'): # Handle SQLCipher with local file when DATABASE_URL wasn't explicitly set - DATABASE_URL = f"sqlite+sqlcipher:///{DATA_DIR}/webui.db" + DATABASE_URL = f'sqlite+sqlcipher:///{DATA_DIR}/webui.db' # Replace the postgres:// with postgresql:// -if "postgres://" in DATABASE_URL: - DATABASE_URL = DATABASE_URL.replace("postgres://", "postgresql://") +if 'postgres://' in DATABASE_URL: + DATABASE_URL = DATABASE_URL.replace('postgres://', 'postgresql://') -DATABASE_SCHEMA = os.environ.get("DATABASE_SCHEMA", None) +DATABASE_SCHEMA = os.environ.get('DATABASE_SCHEMA', None) -DATABASE_POOL_SIZE = os.environ.get("DATABASE_POOL_SIZE", None) +DATABASE_POOL_SIZE = os.environ.get('DATABASE_POOL_SIZE', None) if DATABASE_POOL_SIZE != None: try: @@ -365,9 +344,9 @@ if DATABASE_POOL_SIZE != None: except Exception: DATABASE_POOL_SIZE = None -DATABASE_POOL_MAX_OVERFLOW = os.environ.get("DATABASE_POOL_MAX_OVERFLOW", 0) +DATABASE_POOL_MAX_OVERFLOW = os.environ.get('DATABASE_POOL_MAX_OVERFLOW', 0) -if DATABASE_POOL_MAX_OVERFLOW == "": +if DATABASE_POOL_MAX_OVERFLOW == '': DATABASE_POOL_MAX_OVERFLOW = 0 else: try: @@ -375,9 +354,9 @@ else: except Exception: DATABASE_POOL_MAX_OVERFLOW = 0 -DATABASE_POOL_TIMEOUT = os.environ.get("DATABASE_POOL_TIMEOUT", 30) +DATABASE_POOL_TIMEOUT = os.environ.get('DATABASE_POOL_TIMEOUT', 30) -if DATABASE_POOL_TIMEOUT == "": +if DATABASE_POOL_TIMEOUT == '': DATABASE_POOL_TIMEOUT = 30 else: try: @@ -385,9 +364,9 @@ else: except Exception: DATABASE_POOL_TIMEOUT = 30 -DATABASE_POOL_RECYCLE = os.environ.get("DATABASE_POOL_RECYCLE", 3600) +DATABASE_POOL_RECYCLE = os.environ.get('DATABASE_POOL_RECYCLE', 3600) -if DATABASE_POOL_RECYCLE == "": +if DATABASE_POOL_RECYCLE == '': DATABASE_POOL_RECYCLE = 3600 else: try: @@ -395,57 +374,43 @@ else: except Exception: DATABASE_POOL_RECYCLE = 3600 -DATABASE_ENABLE_SQLITE_WAL = ( - os.environ.get("DATABASE_ENABLE_SQLITE_WAL", "False").lower() == "true" -) +DATABASE_ENABLE_SQLITE_WAL = os.environ.get('DATABASE_ENABLE_SQLITE_WAL', 'False').lower() == 'true' -DATABASE_USER_ACTIVE_STATUS_UPDATE_INTERVAL = os.environ.get( - "DATABASE_USER_ACTIVE_STATUS_UPDATE_INTERVAL", None -) +DATABASE_USER_ACTIVE_STATUS_UPDATE_INTERVAL = os.environ.get('DATABASE_USER_ACTIVE_STATUS_UPDATE_INTERVAL', None) if DATABASE_USER_ACTIVE_STATUS_UPDATE_INTERVAL is not None: try: - DATABASE_USER_ACTIVE_STATUS_UPDATE_INTERVAL = float( - DATABASE_USER_ACTIVE_STATUS_UPDATE_INTERVAL - ) + DATABASE_USER_ACTIVE_STATUS_UPDATE_INTERVAL = float(DATABASE_USER_ACTIVE_STATUS_UPDATE_INTERVAL) except Exception: DATABASE_USER_ACTIVE_STATUS_UPDATE_INTERVAL = 0.0 # When enabled, get_db_context reuses existing sessions; set to False to always create new sessions -DATABASE_ENABLE_SESSION_SHARING = ( - os.environ.get("DATABASE_ENABLE_SESSION_SHARING", "False").lower() == "true" -) +DATABASE_ENABLE_SESSION_SHARING = os.environ.get('DATABASE_ENABLE_SESSION_SHARING', 'False').lower() == 'true' # Enable public visibility of active user count (when disabled, only admins can see it) -ENABLE_PUBLIC_ACTIVE_USERS_COUNT = ( - os.environ.get("ENABLE_PUBLIC_ACTIVE_USERS_COUNT", "True").lower() == "true" -) +ENABLE_PUBLIC_ACTIVE_USERS_COUNT = os.environ.get('ENABLE_PUBLIC_ACTIVE_USERS_COUNT', 'True').lower() == 'true' -RESET_CONFIG_ON_START = ( - os.environ.get("RESET_CONFIG_ON_START", "False").lower() == "true" -) +RESET_CONFIG_ON_START = os.environ.get('RESET_CONFIG_ON_START', 'False').lower() == 'true' -ENABLE_REALTIME_CHAT_SAVE = ( - os.environ.get("ENABLE_REALTIME_CHAT_SAVE", "False").lower() == "true" -) +ENABLE_REALTIME_CHAT_SAVE = os.environ.get('ENABLE_REALTIME_CHAT_SAVE', 'False').lower() == 'true' -ENABLE_QUERIES_CACHE = os.environ.get("ENABLE_QUERIES_CACHE", "False").lower() == "true" +ENABLE_QUERIES_CACHE = os.environ.get('ENABLE_QUERIES_CACHE', 'False').lower() == 'true' -RAG_SYSTEM_CONTEXT = os.environ.get("RAG_SYSTEM_CONTEXT", "False").lower() == "true" +RAG_SYSTEM_CONTEXT = os.environ.get('RAG_SYSTEM_CONTEXT', 'False').lower() == 'true' #################################### # REDIS #################################### -REDIS_URL = os.environ.get("REDIS_URL", "") -REDIS_CLUSTER = os.environ.get("REDIS_CLUSTER", "False").lower() == "true" +REDIS_URL = os.environ.get('REDIS_URL', '') +REDIS_CLUSTER = os.environ.get('REDIS_CLUSTER', 'False').lower() == 'true' -REDIS_KEY_PREFIX = os.environ.get("REDIS_KEY_PREFIX", "open-webui") +REDIS_KEY_PREFIX = os.environ.get('REDIS_KEY_PREFIX', 'open-webui') -REDIS_SENTINEL_HOSTS = os.environ.get("REDIS_SENTINEL_HOSTS", "") -REDIS_SENTINEL_PORT = os.environ.get("REDIS_SENTINEL_PORT", "26379") +REDIS_SENTINEL_HOSTS = os.environ.get('REDIS_SENTINEL_HOSTS', '') +REDIS_SENTINEL_PORT = os.environ.get('REDIS_SENTINEL_PORT', '26379') # Maximum number of retries for Redis operations when using Sentinel fail-over -REDIS_SENTINEL_MAX_RETRY_COUNT = os.environ.get("REDIS_SENTINEL_MAX_RETRY_COUNT", "2") +REDIS_SENTINEL_MAX_RETRY_COUNT = os.environ.get('REDIS_SENTINEL_MAX_RETRY_COUNT', '2') try: REDIS_SENTINEL_MAX_RETRY_COUNT = int(REDIS_SENTINEL_MAX_RETRY_COUNT) if REDIS_SENTINEL_MAX_RETRY_COUNT < 1: @@ -454,15 +419,15 @@ except ValueError: REDIS_SENTINEL_MAX_RETRY_COUNT = 2 -REDIS_SOCKET_CONNECT_TIMEOUT = os.environ.get("REDIS_SOCKET_CONNECT_TIMEOUT", "") +REDIS_SOCKET_CONNECT_TIMEOUT = os.environ.get('REDIS_SOCKET_CONNECT_TIMEOUT', '') try: REDIS_SOCKET_CONNECT_TIMEOUT = float(REDIS_SOCKET_CONNECT_TIMEOUT) except ValueError: REDIS_SOCKET_CONNECT_TIMEOUT = None -REDIS_RECONNECT_DELAY = os.environ.get("REDIS_RECONNECT_DELAY", "") +REDIS_RECONNECT_DELAY = os.environ.get('REDIS_RECONNECT_DELAY', '') -if REDIS_RECONNECT_DELAY == "": +if REDIS_RECONNECT_DELAY == '': REDIS_RECONNECT_DELAY = None else: try: @@ -477,27 +442,23 @@ else: #################################### # Number of uvicorn worker processes for handling requests -UVICORN_WORKERS = os.environ.get("UVICORN_WORKERS", "1") +UVICORN_WORKERS = os.environ.get('UVICORN_WORKERS', '1') try: UVICORN_WORKERS = int(UVICORN_WORKERS) if UVICORN_WORKERS < 1: UVICORN_WORKERS = 1 except ValueError: UVICORN_WORKERS = 1 - log.info(f"Invalid UVICORN_WORKERS value, defaulting to {UVICORN_WORKERS}") + log.info(f'Invalid UVICORN_WORKERS value, defaulting to {UVICORN_WORKERS}') #################################### # WEBUI_AUTH (Required for security) #################################### -WEBUI_AUTH = os.environ.get("WEBUI_AUTH", "True").lower() == "true" +WEBUI_AUTH = os.environ.get('WEBUI_AUTH', 'True').lower() == 'true' -ENABLE_INITIAL_ADMIN_SIGNUP = ( - os.environ.get("ENABLE_INITIAL_ADMIN_SIGNUP", "False").lower() == "true" -) -ENABLE_SIGNUP_PASSWORD_CONFIRMATION = ( - os.environ.get("ENABLE_SIGNUP_PASSWORD_CONFIRMATION", "False").lower() == "true" -) +ENABLE_INITIAL_ADMIN_SIGNUP = os.environ.get('ENABLE_INITIAL_ADMIN_SIGNUP', 'False').lower() == 'true' +ENABLE_SIGNUP_PASSWORD_CONFIRMATION = os.environ.get('ENABLE_SIGNUP_PASSWORD_CONFIRMATION', 'False').lower() == 'true' #################################### # Admin Account Runtime Creation @@ -505,164 +466,131 @@ ENABLE_SIGNUP_PASSWORD_CONFIRMATION = ( # Optional env vars for creating an admin account on startup # Useful for headless/automated deployments -WEBUI_ADMIN_EMAIL = os.environ.get("WEBUI_ADMIN_EMAIL", "") -WEBUI_ADMIN_PASSWORD = os.environ.get("WEBUI_ADMIN_PASSWORD", "") -WEBUI_ADMIN_NAME = os.environ.get("WEBUI_ADMIN_NAME", "Admin") +WEBUI_ADMIN_EMAIL = os.environ.get('WEBUI_ADMIN_EMAIL', '') +WEBUI_ADMIN_PASSWORD = os.environ.get('WEBUI_ADMIN_PASSWORD', '') +WEBUI_ADMIN_NAME = os.environ.get('WEBUI_ADMIN_NAME', 'Admin') -WEBUI_AUTH_TRUSTED_EMAIL_HEADER = os.environ.get( - "WEBUI_AUTH_TRUSTED_EMAIL_HEADER", None -) -WEBUI_AUTH_TRUSTED_NAME_HEADER = os.environ.get("WEBUI_AUTH_TRUSTED_NAME_HEADER", None) -WEBUI_AUTH_TRUSTED_GROUPS_HEADER = os.environ.get( - "WEBUI_AUTH_TRUSTED_GROUPS_HEADER", None -) +WEBUI_AUTH_TRUSTED_EMAIL_HEADER = os.environ.get('WEBUI_AUTH_TRUSTED_EMAIL_HEADER', None) +WEBUI_AUTH_TRUSTED_NAME_HEADER = os.environ.get('WEBUI_AUTH_TRUSTED_NAME_HEADER', None) +WEBUI_AUTH_TRUSTED_GROUPS_HEADER = os.environ.get('WEBUI_AUTH_TRUSTED_GROUPS_HEADER', None) -ENABLE_PASSWORD_VALIDATION = ( - os.environ.get("ENABLE_PASSWORD_VALIDATION", "False").lower() == "true" -) +ENABLE_PASSWORD_VALIDATION = os.environ.get('ENABLE_PASSWORD_VALIDATION', 'False').lower() == 'true' PASSWORD_VALIDATION_REGEX_PATTERN = os.environ.get( - "PASSWORD_VALIDATION_REGEX_PATTERN", - r"^(?=.*[a-z])(?=.*[A-Z])(?=.*\d)(?=.*[^\w\s]).{8,}$", + 'PASSWORD_VALIDATION_REGEX_PATTERN', + r'^(?=.*[a-z])(?=.*[A-Z])(?=.*\d)(?=.*[^\w\s]).{8,}$', ) try: - PASSWORD_VALIDATION_REGEX_PATTERN = rf"{PASSWORD_VALIDATION_REGEX_PATTERN}" + PASSWORD_VALIDATION_REGEX_PATTERN = rf'{PASSWORD_VALIDATION_REGEX_PATTERN}' PASSWORD_VALIDATION_REGEX_PATTERN = re.compile(PASSWORD_VALIDATION_REGEX_PATTERN) except Exception as e: - log.error(f"Invalid PASSWORD_VALIDATION_REGEX_PATTERN: {e}") - PASSWORD_VALIDATION_REGEX_PATTERN = re.compile( - r"^(?=.*[a-z])(?=.*[A-Z])(?=.*\d)(?=.*[^\w\s]).{8,}$" - ) + log.error(f'Invalid PASSWORD_VALIDATION_REGEX_PATTERN: {e}') + PASSWORD_VALIDATION_REGEX_PATTERN = re.compile(r'^(?=.*[a-z])(?=.*[A-Z])(?=.*\d)(?=.*[^\w\s]).{8,}$') -PASSWORD_VALIDATION_HINT = os.environ.get("PASSWORD_VALIDATION_HINT", "") +PASSWORD_VALIDATION_HINT = os.environ.get('PASSWORD_VALIDATION_HINT', '') -BYPASS_MODEL_ACCESS_CONTROL = ( - os.environ.get("BYPASS_MODEL_ACCESS_CONTROL", "False").lower() == "true" -) +BYPASS_MODEL_ACCESS_CONTROL = os.environ.get('BYPASS_MODEL_ACCESS_CONTROL', 'False').lower() == 'true' -WEBUI_AUTH_SIGNOUT_REDIRECT_URL = os.environ.get( - "WEBUI_AUTH_SIGNOUT_REDIRECT_URL", None -) +WEBUI_AUTH_SIGNOUT_REDIRECT_URL = os.environ.get('WEBUI_AUTH_SIGNOUT_REDIRECT_URL', None) #################################### # WEBUI_SECRET_KEY #################################### WEBUI_SECRET_KEY = os.environ.get( - "WEBUI_SECRET_KEY", - os.environ.get( - "WEBUI_JWT_SECRET_KEY", "t0p-s3cr3t" - ), # DEPRECATED: remove at next major version + 'WEBUI_SECRET_KEY', + os.environ.get('WEBUI_JWT_SECRET_KEY', 't0p-s3cr3t'), # DEPRECATED: remove at next major version ) -WEBUI_SESSION_COOKIE_SAME_SITE = os.environ.get("WEBUI_SESSION_COOKIE_SAME_SITE", "lax") +WEBUI_SESSION_COOKIE_SAME_SITE = os.environ.get('WEBUI_SESSION_COOKIE_SAME_SITE', 'lax') -WEBUI_SESSION_COOKIE_SECURE = ( - os.environ.get("WEBUI_SESSION_COOKIE_SECURE", "false").lower() == "true" -) +WEBUI_SESSION_COOKIE_SECURE = os.environ.get('WEBUI_SESSION_COOKIE_SECURE', 'false').lower() == 'true' -WEBUI_AUTH_COOKIE_SAME_SITE = os.environ.get( - "WEBUI_AUTH_COOKIE_SAME_SITE", WEBUI_SESSION_COOKIE_SAME_SITE -) +WEBUI_AUTH_COOKIE_SAME_SITE = os.environ.get('WEBUI_AUTH_COOKIE_SAME_SITE', WEBUI_SESSION_COOKIE_SAME_SITE) WEBUI_AUTH_COOKIE_SECURE = ( os.environ.get( - "WEBUI_AUTH_COOKIE_SECURE", - os.environ.get("WEBUI_SESSION_COOKIE_SECURE", "false"), + 'WEBUI_AUTH_COOKIE_SECURE', + os.environ.get('WEBUI_SESSION_COOKIE_SECURE', 'false'), ).lower() - == "true" + == 'true' ) -if WEBUI_AUTH and WEBUI_SECRET_KEY == "": +if WEBUI_AUTH and WEBUI_SECRET_KEY == '': raise ValueError(ERROR_MESSAGES.ENV_VAR_NOT_FOUND) -ENABLE_COMPRESSION_MIDDLEWARE = ( - os.environ.get("ENABLE_COMPRESSION_MIDDLEWARE", "True").lower() == "true" -) +ENABLE_COMPRESSION_MIDDLEWARE = os.environ.get('ENABLE_COMPRESSION_MIDDLEWARE', 'True').lower() == 'true' #################################### # OAUTH Configuration #################################### -ENABLE_OAUTH_EMAIL_FALLBACK = ( - os.environ.get("ENABLE_OAUTH_EMAIL_FALLBACK", "False").lower() == "true" -) +ENABLE_OAUTH_EMAIL_FALLBACK = os.environ.get('ENABLE_OAUTH_EMAIL_FALLBACK', 'False').lower() == 'true' -ENABLE_OAUTH_ID_TOKEN_COOKIE = ( - os.environ.get("ENABLE_OAUTH_ID_TOKEN_COOKIE", "True").lower() == "true" -) +ENABLE_OAUTH_ID_TOKEN_COOKIE = os.environ.get('ENABLE_OAUTH_ID_TOKEN_COOKIE', 'True').lower() == 'true' -OAUTH_CLIENT_INFO_ENCRYPTION_KEY = os.environ.get( - "OAUTH_CLIENT_INFO_ENCRYPTION_KEY", WEBUI_SECRET_KEY -) +OAUTH_CLIENT_INFO_ENCRYPTION_KEY = os.environ.get('OAUTH_CLIENT_INFO_ENCRYPTION_KEY', WEBUI_SECRET_KEY) -OAUTH_SESSION_TOKEN_ENCRYPTION_KEY = os.environ.get( - "OAUTH_SESSION_TOKEN_ENCRYPTION_KEY", WEBUI_SECRET_KEY -) +OAUTH_SESSION_TOKEN_ENCRYPTION_KEY = os.environ.get('OAUTH_SESSION_TOKEN_ENCRYPTION_KEY', WEBUI_SECRET_KEY) # Maximum number of concurrent OAuth sessions per user per provider # This prevents unbounded session growth while allowing multi-device usage -OAUTH_MAX_SESSIONS_PER_USER = int(os.environ.get("OAUTH_MAX_SESSIONS_PER_USER", "10")) +OAUTH_MAX_SESSIONS_PER_USER = int(os.environ.get('OAUTH_MAX_SESSIONS_PER_USER', '10')) # Token Exchange Configuration # Allows external apps to exchange OAuth tokens for OpenWebUI tokens -ENABLE_OAUTH_TOKEN_EXCHANGE = ( - os.environ.get("ENABLE_OAUTH_TOKEN_EXCHANGE", "False").lower() == "true" -) +ENABLE_OAUTH_TOKEN_EXCHANGE = os.environ.get('ENABLE_OAUTH_TOKEN_EXCHANGE', 'False').lower() == 'true' #################################### # SCIM Configuration #################################### -ENABLE_SCIM = ( - os.environ.get("ENABLE_SCIM", os.environ.get("SCIM_ENABLED", "False")).lower() - == "true" -) -SCIM_TOKEN = os.environ.get("SCIM_TOKEN", "") -SCIM_AUTH_PROVIDER = os.environ.get("SCIM_AUTH_PROVIDER", "") +ENABLE_SCIM = os.environ.get('ENABLE_SCIM', os.environ.get('SCIM_ENABLED', 'False')).lower() == 'true' +SCIM_TOKEN = os.environ.get('SCIM_TOKEN', '') +SCIM_AUTH_PROVIDER = os.environ.get('SCIM_AUTH_PROVIDER', '') if ENABLE_SCIM and not SCIM_AUTH_PROVIDER: log.warning( - "SCIM is enabled but SCIM_AUTH_PROVIDER is not set. " + 'SCIM is enabled but SCIM_AUTH_PROVIDER is not set. ' "Set SCIM_AUTH_PROVIDER to the OAuth provider name (e.g. 'microsoft', 'oidc') " - "to enable externalId storage." + 'to enable externalId storage.' ) #################################### # LICENSE_KEY #################################### -LICENSE_KEY = os.environ.get("LICENSE_KEY", "") +LICENSE_KEY = os.environ.get('LICENSE_KEY', '') LICENSE_BLOB = None -LICENSE_BLOB_PATH = os.environ.get("LICENSE_BLOB_PATH", DATA_DIR / "l.data") +LICENSE_BLOB_PATH = os.environ.get('LICENSE_BLOB_PATH', DATA_DIR / 'l.data') if LICENSE_BLOB_PATH and os.path.exists(LICENSE_BLOB_PATH): - with open(LICENSE_BLOB_PATH, "rb") as f: + with open(LICENSE_BLOB_PATH, 'rb') as f: LICENSE_BLOB = f.read() -LICENSE_PUBLIC_KEY = os.environ.get("LICENSE_PUBLIC_KEY", "") +LICENSE_PUBLIC_KEY = os.environ.get('LICENSE_PUBLIC_KEY', '') pk = None if LICENSE_PUBLIC_KEY: - pk = serialization.load_pem_public_key(f""" + pk = serialization.load_pem_public_key( + f""" -----BEGIN PUBLIC KEY----- {LICENSE_PUBLIC_KEY} -----END PUBLIC KEY----- -""".encode("utf-8")) +""".encode('utf-8') + ) #################################### # MODELS #################################### -ENABLE_CUSTOM_MODEL_FALLBACK = ( - os.environ.get("ENABLE_CUSTOM_MODEL_FALLBACK", "False").lower() == "true" -) +ENABLE_CUSTOM_MODEL_FALLBACK = os.environ.get('ENABLE_CUSTOM_MODEL_FALLBACK', 'False').lower() == 'true' -MODELS_CACHE_TTL = os.environ.get("MODELS_CACHE_TTL", "1") -if MODELS_CACHE_TTL == "": +MODELS_CACHE_TTL = os.environ.get('MODELS_CACHE_TTL', '1') +if MODELS_CACHE_TTL == '': MODELS_CACHE_TTL = None else: try: @@ -676,30 +604,23 @@ else: #################################### ENABLE_CHAT_RESPONSE_BASE64_IMAGE_URL_CONVERSION = ( - os.environ.get("ENABLE_CHAT_RESPONSE_BASE64_IMAGE_URL_CONVERSION", "False").lower() - == "true" + os.environ.get('ENABLE_CHAT_RESPONSE_BASE64_IMAGE_URL_CONVERSION', 'False').lower() == 'true' ) -CHAT_RESPONSE_STREAM_DELTA_CHUNK_SIZE = os.environ.get( - "CHAT_RESPONSE_STREAM_DELTA_CHUNK_SIZE", "1" -) +CHAT_RESPONSE_STREAM_DELTA_CHUNK_SIZE = os.environ.get('CHAT_RESPONSE_STREAM_DELTA_CHUNK_SIZE', '1') -if CHAT_RESPONSE_STREAM_DELTA_CHUNK_SIZE == "": +if CHAT_RESPONSE_STREAM_DELTA_CHUNK_SIZE == '': CHAT_RESPONSE_STREAM_DELTA_CHUNK_SIZE = 1 else: try: - CHAT_RESPONSE_STREAM_DELTA_CHUNK_SIZE = int( - CHAT_RESPONSE_STREAM_DELTA_CHUNK_SIZE - ) + CHAT_RESPONSE_STREAM_DELTA_CHUNK_SIZE = int(CHAT_RESPONSE_STREAM_DELTA_CHUNK_SIZE) except Exception: CHAT_RESPONSE_STREAM_DELTA_CHUNK_SIZE = 1 -CHAT_RESPONSE_MAX_TOOL_CALL_RETRIES = os.environ.get( - "CHAT_RESPONSE_MAX_TOOL_CALL_RETRIES", "30" -) +CHAT_RESPONSE_MAX_TOOL_CALL_RETRIES = os.environ.get('CHAT_RESPONSE_MAX_TOOL_CALL_RETRIES', '30') -if CHAT_RESPONSE_MAX_TOOL_CALL_RETRIES == "": +if CHAT_RESPONSE_MAX_TOOL_CALL_RETRIES == '': CHAT_RESPONSE_MAX_TOOL_CALL_RETRIES = 30 else: try: @@ -708,17 +629,13 @@ else: CHAT_RESPONSE_MAX_TOOL_CALL_RETRIES = 30 -CHAT_STREAM_RESPONSE_CHUNK_MAX_BUFFER_SIZE = os.environ.get( - "CHAT_STREAM_RESPONSE_CHUNK_MAX_BUFFER_SIZE", "" -) +CHAT_STREAM_RESPONSE_CHUNK_MAX_BUFFER_SIZE = os.environ.get('CHAT_STREAM_RESPONSE_CHUNK_MAX_BUFFER_SIZE', '') -if CHAT_STREAM_RESPONSE_CHUNK_MAX_BUFFER_SIZE == "": +if CHAT_STREAM_RESPONSE_CHUNK_MAX_BUFFER_SIZE == '': CHAT_STREAM_RESPONSE_CHUNK_MAX_BUFFER_SIZE = None else: try: - CHAT_STREAM_RESPONSE_CHUNK_MAX_BUFFER_SIZE = int( - CHAT_STREAM_RESPONSE_CHUNK_MAX_BUFFER_SIZE - ) + CHAT_STREAM_RESPONSE_CHUNK_MAX_BUFFER_SIZE = int(CHAT_STREAM_RESPONSE_CHUNK_MAX_BUFFER_SIZE) except Exception: CHAT_STREAM_RESPONSE_CHUNK_MAX_BUFFER_SIZE = None @@ -727,70 +644,62 @@ else: # WEBSOCKET SUPPORT #################################### -ENABLE_WEBSOCKET_SUPPORT = ( - os.environ.get("ENABLE_WEBSOCKET_SUPPORT", "True").lower() == "true" -) +ENABLE_WEBSOCKET_SUPPORT = os.environ.get('ENABLE_WEBSOCKET_SUPPORT', 'True').lower() == 'true' -WEBSOCKET_MANAGER = os.environ.get("WEBSOCKET_MANAGER", "") +WEBSOCKET_MANAGER = os.environ.get('WEBSOCKET_MANAGER', '') -WEBSOCKET_REDIS_OPTIONS = os.environ.get("WEBSOCKET_REDIS_OPTIONS", "") +WEBSOCKET_REDIS_OPTIONS = os.environ.get('WEBSOCKET_REDIS_OPTIONS', '') -if WEBSOCKET_REDIS_OPTIONS == "": +if WEBSOCKET_REDIS_OPTIONS == '': if REDIS_SOCKET_CONNECT_TIMEOUT: - WEBSOCKET_REDIS_OPTIONS = { - "socket_connect_timeout": REDIS_SOCKET_CONNECT_TIMEOUT - } + WEBSOCKET_REDIS_OPTIONS = {'socket_connect_timeout': REDIS_SOCKET_CONNECT_TIMEOUT} else: - log.debug("No WEBSOCKET_REDIS_OPTIONS provided, defaulting to None") + log.debug('No WEBSOCKET_REDIS_OPTIONS provided, defaulting to None') WEBSOCKET_REDIS_OPTIONS = None else: try: WEBSOCKET_REDIS_OPTIONS = json.loads(WEBSOCKET_REDIS_OPTIONS) except Exception: - log.warning("Invalid WEBSOCKET_REDIS_OPTIONS, defaulting to None") + log.warning('Invalid WEBSOCKET_REDIS_OPTIONS, defaulting to None') WEBSOCKET_REDIS_OPTIONS = None -WEBSOCKET_REDIS_URL = os.environ.get("WEBSOCKET_REDIS_URL", REDIS_URL) -WEBSOCKET_REDIS_CLUSTER = ( - os.environ.get("WEBSOCKET_REDIS_CLUSTER", str(REDIS_CLUSTER)).lower() == "true" -) +WEBSOCKET_REDIS_URL = os.environ.get('WEBSOCKET_REDIS_URL', REDIS_URL) +WEBSOCKET_REDIS_CLUSTER = os.environ.get('WEBSOCKET_REDIS_CLUSTER', str(REDIS_CLUSTER)).lower() == 'true' -websocket_redis_lock_timeout = os.environ.get("WEBSOCKET_REDIS_LOCK_TIMEOUT", "60") +websocket_redis_lock_timeout = os.environ.get('WEBSOCKET_REDIS_LOCK_TIMEOUT', '60') try: WEBSOCKET_REDIS_LOCK_TIMEOUT = int(websocket_redis_lock_timeout) except ValueError: WEBSOCKET_REDIS_LOCK_TIMEOUT = 60 -WEBSOCKET_SENTINEL_HOSTS = os.environ.get("WEBSOCKET_SENTINEL_HOSTS", "") -WEBSOCKET_SENTINEL_PORT = os.environ.get("WEBSOCKET_SENTINEL_PORT", "26379") -WEBSOCKET_SERVER_LOGGING = ( - os.environ.get("WEBSOCKET_SERVER_LOGGING", "False").lower() == "true" -) +WEBSOCKET_SENTINEL_HOSTS = os.environ.get('WEBSOCKET_SENTINEL_HOSTS', '') +WEBSOCKET_SENTINEL_PORT = os.environ.get('WEBSOCKET_SENTINEL_PORT', '26379') +WEBSOCKET_SERVER_LOGGING = os.environ.get('WEBSOCKET_SERVER_LOGGING', 'False').lower() == 'true' WEBSOCKET_SERVER_ENGINEIO_LOGGING = ( os.environ.get( - "WEBSOCKET_SERVER_ENGINEIO_LOGGING", - os.environ.get("WEBSOCKET_SERVER_LOGGING", "False"), + 'WEBSOCKET_SERVER_ENGINEIO_LOGGING', + os.environ.get('WEBSOCKET_SERVER_LOGGING', 'False'), ).lower() - == "true" + == 'true' ) -WEBSOCKET_SERVER_PING_TIMEOUT = os.environ.get("WEBSOCKET_SERVER_PING_TIMEOUT", "20") +WEBSOCKET_SERVER_PING_TIMEOUT = os.environ.get('WEBSOCKET_SERVER_PING_TIMEOUT', '20') try: WEBSOCKET_SERVER_PING_TIMEOUT = int(WEBSOCKET_SERVER_PING_TIMEOUT) except ValueError: WEBSOCKET_SERVER_PING_TIMEOUT = 20 -WEBSOCKET_SERVER_PING_INTERVAL = os.environ.get("WEBSOCKET_SERVER_PING_INTERVAL", "25") +WEBSOCKET_SERVER_PING_INTERVAL = os.environ.get('WEBSOCKET_SERVER_PING_INTERVAL', '25') try: WEBSOCKET_SERVER_PING_INTERVAL = int(WEBSOCKET_SERVER_PING_INTERVAL) except ValueError: WEBSOCKET_SERVER_PING_INTERVAL = 25 -WEBSOCKET_EVENT_CALLER_TIMEOUT = os.environ.get("WEBSOCKET_EVENT_CALLER_TIMEOUT", "") +WEBSOCKET_EVENT_CALLER_TIMEOUT = os.environ.get('WEBSOCKET_EVENT_CALLER_TIMEOUT', '') -if WEBSOCKET_EVENT_CALLER_TIMEOUT == "": +if WEBSOCKET_EVENT_CALLER_TIMEOUT == '': WEBSOCKET_EVENT_CALLER_TIMEOUT = None else: try: @@ -799,11 +708,11 @@ else: WEBSOCKET_EVENT_CALLER_TIMEOUT = 300 -REQUESTS_VERIFY = os.environ.get("REQUESTS_VERIFY", "True").lower() == "true" +REQUESTS_VERIFY = os.environ.get('REQUESTS_VERIFY', 'True').lower() == 'true' -AIOHTTP_CLIENT_TIMEOUT = os.environ.get("AIOHTTP_CLIENT_TIMEOUT", "") +AIOHTTP_CLIENT_TIMEOUT = os.environ.get('AIOHTTP_CLIENT_TIMEOUT', '') -if AIOHTTP_CLIENT_TIMEOUT == "": +if AIOHTTP_CLIENT_TIMEOUT == '': AIOHTTP_CLIENT_TIMEOUT = None else: try: @@ -812,16 +721,14 @@ else: AIOHTTP_CLIENT_TIMEOUT = 300 -AIOHTTP_CLIENT_SESSION_SSL = ( - os.environ.get("AIOHTTP_CLIENT_SESSION_SSL", "True").lower() == "true" -) +AIOHTTP_CLIENT_SESSION_SSL = os.environ.get('AIOHTTP_CLIENT_SESSION_SSL', 'True').lower() == 'true' AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST = os.environ.get( - "AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST", - os.environ.get("AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST", "10"), + 'AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST', + os.environ.get('AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST', '10'), ) -if AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST == "": +if AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST == '': AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST = None else: try: @@ -830,29 +737,25 @@ else: AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST = 10 -AIOHTTP_CLIENT_TIMEOUT_TOOL_SERVER_DATA = os.environ.get( - "AIOHTTP_CLIENT_TIMEOUT_TOOL_SERVER_DATA", "10" -) +AIOHTTP_CLIENT_TIMEOUT_TOOL_SERVER_DATA = os.environ.get('AIOHTTP_CLIENT_TIMEOUT_TOOL_SERVER_DATA', '10') -if AIOHTTP_CLIENT_TIMEOUT_TOOL_SERVER_DATA == "": +if AIOHTTP_CLIENT_TIMEOUT_TOOL_SERVER_DATA == '': AIOHTTP_CLIENT_TIMEOUT_TOOL_SERVER_DATA = None else: try: - AIOHTTP_CLIENT_TIMEOUT_TOOL_SERVER_DATA = int( - AIOHTTP_CLIENT_TIMEOUT_TOOL_SERVER_DATA - ) + AIOHTTP_CLIENT_TIMEOUT_TOOL_SERVER_DATA = int(AIOHTTP_CLIENT_TIMEOUT_TOOL_SERVER_DATA) except Exception: AIOHTTP_CLIENT_TIMEOUT_TOOL_SERVER_DATA = 10 AIOHTTP_CLIENT_SESSION_TOOL_SERVER_SSL = ( - os.environ.get("AIOHTTP_CLIENT_SESSION_TOOL_SERVER_SSL", "True").lower() == "true" + os.environ.get('AIOHTTP_CLIENT_SESSION_TOOL_SERVER_SSL', 'True').lower() == 'true' ) -RAG_EMBEDDING_TIMEOUT = os.environ.get("RAG_EMBEDDING_TIMEOUT", "") +RAG_EMBEDDING_TIMEOUT = os.environ.get('RAG_EMBEDDING_TIMEOUT', '') -if RAG_EMBEDDING_TIMEOUT == "": +if RAG_EMBEDDING_TIMEOUT == '': RAG_EMBEDDING_TIMEOUT = None else: try: @@ -866,42 +769,34 @@ else: #################################### -SENTENCE_TRANSFORMERS_BACKEND = os.environ.get("SENTENCE_TRANSFORMERS_BACKEND", "") -if SENTENCE_TRANSFORMERS_BACKEND == "": - SENTENCE_TRANSFORMERS_BACKEND = "torch" +SENTENCE_TRANSFORMERS_BACKEND = os.environ.get('SENTENCE_TRANSFORMERS_BACKEND', '') +if SENTENCE_TRANSFORMERS_BACKEND == '': + SENTENCE_TRANSFORMERS_BACKEND = 'torch' -SENTENCE_TRANSFORMERS_MODEL_KWARGS = os.environ.get( - "SENTENCE_TRANSFORMERS_MODEL_KWARGS", "" -) -if SENTENCE_TRANSFORMERS_MODEL_KWARGS == "": +SENTENCE_TRANSFORMERS_MODEL_KWARGS = os.environ.get('SENTENCE_TRANSFORMERS_MODEL_KWARGS', '') +if SENTENCE_TRANSFORMERS_MODEL_KWARGS == '': SENTENCE_TRANSFORMERS_MODEL_KWARGS = None else: try: - SENTENCE_TRANSFORMERS_MODEL_KWARGS = json.loads( - SENTENCE_TRANSFORMERS_MODEL_KWARGS - ) + SENTENCE_TRANSFORMERS_MODEL_KWARGS = json.loads(SENTENCE_TRANSFORMERS_MODEL_KWARGS) except Exception: SENTENCE_TRANSFORMERS_MODEL_KWARGS = None -SENTENCE_TRANSFORMERS_CROSS_ENCODER_BACKEND = os.environ.get( - "SENTENCE_TRANSFORMERS_CROSS_ENCODER_BACKEND", "" -) -if SENTENCE_TRANSFORMERS_CROSS_ENCODER_BACKEND == "": - SENTENCE_TRANSFORMERS_CROSS_ENCODER_BACKEND = "torch" +SENTENCE_TRANSFORMERS_CROSS_ENCODER_BACKEND = os.environ.get('SENTENCE_TRANSFORMERS_CROSS_ENCODER_BACKEND', '') +if SENTENCE_TRANSFORMERS_CROSS_ENCODER_BACKEND == '': + SENTENCE_TRANSFORMERS_CROSS_ENCODER_BACKEND = 'torch' SENTENCE_TRANSFORMERS_CROSS_ENCODER_MODEL_KWARGS = os.environ.get( - "SENTENCE_TRANSFORMERS_CROSS_ENCODER_MODEL_KWARGS", "" + 'SENTENCE_TRANSFORMERS_CROSS_ENCODER_MODEL_KWARGS', '' ) -if SENTENCE_TRANSFORMERS_CROSS_ENCODER_MODEL_KWARGS == "": +if SENTENCE_TRANSFORMERS_CROSS_ENCODER_MODEL_KWARGS == '': SENTENCE_TRANSFORMERS_CROSS_ENCODER_MODEL_KWARGS = None else: try: - SENTENCE_TRANSFORMERS_CROSS_ENCODER_MODEL_KWARGS = json.loads( - SENTENCE_TRANSFORMERS_CROSS_ENCODER_MODEL_KWARGS - ) + SENTENCE_TRANSFORMERS_CROSS_ENCODER_MODEL_KWARGS = json.loads(SENTENCE_TRANSFORMERS_CROSS_ENCODER_MODEL_KWARGS) except Exception: SENTENCE_TRANSFORMERS_CROSS_ENCODER_MODEL_KWARGS = None @@ -909,23 +804,18 @@ else: # When enabled (default), scores are normalized to 0-1 range for proper # relevance threshold behavior with MS MARCO models. SENTENCE_TRANSFORMERS_CROSS_ENCODER_SIGMOID_ACTIVATION_FUNCTION = ( - os.environ.get( - "SENTENCE_TRANSFORMERS_CROSS_ENCODER_SIGMOID_ACTIVATION_FUNCTION", "True" - ).lower() - == "true" + os.environ.get('SENTENCE_TRANSFORMERS_CROSS_ENCODER_SIGMOID_ACTIVATION_FUNCTION', 'True').lower() == 'true' ) #################################### # OFFLINE_MODE #################################### -ENABLE_VERSION_UPDATE_CHECK = ( - os.environ.get("ENABLE_VERSION_UPDATE_CHECK", "true").lower() == "true" -) -OFFLINE_MODE = os.environ.get("OFFLINE_MODE", "false").lower() == "true" +ENABLE_VERSION_UPDATE_CHECK = os.environ.get('ENABLE_VERSION_UPDATE_CHECK', 'true').lower() == 'true' +OFFLINE_MODE = os.environ.get('OFFLINE_MODE', 'false').lower() == 'true' if OFFLINE_MODE: - os.environ["HF_HUB_OFFLINE"] = "1" + os.environ['HF_HUB_OFFLINE'] = '1' ENABLE_VERSION_UPDATE_CHECK = False #################################### @@ -933,113 +823,79 @@ if OFFLINE_MODE: #################################### -ENABLE_AUDIT_STDOUT = os.getenv("ENABLE_AUDIT_STDOUT", "False").lower() == "true" -ENABLE_AUDIT_LOGS_FILE = os.getenv("ENABLE_AUDIT_LOGS_FILE", "True").lower() == "true" +ENABLE_AUDIT_STDOUT = os.getenv('ENABLE_AUDIT_STDOUT', 'False').lower() == 'true' +ENABLE_AUDIT_LOGS_FILE = os.getenv('ENABLE_AUDIT_LOGS_FILE', 'True').lower() == 'true' # Where to store log file # Defaults to the DATA_DIR/audit.log. To set AUDIT_LOGS_FILE_PATH you need to # provide the whole path, like: /app/audit.log -AUDIT_LOGS_FILE_PATH = os.getenv("AUDIT_LOGS_FILE_PATH", f"{DATA_DIR}/audit.log") +AUDIT_LOGS_FILE_PATH = os.getenv('AUDIT_LOGS_FILE_PATH', f'{DATA_DIR}/audit.log') # Maximum size of a file before rotating into a new log file -AUDIT_LOG_FILE_ROTATION_SIZE = os.getenv("AUDIT_LOG_FILE_ROTATION_SIZE", "10MB") +AUDIT_LOG_FILE_ROTATION_SIZE = os.getenv('AUDIT_LOG_FILE_ROTATION_SIZE', '10MB') # Comma separated list of logger names to use for audit logging # Default is "uvicorn.access" which is the access log for Uvicorn # You can add more logger names to this list if you want to capture more logs -AUDIT_UVICORN_LOGGER_NAMES = os.getenv( - "AUDIT_UVICORN_LOGGER_NAMES", "uvicorn.access" -).split(",") +AUDIT_UVICORN_LOGGER_NAMES = os.getenv('AUDIT_UVICORN_LOGGER_NAMES', 'uvicorn.access').split(',') # METADATA | REQUEST | REQUEST_RESPONSE -AUDIT_LOG_LEVEL = os.getenv("AUDIT_LOG_LEVEL", "NONE").upper() +AUDIT_LOG_LEVEL = os.getenv('AUDIT_LOG_LEVEL', 'NONE').upper() try: - MAX_BODY_LOG_SIZE = int(os.environ.get("MAX_BODY_LOG_SIZE") or 2048) + MAX_BODY_LOG_SIZE = int(os.environ.get('MAX_BODY_LOG_SIZE') or 2048) except ValueError: MAX_BODY_LOG_SIZE = 2048 # Comma separated list for urls to exclude from audit -AUDIT_EXCLUDED_PATHS = os.getenv("AUDIT_EXCLUDED_PATHS", "/chats,/chat,/folders").split( - "," -) +AUDIT_EXCLUDED_PATHS = os.getenv('AUDIT_EXCLUDED_PATHS', '/chats,/chat,/folders').split(',') AUDIT_EXCLUDED_PATHS = [path.strip() for path in AUDIT_EXCLUDED_PATHS] -AUDIT_EXCLUDED_PATHS = [path.lstrip("/") for path in AUDIT_EXCLUDED_PATHS] +AUDIT_EXCLUDED_PATHS = [path.lstrip('/') for path in AUDIT_EXCLUDED_PATHS] # Comma separated list of urls to include in audit (whitelist mode) # When set, only these paths are audited and AUDIT_EXCLUDED_PATHS is ignored -AUDIT_INCLUDED_PATHS = os.getenv("AUDIT_INCLUDED_PATHS", "").split(",") +AUDIT_INCLUDED_PATHS = os.getenv('AUDIT_INCLUDED_PATHS', '').split(',') AUDIT_INCLUDED_PATHS = [path.strip() for path in AUDIT_INCLUDED_PATHS] -AUDIT_INCLUDED_PATHS = [path.lstrip("/") for path in AUDIT_INCLUDED_PATHS if path] +AUDIT_INCLUDED_PATHS = [path.lstrip('/') for path in AUDIT_INCLUDED_PATHS if path] #################################### # OPENTELEMETRY #################################### -ENABLE_OTEL = os.environ.get("ENABLE_OTEL", "False").lower() == "true" -ENABLE_OTEL_TRACES = os.environ.get("ENABLE_OTEL_TRACES", "False").lower() == "true" -ENABLE_OTEL_METRICS = os.environ.get("ENABLE_OTEL_METRICS", "False").lower() == "true" -ENABLE_OTEL_LOGS = os.environ.get("ENABLE_OTEL_LOGS", "False").lower() == "true" +ENABLE_OTEL = os.environ.get('ENABLE_OTEL', 'False').lower() == 'true' +ENABLE_OTEL_TRACES = os.environ.get('ENABLE_OTEL_TRACES', 'False').lower() == 'true' +ENABLE_OTEL_METRICS = os.environ.get('ENABLE_OTEL_METRICS', 'False').lower() == 'true' +ENABLE_OTEL_LOGS = os.environ.get('ENABLE_OTEL_LOGS', 'False').lower() == 'true' -OTEL_EXPORTER_OTLP_ENDPOINT = os.environ.get( - "OTEL_EXPORTER_OTLP_ENDPOINT", "http://localhost:4317" -) -OTEL_METRICS_EXPORTER_OTLP_ENDPOINT = os.environ.get( - "OTEL_METRICS_EXPORTER_OTLP_ENDPOINT", OTEL_EXPORTER_OTLP_ENDPOINT -) -OTEL_LOGS_EXPORTER_OTLP_ENDPOINT = os.environ.get( - "OTEL_LOGS_EXPORTER_OTLP_ENDPOINT", OTEL_EXPORTER_OTLP_ENDPOINT -) -OTEL_EXPORTER_OTLP_INSECURE = ( - os.environ.get("OTEL_EXPORTER_OTLP_INSECURE", "False").lower() == "true" -) +OTEL_EXPORTER_OTLP_ENDPOINT = os.environ.get('OTEL_EXPORTER_OTLP_ENDPOINT', 'http://localhost:4317') +OTEL_METRICS_EXPORTER_OTLP_ENDPOINT = os.environ.get('OTEL_METRICS_EXPORTER_OTLP_ENDPOINT', OTEL_EXPORTER_OTLP_ENDPOINT) +OTEL_LOGS_EXPORTER_OTLP_ENDPOINT = os.environ.get('OTEL_LOGS_EXPORTER_OTLP_ENDPOINT', OTEL_EXPORTER_OTLP_ENDPOINT) +OTEL_EXPORTER_OTLP_INSECURE = os.environ.get('OTEL_EXPORTER_OTLP_INSECURE', 'False').lower() == 'true' OTEL_METRICS_EXPORTER_OTLP_INSECURE = ( - os.environ.get( - "OTEL_METRICS_EXPORTER_OTLP_INSECURE", str(OTEL_EXPORTER_OTLP_INSECURE) - ).lower() - == "true" + os.environ.get('OTEL_METRICS_EXPORTER_OTLP_INSECURE', str(OTEL_EXPORTER_OTLP_INSECURE)).lower() == 'true' ) OTEL_LOGS_EXPORTER_OTLP_INSECURE = ( - os.environ.get( - "OTEL_LOGS_EXPORTER_OTLP_INSECURE", str(OTEL_EXPORTER_OTLP_INSECURE) - ).lower() - == "true" -) -OTEL_SERVICE_NAME = os.environ.get("OTEL_SERVICE_NAME", "open-webui") -OTEL_RESOURCE_ATTRIBUTES = os.environ.get( - "OTEL_RESOURCE_ATTRIBUTES", "" -) # e.g. key1=val1,key2=val2 -OTEL_TRACES_SAMPLER = os.environ.get( - "OTEL_TRACES_SAMPLER", "parentbased_always_on" -).lower() -OTEL_BASIC_AUTH_USERNAME = os.environ.get("OTEL_BASIC_AUTH_USERNAME", "") -OTEL_BASIC_AUTH_PASSWORD = os.environ.get("OTEL_BASIC_AUTH_PASSWORD", "") -OTEL_METRICS_EXPORT_INTERVAL_MILLIS = int( - os.environ.get("OTEL_METRICS_EXPORT_INTERVAL_MILLIS", "10000") + os.environ.get('OTEL_LOGS_EXPORTER_OTLP_INSECURE', str(OTEL_EXPORTER_OTLP_INSECURE)).lower() == 'true' ) +OTEL_SERVICE_NAME = os.environ.get('OTEL_SERVICE_NAME', 'open-webui') +OTEL_RESOURCE_ATTRIBUTES = os.environ.get('OTEL_RESOURCE_ATTRIBUTES', '') # e.g. key1=val1,key2=val2 +OTEL_TRACES_SAMPLER = os.environ.get('OTEL_TRACES_SAMPLER', 'parentbased_always_on').lower() +OTEL_BASIC_AUTH_USERNAME = os.environ.get('OTEL_BASIC_AUTH_USERNAME', '') +OTEL_BASIC_AUTH_PASSWORD = os.environ.get('OTEL_BASIC_AUTH_PASSWORD', '') +OTEL_METRICS_EXPORT_INTERVAL_MILLIS = int(os.environ.get('OTEL_METRICS_EXPORT_INTERVAL_MILLIS', '10000')) -OTEL_METRICS_BASIC_AUTH_USERNAME = os.environ.get( - "OTEL_METRICS_BASIC_AUTH_USERNAME", OTEL_BASIC_AUTH_USERNAME -) -OTEL_METRICS_BASIC_AUTH_PASSWORD = os.environ.get( - "OTEL_METRICS_BASIC_AUTH_PASSWORD", OTEL_BASIC_AUTH_PASSWORD -) -OTEL_LOGS_BASIC_AUTH_USERNAME = os.environ.get( - "OTEL_LOGS_BASIC_AUTH_USERNAME", OTEL_BASIC_AUTH_USERNAME -) -OTEL_LOGS_BASIC_AUTH_PASSWORD = os.environ.get( - "OTEL_LOGS_BASIC_AUTH_PASSWORD", OTEL_BASIC_AUTH_PASSWORD -) +OTEL_METRICS_BASIC_AUTH_USERNAME = os.environ.get('OTEL_METRICS_BASIC_AUTH_USERNAME', OTEL_BASIC_AUTH_USERNAME) +OTEL_METRICS_BASIC_AUTH_PASSWORD = os.environ.get('OTEL_METRICS_BASIC_AUTH_PASSWORD', OTEL_BASIC_AUTH_PASSWORD) +OTEL_LOGS_BASIC_AUTH_USERNAME = os.environ.get('OTEL_LOGS_BASIC_AUTH_USERNAME', OTEL_BASIC_AUTH_USERNAME) +OTEL_LOGS_BASIC_AUTH_PASSWORD = os.environ.get('OTEL_LOGS_BASIC_AUTH_PASSWORD', OTEL_BASIC_AUTH_PASSWORD) -OTEL_OTLP_SPAN_EXPORTER = os.environ.get( - "OTEL_OTLP_SPAN_EXPORTER", "grpc" -).lower() # grpc or http +OTEL_OTLP_SPAN_EXPORTER = os.environ.get('OTEL_OTLP_SPAN_EXPORTER', 'grpc').lower() # grpc or http OTEL_METRICS_OTLP_SPAN_EXPORTER = os.environ.get( - "OTEL_METRICS_OTLP_SPAN_EXPORTER", OTEL_OTLP_SPAN_EXPORTER + 'OTEL_METRICS_OTLP_SPAN_EXPORTER', OTEL_OTLP_SPAN_EXPORTER ).lower() # grpc or http OTEL_LOGS_OTLP_SPAN_EXPORTER = os.environ.get( - "OTEL_LOGS_OTLP_SPAN_EXPORTER", OTEL_OTLP_SPAN_EXPORTER + 'OTEL_LOGS_OTLP_SPAN_EXPORTER', OTEL_OTLP_SPAN_EXPORTER ).lower() # grpc or http #################################### @@ -1047,19 +903,18 @@ OTEL_LOGS_OTLP_SPAN_EXPORTER = os.environ.get( #################################### ENABLE_PIP_INSTALL_FRONTMATTER_REQUIREMENTS = ( - os.environ.get("ENABLE_PIP_INSTALL_FRONTMATTER_REQUIREMENTS", "True").lower() - == "true" + os.environ.get('ENABLE_PIP_INSTALL_FRONTMATTER_REQUIREMENTS', 'True').lower() == 'true' ) -PIP_OPTIONS = os.getenv("PIP_OPTIONS", "").split() -PIP_PACKAGE_INDEX_OPTIONS = os.getenv("PIP_PACKAGE_INDEX_OPTIONS", "").split() +PIP_OPTIONS = os.getenv('PIP_OPTIONS', '').split() +PIP_PACKAGE_INDEX_OPTIONS = os.getenv('PIP_PACKAGE_INDEX_OPTIONS', '').split() #################################### # PROGRESSIVE WEB APP OPTIONS #################################### -EXTERNAL_PWA_MANIFEST_URL = os.environ.get("EXTERNAL_PWA_MANIFEST_URL") +EXTERNAL_PWA_MANIFEST_URL = os.environ.get('EXTERNAL_PWA_MANIFEST_URL') #################################### # GROUP DEFAULTS @@ -1067,9 +922,5 @@ EXTERNAL_PWA_MANIFEST_URL = os.environ.get("EXTERNAL_PWA_MANIFEST_URL") # Controls the default "Who can share to this group" setting for new groups. # Env var values: "true" (anyone), "false" (no one), "members" (only group members). -_default_group_share = ( - os.environ.get("DEFAULT_GROUP_SHARE_PERMISSION", "members").strip().lower() -) -DEFAULT_GROUP_SHARE_PERMISSION = ( - "members" if _default_group_share == "members" else _default_group_share == "true" -) +_default_group_share = os.environ.get('DEFAULT_GROUP_SHARE_PERMISSION', 'members').strip().lower() +DEFAULT_GROUP_SHARE_PERMISSION = 'members' if _default_group_share == 'members' else _default_group_share == 'true' diff --git a/backend/open_webui/functions.py b/backend/open_webui/functions.py index 364d2a889b..9bfe77c41e 100644 --- a/backend/open_webui/functions.py +++ b/backend/open_webui/functions.py @@ -57,17 +57,15 @@ log = logging.getLogger(__name__) def get_function_module_by_id(request: Request, pipe_id: str): function_module, _, _ = get_function_module_from_cache(request, pipe_id) - if hasattr(function_module, "valves") and hasattr(function_module, "Valves"): + if hasattr(function_module, 'valves') and hasattr(function_module, 'Valves'): Valves = function_module.Valves valves = Functions.get_function_valves_by_id(pipe_id) if valves: try: - function_module.valves = Valves( - **{k: v for k, v in valves.items() if v is not None} - ) + function_module.valves = Valves(**{k: v for k, v in valves.items() if v is not None}) except Exception as e: - log.exception(f"Error loading valves for function {pipe_id}: {e}") + log.exception(f'Error loading valves for function {pipe_id}: {e}') raise e else: function_module.valves = Valves() @@ -76,7 +74,7 @@ def get_function_module_by_id(request: Request, pipe_id: str): async def get_function_models(request): - pipes = Functions.get_functions_by_type("pipe", active_only=True) + pipes = Functions.get_functions_by_type('pipe', active_only=True) pipe_models = [] for pipe in pipes: @@ -84,11 +82,11 @@ async def get_function_models(request): function_module = get_function_module_by_id(request, pipe.id) has_user_valves = False - if hasattr(function_module, "UserValves"): + if hasattr(function_module, 'UserValves'): has_user_valves = True # Check if function is a manifold - if hasattr(function_module, "pipes"): + if hasattr(function_module, 'pipes'): sub_pipes = [] # Handle pipes being a list, sync function, or async function @@ -104,32 +102,30 @@ async def get_function_models(request): log.exception(e) sub_pipes = [] - log.debug( - f"get_function_models: function '{pipe.id}' is a manifold of {sub_pipes}" - ) + log.debug(f"get_function_models: function '{pipe.id}' is a manifold of {sub_pipes}") for p in sub_pipes: sub_pipe_id = f'{pipe.id}.{p["id"]}' - sub_pipe_name = p["name"] + sub_pipe_name = p['name'] - if hasattr(function_module, "name"): - sub_pipe_name = f"{function_module.name}{sub_pipe_name}" + if hasattr(function_module, 'name'): + sub_pipe_name = f'{function_module.name}{sub_pipe_name}' - pipe_flag = {"type": pipe.type} + pipe_flag = {'type': pipe.type} pipe_models.append( { - "id": sub_pipe_id, - "name": sub_pipe_name, - "object": "model", - "created": pipe.created_at, - "owned_by": "openai", - "pipe": pipe_flag, - "has_user_valves": has_user_valves, + 'id': sub_pipe_id, + 'name': sub_pipe_name, + 'object': 'model', + 'created': pipe.created_at, + 'owned_by': 'openai', + 'pipe': pipe_flag, + 'has_user_valves': has_user_valves, } ) else: - pipe_flag = {"type": "pipe"} + pipe_flag = {'type': 'pipe'} log.debug( f"get_function_models: function '{pipe.id}' is a single pipe {{ 'id': {pipe.id}, 'name': {pipe.name} }}" @@ -137,13 +133,13 @@ async def get_function_models(request): pipe_models.append( { - "id": pipe.id, - "name": pipe.name, - "object": "model", - "created": pipe.created_at, - "owned_by": "openai", - "pipe": pipe_flag, - "has_user_valves": has_user_valves, + 'id': pipe.id, + 'name': pipe.name, + 'object': 'model', + 'created': pipe.created_at, + 'owned_by': 'openai', + 'pipe': pipe_flag, + 'has_user_valves': has_user_valves, } ) except Exception as e: @@ -153,9 +149,7 @@ async def get_function_models(request): return pipe_models -async def generate_function_chat_completion( - request, form_data, user, models: dict = {} -): +async def generate_function_chat_completion(request, form_data, user, models: dict = {}): async def execute_pipe(pipe, params): if inspect.iscoroutinefunction(pipe): return await pipe(**params) @@ -166,32 +160,32 @@ async def generate_function_chat_completion( if isinstance(res, str): return res if isinstance(res, Generator): - return "".join(map(str, res)) + return ''.join(map(str, res)) if isinstance(res, AsyncGenerator): - return "".join([str(stream) async for stream in res]) + return ''.join([str(stream) async for stream in res]) def process_line(form_data: dict, line): if isinstance(line, BaseModel): line = line.model_dump_json() - line = f"data: {line}" + line = f'data: {line}' if isinstance(line, dict): - line = f"data: {json.dumps(line)}" + line = f'data: {json.dumps(line)}' try: - line = line.decode("utf-8") + line = line.decode('utf-8') except Exception: pass - if line.startswith("data:"): - return f"{line}\n\n" + if line.startswith('data:'): + return f'{line}\n\n' else: - line = openai_chat_chunk_message_template(form_data["model"], line) - return f"data: {json.dumps(line)}\n\n" + line = openai_chat_chunk_message_template(form_data['model'], line) + return f'data: {json.dumps(line)}\n\n' def get_pipe_id(form_data: dict) -> str: - pipe_id = form_data["model"] - if "." in pipe_id: - pipe_id, _ = pipe_id.split(".", 1) + pipe_id = form_data['model'] + if '.' in pipe_id: + pipe_id, _ = pipe_id.split('.', 1) return pipe_id def get_function_params(function_module, form_data, user, extra_params=None): @@ -202,27 +196,25 @@ async def generate_function_chat_completion( # Get the signature of the function sig = inspect.signature(function_module.pipe) - params = {"body": form_data} | { - k: v for k, v in extra_params.items() if k in sig.parameters - } + params = {'body': form_data} | {k: v for k, v in extra_params.items() if k in sig.parameters} - if "__user__" in params and hasattr(function_module, "UserValves"): + if '__user__' in params and hasattr(function_module, 'UserValves'): user_valves = Functions.get_user_valves_by_id_and_user_id(pipe_id, user.id) try: - params["__user__"]["valves"] = function_module.UserValves(**user_valves) + params['__user__']['valves'] = function_module.UserValves(**user_valves) except Exception as e: log.exception(e) - params["__user__"]["valves"] = function_module.UserValves() + params['__user__']['valves'] = function_module.UserValves() return params - model_id = form_data.get("model") + model_id = form_data.get('model') model_info = Models.get_model_by_id(model_id) - metadata = form_data.pop("metadata", {}) + metadata = form_data.pop('metadata', {}) - files = metadata.get("files", []) - tool_ids = metadata.get("tool_ids", []) + files = metadata.get('files', []) + tool_ids = metadata.get('tool_ids', []) # Check if tool_ids is None if tool_ids is None: tool_ids = [] @@ -233,56 +225,56 @@ async def generate_function_chat_completion( __task_body__ = None if metadata: - if all(k in metadata for k in ("session_id", "chat_id", "message_id")): + if all(k in metadata for k in ('session_id', 'chat_id', 'message_id')): __event_emitter__ = get_event_emitter(metadata) __event_call__ = get_event_call(metadata) - __task__ = metadata.get("task", None) - __task_body__ = metadata.get("task_body", None) + __task__ = metadata.get('task', None) + __task_body__ = metadata.get('task_body', None) oauth_token = None try: - if request.cookies.get("oauth_session_id", None): + if request.cookies.get('oauth_session_id', None): oauth_token = await request.app.state.oauth_manager.get_oauth_token( user.id, - request.cookies.get("oauth_session_id", None), + request.cookies.get('oauth_session_id', None), ) except Exception as e: - log.error(f"Error getting OAuth token: {e}") + log.error(f'Error getting OAuth token: {e}') extra_params = { - "__event_emitter__": __event_emitter__, - "__event_call__": __event_call__, - "__chat_id__": metadata.get("chat_id", None), - "__session_id__": metadata.get("session_id", None), - "__message_id__": metadata.get("message_id", None), - "__task__": __task__, - "__task_body__": __task_body__, - "__files__": files, - "__user__": user.model_dump() if isinstance(user, UserModel) else {}, - "__metadata__": metadata, - "__oauth_token__": oauth_token, - "__request__": request, + '__event_emitter__': __event_emitter__, + '__event_call__': __event_call__, + '__chat_id__': metadata.get('chat_id', None), + '__session_id__': metadata.get('session_id', None), + '__message_id__': metadata.get('message_id', None), + '__task__': __task__, + '__task_body__': __task_body__, + '__files__': files, + '__user__': user.model_dump() if isinstance(user, UserModel) else {}, + '__metadata__': metadata, + '__oauth_token__': oauth_token, + '__request__': request, } - extra_params["__tools__"] = await get_tools( + extra_params['__tools__'] = await get_tools( request, tool_ids, user, { **extra_params, - "__model__": models.get(form_data["model"], None), - "__messages__": form_data["messages"], - "__files__": files, + '__model__': models.get(form_data['model'], None), + '__messages__': form_data['messages'], + '__files__': files, }, ) if model_info: if model_info.base_model_id: - form_data["model"] = model_info.base_model_id + form_data['model'] = model_info.base_model_id params = model_info.params.model_dump() if params: - system = params.pop("system", None) + system = params.pop('system', None) form_data = apply_model_params_to_body_openai(params, form_data) form_data = apply_system_prompt_to_body(system, form_data, metadata, user) @@ -292,7 +284,7 @@ async def generate_function_chat_completion( pipe = function_module.pipe params = get_function_params(function_module, form_data, user, extra_params) - if form_data.get("stream", False): + if form_data.get('stream', False): async def stream_content(): try: @@ -304,17 +296,17 @@ async def generate_function_chat_completion( yield data return if isinstance(res, dict): - yield f"data: {json.dumps(res)}\n\n" + yield f'data: {json.dumps(res)}\n\n' return except Exception as e: - log.error(f"Error: {e}") - yield f"data: {json.dumps({'error': {'detail':str(e)}})}\n\n" + log.error(f'Error: {e}') + yield f'data: {json.dumps({"error": {"detail": str(e)}})}\n\n' return if isinstance(res, str): - message = openai_chat_chunk_message_template(form_data["model"], res) - yield f"data: {json.dumps(message)}\n\n" + message = openai_chat_chunk_message_template(form_data['model'], res) + yield f'data: {json.dumps(message)}\n\n' if isinstance(res, Iterator): for line in res: @@ -325,21 +317,19 @@ async def generate_function_chat_completion( yield process_line(form_data, line) if isinstance(res, str) or isinstance(res, Generator): - finish_message = openai_chat_chunk_message_template( - form_data["model"], "" - ) - finish_message["choices"][0]["finish_reason"] = "stop" - yield f"data: {json.dumps(finish_message)}\n\n" - yield "data: [DONE]" + finish_message = openai_chat_chunk_message_template(form_data['model'], '') + finish_message['choices'][0]['finish_reason'] = 'stop' + yield f'data: {json.dumps(finish_message)}\n\n' + yield 'data: [DONE]' - return StreamingResponse(stream_content(), media_type="text/event-stream") + return StreamingResponse(stream_content(), media_type='text/event-stream') else: try: res = await execute_pipe(pipe, params) except Exception as e: - log.error(f"Error: {e}") - return {"error": {"detail": str(e)}} + log.error(f'Error: {e}') + return {'error': {'detail': str(e)}} if isinstance(res, StreamingResponse) or isinstance(res, dict): return res @@ -347,4 +337,4 @@ async def generate_function_chat_completion( return res.model_dump() message = await get_message_content(res) - return openai_chat_completion_message_template(form_data["model"], message) + return openai_chat_completion_message_template(form_data['model'], message) diff --git a/backend/open_webui/internal/db.py b/backend/open_webui/internal/db.py index afc2e76621..b0545255a6 100644 --- a/backend/open_webui/internal/db.py +++ b/backend/open_webui/internal/db.py @@ -56,17 +56,15 @@ def handle_peewee_migration(DATABASE_URL): # db = None try: # Replace the postgresql:// with postgres:// to handle the peewee migration - db = register_connection(DATABASE_URL.replace("postgresql://", "postgres://")) - migrate_dir = OPEN_WEBUI_DIR / "internal" / "migrations" + db = register_connection(DATABASE_URL.replace('postgresql://', 'postgres://')) + migrate_dir = OPEN_WEBUI_DIR / 'internal' / 'migrations' router = Router(db, logger=log, migrate_dir=migrate_dir) router.run() db.close() except Exception as e: - log.error(f"Failed to initialize the database connection: {e}") - log.warning( - "Hint: If your database password contains special characters, you may need to URL-encode it." - ) + log.error(f'Failed to initialize the database connection: {e}') + log.warning('Hint: If your database password contains special characters, you may need to URL-encode it.') raise finally: # Properly closing the database connection @@ -74,7 +72,7 @@ def handle_peewee_migration(DATABASE_URL): db.close() # Assert if db connection has been closed - assert db.is_closed(), "Database connection is still open." + assert db.is_closed(), 'Database connection is still open.' if ENABLE_DB_MIGRATIONS: @@ -84,15 +82,13 @@ if ENABLE_DB_MIGRATIONS: SQLALCHEMY_DATABASE_URL = DATABASE_URL # Handle SQLCipher URLs -if SQLALCHEMY_DATABASE_URL.startswith("sqlite+sqlcipher://"): - database_password = os.environ.get("DATABASE_PASSWORD") - if not database_password or database_password.strip() == "": - raise ValueError( - "DATABASE_PASSWORD is required when using sqlite+sqlcipher:// URLs" - ) +if SQLALCHEMY_DATABASE_URL.startswith('sqlite+sqlcipher://'): + database_password = os.environ.get('DATABASE_PASSWORD') + if not database_password or database_password.strip() == '': + raise ValueError('DATABASE_PASSWORD is required when using sqlite+sqlcipher:// URLs') # Extract database path from SQLCipher URL - db_path = SQLALCHEMY_DATABASE_URL.replace("sqlite+sqlcipher://", "") + db_path = SQLALCHEMY_DATABASE_URL.replace('sqlite+sqlcipher://', '') # Create a custom creator function that uses sqlcipher3 def create_sqlcipher_connection(): @@ -109,7 +105,7 @@ if SQLALCHEMY_DATABASE_URL.startswith("sqlite+sqlcipher://"): # or QueuePool if DATABASE_POOL_SIZE is explicitly configured. if isinstance(DATABASE_POOL_SIZE, int) and DATABASE_POOL_SIZE > 0: engine = create_engine( - "sqlite://", + 'sqlite://', creator=create_sqlcipher_connection, pool_size=DATABASE_POOL_SIZE, max_overflow=DATABASE_POOL_MAX_OVERFLOW, @@ -121,28 +117,26 @@ if SQLALCHEMY_DATABASE_URL.startswith("sqlite+sqlcipher://"): ) else: engine = create_engine( - "sqlite://", + 'sqlite://', creator=create_sqlcipher_connection, poolclass=NullPool, echo=False, ) - log.info("Connected to encrypted SQLite database using SQLCipher") + log.info('Connected to encrypted SQLite database using SQLCipher') -elif "sqlite" in SQLALCHEMY_DATABASE_URL: - engine = create_engine( - SQLALCHEMY_DATABASE_URL, connect_args={"check_same_thread": False} - ) +elif 'sqlite' in SQLALCHEMY_DATABASE_URL: + engine = create_engine(SQLALCHEMY_DATABASE_URL, connect_args={'check_same_thread': False}) def on_connect(dbapi_connection, connection_record): cursor = dbapi_connection.cursor() if DATABASE_ENABLE_SQLITE_WAL: - cursor.execute("PRAGMA journal_mode=WAL") + cursor.execute('PRAGMA journal_mode=WAL') else: - cursor.execute("PRAGMA journal_mode=DELETE") + cursor.execute('PRAGMA journal_mode=DELETE') cursor.close() - event.listen(engine, "connect", on_connect) + event.listen(engine, 'connect', on_connect) else: if isinstance(DATABASE_POOL_SIZE, int): if DATABASE_POOL_SIZE > 0: @@ -156,16 +150,12 @@ else: poolclass=QueuePool, ) else: - engine = create_engine( - SQLALCHEMY_DATABASE_URL, pool_pre_ping=True, poolclass=NullPool - ) + engine = create_engine(SQLALCHEMY_DATABASE_URL, pool_pre_ping=True, poolclass=NullPool) else: engine = create_engine(SQLALCHEMY_DATABASE_URL, pool_pre_ping=True) -SessionLocal = sessionmaker( - autocommit=False, autoflush=False, bind=engine, expire_on_commit=False -) +SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine, expire_on_commit=False) metadata_obj = MetaData(schema=DATABASE_SCHEMA) Base = declarative_base(metadata=metadata_obj) ScopedSession = scoped_session(SessionLocal) diff --git a/backend/open_webui/internal/migrations/001_initial_schema.py b/backend/open_webui/internal/migrations/001_initial_schema.py index 0df2249b21..4268201ae7 100644 --- a/backend/open_webui/internal/migrations/001_initial_schema.py +++ b/backend/open_webui/internal/migrations/001_initial_schema.py @@ -56,7 +56,7 @@ def migrate_sqlite(migrator: Migrator, database: pw.Database, *, fake=False): active = pw.BooleanField() class Meta: - table_name = "auth" + table_name = 'auth' @migrator.create_model class Chat(pw.Model): @@ -67,7 +67,7 @@ def migrate_sqlite(migrator: Migrator, database: pw.Database, *, fake=False): timestamp = pw.BigIntegerField() class Meta: - table_name = "chat" + table_name = 'chat' @migrator.create_model class ChatIdTag(pw.Model): @@ -78,7 +78,7 @@ def migrate_sqlite(migrator: Migrator, database: pw.Database, *, fake=False): timestamp = pw.BigIntegerField() class Meta: - table_name = "chatidtag" + table_name = 'chatidtag' @migrator.create_model class Document(pw.Model): @@ -92,7 +92,7 @@ def migrate_sqlite(migrator: Migrator, database: pw.Database, *, fake=False): timestamp = pw.BigIntegerField() class Meta: - table_name = "document" + table_name = 'document' @migrator.create_model class Modelfile(pw.Model): @@ -103,7 +103,7 @@ def migrate_sqlite(migrator: Migrator, database: pw.Database, *, fake=False): timestamp = pw.BigIntegerField() class Meta: - table_name = "modelfile" + table_name = 'modelfile' @migrator.create_model class Prompt(pw.Model): @@ -115,7 +115,7 @@ def migrate_sqlite(migrator: Migrator, database: pw.Database, *, fake=False): timestamp = pw.BigIntegerField() class Meta: - table_name = "prompt" + table_name = 'prompt' @migrator.create_model class Tag(pw.Model): @@ -125,7 +125,7 @@ def migrate_sqlite(migrator: Migrator, database: pw.Database, *, fake=False): data = pw.TextField(null=True) class Meta: - table_name = "tag" + table_name = 'tag' @migrator.create_model class User(pw.Model): @@ -137,7 +137,7 @@ def migrate_sqlite(migrator: Migrator, database: pw.Database, *, fake=False): timestamp = pw.BigIntegerField() class Meta: - table_name = "user" + table_name = 'user' def migrate_external(migrator: Migrator, database: pw.Database, *, fake=False): @@ -149,7 +149,7 @@ def migrate_external(migrator: Migrator, database: pw.Database, *, fake=False): active = pw.BooleanField() class Meta: - table_name = "auth" + table_name = 'auth' @migrator.create_model class Chat(pw.Model): @@ -160,7 +160,7 @@ def migrate_external(migrator: Migrator, database: pw.Database, *, fake=False): timestamp = pw.BigIntegerField() class Meta: - table_name = "chat" + table_name = 'chat' @migrator.create_model class ChatIdTag(pw.Model): @@ -171,7 +171,7 @@ def migrate_external(migrator: Migrator, database: pw.Database, *, fake=False): timestamp = pw.BigIntegerField() class Meta: - table_name = "chatidtag" + table_name = 'chatidtag' @migrator.create_model class Document(pw.Model): @@ -185,7 +185,7 @@ def migrate_external(migrator: Migrator, database: pw.Database, *, fake=False): timestamp = pw.BigIntegerField() class Meta: - table_name = "document" + table_name = 'document' @migrator.create_model class Modelfile(pw.Model): @@ -196,7 +196,7 @@ def migrate_external(migrator: Migrator, database: pw.Database, *, fake=False): timestamp = pw.BigIntegerField() class Meta: - table_name = "modelfile" + table_name = 'modelfile' @migrator.create_model class Prompt(pw.Model): @@ -208,7 +208,7 @@ def migrate_external(migrator: Migrator, database: pw.Database, *, fake=False): timestamp = pw.BigIntegerField() class Meta: - table_name = "prompt" + table_name = 'prompt' @migrator.create_model class Tag(pw.Model): @@ -218,7 +218,7 @@ def migrate_external(migrator: Migrator, database: pw.Database, *, fake=False): data = pw.TextField(null=True) class Meta: - table_name = "tag" + table_name = 'tag' @migrator.create_model class User(pw.Model): @@ -230,24 +230,24 @@ def migrate_external(migrator: Migrator, database: pw.Database, *, fake=False): timestamp = pw.BigIntegerField() class Meta: - table_name = "user" + table_name = 'user' def rollback(migrator: Migrator, database: pw.Database, *, fake=False): """Write your rollback migrations here.""" - migrator.remove_model("user") + migrator.remove_model('user') - migrator.remove_model("tag") + migrator.remove_model('tag') - migrator.remove_model("prompt") + migrator.remove_model('prompt') - migrator.remove_model("modelfile") + migrator.remove_model('modelfile') - migrator.remove_model("document") + migrator.remove_model('document') - migrator.remove_model("chatidtag") + migrator.remove_model('chatidtag') - migrator.remove_model("chat") + migrator.remove_model('chat') - migrator.remove_model("auth") + migrator.remove_model('auth') diff --git a/backend/open_webui/internal/migrations/002_add_local_sharing.py b/backend/open_webui/internal/migrations/002_add_local_sharing.py index a01862d103..e3e557602b 100644 --- a/backend/open_webui/internal/migrations/002_add_local_sharing.py +++ b/backend/open_webui/internal/migrations/002_add_local_sharing.py @@ -36,12 +36,10 @@ with suppress(ImportError): def migrate(migrator: Migrator, database: pw.Database, *, fake=False): """Write your migrations here.""" - migrator.add_fields( - "chat", share_id=pw.CharField(max_length=255, null=True, unique=True) - ) + migrator.add_fields('chat', share_id=pw.CharField(max_length=255, null=True, unique=True)) def rollback(migrator: Migrator, database: pw.Database, *, fake=False): """Write your rollback migrations here.""" - migrator.remove_fields("chat", "share_id") + migrator.remove_fields('chat', 'share_id') diff --git a/backend/open_webui/internal/migrations/003_add_auth_api_key.py b/backend/open_webui/internal/migrations/003_add_auth_api_key.py index 23cba26383..acb63fc728 100644 --- a/backend/open_webui/internal/migrations/003_add_auth_api_key.py +++ b/backend/open_webui/internal/migrations/003_add_auth_api_key.py @@ -36,12 +36,10 @@ with suppress(ImportError): def migrate(migrator: Migrator, database: pw.Database, *, fake=False): """Write your migrations here.""" - migrator.add_fields( - "user", api_key=pw.CharField(max_length=255, null=True, unique=True) - ) + migrator.add_fields('user', api_key=pw.CharField(max_length=255, null=True, unique=True)) def rollback(migrator: Migrator, database: pw.Database, *, fake=False): """Write your rollback migrations here.""" - migrator.remove_fields("user", "api_key") + migrator.remove_fields('user', 'api_key') diff --git a/backend/open_webui/internal/migrations/004_add_archived.py b/backend/open_webui/internal/migrations/004_add_archived.py index 11108a3e0b..abed1727b9 100644 --- a/backend/open_webui/internal/migrations/004_add_archived.py +++ b/backend/open_webui/internal/migrations/004_add_archived.py @@ -36,10 +36,10 @@ with suppress(ImportError): def migrate(migrator: Migrator, database: pw.Database, *, fake=False): """Write your migrations here.""" - migrator.add_fields("chat", archived=pw.BooleanField(default=False)) + migrator.add_fields('chat', archived=pw.BooleanField(default=False)) def rollback(migrator: Migrator, database: pw.Database, *, fake=False): """Write your rollback migrations here.""" - migrator.remove_fields("chat", "archived") + migrator.remove_fields('chat', 'archived') diff --git a/backend/open_webui/internal/migrations/005_add_updated_at.py b/backend/open_webui/internal/migrations/005_add_updated_at.py index f7fc69a5db..bff311e2d4 100644 --- a/backend/open_webui/internal/migrations/005_add_updated_at.py +++ b/backend/open_webui/internal/migrations/005_add_updated_at.py @@ -45,22 +45,20 @@ def migrate(migrator: Migrator, database: pw.Database, *, fake=False): def migrate_sqlite(migrator: Migrator, database: pw.Database, *, fake=False): # Adding fields created_at and updated_at to the 'chat' table migrator.add_fields( - "chat", + 'chat', created_at=pw.DateTimeField(null=True), # Allow null for transition updated_at=pw.DateTimeField(null=True), # Allow null for transition ) # Populate the new fields from an existing 'timestamp' field - migrator.sql( - "UPDATE chat SET created_at = timestamp, updated_at = timestamp WHERE timestamp IS NOT NULL" - ) + migrator.sql('UPDATE chat SET created_at = timestamp, updated_at = timestamp WHERE timestamp IS NOT NULL') # Now that the data has been copied, remove the original 'timestamp' field - migrator.remove_fields("chat", "timestamp") + migrator.remove_fields('chat', 'timestamp') # Update the fields to be not null now that they are populated migrator.change_fields( - "chat", + 'chat', created_at=pw.DateTimeField(null=False), updated_at=pw.DateTimeField(null=False), ) @@ -69,22 +67,20 @@ def migrate_sqlite(migrator: Migrator, database: pw.Database, *, fake=False): def migrate_external(migrator: Migrator, database: pw.Database, *, fake=False): # Adding fields created_at and updated_at to the 'chat' table migrator.add_fields( - "chat", + 'chat', created_at=pw.BigIntegerField(null=True), # Allow null for transition updated_at=pw.BigIntegerField(null=True), # Allow null for transition ) # Populate the new fields from an existing 'timestamp' field - migrator.sql( - "UPDATE chat SET created_at = timestamp, updated_at = timestamp WHERE timestamp IS NOT NULL" - ) + migrator.sql('UPDATE chat SET created_at = timestamp, updated_at = timestamp WHERE timestamp IS NOT NULL') # Now that the data has been copied, remove the original 'timestamp' field - migrator.remove_fields("chat", "timestamp") + migrator.remove_fields('chat', 'timestamp') # Update the fields to be not null now that they are populated migrator.change_fields( - "chat", + 'chat', created_at=pw.BigIntegerField(null=False), updated_at=pw.BigIntegerField(null=False), ) @@ -101,29 +97,29 @@ def rollback(migrator: Migrator, database: pw.Database, *, fake=False): def rollback_sqlite(migrator: Migrator, database: pw.Database, *, fake=False): # Recreate the timestamp field initially allowing null values for safe transition - migrator.add_fields("chat", timestamp=pw.DateTimeField(null=True)) + migrator.add_fields('chat', timestamp=pw.DateTimeField(null=True)) # Copy the earliest created_at date back into the new timestamp field # This assumes created_at was originally a copy of timestamp - migrator.sql("UPDATE chat SET timestamp = created_at") + migrator.sql('UPDATE chat SET timestamp = created_at') # Remove the created_at and updated_at fields - migrator.remove_fields("chat", "created_at", "updated_at") + migrator.remove_fields('chat', 'created_at', 'updated_at') # Finally, alter the timestamp field to not allow nulls if that was the original setting - migrator.change_fields("chat", timestamp=pw.DateTimeField(null=False)) + migrator.change_fields('chat', timestamp=pw.DateTimeField(null=False)) def rollback_external(migrator: Migrator, database: pw.Database, *, fake=False): # Recreate the timestamp field initially allowing null values for safe transition - migrator.add_fields("chat", timestamp=pw.BigIntegerField(null=True)) + migrator.add_fields('chat', timestamp=pw.BigIntegerField(null=True)) # Copy the earliest created_at date back into the new timestamp field # This assumes created_at was originally a copy of timestamp - migrator.sql("UPDATE chat SET timestamp = created_at") + migrator.sql('UPDATE chat SET timestamp = created_at') # Remove the created_at and updated_at fields - migrator.remove_fields("chat", "created_at", "updated_at") + migrator.remove_fields('chat', 'created_at', 'updated_at') # Finally, alter the timestamp field to not allow nulls if that was the original setting - migrator.change_fields("chat", timestamp=pw.BigIntegerField(null=False)) + migrator.change_fields('chat', timestamp=pw.BigIntegerField(null=False)) diff --git a/backend/open_webui/internal/migrations/006_migrate_timestamps_and_charfields.py b/backend/open_webui/internal/migrations/006_migrate_timestamps_and_charfields.py index abe7016c57..86f90eb880 100644 --- a/backend/open_webui/internal/migrations/006_migrate_timestamps_and_charfields.py +++ b/backend/open_webui/internal/migrations/006_migrate_timestamps_and_charfields.py @@ -38,45 +38,45 @@ def migrate(migrator: Migrator, database: pw.Database, *, fake=False): # Alter the tables with timestamps migrator.change_fields( - "chatidtag", + 'chatidtag', timestamp=pw.BigIntegerField(), ) migrator.change_fields( - "document", + 'document', timestamp=pw.BigIntegerField(), ) migrator.change_fields( - "modelfile", + 'modelfile', timestamp=pw.BigIntegerField(), ) migrator.change_fields( - "prompt", + 'prompt', timestamp=pw.BigIntegerField(), ) migrator.change_fields( - "user", + 'user', timestamp=pw.BigIntegerField(), ) # Alter the tables with varchar to text where necessary migrator.change_fields( - "auth", + 'auth', password=pw.TextField(), ) migrator.change_fields( - "chat", + 'chat', title=pw.TextField(), ) migrator.change_fields( - "document", + 'document', title=pw.TextField(), filename=pw.TextField(), ) migrator.change_fields( - "prompt", + 'prompt', title=pw.TextField(), ) migrator.change_fields( - "user", + 'user', profile_image_url=pw.TextField(), ) @@ -87,43 +87,43 @@ def rollback(migrator: Migrator, database: pw.Database, *, fake=False): if isinstance(database, pw.SqliteDatabase): # Alter the tables with timestamps migrator.change_fields( - "chatidtag", + 'chatidtag', timestamp=pw.DateField(), ) migrator.change_fields( - "document", + 'document', timestamp=pw.DateField(), ) migrator.change_fields( - "modelfile", + 'modelfile', timestamp=pw.DateField(), ) migrator.change_fields( - "prompt", + 'prompt', timestamp=pw.DateField(), ) migrator.change_fields( - "user", + 'user', timestamp=pw.DateField(), ) migrator.change_fields( - "auth", + 'auth', password=pw.CharField(max_length=255), ) migrator.change_fields( - "chat", + 'chat', title=pw.CharField(), ) migrator.change_fields( - "document", + 'document', title=pw.CharField(), filename=pw.CharField(), ) migrator.change_fields( - "prompt", + 'prompt', title=pw.CharField(), ) migrator.change_fields( - "user", + 'user', profile_image_url=pw.CharField(), ) diff --git a/backend/open_webui/internal/migrations/007_add_user_last_active_at.py b/backend/open_webui/internal/migrations/007_add_user_last_active_at.py index 3f89a5f59f..19a26c3515 100644 --- a/backend/open_webui/internal/migrations/007_add_user_last_active_at.py +++ b/backend/open_webui/internal/migrations/007_add_user_last_active_at.py @@ -38,7 +38,7 @@ def migrate(migrator: Migrator, database: pw.Database, *, fake=False): # Adding fields created_at and updated_at to the 'user' table migrator.add_fields( - "user", + 'user', created_at=pw.BigIntegerField(null=True), # Allow null for transition updated_at=pw.BigIntegerField(null=True), # Allow null for transition last_active_at=pw.BigIntegerField(null=True), # Allow null for transition @@ -50,11 +50,11 @@ def migrate(migrator: Migrator, database: pw.Database, *, fake=False): ) # Now that the data has been copied, remove the original 'timestamp' field - migrator.remove_fields("user", "timestamp") + migrator.remove_fields('user', 'timestamp') # Update the fields to be not null now that they are populated migrator.change_fields( - "user", + 'user', created_at=pw.BigIntegerField(null=False), updated_at=pw.BigIntegerField(null=False), last_active_at=pw.BigIntegerField(null=False), @@ -65,14 +65,14 @@ def rollback(migrator: Migrator, database: pw.Database, *, fake=False): """Write your rollback migrations here.""" # Recreate the timestamp field initially allowing null values for safe transition - migrator.add_fields("user", timestamp=pw.BigIntegerField(null=True)) + migrator.add_fields('user', timestamp=pw.BigIntegerField(null=True)) # Copy the earliest created_at date back into the new timestamp field # This assumes created_at was originally a copy of timestamp migrator.sql('UPDATE "user" SET timestamp = created_at') # Remove the created_at and updated_at fields - migrator.remove_fields("user", "created_at", "updated_at", "last_active_at") + migrator.remove_fields('user', 'created_at', 'updated_at', 'last_active_at') # Finally, alter the timestamp field to not allow nulls if that was the original setting - migrator.change_fields("user", timestamp=pw.BigIntegerField(null=False)) + migrator.change_fields('user', timestamp=pw.BigIntegerField(null=False)) diff --git a/backend/open_webui/internal/migrations/008_add_memory.py b/backend/open_webui/internal/migrations/008_add_memory.py index 96be907eba..f3af64fe95 100644 --- a/backend/open_webui/internal/migrations/008_add_memory.py +++ b/backend/open_webui/internal/migrations/008_add_memory.py @@ -43,10 +43,10 @@ def migrate(migrator: Migrator, database: pw.Database, *, fake=False): created_at = pw.BigIntegerField(null=False) class Meta: - table_name = "memory" + table_name = 'memory' def rollback(migrator: Migrator, database: pw.Database, *, fake=False): """Write your rollback migrations here.""" - migrator.remove_model("memory") + migrator.remove_model('memory') diff --git a/backend/open_webui/internal/migrations/009_add_models.py b/backend/open_webui/internal/migrations/009_add_models.py index 0a8d73bd3b..45f4a3d163 100644 --- a/backend/open_webui/internal/migrations/009_add_models.py +++ b/backend/open_webui/internal/migrations/009_add_models.py @@ -51,10 +51,10 @@ def migrate(migrator: Migrator, database: pw.Database, *, fake=False): updated_at = pw.BigIntegerField(null=False) class Meta: - table_name = "model" + table_name = 'model' def rollback(migrator: Migrator, database: pw.Database, *, fake=False): """Write your rollback migrations here.""" - migrator.remove_model("model") + migrator.remove_model('model') diff --git a/backend/open_webui/internal/migrations/010_migrate_modelfiles_to_models.py b/backend/open_webui/internal/migrations/010_migrate_modelfiles_to_models.py index 322ddd44ec..e523d6a098 100644 --- a/backend/open_webui/internal/migrations/010_migrate_modelfiles_to_models.py +++ b/backend/open_webui/internal/migrations/010_migrate_modelfiles_to_models.py @@ -42,12 +42,12 @@ def migrate(migrator: Migrator, database: pw.Database, *, fake=False): # Fetch data from 'modelfile' table and insert into 'model' table migrate_modelfile_to_model(migrator, database) # Drop the 'modelfile' table - migrator.remove_model("modelfile") + migrator.remove_model('modelfile') def migrate_modelfile_to_model(migrator: Migrator, database: pw.Database): - ModelFile = migrator.orm["modelfile"] - Model = migrator.orm["model"] + ModelFile = migrator.orm['modelfile'] + Model = migrator.orm['model'] modelfiles = ModelFile.select() @@ -57,25 +57,25 @@ def migrate_modelfile_to_model(migrator: Migrator, database: pw.Database): modelfile.modelfile = json.loads(modelfile.modelfile) meta = json.dumps( { - "description": modelfile.modelfile.get("desc"), - "profile_image_url": modelfile.modelfile.get("imageUrl"), - "ollama": {"modelfile": modelfile.modelfile.get("content")}, - "suggestion_prompts": modelfile.modelfile.get("suggestionPrompts"), - "categories": modelfile.modelfile.get("categories"), - "user": {**modelfile.modelfile.get("user", {}), "community": True}, + 'description': modelfile.modelfile.get('desc'), + 'profile_image_url': modelfile.modelfile.get('imageUrl'), + 'ollama': {'modelfile': modelfile.modelfile.get('content')}, + 'suggestion_prompts': modelfile.modelfile.get('suggestionPrompts'), + 'categories': modelfile.modelfile.get('categories'), + 'user': {**modelfile.modelfile.get('user', {}), 'community': True}, } ) - info = parse_ollama_modelfile(modelfile.modelfile.get("content")) + info = parse_ollama_modelfile(modelfile.modelfile.get('content')) # Insert the processed data into the 'model' table Model.create( - id=f"ollama-{modelfile.tag_name}", + id=f'ollama-{modelfile.tag_name}', user_id=modelfile.user_id, - base_model_id=info.get("base_model_id"), - name=modelfile.modelfile.get("title"), + base_model_id=info.get('base_model_id'), + name=modelfile.modelfile.get('title'), meta=meta, - params=json.dumps(info.get("params", {})), + params=json.dumps(info.get('params', {})), created_at=modelfile.timestamp, updated_at=modelfile.timestamp, ) @@ -86,7 +86,7 @@ def rollback(migrator: Migrator, database: pw.Database, *, fake=False): recreate_modelfile_table(migrator, database) move_data_back_to_modelfile(migrator, database) - migrator.remove_model("model") + migrator.remove_model('model') def recreate_modelfile_table(migrator: Migrator, database: pw.Database): @@ -102,8 +102,8 @@ def recreate_modelfile_table(migrator: Migrator, database: pw.Database): def move_data_back_to_modelfile(migrator: Migrator, database: pw.Database): - Model = migrator.orm["model"] - Modelfile = migrator.orm["modelfile"] + Model = migrator.orm['model'] + Modelfile = migrator.orm['modelfile'] models = Model.select() @@ -112,13 +112,13 @@ def move_data_back_to_modelfile(migrator: Migrator, database: pw.Database): meta = json.loads(model.meta) modelfile_data = { - "title": model.name, - "desc": meta.get("description"), - "imageUrl": meta.get("profile_image_url"), - "content": meta.get("ollama", {}).get("modelfile"), - "suggestionPrompts": meta.get("suggestion_prompts"), - "categories": meta.get("categories"), - "user": {k: v for k, v in meta.get("user", {}).items() if k != "community"}, + 'title': model.name, + 'desc': meta.get('description'), + 'imageUrl': meta.get('profile_image_url'), + 'content': meta.get('ollama', {}).get('modelfile'), + 'suggestionPrompts': meta.get('suggestion_prompts'), + 'categories': meta.get('categories'), + 'user': {k: v for k, v in meta.get('user', {}).items() if k != 'community'}, } # Insert the processed data back into the 'modelfile' table diff --git a/backend/open_webui/internal/migrations/011_add_user_settings.py b/backend/open_webui/internal/migrations/011_add_user_settings.py index c3b9ab6edc..73d27392f7 100644 --- a/backend/open_webui/internal/migrations/011_add_user_settings.py +++ b/backend/open_webui/internal/migrations/011_add_user_settings.py @@ -37,11 +37,11 @@ def migrate(migrator: Migrator, database: pw.Database, *, fake=False): """Write your migrations here.""" # Adding fields settings to the 'user' table - migrator.add_fields("user", settings=pw.TextField(null=True)) + migrator.add_fields('user', settings=pw.TextField(null=True)) def rollback(migrator: Migrator, database: pw.Database, *, fake=False): """Write your rollback migrations here.""" # Remove the settings field - migrator.remove_fields("user", "settings") + migrator.remove_fields('user', 'settings') diff --git a/backend/open_webui/internal/migrations/012_add_tools.py b/backend/open_webui/internal/migrations/012_add_tools.py index ac3cd8bfec..a488678c3c 100644 --- a/backend/open_webui/internal/migrations/012_add_tools.py +++ b/backend/open_webui/internal/migrations/012_add_tools.py @@ -51,10 +51,10 @@ def migrate(migrator: Migrator, database: pw.Database, *, fake=False): updated_at = pw.BigIntegerField(null=False) class Meta: - table_name = "tool" + table_name = 'tool' def rollback(migrator: Migrator, database: pw.Database, *, fake=False): """Write your rollback migrations here.""" - migrator.remove_model("tool") + migrator.remove_model('tool') diff --git a/backend/open_webui/internal/migrations/013_add_user_info.py b/backend/open_webui/internal/migrations/013_add_user_info.py index 6fafa951f0..db77cfff3a 100644 --- a/backend/open_webui/internal/migrations/013_add_user_info.py +++ b/backend/open_webui/internal/migrations/013_add_user_info.py @@ -37,11 +37,11 @@ def migrate(migrator: Migrator, database: pw.Database, *, fake=False): """Write your migrations here.""" # Adding fields info to the 'user' table - migrator.add_fields("user", info=pw.TextField(null=True)) + migrator.add_fields('user', info=pw.TextField(null=True)) def rollback(migrator: Migrator, database: pw.Database, *, fake=False): """Write your rollback migrations here.""" # Remove the settings field - migrator.remove_fields("user", "info") + migrator.remove_fields('user', 'info') diff --git a/backend/open_webui/internal/migrations/014_add_files.py b/backend/open_webui/internal/migrations/014_add_files.py index 655b00d238..9c01ac08c3 100644 --- a/backend/open_webui/internal/migrations/014_add_files.py +++ b/backend/open_webui/internal/migrations/014_add_files.py @@ -45,10 +45,10 @@ def migrate(migrator: Migrator, database: pw.Database, *, fake=False): created_at = pw.BigIntegerField(null=False) class Meta: - table_name = "file" + table_name = 'file' def rollback(migrator: Migrator, database: pw.Database, *, fake=False): """Write your rollback migrations here.""" - migrator.remove_model("file") + migrator.remove_model('file') diff --git a/backend/open_webui/internal/migrations/015_add_functions.py b/backend/open_webui/internal/migrations/015_add_functions.py index 84d2843839..488e546ab1 100644 --- a/backend/open_webui/internal/migrations/015_add_functions.py +++ b/backend/open_webui/internal/migrations/015_add_functions.py @@ -51,10 +51,10 @@ def migrate(migrator: Migrator, database: pw.Database, *, fake=False): updated_at = pw.BigIntegerField(null=False) class Meta: - table_name = "function" + table_name = 'function' def rollback(migrator: Migrator, database: pw.Database, *, fake=False): """Write your rollback migrations here.""" - migrator.remove_model("function") + migrator.remove_model('function') diff --git a/backend/open_webui/internal/migrations/016_add_valves_and_is_active.py b/backend/open_webui/internal/migrations/016_add_valves_and_is_active.py index fadf964e46..57a2dfbd5b 100644 --- a/backend/open_webui/internal/migrations/016_add_valves_and_is_active.py +++ b/backend/open_webui/internal/migrations/016_add_valves_and_is_active.py @@ -36,14 +36,14 @@ with suppress(ImportError): def migrate(migrator: Migrator, database: pw.Database, *, fake=False): """Write your migrations here.""" - migrator.add_fields("tool", valves=pw.TextField(null=True)) - migrator.add_fields("function", valves=pw.TextField(null=True)) - migrator.add_fields("function", is_active=pw.BooleanField(default=False)) + migrator.add_fields('tool', valves=pw.TextField(null=True)) + migrator.add_fields('function', valves=pw.TextField(null=True)) + migrator.add_fields('function', is_active=pw.BooleanField(default=False)) def rollback(migrator: Migrator, database: pw.Database, *, fake=False): """Write your rollback migrations here.""" - migrator.remove_fields("tool", "valves") - migrator.remove_fields("function", "valves") - migrator.remove_fields("function", "is_active") + migrator.remove_fields('tool', 'valves') + migrator.remove_fields('function', 'valves') + migrator.remove_fields('function', 'is_active') diff --git a/backend/open_webui/internal/migrations/017_add_user_oauth_sub.py b/backend/open_webui/internal/migrations/017_add_user_oauth_sub.py index 67a36b4889..f998c742d1 100644 --- a/backend/open_webui/internal/migrations/017_add_user_oauth_sub.py +++ b/backend/open_webui/internal/migrations/017_add_user_oauth_sub.py @@ -33,7 +33,7 @@ def migrate(migrator: Migrator, database: pw.Database, *, fake=False): """Write your migrations here.""" migrator.add_fields( - "user", + 'user', oauth_sub=pw.TextField(null=True, unique=True), ) @@ -41,4 +41,4 @@ def migrate(migrator: Migrator, database: pw.Database, *, fake=False): def rollback(migrator: Migrator, database: pw.Database, *, fake=False): """Write your rollback migrations here.""" - migrator.remove_fields("user", "oauth_sub") + migrator.remove_fields('user', 'oauth_sub') diff --git a/backend/open_webui/internal/migrations/018_add_function_is_global.py b/backend/open_webui/internal/migrations/018_add_function_is_global.py index 1e932ed710..7f7cd4f725 100644 --- a/backend/open_webui/internal/migrations/018_add_function_is_global.py +++ b/backend/open_webui/internal/migrations/018_add_function_is_global.py @@ -37,7 +37,7 @@ def migrate(migrator: Migrator, database: pw.Database, *, fake=False): """Write your migrations here.""" migrator.add_fields( - "function", + 'function', is_global=pw.BooleanField(default=False), ) @@ -45,4 +45,4 @@ def migrate(migrator: Migrator, database: pw.Database, *, fake=False): def rollback(migrator: Migrator, database: pw.Database, *, fake=False): """Write your rollback migrations here.""" - migrator.remove_fields("function", "is_global") + migrator.remove_fields('function', 'is_global') diff --git a/backend/open_webui/internal/wrappers.py b/backend/open_webui/internal/wrappers.py index 80b1aab8ff..3d54d02e3a 100644 --- a/backend/open_webui/internal/wrappers.py +++ b/backend/open_webui/internal/wrappers.py @@ -10,13 +10,13 @@ from playhouse.shortcuts import ReconnectMixin log = logging.getLogger(__name__) -db_state_default = {"closed": None, "conn": None, "ctx": None, "transactions": None} -db_state = ContextVar("db_state", default=db_state_default.copy()) +db_state_default = {'closed': None, 'conn': None, 'ctx': None, 'transactions': None} +db_state = ContextVar('db_state', default=db_state_default.copy()) class PeeweeConnectionState(object): def __init__(self, **kwargs): - super().__setattr__("_state", db_state) + super().__setattr__('_state', db_state) super().__init__(**kwargs) def __setattr__(self, name, value): @@ -30,10 +30,10 @@ class PeeweeConnectionState(object): class CustomReconnectMixin(ReconnectMixin): reconnect_errors = ( # psycopg2 - (OperationalError, "termin"), - (InterfaceError, "closed"), + (OperationalError, 'termin'), + (InterfaceError, 'closed'), # peewee - (PeeWeeInterfaceError, "closed"), + (PeeWeeInterfaceError, 'closed'), ) @@ -43,23 +43,21 @@ class ReconnectingPostgresqlDatabase(CustomReconnectMixin, PostgresqlDatabase): def register_connection(db_url): # Check if using SQLCipher protocol - if db_url.startswith("sqlite+sqlcipher://"): - database_password = os.environ.get("DATABASE_PASSWORD") - if not database_password or database_password.strip() == "": - raise ValueError( - "DATABASE_PASSWORD is required when using sqlite+sqlcipher:// URLs" - ) + if db_url.startswith('sqlite+sqlcipher://'): + database_password = os.environ.get('DATABASE_PASSWORD') + if not database_password or database_password.strip() == '': + raise ValueError('DATABASE_PASSWORD is required when using sqlite+sqlcipher:// URLs') from playhouse.sqlcipher_ext import SqlCipherDatabase # Parse the database path from SQLCipher URL # Convert sqlite+sqlcipher:///path/to/db.sqlite to /path/to/db.sqlite - db_path = db_url.replace("sqlite+sqlcipher://", "") + db_path = db_url.replace('sqlite+sqlcipher://', '') # Use Peewee's native SqlCipherDatabase with encryption db = SqlCipherDatabase(db_path, passphrase=database_password) db.autoconnect = True db.reuse_if_open = True - log.info("Connected to encrypted SQLite database using SQLCipher") + log.info('Connected to encrypted SQLite database using SQLCipher') else: # Standard database connection (existing logic) @@ -68,7 +66,7 @@ def register_connection(db_url): # Enable autoconnect for SQLite databases, managed by Peewee db.autoconnect = True db.reuse_if_open = True - log.info("Connected to PostgreSQL database") + log.info('Connected to PostgreSQL database') # Get the connection details connection = parse(db_url, unquote_user=True, unquote_password=True) @@ -80,7 +78,7 @@ def register_connection(db_url): # Enable autoconnect for SQLite databases, managed by Peewee db.autoconnect = True db.reuse_if_open = True - log.info("Connected to SQLite database") + log.info('Connected to SQLite database') else: - raise ValueError("Unsupported database connection") + raise ValueError('Unsupported database connection') return db diff --git a/backend/open_webui/main.py b/backend/open_webui/main.py index 93b3b5bece..7841c84447 100644 --- a/backend/open_webui/main.py +++ b/backend/open_webui/main.py @@ -565,7 +565,7 @@ from open_webui.utils.redis import get_sentinels_from_env from open_webui.constants import ERROR_MESSAGES if SAFE_MODE: - print("SAFE MODE ENABLED") + print('SAFE MODE ENABLED') Functions.deactivate_all_functions() logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL) @@ -578,16 +578,16 @@ class SPAStaticFiles(StaticFiles): return await super().get_response(path, scope) except (HTTPException, StarletteHTTPException) as ex: if ex.status_code == 404: - if path.endswith(".js"): + if path.endswith('.js'): # Return 404 for javascript files raise ex else: - return await super().get_response("index.html", scope) + return await super().get_response('index.html', scope) else: raise ex -if LOG_FORMAT != "json": +if LOG_FORMAT != 'json': print(rf""" ██████╗ ██████╗ ███████╗███╗ ██╗ ██╗ ██╗███████╗██████╗ ██╗ ██╗██╗ ██╔═══██╗██╔══██╗██╔════╝████╗ ██║ ██║ ██║██╔════╝██╔══██╗██║ ██║██║ @@ -598,7 +598,7 @@ if LOG_FORMAT != "json": v{VERSION} - building the best AI user interface. -{f"Commit: {WEBUI_BUILD_HASH}" if WEBUI_BUILD_HASH != "dev-build" else ""} +{f'Commit: {WEBUI_BUILD_HASH}' if WEBUI_BUILD_HASH != 'dev-build' else ''} https://github.com/open-webui/open-webui """) @@ -626,22 +626,18 @@ async def lifespan(app: FastAPI): # This should be blocking (sync) so functions are not deactivated on first /get_models calls # when the first user lands on the / route. - log.info("Installing external dependencies of functions and tools...") + log.info('Installing external dependencies of functions and tools...') install_tool_and_function_dependencies() app.state.redis = get_redis_connection( redis_url=REDIS_URL, - redis_sentinels=get_sentinels_from_env( - REDIS_SENTINEL_HOSTS, REDIS_SENTINEL_PORT - ), + redis_sentinels=get_sentinels_from_env(REDIS_SENTINEL_HOSTS, REDIS_SENTINEL_PORT), redis_cluster=REDIS_CLUSTER, async_mode=True, ) if app.state.redis is not None: - app.state.redis_task_command_listener = asyncio.create_task( - redis_task_command_listener(app) - ) + app.state.redis_task_command_listener = asyncio.create_task(redis_task_command_listener(app)) if THREAD_POOL_SIZE and THREAD_POOL_SIZE > 0: limiter = anyio.to_thread.current_default_thread_limiter() @@ -656,66 +652,64 @@ async def lifespan(app: FastAPI): Request( # Creating a mock request object to pass to get_all_models { - "type": "http", - "asgi.version": "3.0", - "asgi.spec_version": "2.0", - "method": "GET", - "path": "/internal", - "query_string": b"", - "headers": Headers({}).raw, - "client": ("127.0.0.1", 12345), - "server": ("127.0.0.1", 80), - "scheme": "http", - "app": app, + 'type': 'http', + 'asgi.version': '3.0', + 'asgi.spec_version': '2.0', + 'method': 'GET', + 'path': '/internal', + 'query_string': b'', + 'headers': Headers({}).raw, + 'client': ('127.0.0.1', 12345), + 'server': ('127.0.0.1', 80), + 'scheme': 'http', + 'app': app, } ), None, ) except Exception as e: - log.warning(f"Failed to pre-fetch models at startup: {e}") + log.warning(f'Failed to pre-fetch models at startup: {e}') # Pre-fetch tool server specs so the first request doesn't pay the latency cost if len(app.state.config.TOOL_SERVER_CONNECTIONS) > 0: - log.info("Initializing tool servers...") + log.info('Initializing tool servers...') try: mock_request = Request( { - "type": "http", - "asgi.version": "3.0", - "asgi.spec_version": "2.0", - "method": "GET", - "path": "/internal", - "query_string": b"", - "headers": Headers({}).raw, - "client": ("127.0.0.1", 12345), - "server": ("127.0.0.1", 80), - "scheme": "http", - "app": app, + 'type': 'http', + 'asgi.version': '3.0', + 'asgi.spec_version': '2.0', + 'method': 'GET', + 'path': '/internal', + 'query_string': b'', + 'headers': Headers({}).raw, + 'client': ('127.0.0.1', 12345), + 'server': ('127.0.0.1', 80), + 'scheme': 'http', + 'app': app, } ) await set_tool_servers(mock_request) - log.info(f"Initialized {len(app.state.TOOL_SERVERS)} tool server(s)") + log.info(f'Initialized {len(app.state.TOOL_SERVERS)} tool server(s)') await set_terminal_servers(mock_request) - log.info( - f"Initialized {len(app.state.TERMINAL_SERVERS)} terminal server(s)" - ) + log.info(f'Initialized {len(app.state.TERMINAL_SERVERS)} terminal server(s)') except Exception as e: - log.warning(f"Failed to initialize tool/terminal servers at startup: {e}") + log.warning(f'Failed to initialize tool/terminal servers at startup: {e}') # Mark application as ready to accept traffic from a startup perspective. app.state.startup_complete = True yield - if hasattr(app.state, "redis_task_command_listener"): + if hasattr(app.state, 'redis_task_command_listener'): app.state.redis_task_command_listener.cancel() app = FastAPI( - title="Open WebUI", - docs_url="/docs" if ENV == "dev" else None, - openapi_url="/openapi.json" if ENV == "dev" else None, + title='Open WebUI', + docs_url='/docs' if ENV == 'dev' else None, + openapi_url='/openapi.json' if ENV == 'dev' else None, redoc_url=None, lifespan=lifespan, ) @@ -837,9 +831,7 @@ app.state.config.ENABLE_SIGNUP = ENABLE_SIGNUP app.state.config.ENABLE_LOGIN_FORM = ENABLE_LOGIN_FORM app.state.config.ENABLE_API_KEYS = ENABLE_API_KEYS -app.state.config.ENABLE_API_KEYS_ENDPOINT_RESTRICTIONS = ( - ENABLE_API_KEYS_ENDPOINT_RESTRICTIONS -) +app.state.config.ENABLE_API_KEYS_ENDPOINT_RESTRICTIONS = ENABLE_API_KEYS_ENDPOINT_RESTRICTIONS app.state.config.API_KEYS_ALLOWED_ENDPOINTS = API_KEYS_ALLOWED_ENDPOINTS app.state.config.JWT_EXPIRES_IN = JWT_EXPIRES_IN @@ -885,15 +877,15 @@ app.state.config.EVALUATION_ARENA_MODELS = EVALUATION_ARENA_MODELS from open_webui.utils.access_control import migrate_access_control connections = app.state.config.TOOL_SERVER_CONNECTIONS -if any("access_control" in c.get("config", {}) for c in connections): +if any('access_control' in c.get('config', {}) for c in connections): for connection in connections: - migrate_access_control(connection.get("config", {})) + migrate_access_control(connection.get('config', {})) app.state.config.TOOL_SERVER_CONNECTIONS = connections arena_models = app.state.config.EVALUATION_ARENA_MODELS -if any("access_control" in m.get("meta", {}) for m in arena_models): +if any('access_control' in m.get('meta', {}) for m in arena_models): for model in arena_models: - migrate_access_control(model.get("meta", {})) + migrate_access_control(model.get('meta', {})) app.state.config.EVALUATION_ARENA_MODELS = arena_models app.state.config.OAUTH_USERNAME_CLAIM = OAUTH_USERNAME_CLAIM @@ -962,9 +954,7 @@ app.state.config.FILE_IMAGE_COMPRESSION_HEIGHT = FILE_IMAGE_COMPRESSION_HEIGHT app.state.config.RAG_FULL_CONTEXT = RAG_FULL_CONTEXT app.state.config.BYPASS_EMBEDDING_AND_RETRIEVAL = BYPASS_EMBEDDING_AND_RETRIEVAL app.state.config.ENABLE_RAG_HYBRID_SEARCH = ENABLE_RAG_HYBRID_SEARCH -app.state.config.ENABLE_RAG_HYBRID_SEARCH_ENRICHED_TEXTS = ( - ENABLE_RAG_HYBRID_SEARCH_ENRICHED_TEXTS -) +app.state.config.ENABLE_RAG_HYBRID_SEARCH_ENRICHED_TEXTS = ENABLE_RAG_HYBRID_SEARCH_ENRICHED_TEXTS app.state.config.ENABLE_WEB_LOADER_SSL_VERIFICATION = ENABLE_WEB_LOADER_SSL_VERIFICATION app.state.config.CONTENT_EXTRACTION_ENGINE = CONTENT_EXTRACTION_ENGINE @@ -975,9 +965,7 @@ app.state.config.DATALAB_MARKER_SKIP_CACHE = DATALAB_MARKER_SKIP_CACHE app.state.config.DATALAB_MARKER_FORCE_OCR = DATALAB_MARKER_FORCE_OCR app.state.config.DATALAB_MARKER_PAGINATE = DATALAB_MARKER_PAGINATE app.state.config.DATALAB_MARKER_STRIP_EXISTING_OCR = DATALAB_MARKER_STRIP_EXISTING_OCR -app.state.config.DATALAB_MARKER_DISABLE_IMAGE_EXTRACTION = ( - DATALAB_MARKER_DISABLE_IMAGE_EXTRACTION -) +app.state.config.DATALAB_MARKER_DISABLE_IMAGE_EXTRACTION = DATALAB_MARKER_DISABLE_IMAGE_EXTRACTION app.state.config.DATALAB_MARKER_FORMAT_LINES = DATALAB_MARKER_FORMAT_LINES app.state.config.DATALAB_MARKER_USE_LLM = DATALAB_MARKER_USE_LLM app.state.config.DATALAB_MARKER_OUTPUT_FORMAT = DATALAB_MARKER_OUTPUT_FORMAT @@ -999,9 +987,7 @@ app.state.config.MINERU_API_TIMEOUT = MINERU_API_TIMEOUT app.state.config.MINERU_PARAMS = MINERU_PARAMS app.state.config.TEXT_SPLITTER = RAG_TEXT_SPLITTER -app.state.config.ENABLE_MARKDOWN_HEADER_TEXT_SPLITTER = ( - ENABLE_MARKDOWN_HEADER_TEXT_SPLITTER -) +app.state.config.ENABLE_MARKDOWN_HEADER_TEXT_SPLITTER = ENABLE_MARKDOWN_HEADER_TEXT_SPLITTER app.state.config.TIKTOKEN_ENCODING_NAME = TIKTOKEN_ENCODING_NAME @@ -1053,9 +1039,7 @@ app.state.config.WEB_LOADER_CONCURRENT_REQUESTS = WEB_LOADER_CONCURRENT_REQUESTS app.state.config.WEB_LOADER_TIMEOUT = WEB_LOADER_TIMEOUT app.state.config.WEB_SEARCH_TRUST_ENV = WEB_SEARCH_TRUST_ENV -app.state.config.BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL = ( - BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL -) +app.state.config.BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL = BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL app.state.config.BYPASS_WEB_SEARCH_WEB_LOADER = BYPASS_WEB_SEARCH_WEB_LOADER app.state.config.ENABLE_GOOGLE_DRIVE_INTEGRATION = ENABLE_GOOGLE_DRIVE_INTEGRATION @@ -1120,13 +1104,8 @@ app.state.YOUTUBE_LOADER_TRANSLATION = None try: - app.state.ef = get_ef( - app.state.config.RAG_EMBEDDING_ENGINE, app.state.config.RAG_EMBEDDING_MODEL - ) - if ( - app.state.config.ENABLE_RAG_HYBRID_SEARCH - and not app.state.config.BYPASS_EMBEDDING_AND_RETRIEVAL - ): + app.state.ef = get_ef(app.state.config.RAG_EMBEDDING_ENGINE, app.state.config.RAG_EMBEDDING_MODEL) + if app.state.config.ENABLE_RAG_HYBRID_SEARCH and not app.state.config.BYPASS_EMBEDDING_AND_RETRIEVAL: app.state.rf = get_rf( app.state.config.RAG_RERANKING_ENGINE, app.state.config.RAG_RERANKING_MODEL, @@ -1137,7 +1116,7 @@ try: else: app.state.rf = None except Exception as e: - log.error(f"Error updating models: {e}") + log.error(f'Error updating models: {e}') pass @@ -1147,26 +1126,26 @@ app.state.EMBEDDING_FUNCTION = get_embedding_function( embedding_function=app.state.ef, url=( app.state.config.RAG_OPENAI_API_BASE_URL - if app.state.config.RAG_EMBEDDING_ENGINE == "openai" + if app.state.config.RAG_EMBEDDING_ENGINE == 'openai' else ( app.state.config.RAG_OLLAMA_BASE_URL - if app.state.config.RAG_EMBEDDING_ENGINE == "ollama" + if app.state.config.RAG_EMBEDDING_ENGINE == 'ollama' else app.state.config.RAG_AZURE_OPENAI_BASE_URL ) ), key=( app.state.config.RAG_OPENAI_API_KEY - if app.state.config.RAG_EMBEDDING_ENGINE == "openai" + if app.state.config.RAG_EMBEDDING_ENGINE == 'openai' else ( app.state.config.RAG_OLLAMA_API_KEY - if app.state.config.RAG_EMBEDDING_ENGINE == "ollama" + if app.state.config.RAG_EMBEDDING_ENGINE == 'ollama' else app.state.config.RAG_AZURE_OPENAI_API_KEY ) ), embedding_batch_size=app.state.config.RAG_EMBEDDING_BATCH_SIZE, azure_api_version=( app.state.config.RAG_AZURE_OPENAI_API_VERSION - if app.state.config.RAG_EMBEDDING_ENGINE == "azure_openai" + if app.state.config.RAG_EMBEDDING_ENGINE == 'azure_openai' else None ), enable_async=app.state.config.ENABLE_ASYNC_EMBEDDING, @@ -1190,9 +1169,7 @@ app.state.config.CODE_EXECUTION_ENGINE = CODE_EXECUTION_ENGINE app.state.config.CODE_EXECUTION_JUPYTER_URL = CODE_EXECUTION_JUPYTER_URL app.state.config.CODE_EXECUTION_JUPYTER_AUTH = CODE_EXECUTION_JUPYTER_AUTH app.state.config.CODE_EXECUTION_JUPYTER_AUTH_TOKEN = CODE_EXECUTION_JUPYTER_AUTH_TOKEN -app.state.config.CODE_EXECUTION_JUPYTER_AUTH_PASSWORD = ( - CODE_EXECUTION_JUPYTER_AUTH_PASSWORD -) +app.state.config.CODE_EXECUTION_JUPYTER_AUTH_PASSWORD = CODE_EXECUTION_JUPYTER_AUTH_PASSWORD app.state.config.CODE_EXECUTION_JUPYTER_TIMEOUT = CODE_EXECUTION_JUPYTER_TIMEOUT app.state.config.ENABLE_CODE_INTERPRETER = ENABLE_CODE_INTERPRETER @@ -1201,12 +1178,8 @@ app.state.config.CODE_INTERPRETER_PROMPT_TEMPLATE = CODE_INTERPRETER_PROMPT_TEMP app.state.config.CODE_INTERPRETER_JUPYTER_URL = CODE_INTERPRETER_JUPYTER_URL app.state.config.CODE_INTERPRETER_JUPYTER_AUTH = CODE_INTERPRETER_JUPYTER_AUTH -app.state.config.CODE_INTERPRETER_JUPYTER_AUTH_TOKEN = ( - CODE_INTERPRETER_JUPYTER_AUTH_TOKEN -) -app.state.config.CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD = ( - CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD -) +app.state.config.CODE_INTERPRETER_JUPYTER_AUTH_TOKEN = CODE_INTERPRETER_JUPYTER_AUTH_TOKEN +app.state.config.CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD = CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD app.state.config.CODE_INTERPRETER_JUPYTER_TIMEOUT = CODE_INTERPRETER_JUPYTER_TIMEOUT ######################################## @@ -1282,9 +1255,7 @@ app.state.config.AUDIO_STT_AZURE_MAX_SPEAKERS = AUDIO_STT_AZURE_MAX_SPEAKERS app.state.config.AUDIO_STT_MISTRAL_API_KEY = AUDIO_STT_MISTRAL_API_KEY app.state.config.AUDIO_STT_MISTRAL_API_BASE_URL = AUDIO_STT_MISTRAL_API_BASE_URL -app.state.config.AUDIO_STT_MISTRAL_USE_CHAT_COMPLETIONS = ( - AUDIO_STT_MISTRAL_USE_CHAT_COMPLETIONS -) +app.state.config.AUDIO_STT_MISTRAL_USE_CHAT_COMPLETIONS = AUDIO_STT_MISTRAL_USE_CHAT_COMPLETIONS app.state.config.TTS_ENGINE = AUDIO_TTS_ENGINE @@ -1330,23 +1301,13 @@ app.state.config.ENABLE_FOLLOW_UP_GENERATION = ENABLE_FOLLOW_UP_GENERATION app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE = TITLE_GENERATION_PROMPT_TEMPLATE app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE = TAGS_GENERATION_PROMPT_TEMPLATE -app.state.config.IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE = ( - IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE -) -app.state.config.FOLLOW_UP_GENERATION_PROMPT_TEMPLATE = ( - FOLLOW_UP_GENERATION_PROMPT_TEMPLATE -) +app.state.config.IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE = IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE +app.state.config.FOLLOW_UP_GENERATION_PROMPT_TEMPLATE = FOLLOW_UP_GENERATION_PROMPT_TEMPLATE -app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = ( - TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE -) +app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE = QUERY_GENERATION_PROMPT_TEMPLATE -app.state.config.AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE = ( - AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE -) -app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH = ( - AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH -) +app.state.config.AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE = AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE +app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH = AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH app.state.config.VOICE_MODE_PROMPT_TEMPLATE = VOICE_MODE_PROMPT_TEMPLATE @@ -1366,36 +1327,36 @@ if ENABLE_COMPRESSION_MIDDLEWARE: class RedirectMiddleware(BaseHTTPMiddleware): async def dispatch(self, request: Request, call_next): # Check if the request is a GET request - if request.method == "GET": + if request.method == 'GET': path = request.url.path query_params = dict(parse_qs(urlparse(str(request.url)).query)) redirect_params = {} # Check for the specific watch path and the presence of 'v' parameter - if path.endswith("/watch") and "v" in query_params: + if path.endswith('/watch') and 'v' in query_params: # Extract the first 'v' parameter - youtube_video_id = query_params["v"][0] - redirect_params["youtube"] = youtube_video_id + youtube_video_id = query_params['v'][0] + redirect_params['youtube'] = youtube_video_id - if "shared" in query_params and len(query_params["shared"]) > 0: + if 'shared' in query_params and len(query_params['shared']) > 0: # PWA share_target support - text = query_params["shared"][0] + text = query_params['shared'][0] if text: - urls = re.match(r"https://\S+", text) + urls = re.match(r'https://\S+', text) if urls: from open_webui.retrieval.loaders.youtube import _parse_video_id if youtube_video_id := _parse_video_id(urls[0]): - redirect_params["youtube"] = youtube_video_id + redirect_params['youtube'] = youtube_video_id else: - redirect_params["load-url"] = urls[0] + redirect_params['load-url'] = urls[0] else: - redirect_params["q"] = text + redirect_params['q'] = text if redirect_params: - redirect_url = f"/?{urlencode(redirect_params)}" + redirect_url = f'/?{urlencode(redirect_params)}' return RedirectResponse(url=redirect_url) # Proceed with the normal flow of other requests @@ -1412,25 +1373,23 @@ class APIKeyRestrictionMiddleware: self.app = app async def __call__(self, scope, receive, send): - if scope["type"] == "http": + if scope['type'] == 'http': request = Request(scope) - auth_header = request.headers.get("Authorization") + auth_header = request.headers.get('Authorization') token = None if auth_header: - parts = auth_header.split(" ", 1) + parts = auth_header.split(' ', 1) if len(parts) == 2: token = parts[1] # Only apply restrictions if an sk- API key is used - if token and token.startswith("sk-"): + if token and token.startswith('sk-'): # Check if restrictions are enabled if app.state.config.ENABLE_API_KEYS_ENDPOINT_RESTRICTIONS: allowed_paths = [ path.strip() - for path in str( - app.state.config.API_KEYS_ALLOWED_ENDPOINTS - ).split(",") + for path in str(app.state.config.API_KEYS_ALLOWED_ENDPOINTS).split(',') if path.strip() ] @@ -1438,17 +1397,13 @@ class APIKeyRestrictionMiddleware: # Match exact path or prefix path is_allowed = any( - request_path == allowed - or request_path.startswith(allowed + "/") - for allowed in allowed_paths + request_path == allowed or request_path.startswith(allowed + '/') for allowed in allowed_paths ) if not is_allowed: await JSONResponse( status_code=status.HTTP_403_FORBIDDEN, - content={ - "detail": "API key not allowed to access this endpoint." - }, + content={'detail': 'API key not allowed to access this endpoint.'}, )(scope, receive, send) return @@ -1458,7 +1413,7 @@ class APIKeyRestrictionMiddleware: app.add_middleware(APIKeyRestrictionMiddleware) -@app.middleware("http") +@app.middleware('http') async def commit_session_after_request(request: Request, call_next): response = await call_next(request) # log.debug("Commit session after request") @@ -1472,51 +1427,44 @@ async def commit_session_after_request(request: Request, call_next): return response -@app.middleware("http") +@app.middleware('http') async def check_url(request: Request, call_next): start_time = int(time.time()) - request.state.token = get_http_authorization_cred( - request.headers.get("Authorization") - ) + request.state.token = get_http_authorization_cred(request.headers.get('Authorization')) # Fallback to cookie token for browser sessions - if request.state.token is None and request.cookies.get("token"): + if request.state.token is None and request.cookies.get('token'): from fastapi.security import HTTPAuthorizationCredentials - request.state.token = HTTPAuthorizationCredentials( - scheme="Bearer", credentials=request.cookies.get("token") - ) + request.state.token = HTTPAuthorizationCredentials(scheme='Bearer', credentials=request.cookies.get('token')) # Fallback to x-api-key header for Anthropic Messages API routes - if request.state.token is None and request.headers.get("x-api-key"): + if request.state.token is None and request.headers.get('x-api-key'): request_path = request.url.path - if request_path in ("/api/message", "/api/v1/messages"): + if request_path in ('/api/message', '/api/v1/messages'): from fastapi.security import HTTPAuthorizationCredentials request.state.token = HTTPAuthorizationCredentials( - scheme="Bearer", credentials=request.headers.get("x-api-key") + scheme='Bearer', credentials=request.headers.get('x-api-key') ) request.state.enable_api_keys = app.state.config.ENABLE_API_KEYS response = await call_next(request) process_time = int(time.time()) - start_time - response.headers["X-Process-Time"] = str(process_time) + response.headers['X-Process-Time'] = str(process_time) return response -@app.middleware("http") +@app.middleware('http') async def inspect_websocket(request: Request, call_next): - if ( - "/ws/socket.io" in request.url.path - and request.query_params.get("transport") == "websocket" - ): - upgrade = (request.headers.get("Upgrade") or "").lower() - connection = (request.headers.get("Connection") or "").lower().split(",") + if '/ws/socket.io' in request.url.path and request.query_params.get('transport') == 'websocket': + upgrade = (request.headers.get('Upgrade') or '').lower() + connection = (request.headers.get('Connection') or '').lower().split(',') # Check that there's the correct headers for an upgrade, else reject the connection # This is to work around this upstream issue: https://github.com/miguelgrinberg/python-engineio/issues/367 - if upgrade != "websocket" or "upgrade" not in connection: + if upgrade != 'websocket' or 'upgrade' not in connection: return JSONResponse( status_code=status.HTTP_400_BAD_REQUEST, - content={"detail": "Invalid WebSocket upgrade request"}, + content={'detail': 'Invalid WebSocket upgrade request'}, ) return await call_next(request) @@ -1525,64 +1473,62 @@ app.add_middleware( CORSMiddleware, allow_origins=CORS_ALLOW_ORIGIN, allow_credentials=True, - allow_methods=["*"], - allow_headers=["*"], + allow_methods=['*'], + allow_headers=['*'], ) -app.mount("/ws", socket_app) +app.mount('/ws', socket_app) -app.include_router(ollama.router, prefix="/ollama", tags=["ollama"]) -app.include_router(openai.router, prefix="/openai", tags=["openai"]) +app.include_router(ollama.router, prefix='/ollama', tags=['ollama']) +app.include_router(openai.router, prefix='/openai', tags=['openai']) -app.include_router(pipelines.router, prefix="/api/v1/pipelines", tags=["pipelines"]) -app.include_router(tasks.router, prefix="/api/v1/tasks", tags=["tasks"]) -app.include_router(images.router, prefix="/api/v1/images", tags=["images"]) +app.include_router(pipelines.router, prefix='/api/v1/pipelines', tags=['pipelines']) +app.include_router(tasks.router, prefix='/api/v1/tasks', tags=['tasks']) +app.include_router(images.router, prefix='/api/v1/images', tags=['images']) -app.include_router(audio.router, prefix="/api/v1/audio", tags=["audio"]) -app.include_router(retrieval.router, prefix="/api/v1/retrieval", tags=["retrieval"]) +app.include_router(audio.router, prefix='/api/v1/audio', tags=['audio']) +app.include_router(retrieval.router, prefix='/api/v1/retrieval', tags=['retrieval']) -app.include_router(configs.router, prefix="/api/v1/configs", tags=["configs"]) +app.include_router(configs.router, prefix='/api/v1/configs', tags=['configs']) -app.include_router(auths.router, prefix="/api/v1/auths", tags=["auths"]) -app.include_router(users.router, prefix="/api/v1/users", tags=["users"]) +app.include_router(auths.router, prefix='/api/v1/auths', tags=['auths']) +app.include_router(users.router, prefix='/api/v1/users', tags=['users']) -app.include_router(channels.router, prefix="/api/v1/channels", tags=["channels"]) -app.include_router(chats.router, prefix="/api/v1/chats", tags=["chats"]) -app.include_router(notes.router, prefix="/api/v1/notes", tags=["notes"]) +app.include_router(channels.router, prefix='/api/v1/channels', tags=['channels']) +app.include_router(chats.router, prefix='/api/v1/chats', tags=['chats']) +app.include_router(notes.router, prefix='/api/v1/notes', tags=['notes']) -app.include_router(models.router, prefix="/api/v1/models", tags=["models"]) -app.include_router(knowledge.router, prefix="/api/v1/knowledge", tags=["knowledge"]) -app.include_router(prompts.router, prefix="/api/v1/prompts", tags=["prompts"]) -app.include_router(tools.router, prefix="/api/v1/tools", tags=["tools"]) -app.include_router(skills.router, prefix="/api/v1/skills", tags=["skills"]) +app.include_router(models.router, prefix='/api/v1/models', tags=['models']) +app.include_router(knowledge.router, prefix='/api/v1/knowledge', tags=['knowledge']) +app.include_router(prompts.router, prefix='/api/v1/prompts', tags=['prompts']) +app.include_router(tools.router, prefix='/api/v1/tools', tags=['tools']) +app.include_router(skills.router, prefix='/api/v1/skills', tags=['skills']) -app.include_router(memories.router, prefix="/api/v1/memories", tags=["memories"]) -app.include_router(folders.router, prefix="/api/v1/folders", tags=["folders"]) -app.include_router(groups.router, prefix="/api/v1/groups", tags=["groups"]) -app.include_router(files.router, prefix="/api/v1/files", tags=["files"]) -app.include_router(functions.router, prefix="/api/v1/functions", tags=["functions"]) -app.include_router( - evaluations.router, prefix="/api/v1/evaluations", tags=["evaluations"] -) +app.include_router(memories.router, prefix='/api/v1/memories', tags=['memories']) +app.include_router(folders.router, prefix='/api/v1/folders', tags=['folders']) +app.include_router(groups.router, prefix='/api/v1/groups', tags=['groups']) +app.include_router(files.router, prefix='/api/v1/files', tags=['files']) +app.include_router(functions.router, prefix='/api/v1/functions', tags=['functions']) +app.include_router(evaluations.router, prefix='/api/v1/evaluations', tags=['evaluations']) if ENABLE_ADMIN_ANALYTICS: - app.include_router(analytics.router, prefix="/api/v1/analytics", tags=["analytics"]) -app.include_router(utils.router, prefix="/api/v1/utils", tags=["utils"]) -app.include_router(terminals.router, prefix="/api/v1/terminals", tags=["terminals"]) + app.include_router(analytics.router, prefix='/api/v1/analytics', tags=['analytics']) +app.include_router(utils.router, prefix='/api/v1/utils', tags=['utils']) +app.include_router(terminals.router, prefix='/api/v1/terminals', tags=['terminals']) # SCIM 2.0 API for identity management if ENABLE_SCIM: - app.include_router(scim.router, prefix="/api/v1/scim/v2", tags=["scim"]) + app.include_router(scim.router, prefix='/api/v1/scim/v2', tags=['scim']) try: audit_level = AuditLevel(AUDIT_LOG_LEVEL) except ValueError as e: - logger.error(f"Invalid audit level: {AUDIT_LOG_LEVEL}. Error: {e}") + logger.error(f'Invalid audit level: {AUDIT_LOG_LEVEL}. Error: {e}') audit_level = AuditLevel.NONE if audit_level != AuditLevel.NONE: @@ -1600,35 +1546,30 @@ if audit_level != AuditLevel.NONE: ################################## -@app.get("/api/models") -@app.get("/api/v1/models") # Experimental: Compatibility with OpenAI API -async def get_models( - request: Request, refresh: bool = False, user=Depends(get_verified_user) -): +@app.get('/api/models') +@app.get('/api/v1/models') # Experimental: Compatibility with OpenAI API +async def get_models(request: Request, refresh: bool = False, user=Depends(get_verified_user)): all_models = await get_all_models(request, refresh=refresh, user=user) models = [] for model in all_models: # Filter out filter pipelines - if "pipeline" in model and model["pipeline"].get("type", None) == "filter": + if 'pipeline' in model and model['pipeline'].get('type', None) == 'filter': continue # Remove profile image URL to reduce payload size - if model.get("info", {}).get("meta", {}).get("profile_image_url"): - model["info"]["meta"].pop("profile_image_url", None) + if model.get('info', {}).get('meta', {}).get('profile_image_url'): + model['info']['meta'].pop('profile_image_url', None) try: - model_tags = [ - tag.get("name") - for tag in model.get("info", {}).get("meta", {}).get("tags", []) - ] - tags = [tag.get("name") for tag in model.get("tags", [])] + model_tags = [tag.get('name') for tag in model.get('info', {}).get('meta', {}).get('tags', [])] + tags = [tag.get('name') for tag in model.get('tags', [])] tags = list(set(model_tags + tags)) - model["tags"] = [{"name": tag} for tag in tags] + model['tags'] = [{'name': tag} for tag in tags] except Exception as e: - log.debug(f"Error processing model tags: {e}") - model["tags"] = [] + log.debug(f'Error processing model tags: {e}') + model['tags'] = [] pass models.append(model) @@ -1639,23 +1580,23 @@ async def get_models( # Sort models by order list priority, with fallback for those not in the list models.sort( key=lambda model: ( - model_order_dict.get(model.get("id", ""), float("inf")), - (model.get("name", "") or ""), + model_order_dict.get(model.get('id', ''), float('inf')), + (model.get('name', '') or ''), ) ) models = get_filtered_models(models, user) log.debug( - f"/api/models returned filtered models accessible to the user: {json.dumps([model.get('id') for model in models])}" + f'/api/models returned filtered models accessible to the user: {json.dumps([model.get("id") for model in models])}' ) - return {"data": models} + return {'data': models} -@app.get("/api/models/base") +@app.get('/api/models/base') async def get_base_models(request: Request, user=Depends(get_admin_user)): models = await get_all_base_models(request, user=user) - return {"data": models} + return {'data': models} ################################## @@ -1663,11 +1604,9 @@ async def get_base_models(request: Request, user=Depends(get_admin_user)): ################################## -@app.post("/api/embeddings") -@app.post("/api/v1/embeddings") # Experimental: Compatibility with OpenAI API -async def embeddings( - request: Request, form_data: dict, user=Depends(get_verified_user) -): +@app.post('/api/embeddings') +@app.post('/api/v1/embeddings') # Experimental: Compatibility with OpenAI API +async def embeddings(request: Request, form_data: dict, user=Depends(get_verified_user)): """ OpenAI-compatible embeddings endpoint. @@ -1690,8 +1629,8 @@ async def embeddings( return await generate_embeddings(request, form_data, user) -@app.post("/api/chat/completions") -@app.post("/api/v1/chat/completions") # Experimental: Compatibility with OpenAI API +@app.post('/api/chat/completions') +@app.post('/api/v1/chat/completions') # Experimental: Compatibility with OpenAI API async def chat_completion( request: Request, form_data: dict, @@ -1700,24 +1639,22 @@ async def chat_completion( if not request.app.state.MODELS: await get_all_models(request, user=user) - model_id = form_data.get("model", None) - model_item = form_data.pop("model_item", {}) - tasks = form_data.pop("background_tasks", None) + model_id = form_data.get('model', None) + model_item = form_data.pop('model_item', {}) + tasks = form_data.pop('background_tasks', None) metadata = {} try: model_info = None - if not model_item.get("direct", False): + if not model_item.get('direct', False): if model_id not in request.app.state.MODELS: - raise Exception("Model not found") + raise Exception('Model not found') model = request.app.state.MODELS[model_id] model_info = Models.get_model_by_id(model_id) # Check if user has access to the model - if not BYPASS_MODEL_ACCESS_CONTROL and ( - user.role != "admin" or not BYPASS_ADMIN_ACCESS_CONTROL - ): + if not BYPASS_MODEL_ACCESS_CONTROL and (user.role != 'admin' or not BYPASS_ADMIN_ACCESS_CONTROL): try: check_model_access(user, model) except Exception as e: @@ -1729,16 +1666,10 @@ async def chat_completion( request.state.model = model # Model params: global defaults as base, per-model overrides win - default_model_params = ( - getattr(request.app.state.config, "DEFAULT_MODEL_PARAMS", None) or {} - ) + default_model_params = getattr(request.app.state.config, 'DEFAULT_MODEL_PARAMS', None) or {} model_info_params = { **default_model_params, - **( - model_info.params.model_dump() - if model_info and model_info.params - else {} - ), + **(model_info.params.model_dump() if model_info and model_info.params else {}), } # Check base model existence for custom models @@ -1746,81 +1677,68 @@ async def chat_completion( base_model_id = model_info.base_model_id if base_model_id not in request.app.state.MODELS: if ENABLE_CUSTOM_MODEL_FALLBACK: - default_models = ( - request.app.state.config.DEFAULT_MODELS or "" - ).split(",") + default_models = (request.app.state.config.DEFAULT_MODELS or '').split(',') - fallback_model_id = ( - default_models[0].strip() if default_models[0] else None - ) + fallback_model_id = default_models[0].strip() if default_models[0] else None - if ( - fallback_model_id - and fallback_model_id in request.app.state.MODELS - ): + if fallback_model_id and fallback_model_id in request.app.state.MODELS: # Update model and form_data so routing uses the fallback model's type model = request.app.state.MODELS[fallback_model_id] - form_data["model"] = fallback_model_id + form_data['model'] = fallback_model_id else: - raise Exception("Model not found") + raise Exception('Model not found') else: - raise Exception("Model not found") + raise Exception('Model not found') # Chat Params - stream_delta_chunk_size = form_data.get("params", {}).get( - "stream_delta_chunk_size" - ) - reasoning_tags = form_data.get("params", {}).get("reasoning_tags") + stream_delta_chunk_size = form_data.get('params', {}).get('stream_delta_chunk_size') + reasoning_tags = form_data.get('params', {}).get('reasoning_tags') # Model Params - if model_info_params.get("stream_response") is not None: - form_data["stream"] = model_info_params.get("stream_response") + if model_info_params.get('stream_response') is not None: + form_data['stream'] = model_info_params.get('stream_response') - if model_info_params.get("stream_delta_chunk_size"): - stream_delta_chunk_size = model_info_params.get("stream_delta_chunk_size") + if model_info_params.get('stream_delta_chunk_size'): + stream_delta_chunk_size = model_info_params.get('stream_delta_chunk_size') - if model_info_params.get("reasoning_tags") is not None: - reasoning_tags = model_info_params.get("reasoning_tags") + if model_info_params.get('reasoning_tags') is not None: + reasoning_tags = model_info_params.get('reasoning_tags') metadata = { - "user_id": user.id, - "chat_id": form_data.pop("chat_id", None), - "message_id": form_data.pop("id", None), - "parent_message": form_data.pop("parent_message", None), - "parent_message_id": form_data.pop("parent_id", None), - "session_id": form_data.pop("session_id", None), - "filter_ids": form_data.pop("filter_ids", []), - "tool_ids": form_data.get("tool_ids", None), - "tool_servers": form_data.pop("tool_servers", None), - "files": form_data.get("files", None), - "features": form_data.get("features", {}), - "variables": form_data.get("variables", {}), - "model": model, - "direct": model_item.get("direct", False), - "params": { - "stream_delta_chunk_size": stream_delta_chunk_size, - "reasoning_tags": reasoning_tags, - "function_calling": ( - "native" + 'user_id': user.id, + 'chat_id': form_data.pop('chat_id', None), + 'message_id': form_data.pop('id', None), + 'parent_message': form_data.pop('parent_message', None), + 'parent_message_id': form_data.pop('parent_id', None), + 'session_id': form_data.pop('session_id', None), + 'filter_ids': form_data.pop('filter_ids', []), + 'tool_ids': form_data.get('tool_ids', None), + 'tool_servers': form_data.pop('tool_servers', None), + 'files': form_data.get('files', None), + 'features': form_data.get('features', {}), + 'variables': form_data.get('variables', {}), + 'model': model, + 'direct': model_item.get('direct', False), + 'params': { + 'stream_delta_chunk_size': stream_delta_chunk_size, + 'reasoning_tags': reasoning_tags, + 'function_calling': ( + 'native' if ( - form_data.get("params", {}).get("function_calling") == "native" - or model_info_params.get("function_calling") == "native" + form_data.get('params', {}).get('function_calling') == 'native' + or model_info_params.get('function_calling') == 'native' ) - else "default" + else 'default' ), }, } - if metadata.get("chat_id") and user: - if not metadata["chat_id"].startswith( - "local:" - ): # temporary chats are not stored - + if metadata.get('chat_id') and user: + if not metadata['chat_id'].startswith('local:'): # temporary chats are not stored # Verify chat ownership — lightweight EXISTS check avoids # deserializing the full chat JSON blob just to confirm the row exists if ( - not Chats.is_chat_owner(metadata["chat_id"], user.id) - and user.role != "admin" + not Chats.is_chat_owner(metadata['chat_id'], user.id) and user.role != 'admin' ): # admins can access any chat raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, @@ -1828,29 +1746,29 @@ async def chat_completion( ) # Insert chat files from parent message if any - parent_message = metadata.get("parent_message") or {} - parent_message_files = parent_message.get("files", []) + parent_message = metadata.get('parent_message') or {} + parent_message_files = parent_message.get('files', []) if parent_message_files: try: Chats.insert_chat_files( - metadata["chat_id"], - parent_message.get("id"), + metadata['chat_id'], + parent_message.get('id'), [ - file_item.get("id") + file_item.get('id') for file_item in parent_message_files - if file_item.get("type") == "file" + if file_item.get('type') == 'file' ], user.id, ) except Exception as e: - log.debug(f"Error inserting chat files: {e}") + log.debug(f'Error inserting chat files: {e}') pass request.state.metadata = metadata - form_data["metadata"] = metadata + form_data['metadata'] = metadata except Exception as e: - log.debug(f"Error processing chat metadata: {e}") + log.debug(f'Error processing chat metadata: {e}') raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail=str(e), @@ -1858,37 +1776,33 @@ async def chat_completion( async def process_chat(request, form_data, user, metadata, model): try: - form_data, metadata, events = await process_chat_payload( - request, form_data, user, metadata, model - ) + form_data, metadata, events = await process_chat_payload(request, form_data, user, metadata, model) response = await chat_completion_handler(request, form_data, user) - if metadata.get("chat_id") and metadata.get("message_id"): + if metadata.get('chat_id') and metadata.get('message_id'): try: - if not metadata["chat_id"].startswith("local:"): + if not metadata['chat_id'].startswith('local:'): Chats.upsert_message_to_chat_by_id_and_message_id( - metadata["chat_id"], - metadata["message_id"], + metadata['chat_id'], + metadata['message_id'], { - "parentId": metadata.get("parent_message_id", None), - "model": model_id, + 'parentId': metadata.get('parent_message_id', None), + 'model': model_id, }, ) except Exception: pass - ctx = build_chat_response_context( - request, form_data, user, model, metadata, tasks, events - ) + ctx = build_chat_response_context(request, form_data, user, model, metadata, tasks, events) return await process_chat_response(response, ctx) except asyncio.CancelledError: - log.info("Chat processing was cancelled") + log.info('Chat processing was cancelled') try: event_emitter = get_event_emitter(metadata) await asyncio.shield( event_emitter( - {"type": "chat:tasks:cancel"}, + {'type': 'chat:tasks:cancel'}, ) ) except Exception as e: @@ -1896,68 +1810,62 @@ async def chat_completion( finally: raise # re-raise to ensure proper task cancellation handling except Exception as e: - log.debug(f"Error processing chat payload: {e}") - if metadata.get("chat_id") and metadata.get("message_id"): + log.debug(f'Error processing chat payload: {e}') + if metadata.get('chat_id') and metadata.get('message_id'): # Update the chat message with the error try: - if not metadata["chat_id"].startswith("local:"): + if not metadata['chat_id'].startswith('local:'): Chats.upsert_message_to_chat_by_id_and_message_id( - metadata["chat_id"], - metadata["message_id"], + metadata['chat_id'], + metadata['message_id'], { - "parentId": metadata.get("parent_message_id", None), - "error": {"content": str(e)}, + 'parentId': metadata.get('parent_message_id', None), + 'error': {'content': str(e)}, }, ) event_emitter = get_event_emitter(metadata) await event_emitter( { - "type": "chat:message:error", - "data": {"error": {"content": str(e)}}, + 'type': 'chat:message:error', + 'data': {'error': {'content': str(e)}}, } ) await event_emitter( - {"type": "chat:tasks:cancel"}, + {'type': 'chat:tasks:cancel'}, ) except Exception: pass finally: try: - if mcp_clients := metadata.get("mcp_clients"): + if mcp_clients := metadata.get('mcp_clients'): for client in reversed(mcp_clients.values()): await client.disconnect() except Exception as e: - log.debug(f"Error cleaning up: {e}") + log.debug(f'Error cleaning up: {e}') pass # Emit chat:active=false when task completes try: - if metadata.get("chat_id"): + if metadata.get('chat_id'): event_emitter = get_event_emitter(metadata, update_db=False) if event_emitter: - await event_emitter( - {"type": "chat:active", "data": {"active": False}} - ) + await event_emitter({'type': 'chat:active', 'data': {'active': False}}) except Exception as e: - log.debug(f"Error emitting chat:active: {e}") + log.debug(f'Error emitting chat:active: {e}') - if ( - metadata.get("session_id") - and metadata.get("chat_id") - and metadata.get("message_id") - ): + if metadata.get('session_id') and metadata.get('chat_id') and metadata.get('message_id'): # Asynchronous Chat Processing task_id, _ = await create_task( request.app.state.redis, process_chat(request, form_data, user, metadata, model), - id=metadata["chat_id"], + id=metadata['chat_id'], ) # Emit chat:active=true when task starts event_emitter = get_event_emitter(metadata, update_db=False) if event_emitter: - await event_emitter({"type": "chat:active", "data": {"active": True}}) - return {"status": True, "task_id": task_id} + await event_emitter({'type': 'chat:active', 'data': {'active': True}}) + return {'status': True, 'task_id': task_id} else: return await process_chat(request, form_data, user, metadata, model) @@ -1981,8 +1889,8 @@ from open_webui.utils.anthropic import ( ) -@app.post("/api/message") -@app.post("/api/v1/messages") # Anthropic Messages API compatible endpoint +@app.post('/api/message') +@app.post('/api/v1/messages') # Anthropic Messages API compatible endpoint async def generate_messages( request: Request, form_data: dict, @@ -2002,7 +1910,7 @@ async def generate_messages( Anthropic's x-api-key header (via middleware translation). """ # Convert Anthropic payload to OpenAI format - requested_model = form_data.get("model", "") + requested_model = form_data.get('model', '') openai_payload = convert_anthropic_to_openai_payload(form_data) @@ -2013,13 +1921,11 @@ async def generate_messages( if isinstance(response, StreamingResponse): # Streaming response: wrap the generator to convert SSE format return StreamingResponse( - openai_stream_to_anthropic_stream( - response.body_iterator, model=requested_model - ), - media_type="text/event-stream", + openai_stream_to_anthropic_stream(response.body_iterator, model=requested_model), + media_type='text/event-stream', headers={ - "Cache-Control": "no-cache", - "Connection": "keep-alive", + 'Cache-Control': 'no-cache', + 'Connection': 'keep-alive', }, ) elif isinstance(response, dict): @@ -2029,14 +1935,12 @@ async def generate_messages( return response -@app.post("/api/chat/completed") -async def chat_completed( - request: Request, form_data: dict, user=Depends(get_verified_user) -): +@app.post('/api/chat/completed') +async def chat_completed(request: Request, form_data: dict, user=Depends(get_verified_user)): try: - model_item = form_data.pop("model_item", {}) + model_item = form_data.pop('model_item', {}) - if model_item.get("direct", False): + if model_item.get('direct', False): request.state.direct = True request.state.model = model_item @@ -2048,14 +1952,12 @@ async def chat_completed( ) -@app.post("/api/chat/actions/{action_id}") -async def chat_action( - request: Request, action_id: str, form_data: dict, user=Depends(get_verified_user) -): +@app.post('/api/chat/actions/{action_id}') +async def chat_action(request: Request, action_id: str, form_data: dict, user=Depends(get_verified_user)): try: - model_item = form_data.pop("model_item", {}) + model_item = form_data.pop('model_item', {}) - if model_item.get("direct", False): + if model_item.get('direct', False): request.state.direct = True request.state.model = model_item @@ -2067,10 +1969,8 @@ async def chat_action( ) -@app.post("/api/tasks/stop/{task_id}") -async def stop_task_endpoint( - request: Request, task_id: str, user=Depends(get_verified_user) -): +@app.post('/api/tasks/stop/{task_id}') +async def stop_task_endpoint(request: Request, task_id: str, user=Depends(get_verified_user)): try: result = await stop_task(request.app.state.redis, task_id) return result @@ -2078,23 +1978,21 @@ async def stop_task_endpoint( raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(e)) -@app.get("/api/tasks") +@app.get('/api/tasks') async def list_tasks_endpoint(request: Request, user=Depends(get_verified_user)): - return {"tasks": await list_tasks(request.app.state.redis)} + return {'tasks': await list_tasks(request.app.state.redis)} -@app.get("/api/tasks/chat/{chat_id}") -async def list_tasks_by_chat_id_endpoint( - request: Request, chat_id: str, user=Depends(get_verified_user) -): +@app.get('/api/tasks/chat/{chat_id}') +async def list_tasks_by_chat_id_endpoint(request: Request, chat_id: str, user=Depends(get_verified_user)): chat = Chats.get_chat_by_id(chat_id) if chat is None or chat.user_id != user.id: - return {"task_ids": []} + return {'task_ids': []} task_ids = await list_task_ids_by_item_id(request.app.state.redis, chat_id) - log.debug(f"Task IDs for chat {chat_id}: {task_ids}") - return {"task_ids": task_ids} + log.debug(f'Task IDs for chat {chat_id}: {task_ids}') + return {'task_ids': task_ids} ################################## @@ -2104,19 +2002,19 @@ async def list_tasks_by_chat_id_endpoint( ################################## -@app.get("/api/config") +@app.get('/api/config') async def get_app_config(request: Request): user = None token = None - auth_header = request.headers.get("Authorization") + auth_header = request.headers.get('Authorization') if auth_header: cred = get_http_authorization_cred(auth_header) if cred: token = cred.credentials - if not token and "token" in request.cookies: - token = request.cookies.get("token") + if not token and 'token' in request.cookies: + token = request.cookies.get('token') if token: try: @@ -2125,10 +2023,10 @@ async def get_app_config(request: Request): log.debug(e) raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, - detail="Invalid token", + detail='Invalid token', ) - if data is not None and "id" in data: - user = Users.get_user_by_id(data["id"]) + if data is not None and 'id' in data: + user = Users.get_user_by_id(data['id']) user_count = Users.get_num_users() onboarding = False @@ -2137,55 +2035,50 @@ async def get_app_config(request: Request): onboarding = user_count == 0 return { - **({"onboarding": True} if onboarding else {}), - "status": True, - "name": app.state.WEBUI_NAME, - "version": VERSION, - "default_locale": str(DEFAULT_LOCALE), - "oauth": { - "providers": { - name: config.get("name", name) - for name, config in OAUTH_PROVIDERS.items() - } - }, - "features": { - "auth": WEBUI_AUTH, - "auth_trusted_header": bool(app.state.AUTH_TRUSTED_EMAIL_HEADER), - "enable_signup_password_confirmation": ENABLE_SIGNUP_PASSWORD_CONFIRMATION, - "enable_ldap": app.state.config.ENABLE_LDAP, - "enable_api_keys": app.state.config.ENABLE_API_KEYS, - "enable_signup": app.state.config.ENABLE_SIGNUP, - "enable_login_form": app.state.config.ENABLE_LOGIN_FORM, - "enable_websocket": ENABLE_WEBSOCKET_SUPPORT, - "enable_version_update_check": ENABLE_VERSION_UPDATE_CHECK, - "enable_public_active_users_count": ENABLE_PUBLIC_ACTIVE_USERS_COUNT, - "enable_easter_eggs": ENABLE_EASTER_EGGS, + **({'onboarding': True} if onboarding else {}), + 'status': True, + 'name': app.state.WEBUI_NAME, + 'version': VERSION, + 'default_locale': str(DEFAULT_LOCALE), + 'oauth': {'providers': {name: config.get('name', name) for name, config in OAUTH_PROVIDERS.items()}}, + 'features': { + 'auth': WEBUI_AUTH, + 'auth_trusted_header': bool(app.state.AUTH_TRUSTED_EMAIL_HEADER), + 'enable_signup_password_confirmation': ENABLE_SIGNUP_PASSWORD_CONFIRMATION, + 'enable_ldap': app.state.config.ENABLE_LDAP, + 'enable_api_keys': app.state.config.ENABLE_API_KEYS, + 'enable_signup': app.state.config.ENABLE_SIGNUP, + 'enable_login_form': app.state.config.ENABLE_LOGIN_FORM, + 'enable_websocket': ENABLE_WEBSOCKET_SUPPORT, + 'enable_version_update_check': ENABLE_VERSION_UPDATE_CHECK, + 'enable_public_active_users_count': ENABLE_PUBLIC_ACTIVE_USERS_COUNT, + 'enable_easter_eggs': ENABLE_EASTER_EGGS, **( { - "enable_direct_connections": app.state.config.ENABLE_DIRECT_CONNECTIONS, - "enable_folders": app.state.config.ENABLE_FOLDERS, - "folder_max_file_count": app.state.config.FOLDER_MAX_FILE_COUNT, - "enable_channels": app.state.config.ENABLE_CHANNELS, - "enable_notes": app.state.config.ENABLE_NOTES, - "enable_web_search": app.state.config.ENABLE_WEB_SEARCH, - "enable_code_execution": app.state.config.ENABLE_CODE_EXECUTION, - "enable_code_interpreter": app.state.config.ENABLE_CODE_INTERPRETER, - "enable_image_generation": app.state.config.ENABLE_IMAGE_GENERATION, - "enable_autocomplete_generation": app.state.config.ENABLE_AUTOCOMPLETE_GENERATION, - "enable_community_sharing": app.state.config.ENABLE_COMMUNITY_SHARING, - "enable_message_rating": app.state.config.ENABLE_MESSAGE_RATING, - "enable_user_webhooks": app.state.config.ENABLE_USER_WEBHOOKS, - "enable_user_status": app.state.config.ENABLE_USER_STATUS, - "enable_admin_export": ENABLE_ADMIN_EXPORT, - "enable_admin_chat_access": ENABLE_ADMIN_CHAT_ACCESS, - "enable_admin_analytics": ENABLE_ADMIN_ANALYTICS, - "enable_google_drive_integration": app.state.config.ENABLE_GOOGLE_DRIVE_INTEGRATION, - "enable_onedrive_integration": app.state.config.ENABLE_ONEDRIVE_INTEGRATION, - "enable_memories": app.state.config.ENABLE_MEMORIES, + 'enable_direct_connections': app.state.config.ENABLE_DIRECT_CONNECTIONS, + 'enable_folders': app.state.config.ENABLE_FOLDERS, + 'folder_max_file_count': app.state.config.FOLDER_MAX_FILE_COUNT, + 'enable_channels': app.state.config.ENABLE_CHANNELS, + 'enable_notes': app.state.config.ENABLE_NOTES, + 'enable_web_search': app.state.config.ENABLE_WEB_SEARCH, + 'enable_code_execution': app.state.config.ENABLE_CODE_EXECUTION, + 'enable_code_interpreter': app.state.config.ENABLE_CODE_INTERPRETER, + 'enable_image_generation': app.state.config.ENABLE_IMAGE_GENERATION, + 'enable_autocomplete_generation': app.state.config.ENABLE_AUTOCOMPLETE_GENERATION, + 'enable_community_sharing': app.state.config.ENABLE_COMMUNITY_SHARING, + 'enable_message_rating': app.state.config.ENABLE_MESSAGE_RATING, + 'enable_user_webhooks': app.state.config.ENABLE_USER_WEBHOOKS, + 'enable_user_status': app.state.config.ENABLE_USER_STATUS, + 'enable_admin_export': ENABLE_ADMIN_EXPORT, + 'enable_admin_chat_access': ENABLE_ADMIN_CHAT_ACCESS, + 'enable_admin_analytics': ENABLE_ADMIN_ANALYTICS, + 'enable_google_drive_integration': app.state.config.ENABLE_GOOGLE_DRIVE_INTEGRATION, + 'enable_onedrive_integration': app.state.config.ENABLE_ONEDRIVE_INTEGRATION, + 'enable_memories': app.state.config.ENABLE_MEMORIES, **( { - "enable_onedrive_personal": ENABLE_ONEDRIVE_PERSONAL, - "enable_onedrive_business": ENABLE_ONEDRIVE_BUSINESS, + 'enable_onedrive_personal': ENABLE_ONEDRIVE_PERSONAL, + 'enable_onedrive_business': ENABLE_ONEDRIVE_BUSINESS, } if app.state.config.ENABLE_ONEDRIVE_INTEGRATION else {} @@ -2197,78 +2090,74 @@ async def get_app_config(request: Request): }, **( { - "default_models": app.state.config.DEFAULT_MODELS, - "default_pinned_models": app.state.config.DEFAULT_PINNED_MODELS, - "default_prompt_suggestions": app.state.config.DEFAULT_PROMPT_SUGGESTIONS, - "user_count": user_count, - "code": { - "engine": app.state.config.CODE_EXECUTION_ENGINE, - "interpreter_engine": app.state.config.CODE_INTERPRETER_ENGINE, + 'default_models': app.state.config.DEFAULT_MODELS, + 'default_pinned_models': app.state.config.DEFAULT_PINNED_MODELS, + 'default_prompt_suggestions': app.state.config.DEFAULT_PROMPT_SUGGESTIONS, + 'user_count': user_count, + 'code': { + 'engine': app.state.config.CODE_EXECUTION_ENGINE, + 'interpreter_engine': app.state.config.CODE_INTERPRETER_ENGINE, }, - "audio": { - "tts": { - "engine": app.state.config.TTS_ENGINE, - "voice": app.state.config.TTS_VOICE, - "split_on": app.state.config.TTS_SPLIT_ON, + 'audio': { + 'tts': { + 'engine': app.state.config.TTS_ENGINE, + 'voice': app.state.config.TTS_VOICE, + 'split_on': app.state.config.TTS_SPLIT_ON, }, - "stt": { - "engine": app.state.config.STT_ENGINE, + 'stt': { + 'engine': app.state.config.STT_ENGINE, }, }, - "file": { - "max_size": app.state.config.FILE_MAX_SIZE, - "max_count": app.state.config.FILE_MAX_COUNT, - "image_compression": { - "width": app.state.config.FILE_IMAGE_COMPRESSION_WIDTH, - "height": app.state.config.FILE_IMAGE_COMPRESSION_HEIGHT, + 'file': { + 'max_size': app.state.config.FILE_MAX_SIZE, + 'max_count': app.state.config.FILE_MAX_COUNT, + 'image_compression': { + 'width': app.state.config.FILE_IMAGE_COMPRESSION_WIDTH, + 'height': app.state.config.FILE_IMAGE_COMPRESSION_HEIGHT, }, }, - "permissions": {**app.state.config.USER_PERMISSIONS}, - "google_drive": { - "client_id": GOOGLE_DRIVE_CLIENT_ID.value, - "api_key": GOOGLE_DRIVE_API_KEY.value, + 'permissions': {**app.state.config.USER_PERMISSIONS}, + 'google_drive': { + 'client_id': GOOGLE_DRIVE_CLIENT_ID.value, + 'api_key': GOOGLE_DRIVE_API_KEY.value, }, - "onedrive": { - "client_id_personal": ONEDRIVE_CLIENT_ID_PERSONAL, - "client_id_business": ONEDRIVE_CLIENT_ID_BUSINESS, - "sharepoint_url": ONEDRIVE_SHAREPOINT_URL.value, - "sharepoint_tenant_id": ONEDRIVE_SHAREPOINT_TENANT_ID.value, + 'onedrive': { + 'client_id_personal': ONEDRIVE_CLIENT_ID_PERSONAL, + 'client_id_business': ONEDRIVE_CLIENT_ID_BUSINESS, + 'sharepoint_url': ONEDRIVE_SHAREPOINT_URL.value, + 'sharepoint_tenant_id': ONEDRIVE_SHAREPOINT_TENANT_ID.value, }, - "ui": { - "pending_user_overlay_title": app.state.config.PENDING_USER_OVERLAY_TITLE, - "pending_user_overlay_content": app.state.config.PENDING_USER_OVERLAY_CONTENT, - "response_watermark": app.state.config.RESPONSE_WATERMARK, + 'ui': { + 'pending_user_overlay_title': app.state.config.PENDING_USER_OVERLAY_TITLE, + 'pending_user_overlay_content': app.state.config.PENDING_USER_OVERLAY_CONTENT, + 'response_watermark': app.state.config.RESPONSE_WATERMARK, }, - "license_metadata": app.state.LICENSE_METADATA, + 'license_metadata': app.state.LICENSE_METADATA, **( { - "active_entries": app.state.USER_COUNT, + 'active_entries': app.state.USER_COUNT, } - if user.role == "admin" + if user.role == 'admin' else {} ), } - if user is not None and (user.role in ["admin", "user"]) + if user is not None and (user.role in ['admin', 'user']) else { **( { - "ui": { - "pending_user_overlay_title": app.state.config.PENDING_USER_OVERLAY_TITLE, - "pending_user_overlay_content": app.state.config.PENDING_USER_OVERLAY_CONTENT, + 'ui': { + 'pending_user_overlay_title': app.state.config.PENDING_USER_OVERLAY_TITLE, + 'pending_user_overlay_content': app.state.config.PENDING_USER_OVERLAY_CONTENT, } } - if user and user.role == "pending" + if user and user.role == 'pending' else {} ), **( { - "metadata": { - "login_footer": app.state.LICENSE_METADATA.get( - "login_footer", "" - ), - "auth_logo_position": app.state.LICENSE_METADATA.get( - "auth_logo_position", "" - ), + 'metadata': { + 'login_footer': app.state.LICENSE_METADATA.get('login_footer', ''), + 'auth_logo_position': app.state.LICENSE_METADATA.get('auth_logo_position', ''), } } if app.state.LICENSE_METADATA @@ -2283,58 +2172,56 @@ class UrlForm(BaseModel): url: str -@app.get("/api/webhook") +@app.get('/api/webhook') async def get_webhook_url(user=Depends(get_admin_user)): return { - "url": app.state.config.WEBHOOK_URL, + 'url': app.state.config.WEBHOOK_URL, } -@app.post("/api/webhook") +@app.post('/api/webhook') async def update_webhook_url(form_data: UrlForm, user=Depends(get_admin_user)): app.state.config.WEBHOOK_URL = form_data.url app.state.WEBHOOK_URL = app.state.config.WEBHOOK_URL - return {"url": app.state.config.WEBHOOK_URL} + return {'url': app.state.config.WEBHOOK_URL} -@app.get("/api/version") +@app.get('/api/version') async def get_app_version(): return { - "version": VERSION, - "deployment_id": DEPLOYMENT_ID, + 'version': VERSION, + 'deployment_id': DEPLOYMENT_ID, } -@app.get("/api/version/updates") +@app.get('/api/version/updates') async def get_app_latest_release_version(user=Depends(get_verified_user)): if not ENABLE_VERSION_UPDATE_CHECK: - log.debug( - f"Version update check is disabled, returning current version as latest version" - ) - return {"current": VERSION, "latest": VERSION} + log.debug(f'Version update check is disabled, returning current version as latest version') + return {'current': VERSION, 'latest': VERSION} try: timeout = aiohttp.ClientTimeout(total=1) async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session: async with session.get( - "https://api.github.com/repos/open-webui/open-webui/releases/latest", + 'https://api.github.com/repos/open-webui/open-webui/releases/latest', ssl=AIOHTTP_CLIENT_SESSION_SSL, ) as response: response.raise_for_status() data = await response.json() - latest_version = data["tag_name"] + latest_version = data['tag_name'] - return {"current": VERSION, "latest": latest_version[1:]} + return {'current': VERSION, 'latest': latest_version[1:]} except Exception as e: log.debug(e) - return {"current": VERSION, "latest": VERSION} + return {'current': VERSION, 'latest': VERSION} -@app.get("/api/changelog") +@app.get('/api/changelog') async def get_app_changelog(): return {key: CHANGELOG[key] for idx, key in enumerate(CHANGELOG) if idx < 5} -@app.get("/api/usage") +@app.get('/api/usage') async def get_current_usage(user=Depends(get_verified_user)): """ Get current usage statistics for Open WebUI. @@ -2342,21 +2229,21 @@ async def get_current_usage(user=Depends(get_verified_user)): """ try: # If public visibility is disabled, only allow admins to access this endpoint - if not ENABLE_PUBLIC_ACTIVE_USERS_COUNT and user.role != "admin": + if not ENABLE_PUBLIC_ACTIVE_USERS_COUNT and user.role != 'admin': raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, - detail="Access denied. Only administrators can view usage statistics.", + detail='Access denied. Only administrators can view usage statistics.', ) return { - "model_ids": get_models_in_use(), - "user_count": Users.get_active_user_count(), + 'model_ids': get_models_in_use(), + 'user_count': Users.get_active_user_count(), } except HTTPException: raise except Exception as e: - log.error(f"Error getting usage statistics: {e}") - raise HTTPException(status_code=500, detail="Internal Server Error") + log.error(f'Error getting usage statistics: {e}') + raise HTTPException(status_code=500, detail='Internal Server Error') ############################ @@ -2367,114 +2254,102 @@ async def get_current_usage(user=Depends(get_verified_user)): # Initialize OAuth client manager with any MCP tool servers using OAuth 2.1 if len(app.state.config.TOOL_SERVER_CONNECTIONS) > 0: for tool_server_connection in app.state.config.TOOL_SERVER_CONNECTIONS: - if tool_server_connection.get("type", "openapi") == "mcp": - server_id = tool_server_connection.get("info", {}).get("id") - auth_type = tool_server_connection.get("auth_type", "none") + if tool_server_connection.get('type', 'openapi') == 'mcp': + server_id = tool_server_connection.get('info', {}).get('id') + auth_type = tool_server_connection.get('auth_type', 'none') - if server_id and auth_type == "oauth_2.1": - oauth_client_info = tool_server_connection.get("info", {}).get( - "oauth_client_info", "" - ) + if server_id and auth_type == 'oauth_2.1': + oauth_client_info = tool_server_connection.get('info', {}).get('oauth_client_info', '') try: oauth_client_info = decrypt_data(oauth_client_info) app.state.oauth_client_manager.add_client( - f"mcp:{server_id}", + f'mcp:{server_id}', OAuthClientInformationFull(**oauth_client_info), ) except Exception as e: - log.error( - f"Error adding OAuth client for MCP tool server {server_id}: {e}" - ) + log.error(f'Error adding OAuth client for MCP tool server {server_id}: {e}') pass try: if ENABLE_STAR_SESSIONS_MIDDLEWARE: redis_session_store = RedisStore( url=REDIS_URL, - prefix=(f"{REDIS_KEY_PREFIX}:session:" if REDIS_KEY_PREFIX else "session:"), + prefix=(f'{REDIS_KEY_PREFIX}:session:' if REDIS_KEY_PREFIX else 'session:'), ) app.add_middleware(SessionAutoloadMiddleware) app.add_middleware( StarSessionsMiddleware, store=redis_session_store, - cookie_name="owui-session", + cookie_name='owui-session', cookie_same_site=WEBUI_SESSION_COOKIE_SAME_SITE, cookie_https_only=WEBUI_SESSION_COOKIE_SECURE, ) - log.info("Using Redis for session") + log.info('Using Redis for session') else: - raise ValueError("No Redis URL provided") + raise ValueError('No Redis URL provided') except Exception as e: app.add_middleware( SessionMiddleware, secret_key=WEBUI_SECRET_KEY, - session_cookie="owui-session", + session_cookie='owui-session', same_site=WEBUI_SESSION_COOKIE_SAME_SITE, https_only=WEBUI_SESSION_COOKIE_SECURE, ) async def register_client(request, client_id: str) -> bool: - server_type, server_id = client_id.split(":", 1) + server_type, server_id = client_id.split(':', 1) connection = None connection_idx = None for idx, conn in enumerate(request.app.state.config.TOOL_SERVER_CONNECTIONS or []): - if conn.get("type", "openapi") == server_type: - info = conn.get("info", {}) - if info.get("id") == server_id: + if conn.get('type', 'openapi') == server_type: + info = conn.get('info', {}) + if info.get('id') == server_id: connection = conn connection_idx = idx break if connection is None or connection_idx is None: - log.warning( - f"Unable to locate MCP tool server configuration for client {client_id} during re-registration" - ) + log.warning(f'Unable to locate MCP tool server configuration for client {client_id} during re-registration') return False - server_url = connection.get("url") - oauth_server_key = (connection.get("config") or {}).get("oauth_server_key") + server_url = connection.get('url') + oauth_server_key = (connection.get('config') or {}).get('oauth_server_key') try: - oauth_client_info = ( - await get_oauth_client_info_with_dynamic_client_registration( - request, - client_id, - server_url, - oauth_server_key, - ) + oauth_client_info = await get_oauth_client_info_with_dynamic_client_registration( + request, + client_id, + server_url, + oauth_server_key, ) except Exception as e: - log.error(f"Dynamic client re-registration failed for {client_id}: {e}") + log.error(f'Dynamic client re-registration failed for {client_id}: {e}') return False try: request.app.state.config.TOOL_SERVER_CONNECTIONS[connection_idx] = { **connection, - "info": { - **connection.get("info", {}), - "oauth_client_info": encrypt_data( - oauth_client_info.model_dump(mode="json") - ), + 'info': { + **connection.get('info', {}), + 'oauth_client_info': encrypt_data(oauth_client_info.model_dump(mode='json')), }, } except Exception as e: - log.error( - f"Failed to persist updated OAuth client info for tool server {client_id}: {e}" - ) + log.error(f'Failed to persist updated OAuth client info for tool server {client_id}: {e}') return False oauth_client_manager.remove_client(client_id) oauth_client_manager.add_client(client_id, oauth_client_info) - log.info(f"Re-registered OAuth client {client_id} for tool server") + log.info(f'Re-registered OAuth client {client_id} for tool server') return True -@app.get("/oauth/clients/{client_id}/authorize") +@app.get('/oauth/clients/{client_id}/authorize') async def oauth_client_authorize( client_id: str, request: Request, @@ -2489,7 +2364,7 @@ async def oauth_client_authorize( if not await oauth_client_manager._preflight_authorization_url(client, client_info): log.info( - "Detected invalid OAuth client %s; attempting re-registration", + 'Detected invalid OAuth client %s; attempting re-registration', client_id, ) @@ -2497,7 +2372,7 @@ async def oauth_client_authorize( if not registered: raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail="Failed to re-register OAuth client", + detail='Failed to re-register OAuth client', ) client = oauth_client_manager.get_client(client_id) @@ -2505,21 +2380,19 @@ async def oauth_client_authorize( if client is None or client_info is None: raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail="OAuth client unavailable after re-registration", + detail='OAuth client unavailable after re-registration', ) - if not await oauth_client_manager._preflight_authorization_url( - client, client_info - ): + if not await oauth_client_manager._preflight_authorization_url(client, client_info): raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail="OAuth client registration is still invalid after re-registration", + detail='OAuth client registration is still invalid after re-registration', ) return await oauth_client_manager.handle_authorize(request, client_id=client_id) -@app.get("/oauth/clients/{client_id}/callback") +@app.get('/oauth/clients/{client_id}/callback') async def oauth_client_callback( client_id: str, request: Request, @@ -2534,7 +2407,7 @@ async def oauth_client_callback( ) -@app.get("/oauth/{provider}/login") +@app.get('/oauth/{provider}/login') async def oauth_login(provider: str, request: Request): return await oauth_manager.handle_login(request, provider) @@ -2545,8 +2418,8 @@ async def oauth_login(provider: str, request: Request): # - This is considered insecure in general, as OAuth providers do not always verify email addresses # 3. If there is no user, and ENABLE_OAUTH_SIGNUP is true, create a user # - Email addresses are considered unique, so we fail registration if the email address is already taken -@app.get("/oauth/{provider}/login/callback") -@app.get("/oauth/{provider}/callback") # Legacy endpoint +@app.get('/oauth/{provider}/login/callback') +@app.get('/oauth/{provider}/callback') # Legacy endpoint async def oauth_login_callback( provider: str, request: Request, @@ -2556,41 +2429,41 @@ async def oauth_login_callback( return await oauth_manager.handle_callback(request, provider, response, db=db) -@app.get("/manifest.json") +@app.get('/manifest.json') async def get_manifest_json(): if app.state.EXTERNAL_PWA_MANIFEST_URL: return requests.get(app.state.EXTERNAL_PWA_MANIFEST_URL).json() else: return { - "name": app.state.WEBUI_NAME, - "short_name": app.state.WEBUI_NAME, - "description": f"{app.state.WEBUI_NAME} is an open, extensible, user-friendly interface for AI that adapts to your workflow.", - "start_url": "/", - "display": "standalone", - "background_color": "#343541", - "icons": [ + 'name': app.state.WEBUI_NAME, + 'short_name': app.state.WEBUI_NAME, + 'description': f'{app.state.WEBUI_NAME} is an open, extensible, user-friendly interface for AI that adapts to your workflow.', + 'start_url': '/', + 'display': 'standalone', + 'background_color': '#343541', + 'icons': [ { - "src": "/static/logo.png", - "type": "image/png", - "sizes": "500x500", - "purpose": "any", + 'src': '/static/logo.png', + 'type': 'image/png', + 'sizes': '500x500', + 'purpose': 'any', }, { - "src": "/static/logo.png", - "type": "image/png", - "sizes": "500x500", - "purpose": "maskable", + 'src': '/static/logo.png', + 'type': 'image/png', + 'sizes': '500x500', + 'purpose': 'maskable', }, ], - "share_target": { - "action": "/", - "method": "GET", - "params": {"text": "shared"}, + 'share_target': { + 'action': '/', + 'method': 'GET', + 'params': {'text': 'shared'}, }, } -@app.get("/opensearch.xml") +@app.get('/opensearch.xml') async def get_opensearch_xml(): xml_content = rf""" @@ -2598,40 +2471,40 @@ async def get_opensearch_xml(): Search {app.state.WEBUI_NAME} UTF-8 {app.state.config.WEBUI_URL}/static/favicon.png - + {app.state.config.WEBUI_URL} """ - return Response(content=xml_content, media_type="application/xml") + return Response(content=xml_content, media_type='application/xml') -@app.get("/health") +@app.get('/health') async def healthcheck(): - return {"status": True} + return {'status': True} -@app.get("/ready") +@app.get('/ready') async def readiness_check(): """ Returns 200 only when the application is ready to accept traffic. """ # Ensure application startup work has completed - if not getattr(app.state, "startup_complete", False): - log.info("Readiness check failed: startup not complete") + if not getattr(app.state, 'startup_complete', False): + log.info('Readiness check failed: startup not complete') raise HTTPException( status_code=status.HTTP_503_SERVICE_UNAVAILABLE, - detail="Startup not complete", + detail='Startup not complete', ) # Check database connectivity try: - ScopedSession.execute(text("SELECT 1;")).all() + ScopedSession.execute(text('SELECT 1;')).all() except Exception as e: - log.warning(f"Readiness check DB ping failed: {e!r}") + log.warning(f'Readiness check DB ping failed: {e!r}') raise HTTPException( status_code=status.HTTP_503_SERVICE_UNAVAILABLE, - detail="Database not ready", + detail='Database not ready', ) # Check Redis connectivity if configured @@ -2640,27 +2513,27 @@ async def readiness_check(): try: pong = await redis.ping() if pong is False: - raise Exception("Redis PING returned False") + raise Exception('Redis PING returned False') except Exception as e: - log.warning(f"Readiness check Redis ping failed: {e!r}") + log.warning(f'Readiness check Redis ping failed: {e!r}') raise HTTPException( status_code=status.HTTP_503_SERVICE_UNAVAILABLE, - detail="Redis not ready", + detail='Redis not ready', ) - return {"status": True} + return {'status': True} -@app.get("/health/db") +@app.get('/health/db') async def healthcheck_with_db(): - ScopedSession.execute(text("SELECT 1;")).all() - return {"status": True} + ScopedSession.execute(text('SELECT 1;')).all() + return {'status': True} -app.mount("/static", StaticFiles(directory=STATIC_DIR), name="static") +app.mount('/static', StaticFiles(directory=STATIC_DIR), name='static') -@app.get("/cache/{path:path}") +@app.get('/cache/{path:path}') async def serve_cache_file( path: str, user=Depends(get_verified_user), @@ -2668,9 +2541,9 @@ async def serve_cache_file( file_path = os.path.abspath(os.path.join(CACHE_DIR, path)) # prevent path traversal if not file_path.startswith(os.path.abspath(CACHE_DIR)): - raise HTTPException(status_code=404, detail="File not found") + raise HTTPException(status_code=404, detail='File not found') if not os.path.isfile(file_path): - raise HTTPException(status_code=404, detail="File not found") + raise HTTPException(status_code=404, detail='File not found') return FileResponse(file_path) @@ -2678,22 +2551,20 @@ def swagger_ui_html(*args, **kwargs): return get_swagger_ui_html( *args, **kwargs, - swagger_js_url="/static/swagger-ui/swagger-ui-bundle.js", - swagger_css_url="/static/swagger-ui/swagger-ui.css", - swagger_favicon_url="/static/swagger-ui/favicon.png", + swagger_js_url='/static/swagger-ui/swagger-ui-bundle.js', + swagger_css_url='/static/swagger-ui/swagger-ui.css', + swagger_favicon_url='/static/swagger-ui/favicon.png', ) applications.get_swagger_ui_html = swagger_ui_html if os.path.exists(FRONTEND_BUILD_DIR): - mimetypes.add_type("text/javascript", ".js") + mimetypes.add_type('text/javascript', '.js') app.mount( - "/", + '/', SPAStaticFiles(directory=FRONTEND_BUILD_DIR, html=True), - name="spa-static-files", + name='spa-static-files', ) else: - log.warning( - f"Frontend build directory not found at '{FRONTEND_BUILD_DIR}'. Serving API only." - ) + log.warning(f"Frontend build directory not found at '{FRONTEND_BUILD_DIR}'. Serving API only.") diff --git a/backend/open_webui/migrations/env.py b/backend/open_webui/migrations/env.py index 720b90f5fc..9ee6c2dceb 100644 --- a/backend/open_webui/migrations/env.py +++ b/backend/open_webui/migrations/env.py @@ -16,7 +16,7 @@ if config.config_file_name is not None: fileConfig(config.config_file_name, disable_existing_loggers=False) # Re-apply JSON formatter after fileConfig replaces handlers. -if LOG_FORMAT == "json": +if LOG_FORMAT == 'json': from open_webui.env import JSONFormatter for handler in logging.root.handlers: @@ -36,7 +36,7 @@ target_metadata = Auth.metadata DB_URL = DATABASE_URL if DB_URL: - config.set_main_option("sqlalchemy.url", DB_URL.replace("%", "%%")) + config.set_main_option('sqlalchemy.url', DB_URL.replace('%', '%%')) def run_migrations_offline() -> None: @@ -51,12 +51,12 @@ def run_migrations_offline() -> None: script output. """ - url = config.get_main_option("sqlalchemy.url") + url = config.get_main_option('sqlalchemy.url') context.configure( url=url, target_metadata=target_metadata, literal_binds=True, - dialect_opts={"paramstyle": "named"}, + dialect_opts={'paramstyle': 'named'}, ) with context.begin_transaction(): @@ -71,15 +71,13 @@ def run_migrations_online() -> None: """ # Handle SQLCipher URLs - if DB_URL and DB_URL.startswith("sqlite+sqlcipher://"): - if not DATABASE_PASSWORD or DATABASE_PASSWORD.strip() == "": - raise ValueError( - "DATABASE_PASSWORD is required when using sqlite+sqlcipher:// URLs" - ) + if DB_URL and DB_URL.startswith('sqlite+sqlcipher://'): + if not DATABASE_PASSWORD or DATABASE_PASSWORD.strip() == '': + raise ValueError('DATABASE_PASSWORD is required when using sqlite+sqlcipher:// URLs') # Extract database path from SQLCipher URL - db_path = DB_URL.replace("sqlite+sqlcipher://", "") - if db_path.startswith("/"): + db_path = DB_URL.replace('sqlite+sqlcipher://', '') + if db_path.startswith('/'): db_path = db_path[1:] # Remove leading slash for relative paths # Create a custom creator function that uses sqlcipher3 @@ -91,7 +89,7 @@ def run_migrations_online() -> None: return conn connectable = create_engine( - "sqlite://", # Dummy URL since we're using creator + 'sqlite://', # Dummy URL since we're using creator creator=create_sqlcipher_connection, echo=False, ) @@ -99,7 +97,7 @@ def run_migrations_online() -> None: # Standard database connection (existing logic) connectable = engine_from_config( config.get_section(config.config_ini_section, {}), - prefix="sqlalchemy.", + prefix='sqlalchemy.', poolclass=pool.NullPool, ) diff --git a/backend/open_webui/migrations/util.py b/backend/open_webui/migrations/util.py index 955066602a..6ea2a5f4bb 100644 --- a/backend/open_webui/migrations/util.py +++ b/backend/open_webui/migrations/util.py @@ -12,4 +12,4 @@ def get_existing_tables(): def get_revision_id(): import uuid - return str(uuid.uuid4()).replace("-", "")[:12] + return str(uuid.uuid4()).replace('-', '')[:12] diff --git a/backend/open_webui/migrations/versions/018012973d35_add_indexes.py b/backend/open_webui/migrations/versions/018012973d35_add_indexes.py index 29af427108..c5016e1a8b 100644 --- a/backend/open_webui/migrations/versions/018012973d35_add_indexes.py +++ b/backend/open_webui/migrations/versions/018012973d35_add_indexes.py @@ -9,38 +9,38 @@ Create Date: 2025-08-13 03:00:00.000000 from alembic import op import sqlalchemy as sa -revision = "018012973d35" -down_revision = "d31026856c01" +revision = '018012973d35' +down_revision = 'd31026856c01' branch_labels = None depends_on = None def upgrade(): # Chat table indexes - op.create_index("folder_id_idx", "chat", ["folder_id"]) - op.create_index("user_id_pinned_idx", "chat", ["user_id", "pinned"]) - op.create_index("user_id_archived_idx", "chat", ["user_id", "archived"]) - op.create_index("updated_at_user_id_idx", "chat", ["updated_at", "user_id"]) - op.create_index("folder_id_user_id_idx", "chat", ["folder_id", "user_id"]) + op.create_index('folder_id_idx', 'chat', ['folder_id']) + op.create_index('user_id_pinned_idx', 'chat', ['user_id', 'pinned']) + op.create_index('user_id_archived_idx', 'chat', ['user_id', 'archived']) + op.create_index('updated_at_user_id_idx', 'chat', ['updated_at', 'user_id']) + op.create_index('folder_id_user_id_idx', 'chat', ['folder_id', 'user_id']) # Tag table index - op.create_index("user_id_idx", "tag", ["user_id"]) + op.create_index('user_id_idx', 'tag', ['user_id']) # Function table index - op.create_index("is_global_idx", "function", ["is_global"]) + op.create_index('is_global_idx', 'function', ['is_global']) def downgrade(): # Chat table indexes - op.drop_index("folder_id_idx", table_name="chat") - op.drop_index("user_id_pinned_idx", table_name="chat") - op.drop_index("user_id_archived_idx", table_name="chat") - op.drop_index("updated_at_user_id_idx", table_name="chat") - op.drop_index("folder_id_user_id_idx", table_name="chat") + op.drop_index('folder_id_idx', table_name='chat') + op.drop_index('user_id_pinned_idx', table_name='chat') + op.drop_index('user_id_archived_idx', table_name='chat') + op.drop_index('updated_at_user_id_idx', table_name='chat') + op.drop_index('folder_id_user_id_idx', table_name='chat') # Tag table index - op.drop_index("user_id_idx", table_name="tag") + op.drop_index('user_id_idx', table_name='tag') # Function table index - op.drop_index("is_global_idx", table_name="function") + op.drop_index('is_global_idx', table_name='function') diff --git a/backend/open_webui/migrations/versions/1af9b942657b_migrate_tags.py b/backend/open_webui/migrations/versions/1af9b942657b_migrate_tags.py index 8a0ab1b491..caffb7e3b4 100644 --- a/backend/open_webui/migrations/versions/1af9b942657b_migrate_tags.py +++ b/backend/open_webui/migrations/versions/1af9b942657b_migrate_tags.py @@ -13,8 +13,8 @@ from sqlalchemy.engine.reflection import Inspector import json -revision = "1af9b942657b" -down_revision = "242a2047eae0" +revision = '1af9b942657b' +down_revision = '242a2047eae0' branch_labels = None depends_on = None @@ -25,43 +25,40 @@ def upgrade(): inspector = Inspector.from_engine(conn) # Clean up potential leftover temp table from previous failures - conn.execute(sa.text("DROP TABLE IF EXISTS _alembic_tmp_tag")) + conn.execute(sa.text('DROP TABLE IF EXISTS _alembic_tmp_tag')) # Check if the 'tag' table exists tables = inspector.get_table_names() # Step 1: Modify Tag table using batch mode for SQLite support - if "tag" in tables: + if 'tag' in tables: # Get the current columns in the 'tag' table - columns = [col["name"] for col in inspector.get_columns("tag")] + columns = [col['name'] for col in inspector.get_columns('tag')] # Get any existing unique constraints on the 'tag' table - current_constraints = inspector.get_unique_constraints("tag") + current_constraints = inspector.get_unique_constraints('tag') - with op.batch_alter_table("tag", schema=None) as batch_op: + with op.batch_alter_table('tag', schema=None) as batch_op: # Check if the unique constraint already exists - if not any( - constraint["name"] == "uq_id_user_id" - for constraint in current_constraints - ): + if not any(constraint['name'] == 'uq_id_user_id' for constraint in current_constraints): # Create unique constraint if it doesn't exist - batch_op.create_unique_constraint("uq_id_user_id", ["id", "user_id"]) + batch_op.create_unique_constraint('uq_id_user_id', ['id', 'user_id']) # Check if the 'data' column exists before trying to drop it - if "data" in columns: - batch_op.drop_column("data") + if 'data' in columns: + batch_op.drop_column('data') # Check if the 'meta' column needs to be created - if "meta" not in columns: + if 'meta' not in columns: # Add the 'meta' column if it doesn't already exist - batch_op.add_column(sa.Column("meta", sa.JSON(), nullable=True)) + batch_op.add_column(sa.Column('meta', sa.JSON(), nullable=True)) tag = table( - "tag", - column("id", sa.String()), - column("name", sa.String()), - column("user_id", sa.String()), - column("meta", sa.JSON()), + 'tag', + column('id', sa.String()), + column('name', sa.String()), + column('user_id', sa.String()), + column('meta', sa.JSON()), ) # Step 2: Migrate tags @@ -70,12 +67,12 @@ def upgrade(): tag_updates = {} for row in result: - new_id = row.name.replace(" ", "_").lower() + new_id = row.name.replace(' ', '_').lower() tag_updates[row.id] = new_id for tag_id, new_tag_id in tag_updates.items(): - print(f"Updating tag {tag_id} to {new_tag_id}") - if new_tag_id == "pinned": + print(f'Updating tag {tag_id} to {new_tag_id}') + if new_tag_id == 'pinned': # delete tag delete_stmt = sa.delete(tag).where(tag.c.id == tag_id) conn.execute(delete_stmt) @@ -86,9 +83,7 @@ def upgrade(): if existing_tag_result: # Handle duplicate case: the new_tag_id already exists - print( - f"Tag {new_tag_id} already exists. Removing current tag with ID {tag_id} to avoid duplicates." - ) + print(f'Tag {new_tag_id} already exists. Removing current tag with ID {tag_id} to avoid duplicates.') # Option 1: Delete the current tag if an update to new_tag_id would cause duplication delete_stmt = sa.delete(tag).where(tag.c.id == tag_id) conn.execute(delete_stmt) @@ -98,19 +93,15 @@ def upgrade(): conn.execute(update_stmt) # Add columns `pinned` and `meta` to 'chat' - op.add_column("chat", sa.Column("pinned", sa.Boolean(), nullable=True)) - op.add_column( - "chat", sa.Column("meta", sa.JSON(), nullable=False, server_default="{}") - ) + op.add_column('chat', sa.Column('pinned', sa.Boolean(), nullable=True)) + op.add_column('chat', sa.Column('meta', sa.JSON(), nullable=False, server_default='{}')) - chatidtag = table( - "chatidtag", column("chat_id", sa.String()), column("tag_name", sa.String()) - ) + chatidtag = table('chatidtag', column('chat_id', sa.String()), column('tag_name', sa.String())) chat = table( - "chat", - column("id", sa.String()), - column("pinned", sa.Boolean()), - column("meta", sa.JSON()), + 'chat', + column('id', sa.String()), + column('pinned', sa.Boolean()), + column('meta', sa.JSON()), ) # Fetch existing tags @@ -120,29 +111,27 @@ def upgrade(): chat_updates = {} for row in result: chat_id = row.chat_id - tag_name = row.tag_name.replace(" ", "_").lower() + tag_name = row.tag_name.replace(' ', '_').lower() - if tag_name == "pinned": + if tag_name == 'pinned': # Specifically handle 'pinned' tag if chat_id not in chat_updates: - chat_updates[chat_id] = {"pinned": True, "meta": {}} + chat_updates[chat_id] = {'pinned': True, 'meta': {}} else: - chat_updates[chat_id]["pinned"] = True + chat_updates[chat_id]['pinned'] = True else: if chat_id not in chat_updates: - chat_updates[chat_id] = {"pinned": False, "meta": {"tags": [tag_name]}} + chat_updates[chat_id] = {'pinned': False, 'meta': {'tags': [tag_name]}} else: - tags = chat_updates[chat_id]["meta"].get("tags", []) + tags = chat_updates[chat_id]['meta'].get('tags', []) tags.append(tag_name) - chat_updates[chat_id]["meta"]["tags"] = list(set(tags)) + chat_updates[chat_id]['meta']['tags'] = list(set(tags)) # Update chats based on accumulated changes for chat_id, updates in chat_updates.items(): update_stmt = sa.update(chat).where(chat.c.id == chat_id) - update_stmt = update_stmt.values( - meta=updates.get("meta", {}), pinned=updates.get("pinned", False) - ) + update_stmt = update_stmt.values(meta=updates.get('meta', {}), pinned=updates.get('pinned', False)) conn.execute(update_stmt) pass diff --git a/backend/open_webui/migrations/versions/242a2047eae0_update_chat_table.py b/backend/open_webui/migrations/versions/242a2047eae0_update_chat_table.py index 6017da3169..7fadb05a92 100644 --- a/backend/open_webui/migrations/versions/242a2047eae0_update_chat_table.py +++ b/backend/open_webui/migrations/versions/242a2047eae0_update_chat_table.py @@ -12,8 +12,8 @@ from sqlalchemy.sql import table, select, update import json -revision = "242a2047eae0" -down_revision = "6a39f3d8e55c" +revision = '242a2047eae0' +down_revision = '6a39f3d8e55c' branch_labels = None depends_on = None @@ -22,39 +22,37 @@ def upgrade(): conn = op.get_bind() inspector = sa.inspect(conn) - columns = inspector.get_columns("chat") - column_dict = {col["name"]: col for col in columns} + columns = inspector.get_columns('chat') + column_dict = {col['name']: col for col in columns} - chat_column = column_dict.get("chat") - old_chat_exists = "old_chat" in column_dict + chat_column = column_dict.get('chat') + old_chat_exists = 'old_chat' in column_dict if chat_column: - if isinstance(chat_column["type"], sa.Text): + if isinstance(chat_column['type'], sa.Text): print("Converting 'chat' column to JSON") if old_chat_exists: print("Dropping old 'old_chat' column") - op.drop_column("chat", "old_chat") + op.drop_column('chat', 'old_chat') # Step 1: Rename current 'chat' column to 'old_chat' print("Renaming 'chat' column to 'old_chat'") - op.alter_column( - "chat", "chat", new_column_name="old_chat", existing_type=sa.Text() - ) + op.alter_column('chat', 'chat', new_column_name='old_chat', existing_type=sa.Text()) # Step 2: Add new 'chat' column of type JSON print("Adding new 'chat' column of type JSON") - op.add_column("chat", sa.Column("chat", sa.JSON(), nullable=True)) + op.add_column('chat', sa.Column('chat', sa.JSON(), nullable=True)) else: # If the column is already JSON, no need to do anything pass # Step 3: Migrate data from 'old_chat' to 'chat' chat_table = table( - "chat", - sa.Column("id", sa.String(), primary_key=True), - sa.Column("old_chat", sa.Text()), - sa.Column("chat", sa.JSON()), + 'chat', + sa.Column('id', sa.String(), primary_key=True), + sa.Column('old_chat', sa.Text()), + sa.Column('chat', sa.JSON()), ) # - Selecting all data from the table @@ -67,41 +65,33 @@ def upgrade(): except json.JSONDecodeError: json_data = None # Handle cases where the text cannot be converted to JSON - connection.execute( - sa.update(chat_table) - .where(chat_table.c.id == row.id) - .values(chat=json_data) - ) + connection.execute(sa.update(chat_table).where(chat_table.c.id == row.id).values(chat=json_data)) # Step 4: Drop 'old_chat' column print("Dropping 'old_chat' column") - op.drop_column("chat", "old_chat") + op.drop_column('chat', 'old_chat') def downgrade(): # Step 1: Add 'old_chat' column back as Text - op.add_column("chat", sa.Column("old_chat", sa.Text(), nullable=True)) + op.add_column('chat', sa.Column('old_chat', sa.Text(), nullable=True)) # Step 2: Convert 'chat' JSON data back to text and store in 'old_chat' chat_table = table( - "chat", - sa.Column("id", sa.String(), primary_key=True), - sa.Column("chat", sa.JSON()), - sa.Column("old_chat", sa.Text()), + 'chat', + sa.Column('id', sa.String(), primary_key=True), + sa.Column('chat', sa.JSON()), + sa.Column('old_chat', sa.Text()), ) connection = op.get_bind() results = connection.execute(select(chat_table.c.id, chat_table.c.chat)) for row in results: text_data = json.dumps(row.chat) if row.chat is not None else None - connection.execute( - sa.update(chat_table) - .where(chat_table.c.id == row.id) - .values(old_chat=text_data) - ) + connection.execute(sa.update(chat_table).where(chat_table.c.id == row.id).values(old_chat=text_data)) # Step 3: Remove the new 'chat' JSON column - op.drop_column("chat", "chat") + op.drop_column('chat', 'chat') # Step 4: Rename 'old_chat' back to 'chat' - op.alter_column("chat", "old_chat", new_column_name="chat", existing_type=sa.Text()) + op.alter_column('chat', 'old_chat', new_column_name='chat', existing_type=sa.Text()) diff --git a/backend/open_webui/migrations/versions/2f1211949ecc_update_message_and_channel_member_table.py b/backend/open_webui/migrations/versions/2f1211949ecc_update_message_and_channel_member_table.py index 1a4ae73180..51a8e329f1 100644 --- a/backend/open_webui/migrations/versions/2f1211949ecc_update_message_and_channel_member_table.py +++ b/backend/open_webui/migrations/versions/2f1211949ecc_update_message_and_channel_member_table.py @@ -13,19 +13,19 @@ import sqlalchemy as sa import open_webui.internal.db # revision identifiers, used by Alembic. -revision: str = "2f1211949ecc" -down_revision: Union[str, None] = "37f288994c47" +revision: str = '2f1211949ecc' +down_revision: Union[str, None] = '37f288994c47' branch_labels: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None def upgrade() -> None: # New columns to be added to channel_member table - op.add_column("channel_member", sa.Column("status", sa.Text(), nullable=True)) + op.add_column('channel_member', sa.Column('status', sa.Text(), nullable=True)) op.add_column( - "channel_member", + 'channel_member', sa.Column( - "is_active", + 'is_active', sa.Boolean(), nullable=False, default=True, @@ -34,9 +34,9 @@ def upgrade() -> None: ) op.add_column( - "channel_member", + 'channel_member', sa.Column( - "is_channel_muted", + 'is_channel_muted', sa.Boolean(), nullable=False, default=False, @@ -44,9 +44,9 @@ def upgrade() -> None: ), ) op.add_column( - "channel_member", + 'channel_member', sa.Column( - "is_channel_pinned", + 'is_channel_pinned', sa.Boolean(), nullable=False, default=False, @@ -54,49 +54,41 @@ def upgrade() -> None: ), ) - op.add_column("channel_member", sa.Column("data", sa.JSON(), nullable=True)) - op.add_column("channel_member", sa.Column("meta", sa.JSON(), nullable=True)) + op.add_column('channel_member', sa.Column('data', sa.JSON(), nullable=True)) + op.add_column('channel_member', sa.Column('meta', sa.JSON(), nullable=True)) - op.add_column( - "channel_member", sa.Column("joined_at", sa.BigInteger(), nullable=False) - ) - op.add_column( - "channel_member", sa.Column("left_at", sa.BigInteger(), nullable=True) - ) + op.add_column('channel_member', sa.Column('joined_at', sa.BigInteger(), nullable=False)) + op.add_column('channel_member', sa.Column('left_at', sa.BigInteger(), nullable=True)) - op.add_column( - "channel_member", sa.Column("last_read_at", sa.BigInteger(), nullable=True) - ) + op.add_column('channel_member', sa.Column('last_read_at', sa.BigInteger(), nullable=True)) - op.add_column( - "channel_member", sa.Column("updated_at", sa.BigInteger(), nullable=True) - ) + op.add_column('channel_member', sa.Column('updated_at', sa.BigInteger(), nullable=True)) # New columns to be added to message table op.add_column( - "message", + 'message', sa.Column( - "is_pinned", + 'is_pinned', sa.Boolean(), nullable=False, default=False, server_default=sa.sql.expression.false(), ), ) - op.add_column("message", sa.Column("pinned_at", sa.BigInteger(), nullable=True)) - op.add_column("message", sa.Column("pinned_by", sa.Text(), nullable=True)) + op.add_column('message', sa.Column('pinned_at', sa.BigInteger(), nullable=True)) + op.add_column('message', sa.Column('pinned_by', sa.Text(), nullable=True)) def downgrade() -> None: - op.drop_column("channel_member", "updated_at") - op.drop_column("channel_member", "last_read_at") + op.drop_column('channel_member', 'updated_at') + op.drop_column('channel_member', 'last_read_at') - op.drop_column("channel_member", "meta") - op.drop_column("channel_member", "data") + op.drop_column('channel_member', 'meta') + op.drop_column('channel_member', 'data') - op.drop_column("channel_member", "is_channel_pinned") - op.drop_column("channel_member", "is_channel_muted") + op.drop_column('channel_member', 'is_channel_pinned') + op.drop_column('channel_member', 'is_channel_muted') - op.drop_column("message", "pinned_by") - op.drop_column("message", "pinned_at") - op.drop_column("message", "is_pinned") + op.drop_column('message', 'pinned_by') + op.drop_column('message', 'pinned_at') + op.drop_column('message', 'is_pinned') diff --git a/backend/open_webui/migrations/versions/374d2f66af06_add_prompt_history_table.py b/backend/open_webui/migrations/versions/374d2f66af06_add_prompt_history_table.py index 57bc8748e3..c412107032 100644 --- a/backend/open_webui/migrations/versions/374d2f66af06_add_prompt_history_table.py +++ b/backend/open_webui/migrations/versions/374d2f66af06_add_prompt_history_table.py @@ -12,8 +12,8 @@ import uuid from alembic import op import sqlalchemy as sa -revision: str = "374d2f66af06" -down_revision: Union[str, None] = "c440947495f3" +revision: str = '374d2f66af06' +down_revision: Union[str, None] = 'c440947495f3' branch_labels: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None @@ -26,13 +26,13 @@ def upgrade() -> None: # We need to assume the OLD structure. old_prompt_table = sa.table( - "prompt", - sa.column("command", sa.Text()), - sa.column("user_id", sa.Text()), - sa.column("title", sa.Text()), - sa.column("content", sa.Text()), - sa.column("timestamp", sa.BigInteger()), - sa.column("access_control", sa.JSON()), + 'prompt', + sa.column('command', sa.Text()), + sa.column('user_id', sa.Text()), + sa.column('title', sa.Text()), + sa.column('content', sa.Text()), + sa.column('timestamp', sa.BigInteger()), + sa.column('access_control', sa.JSON()), ) # Check if table exists/read data @@ -53,61 +53,61 @@ def upgrade() -> None: # Step 2: Create new prompt table with 'id' as PRIMARY KEY op.create_table( - "prompt_new", - sa.Column("id", sa.Text(), primary_key=True), - sa.Column("command", sa.String(), unique=True, index=True), - sa.Column("user_id", sa.String(), nullable=False), - sa.Column("name", sa.Text(), nullable=False), - sa.Column("content", sa.Text(), nullable=False), - sa.Column("data", sa.JSON(), nullable=True), - sa.Column("meta", sa.JSON(), nullable=True), - sa.Column("access_control", sa.JSON(), nullable=True), - sa.Column("is_active", sa.Boolean(), nullable=False, server_default="1"), - sa.Column("version_id", sa.Text(), nullable=True), - sa.Column("tags", sa.JSON(), nullable=True), - sa.Column("created_at", sa.BigInteger(), nullable=False), - sa.Column("updated_at", sa.BigInteger(), nullable=False), + 'prompt_new', + sa.Column('id', sa.Text(), primary_key=True), + sa.Column('command', sa.String(), unique=True, index=True), + sa.Column('user_id', sa.String(), nullable=False), + sa.Column('name', sa.Text(), nullable=False), + sa.Column('content', sa.Text(), nullable=False), + sa.Column('data', sa.JSON(), nullable=True), + sa.Column('meta', sa.JSON(), nullable=True), + sa.Column('access_control', sa.JSON(), nullable=True), + sa.Column('is_active', sa.Boolean(), nullable=False, server_default='1'), + sa.Column('version_id', sa.Text(), nullable=True), + sa.Column('tags', sa.JSON(), nullable=True), + sa.Column('created_at', sa.BigInteger(), nullable=False), + sa.Column('updated_at', sa.BigInteger(), nullable=False), ) # Step 3: Create prompt_history table op.create_table( - "prompt_history", - sa.Column("id", sa.Text(), primary_key=True), - sa.Column("prompt_id", sa.Text(), nullable=False, index=True), - sa.Column("parent_id", sa.Text(), nullable=True), - sa.Column("snapshot", sa.JSON(), nullable=False), - sa.Column("user_id", sa.Text(), nullable=False), - sa.Column("commit_message", sa.Text(), nullable=True), - sa.Column("created_at", sa.BigInteger(), nullable=False), + 'prompt_history', + sa.Column('id', sa.Text(), primary_key=True), + sa.Column('prompt_id', sa.Text(), nullable=False, index=True), + sa.Column('parent_id', sa.Text(), nullable=True), + sa.Column('snapshot', sa.JSON(), nullable=False), + sa.Column('user_id', sa.Text(), nullable=False), + sa.Column('commit_message', sa.Text(), nullable=True), + sa.Column('created_at', sa.BigInteger(), nullable=False), ) # Step 4: Migrate data prompt_new_table = sa.table( - "prompt_new", - sa.column("id", sa.Text()), - sa.column("command", sa.String()), - sa.column("user_id", sa.String()), - sa.column("name", sa.Text()), - sa.column("content", sa.Text()), - sa.column("data", sa.JSON()), - sa.column("meta", sa.JSON()), - sa.column("access_control", sa.JSON()), - sa.column("is_active", sa.Boolean()), - sa.column("version_id", sa.Text()), - sa.column("tags", sa.JSON()), - sa.column("created_at", sa.BigInteger()), - sa.column("updated_at", sa.BigInteger()), + 'prompt_new', + sa.column('id', sa.Text()), + sa.column('command', sa.String()), + sa.column('user_id', sa.String()), + sa.column('name', sa.Text()), + sa.column('content', sa.Text()), + sa.column('data', sa.JSON()), + sa.column('meta', sa.JSON()), + sa.column('access_control', sa.JSON()), + sa.column('is_active', sa.Boolean()), + sa.column('version_id', sa.Text()), + sa.column('tags', sa.JSON()), + sa.column('created_at', sa.BigInteger()), + sa.column('updated_at', sa.BigInteger()), ) prompt_history_table = sa.table( - "prompt_history", - sa.column("id", sa.Text()), - sa.column("prompt_id", sa.Text()), - sa.column("parent_id", sa.Text()), - sa.column("snapshot", sa.JSON()), - sa.column("user_id", sa.Text()), - sa.column("commit_message", sa.Text()), - sa.column("created_at", sa.BigInteger()), + 'prompt_history', + sa.column('id', sa.Text()), + sa.column('prompt_id', sa.Text()), + sa.column('parent_id', sa.Text()), + sa.column('snapshot', sa.JSON()), + sa.column('user_id', sa.Text()), + sa.column('commit_message', sa.Text()), + sa.column('created_at', sa.BigInteger()), ) for row in existing_prompts: @@ -120,7 +120,7 @@ def upgrade() -> None: new_uuid = str(uuid.uuid4()) history_uuid = str(uuid.uuid4()) - clean_command = command[1:] if command and command.startswith("/") else command + clean_command = command[1:] if command and command.startswith('/') else command # Insert into prompt_new conn.execute( @@ -148,12 +148,12 @@ def upgrade() -> None: prompt_id=new_uuid, parent_id=None, snapshot={ - "name": title, - "content": content, - "command": clean_command, - "data": {}, - "meta": {}, - "access_control": access_control, + 'name': title, + 'content': content, + 'command': clean_command, + 'data': {}, + 'meta': {}, + 'access_control': access_control, }, user_id=user_id, commit_message=None, @@ -162,8 +162,8 @@ def upgrade() -> None: ) # Step 5: Replace old table with new one - op.drop_table("prompt") - op.rename_table("prompt_new", "prompt") + op.drop_table('prompt') + op.rename_table('prompt_new', 'prompt') def downgrade() -> None: @@ -171,13 +171,13 @@ def downgrade() -> None: # Step 1: Read new data prompt_table = sa.table( - "prompt", - sa.column("command", sa.String()), - sa.column("name", sa.Text()), - sa.column("created_at", sa.BigInteger()), - sa.column("user_id", sa.Text()), - sa.column("content", sa.Text()), - sa.column("access_control", sa.JSON()), + 'prompt', + sa.column('command', sa.String()), + sa.column('name', sa.Text()), + sa.column('created_at', sa.BigInteger()), + sa.column('user_id', sa.Text()), + sa.column('content', sa.Text()), + sa.column('access_control', sa.JSON()), ) try: @@ -195,31 +195,31 @@ def downgrade() -> None: current_data = [] # Step 2: Drop history and table - op.drop_table("prompt_history") - op.drop_table("prompt") + op.drop_table('prompt_history') + op.drop_table('prompt') # Step 3: Recreate old table (command as PK?) # Assuming old schema: op.create_table( - "prompt", - sa.Column("command", sa.String(), primary_key=True), - sa.Column("user_id", sa.String()), - sa.Column("title", sa.Text()), - sa.Column("content", sa.Text()), - sa.Column("timestamp", sa.BigInteger()), - sa.Column("access_control", sa.JSON()), - sa.Column("id", sa.Integer(), nullable=True), + 'prompt', + sa.Column('command', sa.String(), primary_key=True), + sa.Column('user_id', sa.String()), + sa.Column('title', sa.Text()), + sa.Column('content', sa.Text()), + sa.Column('timestamp', sa.BigInteger()), + sa.Column('access_control', sa.JSON()), + sa.Column('id', sa.Integer(), nullable=True), ) # Step 4: Restore data old_prompt_table = sa.table( - "prompt", - sa.column("command", sa.String()), - sa.column("user_id", sa.String()), - sa.column("title", sa.Text()), - sa.column("content", sa.Text()), - sa.column("timestamp", sa.BigInteger()), - sa.column("access_control", sa.JSON()), + 'prompt', + sa.column('command', sa.String()), + sa.column('user_id', sa.String()), + sa.column('title', sa.Text()), + sa.column('content', sa.Text()), + sa.column('timestamp', sa.BigInteger()), + sa.column('access_control', sa.JSON()), ) for row in current_data: @@ -231,9 +231,7 @@ def downgrade() -> None: access_control = row[5] # Restore leading / - old_command = ( - "/" + command if command and not command.startswith("/") else command - ) + old_command = '/' + command if command and not command.startswith('/') else command conn.execute( sa.insert(old_prompt_table).values( diff --git a/backend/open_webui/migrations/versions/3781e22d8b01_update_message_table.py b/backend/open_webui/migrations/versions/3781e22d8b01_update_message_table.py index 16fb0e85eb..170137f23c 100644 --- a/backend/open_webui/migrations/versions/3781e22d8b01_update_message_table.py +++ b/backend/open_webui/migrations/versions/3781e22d8b01_update_message_table.py @@ -9,8 +9,8 @@ Create Date: 2024-12-30 03:00:00.000000 from alembic import op import sqlalchemy as sa -revision = "3781e22d8b01" -down_revision = "7826ab40b532" +revision = '3781e22d8b01' +down_revision = '7826ab40b532' branch_labels = None depends_on = None @@ -18,9 +18,9 @@ depends_on = None def upgrade(): # Add 'type' column to the 'channel' table op.add_column( - "channel", + 'channel', sa.Column( - "type", + 'type', sa.Text(), nullable=True, ), @@ -28,43 +28,31 @@ def upgrade(): # Add 'parent_id' column to the 'message' table for threads op.add_column( - "message", - sa.Column("parent_id", sa.Text(), nullable=True), + 'message', + sa.Column('parent_id', sa.Text(), nullable=True), ) op.create_table( - "message_reaction", - sa.Column( - "id", sa.Text(), nullable=False, primary_key=True, unique=True - ), # Unique reaction ID - sa.Column("user_id", sa.Text(), nullable=False), # User who reacted - sa.Column( - "message_id", sa.Text(), nullable=False - ), # Message that was reacted to - sa.Column( - "name", sa.Text(), nullable=False - ), # Reaction name (e.g. "thumbs_up") - sa.Column( - "created_at", sa.BigInteger(), nullable=True - ), # Timestamp of when the reaction was added + 'message_reaction', + sa.Column('id', sa.Text(), nullable=False, primary_key=True, unique=True), # Unique reaction ID + sa.Column('user_id', sa.Text(), nullable=False), # User who reacted + sa.Column('message_id', sa.Text(), nullable=False), # Message that was reacted to + sa.Column('name', sa.Text(), nullable=False), # Reaction name (e.g. "thumbs_up") + sa.Column('created_at', sa.BigInteger(), nullable=True), # Timestamp of when the reaction was added ) op.create_table( - "channel_member", - sa.Column( - "id", sa.Text(), nullable=False, primary_key=True, unique=True - ), # Record ID for the membership row - sa.Column("channel_id", sa.Text(), nullable=False), # Associated channel - sa.Column("user_id", sa.Text(), nullable=False), # Associated user - sa.Column( - "created_at", sa.BigInteger(), nullable=True - ), # Timestamp of when the user joined the channel + 'channel_member', + sa.Column('id', sa.Text(), nullable=False, primary_key=True, unique=True), # Record ID for the membership row + sa.Column('channel_id', sa.Text(), nullable=False), # Associated channel + sa.Column('user_id', sa.Text(), nullable=False), # Associated user + sa.Column('created_at', sa.BigInteger(), nullable=True), # Timestamp of when the user joined the channel ) def downgrade(): # Revert 'type' column addition to the 'channel' table - op.drop_column("channel", "type") - op.drop_column("message", "parent_id") - op.drop_table("message_reaction") - op.drop_table("channel_member") + op.drop_column('channel', 'type') + op.drop_column('message', 'parent_id') + op.drop_table('message_reaction') + op.drop_table('channel_member') diff --git a/backend/open_webui/migrations/versions/37f288994c47_add_group_member_table.py b/backend/open_webui/migrations/versions/37f288994c47_add_group_member_table.py index 229bb8cffb..4bf24d3b46 100644 --- a/backend/open_webui/migrations/versions/37f288994c47_add_group_member_table.py +++ b/backend/open_webui/migrations/versions/37f288994c47_add_group_member_table.py @@ -15,8 +15,8 @@ from alembic import op import sqlalchemy as sa # revision identifiers, used by Alembic. -revision: str = "37f288994c47" -down_revision: Union[str, None] = "a5c220713937" +revision: str = '37f288994c47' +down_revision: Union[str, None] = 'a5c220713937' branch_labels: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None @@ -24,50 +24,48 @@ depends_on: Union[str, Sequence[str], None] = None def upgrade() -> None: # 1. Create new table op.create_table( - "group_member", - sa.Column("id", sa.Text(), primary_key=True, unique=True, nullable=False), + 'group_member', + sa.Column('id', sa.Text(), primary_key=True, unique=True, nullable=False), sa.Column( - "group_id", + 'group_id', sa.Text(), - sa.ForeignKey("group.id", ondelete="CASCADE"), + sa.ForeignKey('group.id', ondelete='CASCADE'), nullable=False, ), sa.Column( - "user_id", + 'user_id', sa.Text(), - sa.ForeignKey("user.id", ondelete="CASCADE"), + sa.ForeignKey('user.id', ondelete='CASCADE'), nullable=False, ), - sa.Column("created_at", sa.BigInteger(), nullable=True), - sa.Column("updated_at", sa.BigInteger(), nullable=True), - sa.UniqueConstraint("group_id", "user_id", name="uq_group_member_group_user"), + sa.Column('created_at', sa.BigInteger(), nullable=True), + sa.Column('updated_at', sa.BigInteger(), nullable=True), + sa.UniqueConstraint('group_id', 'user_id', name='uq_group_member_group_user'), ) connection = op.get_bind() # 2. Read existing group with user_ids JSON column group_table = sa.Table( - "group", + 'group', sa.MetaData(), - sa.Column("id", sa.Text()), - sa.Column("user_ids", sa.JSON()), # JSON stored as text in SQLite + PG + sa.Column('id', sa.Text()), + sa.Column('user_ids', sa.JSON()), # JSON stored as text in SQLite + PG ) - results = connection.execute( - sa.select(group_table.c.id, group_table.c.user_ids) - ).fetchall() + results = connection.execute(sa.select(group_table.c.id, group_table.c.user_ids)).fetchall() print(results) # 3. Insert members into group_member table gm_table = sa.Table( - "group_member", + 'group_member', sa.MetaData(), - sa.Column("id", sa.Text()), - sa.Column("group_id", sa.Text()), - sa.Column("user_id", sa.Text()), - sa.Column("created_at", sa.BigInteger()), - sa.Column("updated_at", sa.BigInteger()), + sa.Column('id', sa.Text()), + sa.Column('group_id', sa.Text()), + sa.Column('user_id', sa.Text()), + sa.Column('created_at', sa.BigInteger()), + sa.Column('updated_at', sa.BigInteger()), ) now = int(time.time()) @@ -86,11 +84,11 @@ def upgrade() -> None: rows = [ { - "id": str(uuid.uuid4()), - "group_id": group_id, - "user_id": uid, - "created_at": now, - "updated_at": now, + 'id': str(uuid.uuid4()), + 'group_id': group_id, + 'user_id': uid, + 'created_at': now, + 'updated_at': now, } for uid in user_ids ] @@ -99,47 +97,41 @@ def upgrade() -> None: connection.execute(gm_table.insert(), rows) # 4. Optionally drop the old column - with op.batch_alter_table("group") as batch: - batch.drop_column("user_ids") + with op.batch_alter_table('group') as batch: + batch.drop_column('user_ids') def downgrade(): # Reverse: restore user_ids column - with op.batch_alter_table("group") as batch: - batch.add_column(sa.Column("user_ids", sa.JSON())) + with op.batch_alter_table('group') as batch: + batch.add_column(sa.Column('user_ids', sa.JSON())) connection = op.get_bind() gm_table = sa.Table( - "group_member", + 'group_member', sa.MetaData(), - sa.Column("group_id", sa.Text()), - sa.Column("user_id", sa.Text()), - sa.Column("created_at", sa.BigInteger()), - sa.Column("updated_at", sa.BigInteger()), + sa.Column('group_id', sa.Text()), + sa.Column('user_id', sa.Text()), + sa.Column('created_at', sa.BigInteger()), + sa.Column('updated_at', sa.BigInteger()), ) group_table = sa.Table( - "group", + 'group', sa.MetaData(), - sa.Column("id", sa.Text()), - sa.Column("user_ids", sa.JSON()), + sa.Column('id', sa.Text()), + sa.Column('user_ids', sa.JSON()), ) # Build JSON arrays again results = connection.execute(sa.select(group_table.c.id)).fetchall() for (group_id,) in results: - members = connection.execute( - sa.select(gm_table.c.user_id).where(gm_table.c.group_id == group_id) - ).fetchall() + members = connection.execute(sa.select(gm_table.c.user_id).where(gm_table.c.group_id == group_id)).fetchall() member_ids = [m[0] for m in members] - connection.execute( - group_table.update() - .where(group_table.c.id == group_id) - .values(user_ids=member_ids) - ) + connection.execute(group_table.update().where(group_table.c.id == group_id).values(user_ids=member_ids)) # Drop the new table - op.drop_table("group_member") + op.drop_table('group_member') diff --git a/backend/open_webui/migrations/versions/38d63c18f30f_add_oauth_session_table.py b/backend/open_webui/migrations/versions/38d63c18f30f_add_oauth_session_table.py index af8340a3cb..d415f500f3 100644 --- a/backend/open_webui/migrations/versions/38d63c18f30f_add_oauth_session_table.py +++ b/backend/open_webui/migrations/versions/38d63c18f30f_add_oauth_session_table.py @@ -12,8 +12,8 @@ from alembic import op import sqlalchemy as sa # revision identifiers, used by Alembic. -revision: str = "38d63c18f30f" -down_revision: Union[str, None] = "3af16a1c9fb6" +revision: str = '38d63c18f30f' +down_revision: Union[str, None] = '3af16a1c9fb6' branch_labels: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None @@ -21,59 +21,55 @@ depends_on: Union[str, Sequence[str], None] = None def upgrade() -> None: # Ensure 'id' column in 'user' table is unique and primary key (ForeignKey constraint) inspector = sa.inspect(op.get_bind()) - columns = inspector.get_columns("user") + columns = inspector.get_columns('user') - pk_columns = inspector.get_pk_constraint("user")["constrained_columns"] - id_column = next((col for col in columns if col["name"] == "id"), None) + pk_columns = inspector.get_pk_constraint('user')['constrained_columns'] + id_column = next((col for col in columns if col['name'] == 'id'), None) - if id_column and not id_column.get("unique", False): - unique_constraints = inspector.get_unique_constraints("user") - unique_columns = {tuple(u["column_names"]) for u in unique_constraints} + if id_column and not id_column.get('unique', False): + unique_constraints = inspector.get_unique_constraints('user') + unique_columns = {tuple(u['column_names']) for u in unique_constraints} - with op.batch_alter_table("user") as batch_op: + with op.batch_alter_table('user') as batch_op: # If primary key is wrong, drop it - if pk_columns and pk_columns != ["id"]: - batch_op.drop_constraint( - inspector.get_pk_constraint("user")["name"], type_="primary" - ) + if pk_columns and pk_columns != ['id']: + batch_op.drop_constraint(inspector.get_pk_constraint('user')['name'], type_='primary') # Add unique constraint if missing - if ("id",) not in unique_columns: - batch_op.create_unique_constraint("uq_user_id", ["id"]) + if ('id',) not in unique_columns: + batch_op.create_unique_constraint('uq_user_id', ['id']) # Re-create correct primary key - batch_op.create_primary_key("pk_user_id", ["id"]) + batch_op.create_primary_key('pk_user_id', ['id']) # Create oauth_session table op.create_table( - "oauth_session", - sa.Column("id", sa.Text(), primary_key=True, nullable=False, unique=True), + 'oauth_session', + sa.Column('id', sa.Text(), primary_key=True, nullable=False, unique=True), sa.Column( - "user_id", + 'user_id', sa.Text(), - sa.ForeignKey("user.id", ondelete="CASCADE"), + sa.ForeignKey('user.id', ondelete='CASCADE'), nullable=False, ), - sa.Column("provider", sa.Text(), nullable=False), - sa.Column("token", sa.Text(), nullable=False), - sa.Column("expires_at", sa.BigInteger(), nullable=False), - sa.Column("created_at", sa.BigInteger(), nullable=False), - sa.Column("updated_at", sa.BigInteger(), nullable=False), + sa.Column('provider', sa.Text(), nullable=False), + sa.Column('token', sa.Text(), nullable=False), + sa.Column('expires_at', sa.BigInteger(), nullable=False), + sa.Column('created_at', sa.BigInteger(), nullable=False), + sa.Column('updated_at', sa.BigInteger(), nullable=False), ) # Create indexes for better performance - op.create_index("idx_oauth_session_user_id", "oauth_session", ["user_id"]) - op.create_index("idx_oauth_session_expires_at", "oauth_session", ["expires_at"]) - op.create_index( - "idx_oauth_session_user_provider", "oauth_session", ["user_id", "provider"] - ) + op.create_index('idx_oauth_session_user_id', 'oauth_session', ['user_id']) + op.create_index('idx_oauth_session_expires_at', 'oauth_session', ['expires_at']) + op.create_index('idx_oauth_session_user_provider', 'oauth_session', ['user_id', 'provider']) def downgrade() -> None: # Drop indexes first - op.drop_index("idx_oauth_session_user_provider", table_name="oauth_session") - op.drop_index("idx_oauth_session_expires_at", table_name="oauth_session") - op.drop_index("idx_oauth_session_user_id", table_name="oauth_session") + op.drop_index('idx_oauth_session_user_provider', table_name='oauth_session') + op.drop_index('idx_oauth_session_expires_at', table_name='oauth_session') + op.drop_index('idx_oauth_session_user_id', table_name='oauth_session') # Drop the table - op.drop_table("oauth_session") + op.drop_table('oauth_session') diff --git a/backend/open_webui/migrations/versions/3ab32c4b8f59_update_tags.py b/backend/open_webui/migrations/versions/3ab32c4b8f59_update_tags.py index 6e010424b0..31bd355ede 100644 --- a/backend/open_webui/migrations/versions/3ab32c4b8f59_update_tags.py +++ b/backend/open_webui/migrations/versions/3ab32c4b8f59_update_tags.py @@ -13,8 +13,8 @@ from sqlalchemy.engine.reflection import Inspector import json -revision = "3ab32c4b8f59" -down_revision = "1af9b942657b" +revision = '3ab32c4b8f59' +down_revision = '1af9b942657b' branch_labels = None depends_on = None @@ -24,58 +24,55 @@ def upgrade(): inspector = Inspector.from_engine(conn) # Inspecting the 'tag' table constraints and structure - existing_pk = inspector.get_pk_constraint("tag") - unique_constraints = inspector.get_unique_constraints("tag") - existing_indexes = inspector.get_indexes("tag") + existing_pk = inspector.get_pk_constraint('tag') + unique_constraints = inspector.get_unique_constraints('tag') + existing_indexes = inspector.get_indexes('tag') - print(f"Primary Key: {existing_pk}") - print(f"Unique Constraints: {unique_constraints}") - print(f"Indexes: {existing_indexes}") + print(f'Primary Key: {existing_pk}') + print(f'Unique Constraints: {unique_constraints}') + print(f'Indexes: {existing_indexes}') - with op.batch_alter_table("tag", schema=None) as batch_op: + with op.batch_alter_table('tag', schema=None) as batch_op: # Drop existing primary key constraint if it exists - if existing_pk and existing_pk.get("constrained_columns"): - pk_name = existing_pk.get("name") + if existing_pk and existing_pk.get('constrained_columns'): + pk_name = existing_pk.get('name') if pk_name: - print(f"Dropping primary key constraint: {pk_name}") - batch_op.drop_constraint(pk_name, type_="primary") + print(f'Dropping primary key constraint: {pk_name}') + batch_op.drop_constraint(pk_name, type_='primary') # Now create the new primary key with the combination of 'id' and 'user_id' print("Creating new primary key with 'id' and 'user_id'.") - batch_op.create_primary_key("pk_id_user_id", ["id", "user_id"]) + batch_op.create_primary_key('pk_id_user_id', ['id', 'user_id']) # Drop unique constraints that could conflict with the new primary key for constraint in unique_constraints: if ( - constraint["name"] == "uq_id_user_id" + constraint['name'] == 'uq_id_user_id' ): # Adjust this name according to what is actually returned by the inspector - print(f"Dropping unique constraint: {constraint['name']}") - batch_op.drop_constraint(constraint["name"], type_="unique") + print(f'Dropping unique constraint: {constraint["name"]}') + batch_op.drop_constraint(constraint['name'], type_='unique') for index in existing_indexes: - if index["unique"]: - if not any( - constraint["name"] == index["name"] - for constraint in unique_constraints - ): + if index['unique']: + if not any(constraint['name'] == index['name'] for constraint in unique_constraints): # You are attempting to drop unique indexes - print(f"Dropping unique index: {index['name']}") - batch_op.drop_index(index["name"]) + print(f'Dropping unique index: {index["name"]}') + batch_op.drop_index(index['name']) def downgrade(): conn = op.get_bind() inspector = Inspector.from_engine(conn) - current_pk = inspector.get_pk_constraint("tag") + current_pk = inspector.get_pk_constraint('tag') - with op.batch_alter_table("tag", schema=None) as batch_op: + with op.batch_alter_table('tag', schema=None) as batch_op: # Drop the current primary key first, if it matches the one we know we added in upgrade - if current_pk and "pk_id_user_id" == current_pk.get("name"): - batch_op.drop_constraint("pk_id_user_id", type_="primary") + if current_pk and 'pk_id_user_id' == current_pk.get('name'): + batch_op.drop_constraint('pk_id_user_id', type_='primary') # Restore the original primary key - batch_op.create_primary_key("pk_id", ["id"]) + batch_op.create_primary_key('pk_id', ['id']) # Since primary key on just 'id' is restored, we now add back any unique constraints if necessary - batch_op.create_unique_constraint("uq_id_user_id", ["id", "user_id"]) + batch_op.create_unique_constraint('uq_id_user_id', ['id', 'user_id']) diff --git a/backend/open_webui/migrations/versions/3af16a1c9fb6_update_user_table.py b/backend/open_webui/migrations/versions/3af16a1c9fb6_update_user_table.py index ab980f27ce..629c1c8c24 100644 --- a/backend/open_webui/migrations/versions/3af16a1c9fb6_update_user_table.py +++ b/backend/open_webui/migrations/versions/3af16a1c9fb6_update_user_table.py @@ -12,21 +12,21 @@ from alembic import op import sqlalchemy as sa # revision identifiers, used by Alembic. -revision: str = "3af16a1c9fb6" -down_revision: Union[str, None] = "018012973d35" +revision: str = '3af16a1c9fb6' +down_revision: Union[str, None] = '018012973d35' branch_labels: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None def upgrade() -> None: - op.add_column("user", sa.Column("username", sa.String(length=50), nullable=True)) - op.add_column("user", sa.Column("bio", sa.Text(), nullable=True)) - op.add_column("user", sa.Column("gender", sa.Text(), nullable=True)) - op.add_column("user", sa.Column("date_of_birth", sa.Date(), nullable=True)) + op.add_column('user', sa.Column('username', sa.String(length=50), nullable=True)) + op.add_column('user', sa.Column('bio', sa.Text(), nullable=True)) + op.add_column('user', sa.Column('gender', sa.Text(), nullable=True)) + op.add_column('user', sa.Column('date_of_birth', sa.Date(), nullable=True)) def downgrade() -> None: - op.drop_column("user", "username") - op.drop_column("user", "bio") - op.drop_column("user", "gender") - op.drop_column("user", "date_of_birth") + op.drop_column('user', 'username') + op.drop_column('user', 'bio') + op.drop_column('user', 'gender') + op.drop_column('user', 'date_of_birth') diff --git a/backend/open_webui/migrations/versions/3e0e00844bb0_add_knowledge_file_table.py b/backend/open_webui/migrations/versions/3e0e00844bb0_add_knowledge_file_table.py index 82249bb278..f772987a44 100644 --- a/backend/open_webui/migrations/versions/3e0e00844bb0_add_knowledge_file_table.py +++ b/backend/open_webui/migrations/versions/3e0e00844bb0_add_knowledge_file_table.py @@ -18,38 +18,38 @@ import json import uuid # revision identifiers, used by Alembic. -revision: str = "3e0e00844bb0" -down_revision: Union[str, None] = "90ef40d4714e" +revision: str = '3e0e00844bb0' +down_revision: Union[str, None] = '90ef40d4714e' branch_labels: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None def upgrade() -> None: op.create_table( - "knowledge_file", - sa.Column("id", sa.Text(), primary_key=True), - sa.Column("user_id", sa.Text(), nullable=False), + 'knowledge_file', + sa.Column('id', sa.Text(), primary_key=True), + sa.Column('user_id', sa.Text(), nullable=False), sa.Column( - "knowledge_id", + 'knowledge_id', sa.Text(), - sa.ForeignKey("knowledge.id", ondelete="CASCADE"), + sa.ForeignKey('knowledge.id', ondelete='CASCADE'), nullable=False, ), sa.Column( - "file_id", + 'file_id', sa.Text(), - sa.ForeignKey("file.id", ondelete="CASCADE"), + sa.ForeignKey('file.id', ondelete='CASCADE'), nullable=False, ), - sa.Column("created_at", sa.BigInteger(), nullable=False), - sa.Column("updated_at", sa.BigInteger(), nullable=False), + sa.Column('created_at', sa.BigInteger(), nullable=False), + sa.Column('updated_at', sa.BigInteger(), nullable=False), # indexes - sa.Index("ix_knowledge_file_knowledge_id", "knowledge_id"), - sa.Index("ix_knowledge_file_file_id", "file_id"), - sa.Index("ix_knowledge_file_user_id", "user_id"), + sa.Index('ix_knowledge_file_knowledge_id', 'knowledge_id'), + sa.Index('ix_knowledge_file_file_id', 'file_id'), + sa.Index('ix_knowledge_file_user_id', 'user_id'), # unique constraints sa.UniqueConstraint( - "knowledge_id", "file_id", name="uq_knowledge_file_knowledge_file" + 'knowledge_id', 'file_id', name='uq_knowledge_file_knowledge_file' ), # prevent duplicate entries ) @@ -57,35 +57,33 @@ def upgrade() -> None: # 2. Read existing group with user_ids JSON column knowledge_table = sa.Table( - "knowledge", + 'knowledge', sa.MetaData(), - sa.Column("id", sa.Text()), - sa.Column("user_id", sa.Text()), - sa.Column("data", sa.JSON()), # JSON stored as text in SQLite + PG + sa.Column('id', sa.Text()), + sa.Column('user_id', sa.Text()), + sa.Column('data', sa.JSON()), # JSON stored as text in SQLite + PG ) results = connection.execute( - sa.select( - knowledge_table.c.id, knowledge_table.c.user_id, knowledge_table.c.data - ) + sa.select(knowledge_table.c.id, knowledge_table.c.user_id, knowledge_table.c.data) ).fetchall() # 3. Insert members into group_member table kf_table = sa.Table( - "knowledge_file", + 'knowledge_file', sa.MetaData(), - sa.Column("id", sa.Text()), - sa.Column("user_id", sa.Text()), - sa.Column("knowledge_id", sa.Text()), - sa.Column("file_id", sa.Text()), - sa.Column("created_at", sa.BigInteger()), - sa.Column("updated_at", sa.BigInteger()), + sa.Column('id', sa.Text()), + sa.Column('user_id', sa.Text()), + sa.Column('knowledge_id', sa.Text()), + sa.Column('file_id', sa.Text()), + sa.Column('created_at', sa.BigInteger()), + sa.Column('updated_at', sa.BigInteger()), ) file_table = sa.Table( - "file", + 'file', sa.MetaData(), - sa.Column("id", sa.Text()), + sa.Column('id', sa.Text()), ) now = int(time.time()) @@ -102,50 +100,48 @@ def upgrade() -> None: if not isinstance(data, dict): continue - file_ids = data.get("file_ids", []) + file_ids = data.get('file_ids', []) for file_id in file_ids: - file_exists = connection.execute( - sa.select(file_table.c.id).where(file_table.c.id == file_id) - ).fetchone() + file_exists = connection.execute(sa.select(file_table.c.id).where(file_table.c.id == file_id)).fetchone() if not file_exists: continue # skip non-existing files row = { - "id": str(uuid.uuid4()), - "user_id": user_id, - "knowledge_id": knowledge_id, - "file_id": file_id, - "created_at": now, - "updated_at": now, + 'id': str(uuid.uuid4()), + 'user_id': user_id, + 'knowledge_id': knowledge_id, + 'file_id': file_id, + 'created_at': now, + 'updated_at': now, } connection.execute(kf_table.insert().values(**row)) - with op.batch_alter_table("knowledge") as batch: - batch.drop_column("data") + with op.batch_alter_table('knowledge') as batch: + batch.drop_column('data') def downgrade() -> None: # 1. Add back the old data column - op.add_column("knowledge", sa.Column("data", sa.JSON(), nullable=True)) + op.add_column('knowledge', sa.Column('data', sa.JSON(), nullable=True)) connection = op.get_bind() # 2. Read knowledge_file entries and reconstruct data JSON knowledge_table = sa.Table( - "knowledge", + 'knowledge', sa.MetaData(), - sa.Column("id", sa.Text()), - sa.Column("data", sa.JSON()), + sa.Column('id', sa.Text()), + sa.Column('data', sa.JSON()), ) kf_table = sa.Table( - "knowledge_file", + 'knowledge_file', sa.MetaData(), - sa.Column("id", sa.Text()), - sa.Column("knowledge_id", sa.Text()), - sa.Column("file_id", sa.Text()), + sa.Column('id', sa.Text()), + sa.Column('knowledge_id', sa.Text()), + sa.Column('file_id', sa.Text()), ) results = connection.execute(sa.select(knowledge_table.c.id)).fetchall() @@ -157,13 +153,9 @@ def downgrade() -> None: file_ids_list = [fid for (fid,) in file_ids] - data_json = {"file_ids": file_ids_list} + data_json = {'file_ids': file_ids_list} - connection.execute( - knowledge_table.update() - .where(knowledge_table.c.id == knowledge_id) - .values(data=data_json) - ) + connection.execute(knowledge_table.update().where(knowledge_table.c.id == knowledge_id).values(data=data_json)) # 3. Drop the knowledge_file table - op.drop_table("knowledge_file") + op.drop_table('knowledge_file') diff --git a/backend/open_webui/migrations/versions/4ace53fd72c8_update_folder_table_datetime.py b/backend/open_webui/migrations/versions/4ace53fd72c8_update_folder_table_datetime.py index 16f7967c8e..91e0dce0be 100644 --- a/backend/open_webui/migrations/versions/4ace53fd72c8_update_folder_table_datetime.py +++ b/backend/open_webui/migrations/versions/4ace53fd72c8_update_folder_table_datetime.py @@ -9,56 +9,56 @@ Create Date: 2024-10-23 03:00:00.000000 from alembic import op import sqlalchemy as sa -revision = "4ace53fd72c8" -down_revision = "af906e964978" +revision = '4ace53fd72c8' +down_revision = 'af906e964978' branch_labels = None depends_on = None def upgrade(): # Perform safe alterations using batch operation - with op.batch_alter_table("folder", schema=None) as batch_op: + with op.batch_alter_table('folder', schema=None) as batch_op: # Step 1: Remove server defaults for created_at and updated_at batch_op.alter_column( - "created_at", + 'created_at', server_default=None, # Removing server default ) batch_op.alter_column( - "updated_at", + 'updated_at', server_default=None, # Removing server default ) # Step 2: Change the column types to BigInteger for created_at batch_op.alter_column( - "created_at", + 'created_at', type_=sa.BigInteger(), existing_type=sa.DateTime(), existing_nullable=False, - postgresql_using="extract(epoch from created_at)::bigint", # Conversion for PostgreSQL + postgresql_using='extract(epoch from created_at)::bigint', # Conversion for PostgreSQL ) # Change the column types to BigInteger for updated_at batch_op.alter_column( - "updated_at", + 'updated_at', type_=sa.BigInteger(), existing_type=sa.DateTime(), existing_nullable=False, - postgresql_using="extract(epoch from updated_at)::bigint", # Conversion for PostgreSQL + postgresql_using='extract(epoch from updated_at)::bigint', # Conversion for PostgreSQL ) def downgrade(): # Downgrade: Convert columns back to DateTime and restore defaults - with op.batch_alter_table("folder", schema=None) as batch_op: + with op.batch_alter_table('folder', schema=None) as batch_op: batch_op.alter_column( - "created_at", + 'created_at', type_=sa.DateTime(), existing_type=sa.BigInteger(), existing_nullable=False, server_default=sa.func.now(), # Restoring server default on downgrade ) batch_op.alter_column( - "updated_at", + 'updated_at', type_=sa.DateTime(), existing_type=sa.BigInteger(), existing_nullable=False, diff --git a/backend/open_webui/migrations/versions/57c599a3cb57_add_channel_table.py b/backend/open_webui/migrations/versions/57c599a3cb57_add_channel_table.py index 54176dc46e..79f0e8827e 100644 --- a/backend/open_webui/migrations/versions/57c599a3cb57_add_channel_table.py +++ b/backend/open_webui/migrations/versions/57c599a3cb57_add_channel_table.py @@ -9,40 +9,40 @@ Create Date: 2024-12-22 03:00:00.000000 from alembic import op import sqlalchemy as sa -revision = "57c599a3cb57" -down_revision = "922e7a387820" +revision = '57c599a3cb57' +down_revision = '922e7a387820' branch_labels = None depends_on = None def upgrade(): op.create_table( - "channel", - sa.Column("id", sa.Text(), nullable=False, primary_key=True, unique=True), - sa.Column("user_id", sa.Text()), - sa.Column("name", sa.Text()), - sa.Column("description", sa.Text(), nullable=True), - sa.Column("data", sa.JSON(), nullable=True), - sa.Column("meta", sa.JSON(), nullable=True), - sa.Column("access_control", sa.JSON(), nullable=True), - sa.Column("created_at", sa.BigInteger(), nullable=True), - sa.Column("updated_at", sa.BigInteger(), nullable=True), + 'channel', + sa.Column('id', sa.Text(), nullable=False, primary_key=True, unique=True), + sa.Column('user_id', sa.Text()), + sa.Column('name', sa.Text()), + sa.Column('description', sa.Text(), nullable=True), + sa.Column('data', sa.JSON(), nullable=True), + sa.Column('meta', sa.JSON(), nullable=True), + sa.Column('access_control', sa.JSON(), nullable=True), + sa.Column('created_at', sa.BigInteger(), nullable=True), + sa.Column('updated_at', sa.BigInteger(), nullable=True), ) op.create_table( - "message", - sa.Column("id", sa.Text(), nullable=False, primary_key=True, unique=True), - sa.Column("user_id", sa.Text()), - sa.Column("channel_id", sa.Text(), nullable=True), - sa.Column("content", sa.Text()), - sa.Column("data", sa.JSON(), nullable=True), - sa.Column("meta", sa.JSON(), nullable=True), - sa.Column("created_at", sa.BigInteger(), nullable=True), - sa.Column("updated_at", sa.BigInteger(), nullable=True), + 'message', + sa.Column('id', sa.Text(), nullable=False, primary_key=True, unique=True), + sa.Column('user_id', sa.Text()), + sa.Column('channel_id', sa.Text(), nullable=True), + sa.Column('content', sa.Text()), + sa.Column('data', sa.JSON(), nullable=True), + sa.Column('meta', sa.JSON(), nullable=True), + sa.Column('created_at', sa.BigInteger(), nullable=True), + sa.Column('updated_at', sa.BigInteger(), nullable=True), ) def downgrade(): - op.drop_table("channel") + op.drop_table('channel') - op.drop_table("message") + op.drop_table('message') diff --git a/backend/open_webui/migrations/versions/6283dc0e4d8d_add_channel_file_table.py b/backend/open_webui/migrations/versions/6283dc0e4d8d_add_channel_file_table.py index f3ef62fd64..2bd2d9fd60 100644 --- a/backend/open_webui/migrations/versions/6283dc0e4d8d_add_channel_file_table.py +++ b/backend/open_webui/migrations/versions/6283dc0e4d8d_add_channel_file_table.py @@ -13,41 +13,39 @@ import sqlalchemy as sa import open_webui.internal.db # revision identifiers, used by Alembic. -revision: str = "6283dc0e4d8d" -down_revision: Union[str, None] = "3e0e00844bb0" +revision: str = '6283dc0e4d8d' +down_revision: Union[str, None] = '3e0e00844bb0' branch_labels: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None def upgrade() -> None: op.create_table( - "channel_file", - sa.Column("id", sa.Text(), primary_key=True), - sa.Column("user_id", sa.Text(), nullable=False), + 'channel_file', + sa.Column('id', sa.Text(), primary_key=True), + sa.Column('user_id', sa.Text(), nullable=False), sa.Column( - "channel_id", + 'channel_id', sa.Text(), - sa.ForeignKey("channel.id", ondelete="CASCADE"), + sa.ForeignKey('channel.id', ondelete='CASCADE'), nullable=False, ), sa.Column( - "file_id", + 'file_id', sa.Text(), - sa.ForeignKey("file.id", ondelete="CASCADE"), + sa.ForeignKey('file.id', ondelete='CASCADE'), nullable=False, ), - sa.Column("created_at", sa.BigInteger(), nullable=False), - sa.Column("updated_at", sa.BigInteger(), nullable=False), + sa.Column('created_at', sa.BigInteger(), nullable=False), + sa.Column('updated_at', sa.BigInteger(), nullable=False), # indexes - sa.Index("ix_channel_file_channel_id", "channel_id"), - sa.Index("ix_channel_file_file_id", "file_id"), - sa.Index("ix_channel_file_user_id", "user_id"), + sa.Index('ix_channel_file_channel_id', 'channel_id'), + sa.Index('ix_channel_file_file_id', 'file_id'), + sa.Index('ix_channel_file_user_id', 'user_id'), # unique constraints - sa.UniqueConstraint( - "channel_id", "file_id", name="uq_channel_file_channel_file" - ), # prevent duplicate entries + sa.UniqueConstraint('channel_id', 'file_id', name='uq_channel_file_channel_file'), # prevent duplicate entries ) def downgrade() -> None: - op.drop_table("channel_file") + op.drop_table('channel_file') diff --git a/backend/open_webui/migrations/versions/6a39f3d8e55c_add_knowledge_table.py b/backend/open_webui/migrations/versions/6a39f3d8e55c_add_knowledge_table.py index d6083d7177..c65ca01415 100644 --- a/backend/open_webui/migrations/versions/6a39f3d8e55c_add_knowledge_table.py +++ b/backend/open_webui/migrations/versions/6a39f3d8e55c_add_knowledge_table.py @@ -11,37 +11,37 @@ import sqlalchemy as sa from sqlalchemy.sql import table, column, select import json -revision = "6a39f3d8e55c" -down_revision = "c0fbf31ca0db" +revision = '6a39f3d8e55c' +down_revision = 'c0fbf31ca0db' branch_labels = None depends_on = None def upgrade(): # Creating the 'knowledge' table - print("Creating knowledge table") + print('Creating knowledge table') knowledge_table = op.create_table( - "knowledge", - sa.Column("id", sa.Text(), primary_key=True), - sa.Column("user_id", sa.Text(), nullable=False), - sa.Column("name", sa.Text(), nullable=False), - sa.Column("description", sa.Text(), nullable=True), - sa.Column("data", sa.JSON(), nullable=True), - sa.Column("meta", sa.JSON(), nullable=True), - sa.Column("created_at", sa.BigInteger(), nullable=False), - sa.Column("updated_at", sa.BigInteger(), nullable=True), + 'knowledge', + sa.Column('id', sa.Text(), primary_key=True), + sa.Column('user_id', sa.Text(), nullable=False), + sa.Column('name', sa.Text(), nullable=False), + sa.Column('description', sa.Text(), nullable=True), + sa.Column('data', sa.JSON(), nullable=True), + sa.Column('meta', sa.JSON(), nullable=True), + sa.Column('created_at', sa.BigInteger(), nullable=False), + sa.Column('updated_at', sa.BigInteger(), nullable=True), ) - print("Migrating data from document table to knowledge table") + print('Migrating data from document table to knowledge table') # Representation of the existing 'document' table document_table = table( - "document", - column("collection_name", sa.String()), - column("user_id", sa.String()), - column("name", sa.String()), - column("title", sa.Text()), - column("content", sa.Text()), - column("timestamp", sa.BigInteger()), + 'document', + column('collection_name', sa.String()), + column('user_id', sa.String()), + column('name', sa.String()), + column('title', sa.Text()), + column('content', sa.Text()), + column('timestamp', sa.BigInteger()), ) # Select all from existing document table @@ -64,9 +64,9 @@ def upgrade(): user_id=doc.user_id, description=doc.name, meta={ - "legacy": True, - "document": True, - "tags": json.loads(doc.content or "{}").get("tags", []), + 'legacy': True, + 'document': True, + 'tags': json.loads(doc.content or '{}').get('tags', []), }, name=doc.title, created_at=doc.timestamp, @@ -76,4 +76,4 @@ def upgrade(): def downgrade(): - op.drop_table("knowledge") + op.drop_table('knowledge') diff --git a/backend/open_webui/migrations/versions/7826ab40b532_update_file_table.py b/backend/open_webui/migrations/versions/7826ab40b532_update_file_table.py index c8afe9d51a..4211c6642e 100644 --- a/backend/open_webui/migrations/versions/7826ab40b532_update_file_table.py +++ b/backend/open_webui/migrations/versions/7826ab40b532_update_file_table.py @@ -9,18 +9,18 @@ Create Date: 2024-12-23 03:00:00.000000 from alembic import op import sqlalchemy as sa -revision = "7826ab40b532" -down_revision = "57c599a3cb57" +revision = '7826ab40b532' +down_revision = '57c599a3cb57' branch_labels = None depends_on = None def upgrade(): op.add_column( - "file", - sa.Column("access_control", sa.JSON(), nullable=True), + 'file', + sa.Column('access_control', sa.JSON(), nullable=True), ) def downgrade(): - op.drop_column("file", "access_control") + op.drop_column('file', 'access_control') diff --git a/backend/open_webui/migrations/versions/7e5b5dc7342b_init.py b/backend/open_webui/migrations/versions/7e5b5dc7342b_init.py index 9e56282ef0..39f488d72e 100644 --- a/backend/open_webui/migrations/versions/7e5b5dc7342b_init.py +++ b/backend/open_webui/migrations/versions/7e5b5dc7342b_init.py @@ -16,7 +16,7 @@ from open_webui.internal.db import JSONField from open_webui.migrations.util import get_existing_tables # revision identifiers, used by Alembic. -revision: str = "7e5b5dc7342b" +revision: str = '7e5b5dc7342b' down_revision: Union[str, None] = None branch_labels: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None @@ -26,179 +26,179 @@ def upgrade() -> None: existing_tables = set(get_existing_tables()) # ### commands auto generated by Alembic - please adjust! ### - if "auth" not in existing_tables: + if 'auth' not in existing_tables: op.create_table( - "auth", - sa.Column("id", sa.String(), nullable=False), - sa.Column("email", sa.String(), nullable=True), - sa.Column("password", sa.Text(), nullable=True), - sa.Column("active", sa.Boolean(), nullable=True), - sa.PrimaryKeyConstraint("id"), + 'auth', + sa.Column('id', sa.String(), nullable=False), + sa.Column('email', sa.String(), nullable=True), + sa.Column('password', sa.Text(), nullable=True), + sa.Column('active', sa.Boolean(), nullable=True), + sa.PrimaryKeyConstraint('id'), ) - if "chat" not in existing_tables: + if 'chat' not in existing_tables: op.create_table( - "chat", - sa.Column("id", sa.String(), nullable=False), - sa.Column("user_id", sa.String(), nullable=True), - sa.Column("title", sa.Text(), nullable=True), - sa.Column("chat", sa.Text(), nullable=True), - sa.Column("created_at", sa.BigInteger(), nullable=True), - sa.Column("updated_at", sa.BigInteger(), nullable=True), - sa.Column("share_id", sa.Text(), nullable=True), - sa.Column("archived", sa.Boolean(), nullable=True), - sa.PrimaryKeyConstraint("id"), - sa.UniqueConstraint("share_id"), + 'chat', + sa.Column('id', sa.String(), nullable=False), + sa.Column('user_id', sa.String(), nullable=True), + sa.Column('title', sa.Text(), nullable=True), + sa.Column('chat', sa.Text(), nullable=True), + sa.Column('created_at', sa.BigInteger(), nullable=True), + sa.Column('updated_at', sa.BigInteger(), nullable=True), + sa.Column('share_id', sa.Text(), nullable=True), + sa.Column('archived', sa.Boolean(), nullable=True), + sa.PrimaryKeyConstraint('id'), + sa.UniqueConstraint('share_id'), ) - if "chatidtag" not in existing_tables: + if 'chatidtag' not in existing_tables: op.create_table( - "chatidtag", - sa.Column("id", sa.String(), nullable=False), - sa.Column("tag_name", sa.String(), nullable=True), - sa.Column("chat_id", sa.String(), nullable=True), - sa.Column("user_id", sa.String(), nullable=True), - sa.Column("timestamp", sa.BigInteger(), nullable=True), - sa.PrimaryKeyConstraint("id"), + 'chatidtag', + sa.Column('id', sa.String(), nullable=False), + sa.Column('tag_name', sa.String(), nullable=True), + sa.Column('chat_id', sa.String(), nullable=True), + sa.Column('user_id', sa.String(), nullable=True), + sa.Column('timestamp', sa.BigInteger(), nullable=True), + sa.PrimaryKeyConstraint('id'), ) - if "document" not in existing_tables: + if 'document' not in existing_tables: op.create_table( - "document", - sa.Column("collection_name", sa.String(), nullable=False), - sa.Column("name", sa.String(), nullable=True), - sa.Column("title", sa.Text(), nullable=True), - sa.Column("filename", sa.Text(), nullable=True), - sa.Column("content", sa.Text(), nullable=True), - sa.Column("user_id", sa.String(), nullable=True), - sa.Column("timestamp", sa.BigInteger(), nullable=True), - sa.PrimaryKeyConstraint("collection_name"), - sa.UniqueConstraint("name"), + 'document', + sa.Column('collection_name', sa.String(), nullable=False), + sa.Column('name', sa.String(), nullable=True), + sa.Column('title', sa.Text(), nullable=True), + sa.Column('filename', sa.Text(), nullable=True), + sa.Column('content', sa.Text(), nullable=True), + sa.Column('user_id', sa.String(), nullable=True), + sa.Column('timestamp', sa.BigInteger(), nullable=True), + sa.PrimaryKeyConstraint('collection_name'), + sa.UniqueConstraint('name'), ) - if "file" not in existing_tables: + if 'file' not in existing_tables: op.create_table( - "file", - sa.Column("id", sa.String(), nullable=False), - sa.Column("user_id", sa.String(), nullable=True), - sa.Column("filename", sa.Text(), nullable=True), - sa.Column("meta", JSONField(), nullable=True), - sa.Column("created_at", sa.BigInteger(), nullable=True), - sa.PrimaryKeyConstraint("id"), + 'file', + sa.Column('id', sa.String(), nullable=False), + sa.Column('user_id', sa.String(), nullable=True), + sa.Column('filename', sa.Text(), nullable=True), + sa.Column('meta', JSONField(), nullable=True), + sa.Column('created_at', sa.BigInteger(), nullable=True), + sa.PrimaryKeyConstraint('id'), ) - if "function" not in existing_tables: + if 'function' not in existing_tables: op.create_table( - "function", - sa.Column("id", sa.String(), nullable=False), - sa.Column("user_id", sa.String(), nullable=True), - sa.Column("name", sa.Text(), nullable=True), - sa.Column("type", sa.Text(), nullable=True), - sa.Column("content", sa.Text(), nullable=True), - sa.Column("meta", JSONField(), nullable=True), - sa.Column("valves", JSONField(), nullable=True), - sa.Column("is_active", sa.Boolean(), nullable=True), - sa.Column("is_global", sa.Boolean(), nullable=True), - sa.Column("updated_at", sa.BigInteger(), nullable=True), - sa.Column("created_at", sa.BigInteger(), nullable=True), - sa.PrimaryKeyConstraint("id"), + 'function', + sa.Column('id', sa.String(), nullable=False), + sa.Column('user_id', sa.String(), nullable=True), + sa.Column('name', sa.Text(), nullable=True), + sa.Column('type', sa.Text(), nullable=True), + sa.Column('content', sa.Text(), nullable=True), + sa.Column('meta', JSONField(), nullable=True), + sa.Column('valves', JSONField(), nullable=True), + sa.Column('is_active', sa.Boolean(), nullable=True), + sa.Column('is_global', sa.Boolean(), nullable=True), + sa.Column('updated_at', sa.BigInteger(), nullable=True), + sa.Column('created_at', sa.BigInteger(), nullable=True), + sa.PrimaryKeyConstraint('id'), ) - if "memory" not in existing_tables: + if 'memory' not in existing_tables: op.create_table( - "memory", - sa.Column("id", sa.String(), nullable=False), - sa.Column("user_id", sa.String(), nullable=True), - sa.Column("content", sa.Text(), nullable=True), - sa.Column("updated_at", sa.BigInteger(), nullable=True), - sa.Column("created_at", sa.BigInteger(), nullable=True), - sa.PrimaryKeyConstraint("id"), + 'memory', + sa.Column('id', sa.String(), nullable=False), + sa.Column('user_id', sa.String(), nullable=True), + sa.Column('content', sa.Text(), nullable=True), + sa.Column('updated_at', sa.BigInteger(), nullable=True), + sa.Column('created_at', sa.BigInteger(), nullable=True), + sa.PrimaryKeyConstraint('id'), ) - if "model" not in existing_tables: + if 'model' not in existing_tables: op.create_table( - "model", - sa.Column("id", sa.Text(), nullable=False), - sa.Column("user_id", sa.Text(), nullable=True), - sa.Column("base_model_id", sa.Text(), nullable=True), - sa.Column("name", sa.Text(), nullable=True), - sa.Column("params", JSONField(), nullable=True), - sa.Column("meta", JSONField(), nullable=True), - sa.Column("updated_at", sa.BigInteger(), nullable=True), - sa.Column("created_at", sa.BigInteger(), nullable=True), - sa.PrimaryKeyConstraint("id"), + 'model', + sa.Column('id', sa.Text(), nullable=False), + sa.Column('user_id', sa.Text(), nullable=True), + sa.Column('base_model_id', sa.Text(), nullable=True), + sa.Column('name', sa.Text(), nullable=True), + sa.Column('params', JSONField(), nullable=True), + sa.Column('meta', JSONField(), nullable=True), + sa.Column('updated_at', sa.BigInteger(), nullable=True), + sa.Column('created_at', sa.BigInteger(), nullable=True), + sa.PrimaryKeyConstraint('id'), ) - if "prompt" not in existing_tables: + if 'prompt' not in existing_tables: op.create_table( - "prompt", - sa.Column("command", sa.String(), nullable=False), - sa.Column("user_id", sa.String(), nullable=True), - sa.Column("title", sa.Text(), nullable=True), - sa.Column("content", sa.Text(), nullable=True), - sa.Column("timestamp", sa.BigInteger(), nullable=True), - sa.PrimaryKeyConstraint("command"), + 'prompt', + sa.Column('command', sa.String(), nullable=False), + sa.Column('user_id', sa.String(), nullable=True), + sa.Column('title', sa.Text(), nullable=True), + sa.Column('content', sa.Text(), nullable=True), + sa.Column('timestamp', sa.BigInteger(), nullable=True), + sa.PrimaryKeyConstraint('command'), ) - if "tag" not in existing_tables: + if 'tag' not in existing_tables: op.create_table( - "tag", - sa.Column("id", sa.String(), nullable=False), - sa.Column("name", sa.String(), nullable=True), - sa.Column("user_id", sa.String(), nullable=True), - sa.Column("data", sa.Text(), nullable=True), - sa.PrimaryKeyConstraint("id"), + 'tag', + sa.Column('id', sa.String(), nullable=False), + sa.Column('name', sa.String(), nullable=True), + sa.Column('user_id', sa.String(), nullable=True), + sa.Column('data', sa.Text(), nullable=True), + sa.PrimaryKeyConstraint('id'), ) - if "tool" not in existing_tables: + if 'tool' not in existing_tables: op.create_table( - "tool", - sa.Column("id", sa.String(), nullable=False), - sa.Column("user_id", sa.String(), nullable=True), - sa.Column("name", sa.Text(), nullable=True), - sa.Column("content", sa.Text(), nullable=True), - sa.Column("specs", JSONField(), nullable=True), - sa.Column("meta", JSONField(), nullable=True), - sa.Column("valves", JSONField(), nullable=True), - sa.Column("updated_at", sa.BigInteger(), nullable=True), - sa.Column("created_at", sa.BigInteger(), nullable=True), - sa.PrimaryKeyConstraint("id"), + 'tool', + sa.Column('id', sa.String(), nullable=False), + sa.Column('user_id', sa.String(), nullable=True), + sa.Column('name', sa.Text(), nullable=True), + sa.Column('content', sa.Text(), nullable=True), + sa.Column('specs', JSONField(), nullable=True), + sa.Column('meta', JSONField(), nullable=True), + sa.Column('valves', JSONField(), nullable=True), + sa.Column('updated_at', sa.BigInteger(), nullable=True), + sa.Column('created_at', sa.BigInteger(), nullable=True), + sa.PrimaryKeyConstraint('id'), ) - if "user" not in existing_tables: + if 'user' not in existing_tables: op.create_table( - "user", - sa.Column("id", sa.String(), nullable=False), - sa.Column("name", sa.String(), nullable=True), - sa.Column("email", sa.String(), nullable=True), - sa.Column("role", sa.String(), nullable=True), - sa.Column("profile_image_url", sa.Text(), nullable=True), - sa.Column("last_active_at", sa.BigInteger(), nullable=True), - sa.Column("updated_at", sa.BigInteger(), nullable=True), - sa.Column("created_at", sa.BigInteger(), nullable=True), - sa.Column("api_key", sa.String(), nullable=True), - sa.Column("settings", JSONField(), nullable=True), - sa.Column("info", JSONField(), nullable=True), - sa.Column("oauth_sub", sa.Text(), nullable=True), - sa.PrimaryKeyConstraint("id"), - sa.UniqueConstraint("api_key"), - sa.UniqueConstraint("oauth_sub"), + 'user', + sa.Column('id', sa.String(), nullable=False), + sa.Column('name', sa.String(), nullable=True), + sa.Column('email', sa.String(), nullable=True), + sa.Column('role', sa.String(), nullable=True), + sa.Column('profile_image_url', sa.Text(), nullable=True), + sa.Column('last_active_at', sa.BigInteger(), nullable=True), + sa.Column('updated_at', sa.BigInteger(), nullable=True), + sa.Column('created_at', sa.BigInteger(), nullable=True), + sa.Column('api_key', sa.String(), nullable=True), + sa.Column('settings', JSONField(), nullable=True), + sa.Column('info', JSONField(), nullable=True), + sa.Column('oauth_sub', sa.Text(), nullable=True), + sa.PrimaryKeyConstraint('id'), + sa.UniqueConstraint('api_key'), + sa.UniqueConstraint('oauth_sub'), ) # ### end Alembic commands ### def downgrade() -> None: # ### commands auto generated by Alembic - please adjust! ### - op.drop_table("user") - op.drop_table("tool") - op.drop_table("tag") - op.drop_table("prompt") - op.drop_table("model") - op.drop_table("memory") - op.drop_table("function") - op.drop_table("file") - op.drop_table("document") - op.drop_table("chatidtag") - op.drop_table("chat") - op.drop_table("auth") + op.drop_table('user') + op.drop_table('tool') + op.drop_table('tag') + op.drop_table('prompt') + op.drop_table('model') + op.drop_table('memory') + op.drop_table('function') + op.drop_table('file') + op.drop_table('document') + op.drop_table('chatidtag') + op.drop_table('chat') + op.drop_table('auth') # ### end Alembic commands ### diff --git a/backend/open_webui/migrations/versions/81cc2ce44d79_update_channel_file_and_knowledge_table.py b/backend/open_webui/migrations/versions/81cc2ce44d79_update_channel_file_and_knowledge_table.py index 3853ec50d9..e45a2443df 100644 --- a/backend/open_webui/migrations/versions/81cc2ce44d79_update_channel_file_and_knowledge_table.py +++ b/backend/open_webui/migrations/versions/81cc2ce44d79_update_channel_file_and_knowledge_table.py @@ -13,36 +13,34 @@ import sqlalchemy as sa import open_webui.internal.db # revision identifiers, used by Alembic. -revision: str = "81cc2ce44d79" -down_revision: Union[str, None] = "6283dc0e4d8d" +revision: str = '81cc2ce44d79' +down_revision: Union[str, None] = '6283dc0e4d8d' branch_labels: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None def upgrade() -> None: # Add message_id column to channel_file table - with op.batch_alter_table("channel_file", schema=None) as batch_op: + with op.batch_alter_table('channel_file', schema=None) as batch_op: batch_op.add_column( sa.Column( - "message_id", + 'message_id', sa.Text(), - sa.ForeignKey( - "message.id", ondelete="CASCADE", name="fk_channel_file_message_id" - ), + sa.ForeignKey('message.id', ondelete='CASCADE', name='fk_channel_file_message_id'), nullable=True, ) ) # Add data column to knowledge table - with op.batch_alter_table("knowledge", schema=None) as batch_op: - batch_op.add_column(sa.Column("data", sa.JSON(), nullable=True)) + with op.batch_alter_table('knowledge', schema=None) as batch_op: + batch_op.add_column(sa.Column('data', sa.JSON(), nullable=True)) def downgrade() -> None: # Remove message_id column from channel_file table - with op.batch_alter_table("channel_file", schema=None) as batch_op: - batch_op.drop_column("message_id") + with op.batch_alter_table('channel_file', schema=None) as batch_op: + batch_op.drop_column('message_id') # Remove data column from knowledge table - with op.batch_alter_table("knowledge", schema=None) as batch_op: - batch_op.drop_column("data") + with op.batch_alter_table('knowledge', schema=None) as batch_op: + batch_op.drop_column('data') diff --git a/backend/open_webui/migrations/versions/8452d01d26d7_add_chat_message_table.py b/backend/open_webui/migrations/versions/8452d01d26d7_add_chat_message_table.py index fda33d17bf..3254b57858 100644 --- a/backend/open_webui/migrations/versions/8452d01d26d7_add_chat_message_table.py +++ b/backend/open_webui/migrations/versions/8452d01d26d7_add_chat_message_table.py @@ -16,8 +16,8 @@ import sqlalchemy as sa log = logging.getLogger(__name__) -revision: str = "8452d01d26d7" -down_revision: Union[str, None] = "374d2f66af06" +revision: str = '8452d01d26d7' +down_revision: Union[str, None] = '374d2f66af06' branch_labels: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None @@ -51,74 +51,68 @@ def _flush_batch(conn, table, batch): except Exception as e: sp.rollback() failed += 1 - log.warning(f"Failed to insert message {msg['id']}: {e}") + log.warning(f'Failed to insert message {msg["id"]}: {e}') return inserted, failed def upgrade() -> None: # Step 1: Create table op.create_table( - "chat_message", - sa.Column("id", sa.Text(), primary_key=True), - sa.Column("chat_id", sa.Text(), nullable=False, index=True), - sa.Column("user_id", sa.Text(), index=True), - sa.Column("role", sa.Text(), nullable=False), - sa.Column("parent_id", sa.Text(), nullable=True), - sa.Column("content", sa.JSON(), nullable=True), - sa.Column("output", sa.JSON(), nullable=True), - sa.Column("model_id", sa.Text(), nullable=True, index=True), - sa.Column("files", sa.JSON(), nullable=True), - sa.Column("sources", sa.JSON(), nullable=True), - sa.Column("embeds", sa.JSON(), nullable=True), - sa.Column("done", sa.Boolean(), default=True), - sa.Column("status_history", sa.JSON(), nullable=True), - sa.Column("error", sa.JSON(), nullable=True), - sa.Column("usage", sa.JSON(), nullable=True), - sa.Column("created_at", sa.BigInteger(), index=True), - sa.Column("updated_at", sa.BigInteger()), - sa.ForeignKeyConstraint(["chat_id"], ["chat.id"], ondelete="CASCADE"), + 'chat_message', + sa.Column('id', sa.Text(), primary_key=True), + sa.Column('chat_id', sa.Text(), nullable=False, index=True), + sa.Column('user_id', sa.Text(), index=True), + sa.Column('role', sa.Text(), nullable=False), + sa.Column('parent_id', sa.Text(), nullable=True), + sa.Column('content', sa.JSON(), nullable=True), + sa.Column('output', sa.JSON(), nullable=True), + sa.Column('model_id', sa.Text(), nullable=True, index=True), + sa.Column('files', sa.JSON(), nullable=True), + sa.Column('sources', sa.JSON(), nullable=True), + sa.Column('embeds', sa.JSON(), nullable=True), + sa.Column('done', sa.Boolean(), default=True), + sa.Column('status_history', sa.JSON(), nullable=True), + sa.Column('error', sa.JSON(), nullable=True), + sa.Column('usage', sa.JSON(), nullable=True), + sa.Column('created_at', sa.BigInteger(), index=True), + sa.Column('updated_at', sa.BigInteger()), + sa.ForeignKeyConstraint(['chat_id'], ['chat.id'], ondelete='CASCADE'), ) # Create composite indexes - op.create_index( - "chat_message_chat_parent_idx", "chat_message", ["chat_id", "parent_id"] - ) - op.create_index( - "chat_message_model_created_idx", "chat_message", ["model_id", "created_at"] - ) - op.create_index( - "chat_message_user_created_idx", "chat_message", ["user_id", "created_at"] - ) + op.create_index('chat_message_chat_parent_idx', 'chat_message', ['chat_id', 'parent_id']) + op.create_index('chat_message_model_created_idx', 'chat_message', ['model_id', 'created_at']) + op.create_index('chat_message_user_created_idx', 'chat_message', ['user_id', 'created_at']) # Step 2: Backfill from existing chats conn = op.get_bind() chat_table = sa.table( - "chat", - sa.column("id", sa.Text()), - sa.column("user_id", sa.Text()), - sa.column("chat", sa.JSON()), + 'chat', + sa.column('id', sa.Text()), + sa.column('user_id', sa.Text()), + sa.column('chat', sa.JSON()), ) chat_message_table = sa.table( - "chat_message", - sa.column("id", sa.Text()), - sa.column("chat_id", sa.Text()), - sa.column("user_id", sa.Text()), - sa.column("role", sa.Text()), - sa.column("parent_id", sa.Text()), - sa.column("content", sa.JSON()), - sa.column("output", sa.JSON()), - sa.column("model_id", sa.Text()), - sa.column("files", sa.JSON()), - sa.column("sources", sa.JSON()), - sa.column("embeds", sa.JSON()), - sa.column("done", sa.Boolean()), - sa.column("status_history", sa.JSON()), - sa.column("error", sa.JSON()), - sa.column("usage", sa.JSON()), - sa.column("created_at", sa.BigInteger()), - sa.column("updated_at", sa.BigInteger()), + 'chat_message', + sa.column('id', sa.Text()), + sa.column('chat_id', sa.Text()), + sa.column('user_id', sa.Text()), + sa.column('role', sa.Text()), + sa.column('parent_id', sa.Text()), + sa.column('content', sa.JSON()), + sa.column('output', sa.JSON()), + sa.column('model_id', sa.Text()), + sa.column('files', sa.JSON()), + sa.column('sources', sa.JSON()), + sa.column('embeds', sa.JSON()), + sa.column('done', sa.Boolean()), + sa.column('status_history', sa.JSON()), + sa.column('error', sa.JSON()), + sa.column('usage', sa.JSON()), + sa.column('created_at', sa.BigInteger()), + sa.column('updated_at', sa.BigInteger()), ) # Stream rows instead of loading all into memory: @@ -126,7 +120,7 @@ def upgrade() -> None: # - stream_results: enables server-side cursors on PostgreSQL (no-op on SQLite) result = conn.execute( sa.select(chat_table.c.id, chat_table.c.user_id, chat_table.c.chat) - .where(~chat_table.c.user_id.like("shared-%")) + .where(~chat_table.c.user_id.like('shared-%')) .execution_options(yield_per=1000, stream_results=True) ) @@ -150,11 +144,11 @@ def upgrade() -> None: except Exception: continue - history = chat_data.get("history", {}) + history = chat_data.get('history', {}) if not isinstance(history, dict): continue - messages = history.get("messages", {}) + messages = history.get('messages', {}) if not isinstance(messages, dict): continue @@ -162,11 +156,11 @@ def upgrade() -> None: if not isinstance(message, dict): continue - role = message.get("role") + role = message.get('role') if not role: continue - timestamp = message.get("timestamp", now) + timestamp = message.get('timestamp', now) try: timestamp = int(float(timestamp)) @@ -182,37 +176,33 @@ def upgrade() -> None: messages_batch.append( { - "id": f"{chat_id}-{message_id}", - "chat_id": chat_id, - "user_id": user_id, - "role": role, - "parent_id": message.get("parentId"), - "content": message.get("content"), - "output": message.get("output"), - "model_id": message.get("model"), - "files": message.get("files"), - "sources": message.get("sources"), - "embeds": message.get("embeds"), - "done": message.get("done", True), - "status_history": message.get("statusHistory"), - "error": message.get("error"), - "usage": message.get("usage"), - "created_at": timestamp, - "updated_at": timestamp, + 'id': f'{chat_id}-{message_id}', + 'chat_id': chat_id, + 'user_id': user_id, + 'role': role, + 'parent_id': message.get('parentId'), + 'content': message.get('content'), + 'output': message.get('output'), + 'model_id': message.get('model'), + 'files': message.get('files'), + 'sources': message.get('sources'), + 'embeds': message.get('embeds'), + 'done': message.get('done', True), + 'status_history': message.get('statusHistory'), + 'error': message.get('error'), + 'usage': message.get('usage'), + 'created_at': timestamp, + 'updated_at': timestamp, } ) # Flush batch when full if len(messages_batch) >= BATCH_SIZE: - inserted, failed = _flush_batch( - conn, chat_message_table, messages_batch - ) + inserted, failed = _flush_batch(conn, chat_message_table, messages_batch) total_inserted += inserted total_failed += failed if total_inserted % 50000 < BATCH_SIZE: - log.info( - f"Migration progress: {total_inserted} messages inserted..." - ) + log.info(f'Migration progress: {total_inserted} messages inserted...') messages_batch.clear() # Flush remaining messages @@ -221,13 +211,11 @@ def upgrade() -> None: total_inserted += inserted total_failed += failed - log.info( - f"Backfilled {total_inserted} messages into chat_message table ({total_failed} failed)" - ) + log.info(f'Backfilled {total_inserted} messages into chat_message table ({total_failed} failed)') def downgrade() -> None: - op.drop_index("chat_message_user_created_idx", table_name="chat_message") - op.drop_index("chat_message_model_created_idx", table_name="chat_message") - op.drop_index("chat_message_chat_parent_idx", table_name="chat_message") - op.drop_table("chat_message") + op.drop_index('chat_message_user_created_idx', table_name='chat_message') + op.drop_index('chat_message_model_created_idx', table_name='chat_message') + op.drop_index('chat_message_chat_parent_idx', table_name='chat_message') + op.drop_table('chat_message') diff --git a/backend/open_webui/migrations/versions/90ef40d4714e_update_channel_and_channel_members_table.py b/backend/open_webui/migrations/versions/90ef40d4714e_update_channel_and_channel_members_table.py index 8b9e338309..9d115b1e5c 100644 --- a/backend/open_webui/migrations/versions/90ef40d4714e_update_channel_and_channel_members_table.py +++ b/backend/open_webui/migrations/versions/90ef40d4714e_update_channel_and_channel_members_table.py @@ -13,48 +13,46 @@ import sqlalchemy as sa import open_webui.internal.db # revision identifiers, used by Alembic. -revision: str = "90ef40d4714e" -down_revision: Union[str, None] = "b10670c03dd5" +revision: str = '90ef40d4714e' +down_revision: Union[str, None] = 'b10670c03dd5' branch_labels: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None def upgrade() -> None: # Update 'channel' table - op.add_column("channel", sa.Column("is_private", sa.Boolean(), nullable=True)) + op.add_column('channel', sa.Column('is_private', sa.Boolean(), nullable=True)) - op.add_column("channel", sa.Column("archived_at", sa.BigInteger(), nullable=True)) - op.add_column("channel", sa.Column("archived_by", sa.Text(), nullable=True)) + op.add_column('channel', sa.Column('archived_at', sa.BigInteger(), nullable=True)) + op.add_column('channel', sa.Column('archived_by', sa.Text(), nullable=True)) - op.add_column("channel", sa.Column("deleted_at", sa.BigInteger(), nullable=True)) - op.add_column("channel", sa.Column("deleted_by", sa.Text(), nullable=True)) + op.add_column('channel', sa.Column('deleted_at', sa.BigInteger(), nullable=True)) + op.add_column('channel', sa.Column('deleted_by', sa.Text(), nullable=True)) - op.add_column("channel", sa.Column("updated_by", sa.Text(), nullable=True)) + op.add_column('channel', sa.Column('updated_by', sa.Text(), nullable=True)) # Update 'channel_member' table - op.add_column("channel_member", sa.Column("role", sa.Text(), nullable=True)) - op.add_column("channel_member", sa.Column("invited_by", sa.Text(), nullable=True)) - op.add_column( - "channel_member", sa.Column("invited_at", sa.BigInteger(), nullable=True) - ) + op.add_column('channel_member', sa.Column('role', sa.Text(), nullable=True)) + op.add_column('channel_member', sa.Column('invited_by', sa.Text(), nullable=True)) + op.add_column('channel_member', sa.Column('invited_at', sa.BigInteger(), nullable=True)) # Create 'channel_webhook' table op.create_table( - "channel_webhook", - sa.Column("id", sa.Text(), primary_key=True, unique=True, nullable=False), - sa.Column("user_id", sa.Text(), nullable=False), + 'channel_webhook', + sa.Column('id', sa.Text(), primary_key=True, unique=True, nullable=False), + sa.Column('user_id', sa.Text(), nullable=False), sa.Column( - "channel_id", + 'channel_id', sa.Text(), - sa.ForeignKey("channel.id", ondelete="CASCADE"), + sa.ForeignKey('channel.id', ondelete='CASCADE'), nullable=False, ), - sa.Column("name", sa.Text(), nullable=False), - sa.Column("profile_image_url", sa.Text(), nullable=True), - sa.Column("token", sa.Text(), nullable=False), - sa.Column("last_used_at", sa.BigInteger(), nullable=True), - sa.Column("created_at", sa.BigInteger(), nullable=False), - sa.Column("updated_at", sa.BigInteger(), nullable=False), + sa.Column('name', sa.Text(), nullable=False), + sa.Column('profile_image_url', sa.Text(), nullable=True), + sa.Column('token', sa.Text(), nullable=False), + sa.Column('last_used_at', sa.BigInteger(), nullable=True), + sa.Column('created_at', sa.BigInteger(), nullable=False), + sa.Column('updated_at', sa.BigInteger(), nullable=False), ) pass @@ -62,19 +60,19 @@ def upgrade() -> None: def downgrade() -> None: # Downgrade 'channel' table - op.drop_column("channel", "is_private") - op.drop_column("channel", "archived_at") - op.drop_column("channel", "archived_by") - op.drop_column("channel", "deleted_at") - op.drop_column("channel", "deleted_by") - op.drop_column("channel", "updated_by") + op.drop_column('channel', 'is_private') + op.drop_column('channel', 'archived_at') + op.drop_column('channel', 'archived_by') + op.drop_column('channel', 'deleted_at') + op.drop_column('channel', 'deleted_by') + op.drop_column('channel', 'updated_by') # Downgrade 'channel_member' table - op.drop_column("channel_member", "role") - op.drop_column("channel_member", "invited_by") - op.drop_column("channel_member", "invited_at") + op.drop_column('channel_member', 'role') + op.drop_column('channel_member', 'invited_by') + op.drop_column('channel_member', 'invited_at') # Drop 'channel_webhook' table - op.drop_table("channel_webhook") + op.drop_table('channel_webhook') pass diff --git a/backend/open_webui/migrations/versions/922e7a387820_add_group_table.py b/backend/open_webui/migrations/versions/922e7a387820_add_group_table.py index a752115844..5e617be1e6 100644 --- a/backend/open_webui/migrations/versions/922e7a387820_add_group_table.py +++ b/backend/open_webui/migrations/versions/922e7a387820_add_group_table.py @@ -9,38 +9,38 @@ Create Date: 2024-11-14 03:00:00.000000 from alembic import op import sqlalchemy as sa -revision = "922e7a387820" -down_revision = "4ace53fd72c8" +revision = '922e7a387820' +down_revision = '4ace53fd72c8' branch_labels = None depends_on = None def upgrade(): op.create_table( - "group", - sa.Column("id", sa.Text(), nullable=False, primary_key=True, unique=True), - sa.Column("user_id", sa.Text(), nullable=True), - sa.Column("name", sa.Text(), nullable=True), - sa.Column("description", sa.Text(), nullable=True), - sa.Column("data", sa.JSON(), nullable=True), - sa.Column("meta", sa.JSON(), nullable=True), - sa.Column("permissions", sa.JSON(), nullable=True), - sa.Column("user_ids", sa.JSON(), nullable=True), - sa.Column("created_at", sa.BigInteger(), nullable=True), - sa.Column("updated_at", sa.BigInteger(), nullable=True), + 'group', + sa.Column('id', sa.Text(), nullable=False, primary_key=True, unique=True), + sa.Column('user_id', sa.Text(), nullable=True), + sa.Column('name', sa.Text(), nullable=True), + sa.Column('description', sa.Text(), nullable=True), + sa.Column('data', sa.JSON(), nullable=True), + sa.Column('meta', sa.JSON(), nullable=True), + sa.Column('permissions', sa.JSON(), nullable=True), + sa.Column('user_ids', sa.JSON(), nullable=True), + sa.Column('created_at', sa.BigInteger(), nullable=True), + sa.Column('updated_at', sa.BigInteger(), nullable=True), ) # Add 'access_control' column to 'model' table op.add_column( - "model", - sa.Column("access_control", sa.JSON(), nullable=True), + 'model', + sa.Column('access_control', sa.JSON(), nullable=True), ) # Add 'is_active' column to 'model' table op.add_column( - "model", + 'model', sa.Column( - "is_active", + 'is_active', sa.Boolean(), nullable=False, server_default=sa.sql.expression.true(), @@ -49,37 +49,37 @@ def upgrade(): # Add 'access_control' column to 'knowledge' table op.add_column( - "knowledge", - sa.Column("access_control", sa.JSON(), nullable=True), + 'knowledge', + sa.Column('access_control', sa.JSON(), nullable=True), ) # Add 'access_control' column to 'prompt' table op.add_column( - "prompt", - sa.Column("access_control", sa.JSON(), nullable=True), + 'prompt', + sa.Column('access_control', sa.JSON(), nullable=True), ) # Add 'access_control' column to 'tools' table op.add_column( - "tool", - sa.Column("access_control", sa.JSON(), nullable=True), + 'tool', + sa.Column('access_control', sa.JSON(), nullable=True), ) def downgrade(): - op.drop_table("group") + op.drop_table('group') # Drop 'access_control' column from 'model' table - op.drop_column("model", "access_control") + op.drop_column('model', 'access_control') # Drop 'is_active' column from 'model' table - op.drop_column("model", "is_active") + op.drop_column('model', 'is_active') # Drop 'access_control' column from 'knowledge' table - op.drop_column("knowledge", "access_control") + op.drop_column('knowledge', 'access_control') # Drop 'access_control' column from 'prompt' table - op.drop_column("prompt", "access_control") + op.drop_column('prompt', 'access_control') # Drop 'access_control' column from 'tools' table - op.drop_column("tool", "access_control") + op.drop_column('tool', 'access_control') diff --git a/backend/open_webui/migrations/versions/9f0c9cd09105_add_note_table.py b/backend/open_webui/migrations/versions/9f0c9cd09105_add_note_table.py index 8e983a2cff..c75db04ca5 100644 --- a/backend/open_webui/migrations/versions/9f0c9cd09105_add_note_table.py +++ b/backend/open_webui/migrations/versions/9f0c9cd09105_add_note_table.py @@ -9,25 +9,25 @@ Create Date: 2025-05-03 03:00:00.000000 from alembic import op import sqlalchemy as sa -revision = "9f0c9cd09105" -down_revision = "3781e22d8b01" +revision = '9f0c9cd09105' +down_revision = '3781e22d8b01' branch_labels = None depends_on = None def upgrade(): op.create_table( - "note", - sa.Column("id", sa.Text(), nullable=False, primary_key=True, unique=True), - sa.Column("user_id", sa.Text(), nullable=True), - sa.Column("title", sa.Text(), nullable=True), - sa.Column("data", sa.JSON(), nullable=True), - sa.Column("meta", sa.JSON(), nullable=True), - sa.Column("access_control", sa.JSON(), nullable=True), - sa.Column("created_at", sa.BigInteger(), nullable=True), - sa.Column("updated_at", sa.BigInteger(), nullable=True), + 'note', + sa.Column('id', sa.Text(), nullable=False, primary_key=True, unique=True), + sa.Column('user_id', sa.Text(), nullable=True), + sa.Column('title', sa.Text(), nullable=True), + sa.Column('data', sa.JSON(), nullable=True), + sa.Column('meta', sa.JSON(), nullable=True), + sa.Column('access_control', sa.JSON(), nullable=True), + sa.Column('created_at', sa.BigInteger(), nullable=True), + sa.Column('updated_at', sa.BigInteger(), nullable=True), ) def downgrade(): - op.drop_table("note") + op.drop_table('note') diff --git a/backend/open_webui/migrations/versions/a1b2c3d4e5f6_add_skill_table.py b/backend/open_webui/migrations/versions/a1b2c3d4e5f6_add_skill_table.py index 26e9e66240..f11f7d8d1b 100644 --- a/backend/open_webui/migrations/versions/a1b2c3d4e5f6_add_skill_table.py +++ b/backend/open_webui/migrations/versions/a1b2c3d4e5f6_add_skill_table.py @@ -13,8 +13,8 @@ import sqlalchemy as sa from open_webui.migrations.util import get_existing_tables -revision: str = "a1b2c3d4e5f6" -down_revision: Union[str, None] = "f1e2d3c4b5a6" +revision: str = 'a1b2c3d4e5f6' +down_revision: Union[str, None] = 'f1e2d3c4b5a6' branch_labels: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None @@ -22,24 +22,24 @@ depends_on: Union[str, Sequence[str], None] = None def upgrade() -> None: existing_tables = set(get_existing_tables()) - if "skill" not in existing_tables: + if 'skill' not in existing_tables: op.create_table( - "skill", - sa.Column("id", sa.String(), nullable=False, primary_key=True), - sa.Column("user_id", sa.String(), nullable=False), - sa.Column("name", sa.Text(), nullable=False, unique=True), - sa.Column("description", sa.Text(), nullable=True), - sa.Column("content", sa.Text(), nullable=False), - sa.Column("meta", sa.JSON(), nullable=True), - sa.Column("is_active", sa.Boolean(), nullable=False), - sa.Column("updated_at", sa.BigInteger(), nullable=False), - sa.Column("created_at", sa.BigInteger(), nullable=False), + 'skill', + sa.Column('id', sa.String(), nullable=False, primary_key=True), + sa.Column('user_id', sa.String(), nullable=False), + sa.Column('name', sa.Text(), nullable=False, unique=True), + sa.Column('description', sa.Text(), nullable=True), + sa.Column('content', sa.Text(), nullable=False), + sa.Column('meta', sa.JSON(), nullable=True), + sa.Column('is_active', sa.Boolean(), nullable=False), + sa.Column('updated_at', sa.BigInteger(), nullable=False), + sa.Column('created_at', sa.BigInteger(), nullable=False), ) - op.create_index("idx_skill_user_id", "skill", ["user_id"]) - op.create_index("idx_skill_updated_at", "skill", ["updated_at"]) + op.create_index('idx_skill_user_id', 'skill', ['user_id']) + op.create_index('idx_skill_updated_at', 'skill', ['updated_at']) def downgrade() -> None: - op.drop_index("idx_skill_updated_at", table_name="skill") - op.drop_index("idx_skill_user_id", table_name="skill") - op.drop_table("skill") + op.drop_index('idx_skill_updated_at', table_name='skill') + op.drop_index('idx_skill_user_id', table_name='skill') + op.drop_table('skill') diff --git a/backend/open_webui/migrations/versions/a5c220713937_add_reply_to_id_column_to_message.py b/backend/open_webui/migrations/versions/a5c220713937_add_reply_to_id_column_to_message.py index dd2b7d1a68..29157baa07 100644 --- a/backend/open_webui/migrations/versions/a5c220713937_add_reply_to_id_column_to_message.py +++ b/backend/open_webui/migrations/versions/a5c220713937_add_reply_to_id_column_to_message.py @@ -12,8 +12,8 @@ from alembic import op import sqlalchemy as sa # revision identifiers, used by Alembic. -revision: str = "a5c220713937" -down_revision: Union[str, None] = "38d63c18f30f" +revision: str = 'a5c220713937' +down_revision: Union[str, None] = '38d63c18f30f' branch_labels: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None @@ -21,14 +21,14 @@ depends_on: Union[str, Sequence[str], None] = None def upgrade() -> None: # Add 'reply_to_id' column to the 'message' table for replying to messages op.add_column( - "message", - sa.Column("reply_to_id", sa.Text(), nullable=True), + 'message', + sa.Column('reply_to_id', sa.Text(), nullable=True), ) pass def downgrade() -> None: # Remove 'reply_to_id' column from the 'message' table - op.drop_column("message", "reply_to_id") + op.drop_column('message', 'reply_to_id') pass diff --git a/backend/open_webui/migrations/versions/af906e964978_add_feedback_table.py b/backend/open_webui/migrations/versions/af906e964978_add_feedback_table.py index 9116aa3884..4d8fd63e80 100644 --- a/backend/open_webui/migrations/versions/af906e964978_add_feedback_table.py +++ b/backend/open_webui/migrations/versions/af906e964978_add_feedback_table.py @@ -10,8 +10,8 @@ from alembic import op import sqlalchemy as sa # Revision identifiers, used by Alembic. -revision = "af906e964978" -down_revision = "c29facfe716b" +revision = 'af906e964978' +down_revision = 'c29facfe716b' branch_labels = None depends_on = None @@ -19,33 +19,23 @@ depends_on = None def upgrade(): # ### Create feedback table ### op.create_table( - "feedback", + 'feedback', + sa.Column('id', sa.Text(), primary_key=True), # Unique identifier for each feedback (TEXT type) + sa.Column('user_id', sa.Text(), nullable=True), # ID of the user providing the feedback (TEXT type) + sa.Column('version', sa.BigInteger(), default=0), # Version of feedback (BIGINT type) + sa.Column('type', sa.Text(), nullable=True), # Type of feedback (TEXT type) + sa.Column('data', sa.JSON(), nullable=True), # Feedback data (JSON type) + sa.Column('meta', sa.JSON(), nullable=True), # Metadata for feedback (JSON type) + sa.Column('snapshot', sa.JSON(), nullable=True), # snapshot data for feedback (JSON type) sa.Column( - "id", sa.Text(), primary_key=True - ), # Unique identifier for each feedback (TEXT type) - sa.Column( - "user_id", sa.Text(), nullable=True - ), # ID of the user providing the feedback (TEXT type) - sa.Column( - "version", sa.BigInteger(), default=0 - ), # Version of feedback (BIGINT type) - sa.Column("type", sa.Text(), nullable=True), # Type of feedback (TEXT type) - sa.Column("data", sa.JSON(), nullable=True), # Feedback data (JSON type) - sa.Column( - "meta", sa.JSON(), nullable=True - ), # Metadata for feedback (JSON type) - sa.Column( - "snapshot", sa.JSON(), nullable=True - ), # snapshot data for feedback (JSON type) - sa.Column( - "created_at", sa.BigInteger(), nullable=False + 'created_at', sa.BigInteger(), nullable=False ), # Feedback creation timestamp (BIGINT representing epoch) sa.Column( - "updated_at", sa.BigInteger(), nullable=False + 'updated_at', sa.BigInteger(), nullable=False ), # Feedback update timestamp (BIGINT representing epoch) ) def downgrade(): # ### Drop feedback table ### - op.drop_table("feedback") + op.drop_table('feedback') diff --git a/backend/open_webui/migrations/versions/b10670c03dd5_update_user_table.py b/backend/open_webui/migrations/versions/b10670c03dd5_update_user_table.py index 0472c08616..623289d885 100644 --- a/backend/open_webui/migrations/versions/b10670c03dd5_update_user_table.py +++ b/backend/open_webui/migrations/versions/b10670c03dd5_update_user_table.py @@ -17,8 +17,8 @@ import json import time # revision identifiers, used by Alembic. -revision: str = "b10670c03dd5" -down_revision: Union[str, None] = "2f1211949ecc" +revision: str = 'b10670c03dd5' +down_revision: Union[str, None] = '2f1211949ecc' branch_labels: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None @@ -33,13 +33,11 @@ def _drop_sqlite_indexes_for_column(table_name, column_name, conn): for idx in indexes: index_name = idx[1] # index name # Get indexed columns - idx_info = conn.execute( - sa.text(f"PRAGMA index_info('{index_name}')") - ).fetchall() + idx_info = conn.execute(sa.text(f"PRAGMA index_info('{index_name}')")).fetchall() indexed_cols = [row[2] for row in idx_info] # col names if column_name in indexed_cols: - conn.execute(sa.text(f"DROP INDEX IF EXISTS {index_name}")) + conn.execute(sa.text(f'DROP INDEX IF EXISTS {index_name}')) def _convert_column_to_json(table: str, column: str): @@ -47,9 +45,9 @@ def _convert_column_to_json(table: str, column: str): dialect = conn.dialect.name # SQLite cannot ALTER COLUMN → must recreate column - if dialect == "sqlite": + if dialect == 'sqlite': # 1. Add temporary column - op.add_column(table, sa.Column(f"{column}_json", sa.JSON(), nullable=True)) + op.add_column(table, sa.Column(f'{column}_json', sa.JSON(), nullable=True)) # 2. Load old data rows = conn.execute(sa.text(f'SELECT id, {column} FROM "{table}"')).fetchall() @@ -66,14 +64,14 @@ def _convert_column_to_json(table: str, column: str): conn.execute( sa.text(f'UPDATE "{table}" SET {column}_json = :val WHERE id = :id'), - {"val": json.dumps(parsed) if parsed else None, "id": uid}, + {'val': json.dumps(parsed) if parsed else None, 'id': uid}, ) # 3. Drop old TEXT column op.drop_column(table, column) # 4. Rename new JSON column → original name - op.alter_column(table, f"{column}_json", new_column_name=column) + op.alter_column(table, f'{column}_json', new_column_name=column) else: # PostgreSQL supports direct CAST @@ -81,7 +79,7 @@ def _convert_column_to_json(table: str, column: str): table, column, type_=sa.JSON(), - postgresql_using=f"{column}::json", + postgresql_using=f'{column}::json', ) @@ -89,85 +87,77 @@ def _convert_column_to_text(table: str, column: str): conn = op.get_bind() dialect = conn.dialect.name - if dialect == "sqlite": - op.add_column(table, sa.Column(f"{column}_text", sa.Text(), nullable=True)) + if dialect == 'sqlite': + op.add_column(table, sa.Column(f'{column}_text', sa.Text(), nullable=True)) rows = conn.execute(sa.text(f'SELECT id, {column} FROM "{table}"')).fetchall() for uid, raw in rows: conn.execute( sa.text(f'UPDATE "{table}" SET {column}_text = :val WHERE id = :id'), - {"val": json.dumps(raw) if raw else None, "id": uid}, + {'val': json.dumps(raw) if raw else None, 'id': uid}, ) op.drop_column(table, column) - op.alter_column(table, f"{column}_text", new_column_name=column) + op.alter_column(table, f'{column}_text', new_column_name=column) else: op.alter_column( table, column, type_=sa.Text(), - postgresql_using=f"to_json({column})::text", + postgresql_using=f'to_json({column})::text', ) def upgrade() -> None: - op.add_column( - "user", sa.Column("profile_banner_image_url", sa.Text(), nullable=True) - ) - op.add_column("user", sa.Column("timezone", sa.String(), nullable=True)) + op.add_column('user', sa.Column('profile_banner_image_url', sa.Text(), nullable=True)) + op.add_column('user', sa.Column('timezone', sa.String(), nullable=True)) - op.add_column("user", sa.Column("presence_state", sa.String(), nullable=True)) - op.add_column("user", sa.Column("status_emoji", sa.String(), nullable=True)) - op.add_column("user", sa.Column("status_message", sa.Text(), nullable=True)) - op.add_column( - "user", sa.Column("status_expires_at", sa.BigInteger(), nullable=True) - ) + op.add_column('user', sa.Column('presence_state', sa.String(), nullable=True)) + op.add_column('user', sa.Column('status_emoji', sa.String(), nullable=True)) + op.add_column('user', sa.Column('status_message', sa.Text(), nullable=True)) + op.add_column('user', sa.Column('status_expires_at', sa.BigInteger(), nullable=True)) - op.add_column("user", sa.Column("oauth", sa.JSON(), nullable=True)) + op.add_column('user', sa.Column('oauth', sa.JSON(), nullable=True)) # Convert info (TEXT/JSONField) → JSON - _convert_column_to_json("user", "info") + _convert_column_to_json('user', 'info') # Convert settings (TEXT/JSONField) → JSON - _convert_column_to_json("user", "settings") + _convert_column_to_json('user', 'settings') op.create_table( - "api_key", - sa.Column("id", sa.Text(), primary_key=True, unique=True), - sa.Column("user_id", sa.Text(), sa.ForeignKey("user.id", ondelete="CASCADE")), - sa.Column("key", sa.Text(), unique=True, nullable=False), - sa.Column("data", sa.JSON(), nullable=True), - sa.Column("expires_at", sa.BigInteger(), nullable=True), - sa.Column("last_used_at", sa.BigInteger(), nullable=True), - sa.Column("created_at", sa.BigInteger(), nullable=False), - sa.Column("updated_at", sa.BigInteger(), nullable=False), + 'api_key', + sa.Column('id', sa.Text(), primary_key=True, unique=True), + sa.Column('user_id', sa.Text(), sa.ForeignKey('user.id', ondelete='CASCADE')), + sa.Column('key', sa.Text(), unique=True, nullable=False), + sa.Column('data', sa.JSON(), nullable=True), + sa.Column('expires_at', sa.BigInteger(), nullable=True), + sa.Column('last_used_at', sa.BigInteger(), nullable=True), + sa.Column('created_at', sa.BigInteger(), nullable=False), + sa.Column('updated_at', sa.BigInteger(), nullable=False), ) conn = op.get_bind() - users = conn.execute( - sa.text('SELECT id, oauth_sub FROM "user" WHERE oauth_sub IS NOT NULL') - ).fetchall() + users = conn.execute(sa.text('SELECT id, oauth_sub FROM "user" WHERE oauth_sub IS NOT NULL')).fetchall() for uid, oauth_sub in users: if oauth_sub: # Example formats supported: # provider@sub # plain sub (stored as {"oidc": {"sub": sub}}) - if "@" in oauth_sub: - provider, sub = oauth_sub.split("@", 1) + if '@' in oauth_sub: + provider, sub = oauth_sub.split('@', 1) else: - provider, sub = "oidc", oauth_sub + provider, sub = 'oidc', oauth_sub - oauth_json = json.dumps({provider: {"sub": sub}}) + oauth_json = json.dumps({provider: {'sub': sub}}) conn.execute( sa.text('UPDATE "user" SET oauth = :oauth WHERE id = :id'), - {"oauth": oauth_json, "id": uid}, + {'oauth': oauth_json, 'id': uid}, ) - users_with_keys = conn.execute( - sa.text('SELECT id, api_key FROM "user" WHERE api_key IS NOT NULL') - ).fetchall() + users_with_keys = conn.execute(sa.text('SELECT id, api_key FROM "user" WHERE api_key IS NOT NULL')).fetchall() now = int(time.time()) for uid, api_key in users_with_keys: @@ -178,72 +168,70 @@ def upgrade() -> None: VALUES (:id, :user_id, :key, :created_at, :updated_at) """), { - "id": f"key_{uid}", - "user_id": uid, - "key": api_key, - "created_at": now, - "updated_at": now, + 'id': f'key_{uid}', + 'user_id': uid, + 'key': api_key, + 'created_at': now, + 'updated_at': now, }, ) - if conn.dialect.name == "sqlite": - _drop_sqlite_indexes_for_column("user", "api_key", conn) - _drop_sqlite_indexes_for_column("user", "oauth_sub", conn) + if conn.dialect.name == 'sqlite': + _drop_sqlite_indexes_for_column('user', 'api_key', conn) + _drop_sqlite_indexes_for_column('user', 'oauth_sub', conn) - with op.batch_alter_table("user") as batch_op: - batch_op.drop_column("api_key") - batch_op.drop_column("oauth_sub") + with op.batch_alter_table('user') as batch_op: + batch_op.drop_column('api_key') + batch_op.drop_column('oauth_sub') def downgrade() -> None: # --- 1. Restore old oauth_sub column --- - op.add_column("user", sa.Column("oauth_sub", sa.Text(), nullable=True)) + op.add_column('user', sa.Column('oauth_sub', sa.Text(), nullable=True)) conn = op.get_bind() - users = conn.execute( - sa.text('SELECT id, oauth FROM "user" WHERE oauth IS NOT NULL') - ).fetchall() + users = conn.execute(sa.text('SELECT id, oauth FROM "user" WHERE oauth IS NOT NULL')).fetchall() for uid, oauth in users: try: data = json.loads(oauth) provider = list(data.keys())[0] - sub = data[provider].get("sub") - oauth_sub = f"{provider}@{sub}" + sub = data[provider].get('sub') + oauth_sub = f'{provider}@{sub}' except Exception: oauth_sub = None conn.execute( sa.text('UPDATE "user" SET oauth_sub = :oauth_sub WHERE id = :id'), - {"oauth_sub": oauth_sub, "id": uid}, + {'oauth_sub': oauth_sub, 'id': uid}, ) - op.drop_column("user", "oauth") + op.drop_column('user', 'oauth') # --- 2. Restore api_key field --- - op.add_column("user", sa.Column("api_key", sa.String(), nullable=True)) + op.add_column('user', sa.Column('api_key', sa.String(), nullable=True)) # Restore values from api_key - keys = conn.execute(sa.text("SELECT user_id, key FROM api_key")).fetchall() + keys = conn.execute(sa.text('SELECT user_id, key FROM api_key')).fetchall() for uid, key in keys: conn.execute( sa.text('UPDATE "user" SET api_key = :key WHERE id = :id'), - {"key": key, "id": uid}, + {'key': key, 'id': uid}, ) # Drop new table - op.drop_table("api_key") + op.drop_table('api_key') - with op.batch_alter_table("user") as batch_op: - batch_op.drop_column("profile_banner_image_url") - batch_op.drop_column("timezone") + with op.batch_alter_table('user') as batch_op: + batch_op.drop_column('profile_banner_image_url') + batch_op.drop_column('timezone') - batch_op.drop_column("presence_state") - batch_op.drop_column("status_emoji") - batch_op.drop_column("status_message") - batch_op.drop_column("status_expires_at") + batch_op.drop_column('presence_state') + batch_op.drop_column('status_emoji') + batch_op.drop_column('status_message') + batch_op.drop_column('status_expires_at') # Convert info (JSON) → TEXT - _convert_column_to_text("user", "info") + _convert_column_to_text('user', 'info') # Convert settings (JSON) → TEXT - _convert_column_to_text("user", "settings") + _convert_column_to_text('user', 'settings') diff --git a/backend/open_webui/migrations/versions/b2c3d4e5f6a7_add_scim_column_to_user_table.py b/backend/open_webui/migrations/versions/b2c3d4e5f6a7_add_scim_column_to_user_table.py index e8bf9a850f..e3668d3b6e 100644 --- a/backend/open_webui/migrations/versions/b2c3d4e5f6a7_add_scim_column_to_user_table.py +++ b/backend/open_webui/migrations/versions/b2c3d4e5f6a7_add_scim_column_to_user_table.py @@ -12,15 +12,15 @@ from alembic import op import sqlalchemy as sa # revision identifiers, used by Alembic. -revision: str = "b2c3d4e5f6a7" -down_revision: Union[str, None] = "a1b2c3d4e5f6" +revision: str = 'b2c3d4e5f6a7' +down_revision: Union[str, None] = 'a1b2c3d4e5f6' branch_labels: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None def upgrade() -> None: - op.add_column("user", sa.Column("scim", sa.JSON(), nullable=True)) + op.add_column('user', sa.Column('scim', sa.JSON(), nullable=True)) def downgrade() -> None: - op.drop_column("user", "scim") + op.drop_column('user', 'scim') diff --git a/backend/open_webui/migrations/versions/c0fbf31ca0db_update_file_table.py b/backend/open_webui/migrations/versions/c0fbf31ca0db_update_file_table.py index 5f7f2abf70..709b644150 100644 --- a/backend/open_webui/migrations/versions/c0fbf31ca0db_update_file_table.py +++ b/backend/open_webui/migrations/versions/c0fbf31ca0db_update_file_table.py @@ -12,21 +12,21 @@ import sqlalchemy as sa from alembic import op # revision identifiers, used by Alembic. -revision: str = "c0fbf31ca0db" -down_revision: Union[str, None] = "ca81bd47c050" +revision: str = 'c0fbf31ca0db' +down_revision: Union[str, None] = 'ca81bd47c050' branch_labels: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None def upgrade(): # ### commands auto generated by Alembic - please adjust! ### - op.add_column("file", sa.Column("hash", sa.Text(), nullable=True)) - op.add_column("file", sa.Column("data", sa.JSON(), nullable=True)) - op.add_column("file", sa.Column("updated_at", sa.BigInteger(), nullable=True)) + op.add_column('file', sa.Column('hash', sa.Text(), nullable=True)) + op.add_column('file', sa.Column('data', sa.JSON(), nullable=True)) + op.add_column('file', sa.Column('updated_at', sa.BigInteger(), nullable=True)) def downgrade(): # ### commands auto generated by Alembic - please adjust! ### - op.drop_column("file", "updated_at") - op.drop_column("file", "data") - op.drop_column("file", "hash") + op.drop_column('file', 'updated_at') + op.drop_column('file', 'data') + op.drop_column('file', 'hash') diff --git a/backend/open_webui/migrations/versions/c29facfe716b_update_file_table_path.py b/backend/open_webui/migrations/versions/c29facfe716b_update_file_table_path.py index 7786de425f..37fe63ef15 100644 --- a/backend/open_webui/migrations/versions/c29facfe716b_update_file_table_path.py +++ b/backend/open_webui/migrations/versions/c29facfe716b_update_file_table_path.py @@ -12,35 +12,33 @@ import json from sqlalchemy.sql import table, column from sqlalchemy import String, Text, JSON, and_ -revision = "c29facfe716b" -down_revision = "c69f45358db4" +revision = 'c29facfe716b' +down_revision = 'c69f45358db4' branch_labels = None depends_on = None def upgrade(): # 1. Add the `path` column to the "file" table. - op.add_column("file", sa.Column("path", sa.Text(), nullable=True)) + op.add_column('file', sa.Column('path', sa.Text(), nullable=True)) # 2. Convert the `meta` column from Text/JSONField to `JSON()` # Use Alembic's default batch_op for dialect compatibility. - with op.batch_alter_table("file", schema=None) as batch_op: + with op.batch_alter_table('file', schema=None) as batch_op: batch_op.alter_column( - "meta", + 'meta', type_=sa.JSON(), existing_type=sa.Text(), existing_nullable=True, nullable=True, - postgresql_using="meta::json", + postgresql_using='meta::json', ) # 3. Migrate legacy data from `meta` JSONField # Fetch and process `meta` data from the table, add values to the new `path` column as necessary. # We will use SQLAlchemy core bindings to ensure safety across different databases. - file_table = table( - "file", column("id", String), column("meta", JSON), column("path", Text) - ) + file_table = table('file', column('id', String), column('meta', JSON), column('path', Text)) # Create connection to the database connection = op.get_bind() @@ -55,24 +53,18 @@ def upgrade(): # Iterate over each row to extract and update the `path` from `meta` column for row in results: - if "path" in row.meta: + if 'path' in row.meta: # Extract the `path` field from the `meta` JSON - path = row.meta.get("path") + path = row.meta.get('path') # Update the `file` table with the new `path` value - connection.execute( - file_table.update() - .where(file_table.c.id == row.id) - .values({"path": path}) - ) + connection.execute(file_table.update().where(file_table.c.id == row.id).values({'path': path})) def downgrade(): # 1. Remove the `path` column - op.drop_column("file", "path") + op.drop_column('file', 'path') # 2. Revert the `meta` column back to Text/JSONField - with op.batch_alter_table("file", schema=None) as batch_op: - batch_op.alter_column( - "meta", type_=sa.Text(), existing_type=sa.JSON(), existing_nullable=True - ) + with op.batch_alter_table('file', schema=None) as batch_op: + batch_op.alter_column('meta', type_=sa.Text(), existing_type=sa.JSON(), existing_nullable=True) diff --git a/backend/open_webui/migrations/versions/c440947495f3_add_chat_file_table.py b/backend/open_webui/migrations/versions/c440947495f3_add_chat_file_table.py index fa818e1f08..0eae928b91 100644 --- a/backend/open_webui/migrations/versions/c440947495f3_add_chat_file_table.py +++ b/backend/open_webui/migrations/versions/c440947495f3_add_chat_file_table.py @@ -12,45 +12,43 @@ from alembic import op import sqlalchemy as sa # revision identifiers, used by Alembic. -revision: str = "c440947495f3" -down_revision: Union[str, None] = "81cc2ce44d79" +revision: str = 'c440947495f3' +down_revision: Union[str, None] = '81cc2ce44d79' branch_labels: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None def upgrade() -> None: op.create_table( - "chat_file", - sa.Column("id", sa.Text(), primary_key=True), - sa.Column("user_id", sa.Text(), nullable=False), + 'chat_file', + sa.Column('id', sa.Text(), primary_key=True), + sa.Column('user_id', sa.Text(), nullable=False), sa.Column( - "chat_id", + 'chat_id', sa.Text(), - sa.ForeignKey("chat.id", ondelete="CASCADE"), + sa.ForeignKey('chat.id', ondelete='CASCADE'), nullable=False, ), sa.Column( - "file_id", + 'file_id', sa.Text(), - sa.ForeignKey("file.id", ondelete="CASCADE"), + sa.ForeignKey('file.id', ondelete='CASCADE'), nullable=False, ), - sa.Column("message_id", sa.Text(), nullable=True), - sa.Column("created_at", sa.BigInteger(), nullable=False), - sa.Column("updated_at", sa.BigInteger(), nullable=False), + sa.Column('message_id', sa.Text(), nullable=True), + sa.Column('created_at', sa.BigInteger(), nullable=False), + sa.Column('updated_at', sa.BigInteger(), nullable=False), # indexes - sa.Index("ix_chat_file_chat_id", "chat_id"), - sa.Index("ix_chat_file_file_id", "file_id"), - sa.Index("ix_chat_file_message_id", "message_id"), - sa.Index("ix_chat_file_user_id", "user_id"), + sa.Index('ix_chat_file_chat_id', 'chat_id'), + sa.Index('ix_chat_file_file_id', 'file_id'), + sa.Index('ix_chat_file_message_id', 'message_id'), + sa.Index('ix_chat_file_user_id', 'user_id'), # unique constraints - sa.UniqueConstraint( - "chat_id", "file_id", name="uq_chat_file_chat_file" - ), # prevent duplicate entries + sa.UniqueConstraint('chat_id', 'file_id', name='uq_chat_file_chat_file'), # prevent duplicate entries ) pass def downgrade() -> None: - op.drop_table("chat_file") + op.drop_table('chat_file') pass diff --git a/backend/open_webui/migrations/versions/c69f45358db4_add_folder_table.py b/backend/open_webui/migrations/versions/c69f45358db4_add_folder_table.py index 83e0dc28ed..c9572fe7a3 100644 --- a/backend/open_webui/migrations/versions/c69f45358db4_add_folder_table.py +++ b/backend/open_webui/migrations/versions/c69f45358db4_add_folder_table.py @@ -9,42 +9,40 @@ Create Date: 2024-10-16 02:02:35.241684 from alembic import op import sqlalchemy as sa -revision = "c69f45358db4" -down_revision = "3ab32c4b8f59" +revision = 'c69f45358db4' +down_revision = '3ab32c4b8f59' branch_labels = None depends_on = None def upgrade(): op.create_table( - "folder", - sa.Column("id", sa.Text(), nullable=False), - sa.Column("parent_id", sa.Text(), nullable=True), - sa.Column("user_id", sa.Text(), nullable=False), - sa.Column("name", sa.Text(), nullable=False), - sa.Column("items", sa.JSON(), nullable=True), - sa.Column("meta", sa.JSON(), nullable=True), - sa.Column("is_expanded", sa.Boolean(), default=False, nullable=False), + 'folder', + sa.Column('id', sa.Text(), nullable=False), + sa.Column('parent_id', sa.Text(), nullable=True), + sa.Column('user_id', sa.Text(), nullable=False), + sa.Column('name', sa.Text(), nullable=False), + sa.Column('items', sa.JSON(), nullable=True), + sa.Column('meta', sa.JSON(), nullable=True), + sa.Column('is_expanded', sa.Boolean(), default=False, nullable=False), + sa.Column('created_at', sa.DateTime(), server_default=sa.func.now(), nullable=False), sa.Column( - "created_at", sa.DateTime(), server_default=sa.func.now(), nullable=False - ), - sa.Column( - "updated_at", + 'updated_at', sa.DateTime(), nullable=False, server_default=sa.func.now(), onupdate=sa.func.now(), ), - sa.PrimaryKeyConstraint("id", "user_id"), + sa.PrimaryKeyConstraint('id', 'user_id'), ) op.add_column( - "chat", - sa.Column("folder_id", sa.Text(), nullable=True), + 'chat', + sa.Column('folder_id', sa.Text(), nullable=True), ) def downgrade(): - op.drop_column("chat", "folder_id") + op.drop_column('chat', 'folder_id') - op.drop_table("folder") + op.drop_table('folder') diff --git a/backend/open_webui/migrations/versions/ca81bd47c050_add_config_table.py b/backend/open_webui/migrations/versions/ca81bd47c050_add_config_table.py index 1540aa6a7f..5fdf933dd6 100644 --- a/backend/open_webui/migrations/versions/ca81bd47c050_add_config_table.py +++ b/backend/open_webui/migrations/versions/ca81bd47c050_add_config_table.py @@ -12,23 +12,21 @@ import sqlalchemy as sa from alembic import op # revision identifiers, used by Alembic. -revision: str = "ca81bd47c050" -down_revision: Union[str, None] = "7e5b5dc7342b" +revision: str = 'ca81bd47c050' +down_revision: Union[str, None] = '7e5b5dc7342b' branch_labels: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None def upgrade(): op.create_table( - "config", - sa.Column("id", sa.Integer, primary_key=True), - sa.Column("data", sa.JSON(), nullable=False), - sa.Column("version", sa.Integer, nullable=False), + 'config', + sa.Column('id', sa.Integer, primary_key=True), + sa.Column('data', sa.JSON(), nullable=False), + sa.Column('version', sa.Integer, nullable=False), + sa.Column('created_at', sa.DateTime(), nullable=False, server_default=sa.func.now()), sa.Column( - "created_at", sa.DateTime(), nullable=False, server_default=sa.func.now() - ), - sa.Column( - "updated_at", + 'updated_at', sa.DateTime(), nullable=True, server_default=sa.func.now(), @@ -38,4 +36,4 @@ def upgrade(): def downgrade(): - op.drop_table("config") + op.drop_table('config') diff --git a/backend/open_webui/migrations/versions/d31026856c01_update_folder_table_data.py b/backend/open_webui/migrations/versions/d31026856c01_update_folder_table_data.py index 3c916964e9..444e131db7 100644 --- a/backend/open_webui/migrations/versions/d31026856c01_update_folder_table_data.py +++ b/backend/open_webui/migrations/versions/d31026856c01_update_folder_table_data.py @@ -9,15 +9,15 @@ Create Date: 2025-07-13 03:00:00.000000 from alembic import op import sqlalchemy as sa -revision = "d31026856c01" -down_revision = "9f0c9cd09105" +revision = 'd31026856c01' +down_revision = '9f0c9cd09105' branch_labels = None depends_on = None def upgrade(): - op.add_column("folder", sa.Column("data", sa.JSON(), nullable=True)) + op.add_column('folder', sa.Column('data', sa.JSON(), nullable=True)) def downgrade(): - op.drop_column("folder", "data") + op.drop_column('folder', 'data') diff --git a/backend/open_webui/migrations/versions/f1e2d3c4b5a6_add_access_grant_table.py b/backend/open_webui/migrations/versions/f1e2d3c4b5a6_add_access_grant_table.py index 5569718dd8..5ed572cf7a 100644 --- a/backend/open_webui/migrations/versions/f1e2d3c4b5a6_add_access_grant_table.py +++ b/backend/open_webui/migrations/versions/f1e2d3c4b5a6_add_access_grant_table.py @@ -20,8 +20,8 @@ import sqlalchemy as sa from open_webui.migrations.util import get_existing_tables -revision: str = "f1e2d3c4b5a6" -down_revision: Union[str, None] = "8452d01d26d7" +revision: str = 'f1e2d3c4b5a6' +down_revision: Union[str, None] = '8452d01d26d7' branch_labels: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None @@ -30,34 +30,34 @@ def upgrade() -> None: existing_tables = set(get_existing_tables()) # Create access_grant table - if "access_grant" not in existing_tables: + if 'access_grant' not in existing_tables: op.create_table( - "access_grant", - sa.Column("id", sa.Text(), nullable=False, primary_key=True), - sa.Column("resource_type", sa.Text(), nullable=False), - sa.Column("resource_id", sa.Text(), nullable=False), - sa.Column("principal_type", sa.Text(), nullable=False), - sa.Column("principal_id", sa.Text(), nullable=False), - sa.Column("permission", sa.Text(), nullable=False), - sa.Column("created_at", sa.BigInteger(), nullable=False), + 'access_grant', + sa.Column('id', sa.Text(), nullable=False, primary_key=True), + sa.Column('resource_type', sa.Text(), nullable=False), + sa.Column('resource_id', sa.Text(), nullable=False), + sa.Column('principal_type', sa.Text(), nullable=False), + sa.Column('principal_id', sa.Text(), nullable=False), + sa.Column('permission', sa.Text(), nullable=False), + sa.Column('created_at', sa.BigInteger(), nullable=False), sa.UniqueConstraint( - "resource_type", - "resource_id", - "principal_type", - "principal_id", - "permission", - name="uq_access_grant_grant", + 'resource_type', + 'resource_id', + 'principal_type', + 'principal_id', + 'permission', + name='uq_access_grant_grant', ), ) op.create_index( - "idx_access_grant_resource", - "access_grant", - ["resource_type", "resource_id"], + 'idx_access_grant_resource', + 'access_grant', + ['resource_type', 'resource_id'], ) op.create_index( - "idx_access_grant_principal", - "access_grant", - ["principal_type", "principal_id"], + 'idx_access_grant_principal', + 'access_grant', + ['principal_type', 'principal_id'], ) # Backfill existing access_control JSON data @@ -65,13 +65,13 @@ def upgrade() -> None: # Tables with access_control JSON columns: (table_name, resource_type) resource_tables = [ - ("knowledge", "knowledge"), - ("prompt", "prompt"), - ("tool", "tool"), - ("model", "model"), - ("note", "note"), - ("channel", "channel"), - ("file", "file"), + ('knowledge', 'knowledge'), + ('prompt', 'prompt'), + ('tool', 'tool'), + ('model', 'model'), + ('note', 'note'), + ('channel', 'channel'), + ('file', 'file'), ] now = int(time.time()) @@ -83,9 +83,7 @@ def upgrade() -> None: # Query all rows try: - result = conn.execute( - sa.text(f'SELECT id, access_control FROM "{table_name}"') - ) + result = conn.execute(sa.text(f'SELECT id, access_control FROM "{table_name}"')) rows = result.fetchall() except Exception: continue @@ -99,19 +97,16 @@ def upgrade() -> None: # EXCEPTION: files with NULL are PRIVATE (owner-only), not public is_null = ( access_control_json is None - or access_control_json == "null" - or ( - isinstance(access_control_json, str) - and access_control_json.strip().lower() == "null" - ) + or access_control_json == 'null' + or (isinstance(access_control_json, str) and access_control_json.strip().lower() == 'null') ) if is_null: # Files: NULL = private (no entry needed, owner has implicit access) # Other resources: NULL = public (insert user:* for read) - if resource_type == "file": + if resource_type == 'file': continue # Private - no entry needed - key = (resource_type, resource_id, "user", "*", "read") + key = (resource_type, resource_id, 'user', '*', 'read') if key not in inserted: try: conn.execute( @@ -120,13 +115,13 @@ def upgrade() -> None: VALUES (:id, :resource_type, :resource_id, :principal_type, :principal_id, :permission, :created_at) """), { - "id": str(uuid.uuid4()), - "resource_type": resource_type, - "resource_id": resource_id, - "principal_type": "user", - "principal_id": "*", - "permission": "read", - "created_at": now, + 'id': str(uuid.uuid4()), + 'resource_type': resource_type, + 'resource_id': resource_id, + 'principal_type': 'user', + 'principal_id': '*', + 'permission': 'read', + 'created_at': now, }, ) inserted.add(key) @@ -149,28 +144,24 @@ def upgrade() -> None: continue # Check if it's effectively empty (no read/write keys with content) - read_data = access_control_json.get("read", {}) - write_data = access_control_json.get("write", {}) + read_data = access_control_json.get('read', {}) + write_data = access_control_json.get('write', {}) - has_read_grants = read_data.get("group_ids", []) or read_data.get( - "user_ids", [] - ) - has_write_grants = write_data.get("group_ids", []) or write_data.get( - "user_ids", [] - ) + has_read_grants = read_data.get('group_ids', []) or read_data.get('user_ids', []) + has_write_grants = write_data.get('group_ids', []) or write_data.get('user_ids', []) if not has_read_grants and not has_write_grants: # Empty permissions = private, no grants needed continue # Extract permissions and insert into access_grant table - for permission in ["read", "write"]: + for permission in ['read', 'write']: perm_data = access_control_json.get(permission, {}) if not perm_data: continue - for group_id in perm_data.get("group_ids", []): - key = (resource_type, resource_id, "group", group_id, permission) + for group_id in perm_data.get('group_ids', []): + key = (resource_type, resource_id, 'group', group_id, permission) if key in inserted: continue try: @@ -180,21 +171,21 @@ def upgrade() -> None: VALUES (:id, :resource_type, :resource_id, :principal_type, :principal_id, :permission, :created_at) """), { - "id": str(uuid.uuid4()), - "resource_type": resource_type, - "resource_id": resource_id, - "principal_type": "group", - "principal_id": group_id, - "permission": permission, - "created_at": now, + 'id': str(uuid.uuid4()), + 'resource_type': resource_type, + 'resource_id': resource_id, + 'principal_type': 'group', + 'principal_id': group_id, + 'permission': permission, + 'created_at': now, }, ) inserted.add(key) except Exception: pass - for user_id in perm_data.get("user_ids", []): - key = (resource_type, resource_id, "user", user_id, permission) + for user_id in perm_data.get('user_ids', []): + key = (resource_type, resource_id, 'user', user_id, permission) if key in inserted: continue try: @@ -204,13 +195,13 @@ def upgrade() -> None: VALUES (:id, :resource_type, :resource_id, :principal_type, :principal_id, :permission, :created_at) """), { - "id": str(uuid.uuid4()), - "resource_type": resource_type, - "resource_id": resource_id, - "principal_type": "user", - "principal_id": user_id, - "permission": permission, - "created_at": now, + 'id': str(uuid.uuid4()), + 'resource_type': resource_type, + 'resource_id': resource_id, + 'principal_type': 'user', + 'principal_id': user_id, + 'permission': permission, + 'created_at': now, }, ) inserted.add(key) @@ -223,7 +214,7 @@ def upgrade() -> None: continue try: with op.batch_alter_table(table_name) as batch: - batch.drop_column("access_control") + batch.drop_column('access_control') except Exception: pass @@ -235,20 +226,20 @@ def downgrade() -> None: # Resource tables mapping: (table_name, resource_type) resource_tables = [ - ("knowledge", "knowledge"), - ("prompt", "prompt"), - ("tool", "tool"), - ("model", "model"), - ("note", "note"), - ("channel", "channel"), - ("file", "file"), + ('knowledge', 'knowledge'), + ('prompt', 'prompt'), + ('tool', 'tool'), + ('model', 'model'), + ('note', 'note'), + ('channel', 'channel'), + ('file', 'file'), ] # Step 1: Re-add access_control columns to resource tables for table_name, _ in resource_tables: try: with op.batch_alter_table(table_name) as batch: - batch.add_column(sa.Column("access_control", sa.JSON(), nullable=True)) + batch.add_column(sa.Column('access_control', sa.JSON(), nullable=True)) except Exception: pass @@ -262,7 +253,7 @@ def downgrade() -> None: FROM access_grant WHERE resource_type = :resource_type """), - {"resource_type": resource_type}, + {'resource_type': resource_type}, ) rows = result.fetchall() except Exception: @@ -278,49 +269,35 @@ def downgrade() -> None: if resource_id not in resource_grants: resource_grants[resource_id] = { - "is_public": False, - "read": {"group_ids": [], "user_ids": []}, - "write": {"group_ids": [], "user_ids": []}, + 'is_public': False, + 'read': {'group_ids': [], 'user_ids': []}, + 'write': {'group_ids': [], 'user_ids': []}, } # Handle public access (user:* for read) - if ( - principal_type == "user" - and principal_id == "*" - and permission == "read" - ): - resource_grants[resource_id]["is_public"] = True + if principal_type == 'user' and principal_id == '*' and permission == 'read': + resource_grants[resource_id]['is_public'] = True continue # Add to appropriate list - if permission in ["read", "write"]: - if principal_type == "group": - if ( - principal_id - not in resource_grants[resource_id][permission]["group_ids"] - ): - resource_grants[resource_id][permission]["group_ids"].append( - principal_id - ) - elif principal_type == "user": - if ( - principal_id - not in resource_grants[resource_id][permission]["user_ids"] - ): - resource_grants[resource_id][permission]["user_ids"].append( - principal_id - ) + if permission in ['read', 'write']: + if principal_type == 'group': + if principal_id not in resource_grants[resource_id][permission]['group_ids']: + resource_grants[resource_id][permission]['group_ids'].append(principal_id) + elif principal_type == 'user': + if principal_id not in resource_grants[resource_id][permission]['user_ids']: + resource_grants[resource_id][permission]['user_ids'].append(principal_id) # Step 3: Update each resource with reconstructed JSON for resource_id, grants in resource_grants.items(): - if grants["is_public"]: + if grants['is_public']: # Public = NULL access_control_value = None elif ( - not grants["read"]["group_ids"] - and not grants["read"]["user_ids"] - and not grants["write"]["group_ids"] - and not grants["write"]["user_ids"] + not grants['read']['group_ids'] + and not grants['read']['user_ids'] + and not grants['write']['group_ids'] + and not grants['write']['user_ids'] ): # No grants = should not happen (would mean no entries), default to {} access_control_value = json.dumps({}) @@ -328,17 +305,15 @@ def downgrade() -> None: # Custom permissions access_control_value = json.dumps( { - "read": grants["read"], - "write": grants["write"], + 'read': grants['read'], + 'write': grants['write'], } ) try: conn.execute( - sa.text( - f'UPDATE "{table_name}" SET access_control = :access_control WHERE id = :id' - ), - {"access_control": access_control_value, "id": resource_id}, + sa.text(f'UPDATE "{table_name}" SET access_control = :access_control WHERE id = :id'), + {'access_control': access_control_value, 'id': resource_id}, ) except Exception: pass @@ -346,7 +321,7 @@ def downgrade() -> None: # Step 4: Set all resources WITHOUT entries to private # For files: NULL means private (owner-only), so leave as NULL # For other resources: {} means private, so update to {} - if resource_type != "file": + if resource_type != 'file': try: conn.execute( sa.text(f""" @@ -357,13 +332,13 @@ def downgrade() -> None: ) AND access_control IS NULL """), - {"private_value": json.dumps({}), "resource_type": resource_type}, + {'private_value': json.dumps({}), 'resource_type': resource_type}, ) except Exception: pass # For files, NULL stays NULL - no action needed # Step 5: Drop the access_grant table - op.drop_index("idx_access_grant_principal", table_name="access_grant") - op.drop_index("idx_access_grant_resource", table_name="access_grant") - op.drop_table("access_grant") + op.drop_index('idx_access_grant_principal', table_name='access_grant') + op.drop_index('idx_access_grant_resource', table_name='access_grant') + op.drop_table('access_grant') diff --git a/backend/open_webui/models/access_grants.py b/backend/open_webui/models/access_grants.py index 4519abc964..ee7f950ff5 100644 --- a/backend/open_webui/models/access_grants.py +++ b/backend/open_webui/models/access_grants.py @@ -19,28 +19,24 @@ log = logging.getLogger(__name__) class AccessGrant(Base): - __tablename__ = "access_grant" + __tablename__ = 'access_grant' id = Column(Text, primary_key=True) - resource_type = Column( - Text, nullable=False - ) # "knowledge", "model", "prompt", "tool", "note", "channel", "file" + resource_type = Column(Text, nullable=False) # "knowledge", "model", "prompt", "tool", "note", "channel", "file" resource_id = Column(Text, nullable=False) principal_type = Column(Text, nullable=False) # "user" or "group" - principal_id = Column( - Text, nullable=False - ) # user_id, group_id, or "*" (wildcard for public) + principal_id = Column(Text, nullable=False) # user_id, group_id, or "*" (wildcard for public) permission = Column(Text, nullable=False) # "read" or "write" created_at = Column(BigInteger, nullable=False) __table_args__ = ( UniqueConstraint( - "resource_type", - "resource_id", - "principal_type", - "principal_id", - "permission", - name="uq_access_grant_grant", + 'resource_type', + 'resource_id', + 'principal_type', + 'principal_id', + 'permission', + name='uq_access_grant_grant', ), ) @@ -66,7 +62,7 @@ class AccessGrantResponse(BaseModel): permission: str @classmethod - def from_grant(cls, grant: "AccessGrantModel") -> "AccessGrantResponse": + def from_grant(cls, grant: 'AccessGrantModel') -> 'AccessGrantResponse': return cls( id=grant.id, principal_type=grant.principal_type, @@ -100,14 +96,14 @@ def access_control_to_grants( if access_control is None: # NULL → public read (user:* for read) # Exception: files with NULL are private (owner-only), no grants needed - if resource_type != "file": + if resource_type != 'file': grants.append( { - "resource_type": resource_type, - "resource_id": resource_id, - "principal_type": "user", - "principal_id": "*", - "permission": "read", + 'resource_type': resource_type, + 'resource_id': resource_id, + 'principal_type': 'user', + 'principal_id': '*', + 'permission': 'read', } ) return grants @@ -117,30 +113,30 @@ def access_control_to_grants( return grants # Parse structured permissions - for permission in ["read", "write"]: + for permission in ['read', 'write']: perm_data = access_control.get(permission, {}) if not perm_data: continue - for group_id in perm_data.get("group_ids", []): + for group_id in perm_data.get('group_ids', []): grants.append( { - "resource_type": resource_type, - "resource_id": resource_id, - "principal_type": "group", - "principal_id": group_id, - "permission": permission, + 'resource_type': resource_type, + 'resource_id': resource_id, + 'principal_type': 'group', + 'principal_id': group_id, + 'permission': permission, } ) - for user_id in perm_data.get("user_ids", []): + for user_id in perm_data.get('user_ids', []): grants.append( { - "resource_type": resource_type, - "resource_id": resource_id, - "principal_type": "user", - "principal_id": user_id, - "permission": permission, + 'resource_type': resource_type, + 'resource_id': resource_id, + 'principal_type': 'user', + 'principal_id': user_id, + 'permission': permission, } ) @@ -164,27 +160,23 @@ def normalize_access_grants(access_grants: Optional[list]) -> list[dict]: if not isinstance(grant, dict): continue - principal_type = grant.get("principal_type") - principal_id = grant.get("principal_id") - permission = grant.get("permission") + principal_type = grant.get('principal_type') + principal_id = grant.get('principal_id') + permission = grant.get('permission') - if principal_type not in ("user", "group"): + if principal_type not in ('user', 'group'): continue - if permission not in ("read", "write"): + if permission not in ('read', 'write'): continue if not isinstance(principal_id, str) or not principal_id: continue key = (principal_type, principal_id, permission) deduped[key] = { - "id": ( - grant.get("id") - if isinstance(grant.get("id"), str) and grant.get("id") - else str(uuid.uuid4()) - ), - "principal_type": principal_type, - "principal_id": principal_id, - "permission": permission, + 'id': (grant.get('id') if isinstance(grant.get('id'), str) and grant.get('id') else str(uuid.uuid4())), + 'principal_type': principal_type, + 'principal_id': principal_id, + 'permission': permission, } return list(deduped.values()) @@ -195,11 +187,7 @@ def has_public_read_access_grant(access_grants: Optional[list]) -> bool: Returns True when a direct grant list includes wildcard public-read. """ for grant in normalize_access_grants(access_grants): - if ( - grant["principal_type"] == "user" - and grant["principal_id"] == "*" - and grant["permission"] == "read" - ): + if grant['principal_type'] == 'user' and grant['principal_id'] == '*' and grant['permission'] == 'read': return True return False @@ -209,7 +197,7 @@ def has_user_access_grant(access_grants: Optional[list]) -> bool: Returns True when a direct grant list includes any non-wildcard user grant. """ for grant in normalize_access_grants(access_grants): - if grant["principal_type"] == "user" and grant["principal_id"] != "*": + if grant['principal_type'] == 'user' and grant['principal_id'] != '*': return True return False @@ -225,18 +213,9 @@ def strip_user_access_grants(access_grants: Optional[list]) -> list: grant for grant in access_grants if not ( - ( - grant.get("principal_type") - if isinstance(grant, dict) - else getattr(grant, "principal_type", None) - ) - == "user" - and ( - grant.get("principal_id") - if isinstance(grant, dict) - else getattr(grant, "principal_id", None) - ) - != "*" + (grant.get('principal_type') if isinstance(grant, dict) else getattr(grant, 'principal_type', None)) + == 'user' + and (grant.get('principal_id') if isinstance(grant, dict) else getattr(grant, 'principal_id', None)) != '*' ) ] @@ -260,29 +239,25 @@ def grants_to_access_control(grants: list) -> Optional[dict]: return {} # No grants = private/owner-only result = { - "read": {"group_ids": [], "user_ids": []}, - "write": {"group_ids": [], "user_ids": []}, + 'read': {'group_ids': [], 'user_ids': []}, + 'write': {'group_ids': [], 'user_ids': []}, } is_public = False for grant in grants: - if ( - grant.principal_type == "user" - and grant.principal_id == "*" - and grant.permission == "read" - ): + if grant.principal_type == 'user' and grant.principal_id == '*' and grant.permission == 'read': is_public = True continue # Don't add wildcard to user_ids list - if grant.permission not in ("read", "write"): + if grant.permission not in ('read', 'write'): continue - if grant.principal_type == "group": - if grant.principal_id not in result[grant.permission]["group_ids"]: - result[grant.permission]["group_ids"].append(grant.principal_id) - elif grant.principal_type == "user": - if grant.principal_id not in result[grant.permission]["user_ids"]: - result[grant.permission]["user_ids"].append(grant.principal_id) + if grant.principal_type == 'group': + if grant.principal_id not in result[grant.permission]['group_ids']: + result[grant.permission]['group_ids'].append(grant.principal_id) + elif grant.principal_type == 'user': + if grant.principal_id not in result[grant.permission]['user_ids']: + result[grant.permission]['user_ids'].append(grant.principal_id) if is_public: return None # Public read access @@ -399,9 +374,7 @@ class AccessGrantsTable: ).delete() # Convert JSON to grant dicts - grant_dicts = access_control_to_grants( - resource_type, resource_id, access_control - ) + grant_dicts = access_control_to_grants(resource_type, resource_id, access_control) # Insert new grants results = [] @@ -442,9 +415,9 @@ class AccessGrantsTable: id=str(uuid.uuid4()), resource_type=resource_type, resource_id=resource_id, - principal_type=grant_dict["principal_type"], - principal_id=grant_dict["principal_id"], - permission=grant_dict["permission"], + principal_type=grant_dict['principal_type'], + principal_id=grant_dict['principal_id'], + permission=grant_dict['permission'], created_at=int(time.time()), ) db.add(grant) @@ -511,9 +484,7 @@ class AccessGrantsTable: ) .all() ) - result: dict[str, list[AccessGrantModel]] = { - rid: [] for rid in resource_ids - } + result: dict[str, list[AccessGrantModel]] = {rid: [] for rid in resource_ids} for g in grants: result[g.resource_id].append(AccessGrantModel.model_validate(g)) return result @@ -523,7 +494,7 @@ class AccessGrantsTable: user_id: str, resource_type: str, resource_id: str, - permission: str = "read", + permission: str = 'read', user_group_ids: Optional[set[str]] = None, db: Optional[Session] = None, ) -> bool: @@ -540,12 +511,12 @@ class AccessGrantsTable: conditions = [ # Public access and_( - AccessGrant.principal_type == "user", - AccessGrant.principal_id == "*", + AccessGrant.principal_type == 'user', + AccessGrant.principal_id == '*', ), # Direct user access and_( - AccessGrant.principal_type == "user", + AccessGrant.principal_type == 'user', AccessGrant.principal_id == user_id, ), ] @@ -560,7 +531,7 @@ class AccessGrantsTable: if user_group_ids: conditions.append( and_( - AccessGrant.principal_type == "group", + AccessGrant.principal_type == 'group', AccessGrant.principal_id.in_(user_group_ids), ) ) @@ -582,7 +553,7 @@ class AccessGrantsTable: user_id: str, resource_type: str, resource_ids: list[str], - permission: str = "read", + permission: str = 'read', user_group_ids: Optional[set[str]] = None, db: Optional[Session] = None, ) -> set[str]: @@ -597,11 +568,11 @@ class AccessGrantsTable: with get_db_context(db) as db: conditions = [ and_( - AccessGrant.principal_type == "user", - AccessGrant.principal_id == "*", + AccessGrant.principal_type == 'user', + AccessGrant.principal_id == '*', ), and_( - AccessGrant.principal_type == "user", + AccessGrant.principal_type == 'user', AccessGrant.principal_id == user_id, ), ] @@ -615,7 +586,7 @@ class AccessGrantsTable: if user_group_ids: conditions.append( and_( - AccessGrant.principal_type == "group", + AccessGrant.principal_type == 'group', AccessGrant.principal_id.in_(user_group_ids), ) ) @@ -637,7 +608,7 @@ class AccessGrantsTable: self, resource_type: str, resource_id: str, - permission: str = "read", + permission: str = 'read', db: Optional[Session] = None, ) -> list: """ @@ -660,19 +631,17 @@ class AccessGrantsTable: # Check for public access for grant in grants: - if grant.principal_type == "user" and grant.principal_id == "*": - result = Users.get_users(filter={"roles": ["!pending"]}, db=db) - return result.get("users", []) + if grant.principal_type == 'user' and grant.principal_id == '*': + result = Users.get_users(filter={'roles': ['!pending']}, db=db) + return result.get('users', []) user_ids_with_access = set() for grant in grants: - if grant.principal_type == "user": + if grant.principal_type == 'user': user_ids_with_access.add(grant.principal_id) - elif grant.principal_type == "group": - group_user_ids = Groups.get_group_user_ids_by_id( - grant.principal_id, db=db - ) + elif grant.principal_type == 'group': + group_user_ids = Groups.get_group_user_ids_by_id(grant.principal_id, db=db) if group_user_ids: user_ids_with_access.update(group_user_ids) @@ -688,20 +657,18 @@ class AccessGrantsTable: DocumentModel, filter: dict, resource_type: str, - permission: str = "read", + permission: str = 'read', ): """ Apply access control filtering to a SQLAlchemy query by JOINing with access_grant. This replaces the old JSON-column-based filtering with a proper relational JOIN. """ - group_ids = filter.get("group_ids", []) - user_id = filter.get("user_id") + group_ids = filter.get('group_ids', []) + user_id = filter.get('user_id') - if permission == "read_only": - return self._has_read_only_permission_filter( - db, query, DocumentModel, filter, resource_type - ) + if permission == 'read_only': + return self._has_read_only_permission_filter(db, query, DocumentModel, filter, resource_type) # Build principal conditions principal_conditions = [] @@ -710,8 +677,8 @@ class AccessGrantsTable: # Public access: user:* read principal_conditions.append( and_( - AccessGrant.principal_type == "user", - AccessGrant.principal_id == "*", + AccessGrant.principal_type == 'user', + AccessGrant.principal_id == '*', ) ) @@ -722,7 +689,7 @@ class AccessGrantsTable: # Direct user grant principal_conditions.append( and_( - AccessGrant.principal_type == "user", + AccessGrant.principal_type == 'user', AccessGrant.principal_id == user_id, ) ) @@ -731,7 +698,7 @@ class AccessGrantsTable: # Group grants principal_conditions.append( and_( - AccessGrant.principal_type == "group", + AccessGrant.principal_type == 'group', AccessGrant.principal_id.in_(group_ids), ) ) @@ -751,13 +718,13 @@ class AccessGrantsTable: AccessGrant.permission == permission, or_( and_( - AccessGrant.principal_type == "user", - AccessGrant.principal_id == "*", + AccessGrant.principal_type == 'user', + AccessGrant.principal_id == '*', ), *( [ and_( - AccessGrant.principal_type == "user", + AccessGrant.principal_type == 'user', AccessGrant.principal_id == user_id, ) ] @@ -767,7 +734,7 @@ class AccessGrantsTable: *( [ and_( - AccessGrant.principal_type == "group", + AccessGrant.principal_type == 'group', AccessGrant.principal_id.in_(group_ids), ) ] @@ -800,8 +767,8 @@ class AccessGrantsTable: Filter for items where user has read BUT NOT write access. Public items are NOT considered read_only. """ - group_ids = filter.get("group_ids", []) - user_id = filter.get("user_id") + group_ids = filter.get('group_ids', []) + user_id = filter.get('user_id') from sqlalchemy import exists as sa_exists, select @@ -811,12 +778,12 @@ class AccessGrantsTable: .where( AccessGrant.resource_type == resource_type, AccessGrant.resource_id == DocumentModel.id, - AccessGrant.permission == "read", + AccessGrant.permission == 'read', or_( *( [ and_( - AccessGrant.principal_type == "user", + AccessGrant.principal_type == 'user', AccessGrant.principal_id == user_id, ) ] @@ -826,7 +793,7 @@ class AccessGrantsTable: *( [ and_( - AccessGrant.principal_type == "group", + AccessGrant.principal_type == 'group', AccessGrant.principal_id.in_(group_ids), ) ] @@ -845,12 +812,12 @@ class AccessGrantsTable: .where( AccessGrant.resource_type == resource_type, AccessGrant.resource_id == DocumentModel.id, - AccessGrant.permission == "write", + AccessGrant.permission == 'write', or_( *( [ and_( - AccessGrant.principal_type == "user", + AccessGrant.principal_type == 'user', AccessGrant.principal_id == user_id, ) ] @@ -860,7 +827,7 @@ class AccessGrantsTable: *( [ and_( - AccessGrant.principal_type == "group", + AccessGrant.principal_type == 'group', AccessGrant.principal_id.in_(group_ids), ) ] @@ -879,9 +846,9 @@ class AccessGrantsTable: .where( AccessGrant.resource_type == resource_type, AccessGrant.resource_id == DocumentModel.id, - AccessGrant.permission == "read", - AccessGrant.principal_type == "user", - AccessGrant.principal_id == "*", + AccessGrant.permission == 'read', + AccessGrant.principal_type == 'user', + AccessGrant.principal_id == '*', ) .correlate(DocumentModel) .exists() diff --git a/backend/open_webui/models/auths.py b/backend/open_webui/models/auths.py index 25fb873b95..1a1b164c12 100644 --- a/backend/open_webui/models/auths.py +++ b/backend/open_webui/models/auths.py @@ -17,7 +17,7 @@ log = logging.getLogger(__name__) class Auth(Base): - __tablename__ = "auth" + __tablename__ = 'auth' id = Column(String, primary_key=True, unique=True) email = Column(String) @@ -73,9 +73,9 @@ class SignupForm(BaseModel): name: str email: str password: str - profile_image_url: Optional[str] = "/user.png" + profile_image_url: Optional[str] = '/user.png' - @field_validator("profile_image_url") + @field_validator('profile_image_url') @classmethod def check_profile_image_url(cls, v: Optional[str]) -> Optional[str]: if v is not None: @@ -84,7 +84,7 @@ class SignupForm(BaseModel): class AddUserForm(SignupForm): - role: Optional[str] = "pending" + role: Optional[str] = 'pending' class AuthsTable: @@ -93,25 +93,21 @@ class AuthsTable: email: str, password: str, name: str, - profile_image_url: str = "/user.png", - role: str = "pending", + profile_image_url: str = '/user.png', + role: str = 'pending', oauth: Optional[dict] = None, db: Optional[Session] = None, ) -> Optional[UserModel]: with get_db_context(db) as db: - log.info("insert_new_auth") + log.info('insert_new_auth') id = str(uuid.uuid4()) - auth = AuthModel( - **{"id": id, "email": email, "password": password, "active": True} - ) + auth = AuthModel(**{'id': id, 'email': email, 'password': password, 'active': True}) result = Auth(**auth.model_dump()) db.add(result) - user = Users.insert_new_user( - id, name, email, profile_image_url, role, oauth=oauth, db=db - ) + user = Users.insert_new_user(id, name, email, profile_image_url, role, oauth=oauth, db=db) db.commit() db.refresh(result) @@ -124,7 +120,7 @@ class AuthsTable: def authenticate_user( self, email: str, verify_password: callable, db: Optional[Session] = None ) -> Optional[UserModel]: - log.info(f"authenticate_user: {email}") + log.info(f'authenticate_user: {email}') user = Users.get_user_by_email(email, db=db) if not user: @@ -143,10 +139,8 @@ class AuthsTable: except Exception: return None - def authenticate_user_by_api_key( - self, api_key: str, db: Optional[Session] = None - ) -> Optional[UserModel]: - log.info(f"authenticate_user_by_api_key") + def authenticate_user_by_api_key(self, api_key: str, db: Optional[Session] = None) -> Optional[UserModel]: + log.info(f'authenticate_user_by_api_key') # if no api_key, return None if not api_key: return None @@ -157,10 +151,8 @@ class AuthsTable: except Exception: return False - def authenticate_user_by_email( - self, email: str, db: Optional[Session] = None - ) -> Optional[UserModel]: - log.info(f"authenticate_user_by_email: {email}") + def authenticate_user_by_email(self, email: str, db: Optional[Session] = None) -> Optional[UserModel]: + log.info(f'authenticate_user_by_email: {email}') try: with get_db_context(db) as db: # Single JOIN query instead of two separate queries @@ -177,28 +169,22 @@ class AuthsTable: except Exception: return None - def update_user_password_by_id( - self, id: str, new_password: str, db: Optional[Session] = None - ) -> bool: + def update_user_password_by_id(self, id: str, new_password: str, db: Optional[Session] = None) -> bool: try: with get_db_context(db) as db: - result = ( - db.query(Auth).filter_by(id=id).update({"password": new_password}) - ) + result = db.query(Auth).filter_by(id=id).update({'password': new_password}) db.commit() return True if result == 1 else False except Exception: return False - def update_email_by_id( - self, id: str, email: str, db: Optional[Session] = None - ) -> bool: + def update_email_by_id(self, id: str, email: str, db: Optional[Session] = None) -> bool: try: with get_db_context(db) as db: - result = db.query(Auth).filter_by(id=id).update({"email": email}) + result = db.query(Auth).filter_by(id=id).update({'email': email}) db.commit() if result == 1: - Users.update_user_by_id(id, {"email": email}, db=db) + Users.update_user_by_id(id, {'email': email}, db=db) return True return False except Exception: diff --git a/backend/open_webui/models/channels.py b/backend/open_webui/models/channels.py index e212789a44..4d773491d5 100644 --- a/backend/open_webui/models/channels.py +++ b/backend/open_webui/models/channels.py @@ -37,7 +37,7 @@ from sqlalchemy.sql import exists class Channel(Base): - __tablename__ = "channel" + __tablename__ = 'channel' id = Column(Text, primary_key=True, unique=True) user_id = Column(Text) @@ -94,7 +94,7 @@ class ChannelModel(BaseModel): class ChannelMember(Base): - __tablename__ = "channel_member" + __tablename__ = 'channel_member' id = Column(Text, primary_key=True, unique=True) channel_id = Column(Text, nullable=False) @@ -154,25 +154,19 @@ class ChannelMemberModel(BaseModel): class ChannelFile(Base): - __tablename__ = "channel_file" + __tablename__ = 'channel_file' id = Column(Text, unique=True, primary_key=True) user_id = Column(Text, nullable=False) - channel_id = Column( - Text, ForeignKey("channel.id", ondelete="CASCADE"), nullable=False - ) - message_id = Column( - Text, ForeignKey("message.id", ondelete="CASCADE"), nullable=True - ) - file_id = Column(Text, ForeignKey("file.id", ondelete="CASCADE"), nullable=False) + channel_id = Column(Text, ForeignKey('channel.id', ondelete='CASCADE'), nullable=False) + message_id = Column(Text, ForeignKey('message.id', ondelete='CASCADE'), nullable=True) + file_id = Column(Text, ForeignKey('file.id', ondelete='CASCADE'), nullable=False) created_at = Column(BigInteger, nullable=False) updated_at = Column(BigInteger, nullable=False) - __table_args__ = ( - UniqueConstraint("channel_id", "file_id", name="uq_channel_file_channel_file"), - ) + __table_args__ = (UniqueConstraint('channel_id', 'file_id', name='uq_channel_file_channel_file'),) class ChannelFileModel(BaseModel): @@ -189,7 +183,7 @@ class ChannelFileModel(BaseModel): class ChannelWebhook(Base): - __tablename__ = "channel_webhook" + __tablename__ = 'channel_webhook' id = Column(Text, primary_key=True, unique=True) channel_id = Column(Text, nullable=False) @@ -235,7 +229,7 @@ class ChannelResponse(ChannelModel): class ChannelForm(BaseModel): - name: str = "" + name: str = '' description: Optional[str] = None is_private: Optional[bool] = None data: Optional[dict] = None @@ -255,10 +249,8 @@ class ChannelWebhookForm(BaseModel): class ChannelTable: - def _get_access_grants( - self, channel_id: str, db: Optional[Session] = None - ) -> list[AccessGrantModel]: - return AccessGrants.get_grants_by_resource("channel", channel_id, db=db) + def _get_access_grants(self, channel_id: str, db: Optional[Session] = None) -> list[AccessGrantModel]: + return AccessGrants.get_grants_by_resource('channel', channel_id, db=db) def _to_channel_model( self, @@ -266,13 +258,9 @@ class ChannelTable: access_grants: Optional[list[AccessGrantModel]] = None, db: Optional[Session] = None, ) -> ChannelModel: - channel_data = ChannelModel.model_validate(channel).model_dump( - exclude={"access_grants"} - ) - channel_data["access_grants"] = ( - access_grants - if access_grants is not None - else self._get_access_grants(channel_data["id"], db=db) + channel_data = ChannelModel.model_validate(channel).model_dump(exclude={'access_grants'}) + channel_data['access_grants'] = ( + access_grants if access_grants is not None else self._get_access_grants(channel_data['id'], db=db) ) return ChannelModel.model_validate(channel_data) @@ -313,20 +301,20 @@ class ChannelTable: for uid in user_ids: model = ChannelMemberModel( **{ - "id": str(uuid.uuid4()), - "channel_id": channel_id, - "user_id": uid, - "status": "joined", - "is_active": True, - "is_channel_muted": False, - "is_channel_pinned": False, - "invited_at": now, - "invited_by": invited_by, - "joined_at": now, - "left_at": None, - "last_read_at": now, - "created_at": now, - "updated_at": now, + 'id': str(uuid.uuid4()), + 'channel_id': channel_id, + 'user_id': uid, + 'status': 'joined', + 'is_active': True, + 'is_channel_muted': False, + 'is_channel_pinned': False, + 'invited_at': now, + 'invited_by': invited_by, + 'joined_at': now, + 'left_at': None, + 'last_read_at': now, + 'created_at': now, + 'updated_at': now, } ) memberships.append(ChannelMember(**model.model_dump())) @@ -339,19 +327,19 @@ class ChannelTable: with get_db_context(db) as db: channel = ChannelModel( **{ - **form_data.model_dump(exclude={"access_grants"}), - "type": form_data.type if form_data.type else None, - "name": form_data.name.lower(), - "id": str(uuid.uuid4()), - "user_id": user_id, - "created_at": int(time.time_ns()), - "updated_at": int(time.time_ns()), - "access_grants": [], + **form_data.model_dump(exclude={'access_grants'}), + 'type': form_data.type if form_data.type else None, + 'name': form_data.name.lower(), + 'id': str(uuid.uuid4()), + 'user_id': user_id, + 'created_at': int(time.time_ns()), + 'updated_at': int(time.time_ns()), + 'access_grants': [], } ) - new_channel = Channel(**channel.model_dump(exclude={"access_grants"})) + new_channel = Channel(**channel.model_dump(exclude={'access_grants'})) - if form_data.type in ["group", "dm"]: + if form_data.type in ['group', 'dm']: users = self._collect_unique_user_ids( invited_by=user_id, user_ids=form_data.user_ids, @@ -366,18 +354,14 @@ class ChannelTable: db.add_all(memberships) db.add(new_channel) db.commit() - AccessGrants.set_access_grants( - "channel", new_channel.id, form_data.access_grants, db=db - ) + AccessGrants.set_access_grants('channel', new_channel.id, form_data.access_grants, db=db) return self._to_channel_model(new_channel, db=db) def get_channels(self, db: Optional[Session] = None) -> list[ChannelModel]: with get_db_context(db) as db: channels = db.query(Channel).all() channel_ids = [channel.id for channel in channels] - grants_map = AccessGrants.get_grants_by_resources( - "channel", channel_ids, db=db - ) + grants_map = AccessGrants.get_grants_by_resources('channel', channel_ids, db=db) return [ self._to_channel_model( channel, @@ -387,23 +371,19 @@ class ChannelTable: for channel in channels ] - def _has_permission(self, db, query, filter: dict, permission: str = "read"): + def _has_permission(self, db, query, filter: dict, permission: str = 'read'): return AccessGrants.has_permission_filter( db=db, query=query, DocumentModel=Channel, filter=filter, - resource_type="channel", + resource_type='channel', permission=permission, ) - def get_channels_by_user_id( - self, user_id: str, db: Optional[Session] = None - ) -> list[ChannelModel]: + def get_channels_by_user_id(self, user_id: str, db: Optional[Session] = None) -> list[ChannelModel]: with get_db_context(db) as db: - user_group_ids = [ - group.id for group in Groups.get_groups_by_member_id(user_id, db=db) - ] + user_group_ids = [group.id for group in Groups.get_groups_by_member_id(user_id, db=db)] membership_channels = ( db.query(Channel) @@ -411,7 +391,7 @@ class ChannelTable: .filter( Channel.deleted_at.is_(None), Channel.archived_at.is_(None), - Channel.type.in_(["group", "dm"]), + Channel.type.in_(['group', 'dm']), ChannelMember.user_id == user_id, ChannelMember.is_active.is_(True), ) @@ -423,29 +403,20 @@ class ChannelTable: Channel.archived_at.is_(None), or_( Channel.type.is_(None), # True NULL/None - Channel.type == "", # Empty string - and_(Channel.type != "group", Channel.type != "dm"), + Channel.type == '', # Empty string + and_(Channel.type != 'group', Channel.type != 'dm'), ), ) - query = self._has_permission( - db, query, {"user_id": user_id, "group_ids": user_group_ids} - ) + query = self._has_permission(db, query, {'user_id': user_id, 'group_ids': user_group_ids}) standard_channels = query.all() all_channels = membership_channels + standard_channels channel_ids = [c.id for c in all_channels] - grants_map = AccessGrants.get_grants_by_resources( - "channel", channel_ids, db=db - ) - return [ - self._to_channel_model(c, access_grants=grants_map.get(c.id, []), db=db) - for c in all_channels - ] + grants_map = AccessGrants.get_grants_by_resources('channel', channel_ids, db=db) + return [self._to_channel_model(c, access_grants=grants_map.get(c.id, []), db=db) for c in all_channels] - def get_dm_channel_by_user_ids( - self, user_ids: list[str], db: Optional[Session] = None - ) -> Optional[ChannelModel]: + def get_dm_channel_by_user_ids(self, user_ids: list[str], db: Optional[Session] = None) -> Optional[ChannelModel]: with get_db_context(db) as db: # Ensure uniqueness in case a list with duplicates is passed unique_user_ids = list(set(user_ids)) @@ -471,7 +442,7 @@ class ChannelTable: db.query(Channel) .filter( Channel.id.in_(subquery), - Channel.type == "dm", + Channel.type == 'dm', ) .first() ) @@ -488,32 +459,23 @@ class ChannelTable: ) -> list[ChannelMemberModel]: with get_db_context(db) as db: # 1. Collect all user_ids including groups + inviter - requested_users = self._collect_unique_user_ids( - invited_by, user_ids, group_ids - ) + requested_users = self._collect_unique_user_ids(invited_by, user_ids, group_ids) existing_users = { row.user_id - for row in db.query(ChannelMember.user_id) - .filter(ChannelMember.channel_id == channel_id) - .all() + for row in db.query(ChannelMember.user_id).filter(ChannelMember.channel_id == channel_id).all() } new_user_ids = requested_users - existing_users if not new_user_ids: return [] # Nothing to add - new_memberships = self._create_membership_models( - channel_id, invited_by, new_user_ids - ) + new_memberships = self._create_membership_models(channel_id, invited_by, new_user_ids) db.add_all(new_memberships) db.commit() - return [ - ChannelMemberModel.model_validate(membership) - for membership in new_memberships - ] + return [ChannelMemberModel.model_validate(membership) for membership in new_memberships] def remove_members_from_channel( self, @@ -533,9 +495,7 @@ class ChannelTable: db.commit() return result # number of rows deleted - def is_user_channel_manager( - self, channel_id: str, user_id: str, db: Optional[Session] = None - ) -> bool: + def is_user_channel_manager(self, channel_id: str, user_id: str, db: Optional[Session] = None) -> bool: with get_db_context(db) as db: # Check if the user is the creator of the channel # or has a 'manager' role in ChannelMember @@ -548,15 +508,13 @@ class ChannelTable: .filter( ChannelMember.channel_id == channel_id, ChannelMember.user_id == user_id, - ChannelMember.role == "manager", + ChannelMember.role == 'manager', ) .first() ) return membership is not None - def join_channel( - self, channel_id: str, user_id: str, db: Optional[Session] = None - ) -> Optional[ChannelMemberModel]: + def join_channel(self, channel_id: str, user_id: str, db: Optional[Session] = None) -> Optional[ChannelMemberModel]: with get_db_context(db) as db: # Check if the membership already exists existing_membership = ( @@ -573,18 +531,18 @@ class ChannelTable: # Create new membership channel_member = ChannelMemberModel( **{ - "id": str(uuid.uuid4()), - "channel_id": channel_id, - "user_id": user_id, - "status": "joined", - "is_active": True, - "is_channel_muted": False, - "is_channel_pinned": False, - "joined_at": int(time.time_ns()), - "left_at": None, - "last_read_at": int(time.time_ns()), - "created_at": int(time.time_ns()), - "updated_at": int(time.time_ns()), + 'id': str(uuid.uuid4()), + 'channel_id': channel_id, + 'user_id': user_id, + 'status': 'joined', + 'is_active': True, + 'is_channel_muted': False, + 'is_channel_pinned': False, + 'joined_at': int(time.time_ns()), + 'left_at': None, + 'last_read_at': int(time.time_ns()), + 'created_at': int(time.time_ns()), + 'updated_at': int(time.time_ns()), } ) new_membership = ChannelMember(**channel_member.model_dump()) @@ -593,9 +551,7 @@ class ChannelTable: db.commit() return channel_member - def leave_channel( - self, channel_id: str, user_id: str, db: Optional[Session] = None - ) -> bool: + def leave_channel(self, channel_id: str, user_id: str, db: Optional[Session] = None) -> bool: with get_db_context(db) as db: membership = ( db.query(ChannelMember) @@ -608,7 +564,7 @@ class ChannelTable: if not membership: return False - membership.status = "left" + membership.status = 'left' membership.is_active = False membership.left_at = int(time.time_ns()) membership.updated_at = int(time.time_ns()) @@ -630,19 +586,10 @@ class ChannelTable: ) return ChannelMemberModel.model_validate(membership) if membership else None - def get_members_by_channel_id( - self, channel_id: str, db: Optional[Session] = None - ) -> list[ChannelMemberModel]: + def get_members_by_channel_id(self, channel_id: str, db: Optional[Session] = None) -> list[ChannelMemberModel]: with get_db_context(db) as db: - memberships = ( - db.query(ChannelMember) - .filter(ChannelMember.channel_id == channel_id) - .all() - ) - return [ - ChannelMemberModel.model_validate(membership) - for membership in memberships - ] + memberships = db.query(ChannelMember).filter(ChannelMember.channel_id == channel_id).all() + return [ChannelMemberModel.model_validate(membership) for membership in memberships] def pin_channel( self, @@ -669,9 +616,7 @@ class ChannelTable: db.commit() return True - def update_member_last_read_at( - self, channel_id: str, user_id: str, db: Optional[Session] = None - ) -> bool: + def update_member_last_read_at(self, channel_id: str, user_id: str, db: Optional[Session] = None) -> bool: with get_db_context(db) as db: membership = ( db.query(ChannelMember) @@ -715,9 +660,7 @@ class ChannelTable: db.commit() return True - def is_user_channel_member( - self, channel_id: str, user_id: str, db: Optional[Session] = None - ) -> bool: + def is_user_channel_member(self, channel_id: str, user_id: str, db: Optional[Session] = None) -> bool: with get_db_context(db) as db: membership = ( db.query(ChannelMember) @@ -729,9 +672,7 @@ class ChannelTable: ) return membership is not None - def get_channel_by_id( - self, id: str, db: Optional[Session] = None - ) -> Optional[ChannelModel]: + def get_channel_by_id(self, id: str, db: Optional[Session] = None) -> Optional[ChannelModel]: try: with get_db_context(db) as db: channel = db.query(Channel).filter(Channel.id == id).first() @@ -739,18 +680,12 @@ class ChannelTable: except Exception: return None - def get_channels_by_file_id( - self, file_id: str, db: Optional[Session] = None - ) -> list[ChannelModel]: + def get_channels_by_file_id(self, file_id: str, db: Optional[Session] = None) -> list[ChannelModel]: with get_db_context(db) as db: - channel_files = ( - db.query(ChannelFile).filter(ChannelFile.file_id == file_id).all() - ) + channel_files = db.query(ChannelFile).filter(ChannelFile.file_id == file_id).all() channel_ids = [cf.channel_id for cf in channel_files] channels = db.query(Channel).filter(Channel.id.in_(channel_ids)).all() - grants_map = AccessGrants.get_grants_by_resources( - "channel", channel_ids, db=db - ) + grants_map = AccessGrants.get_grants_by_resources('channel', channel_ids, db=db) return [ self._to_channel_model( channel, @@ -765,9 +700,7 @@ class ChannelTable: ) -> list[ChannelModel]: with get_db_context(db) as db: # 1. Determine which channels have this file - channel_file_rows = ( - db.query(ChannelFile).filter(ChannelFile.file_id == file_id).all() - ) + channel_file_rows = db.query(ChannelFile).filter(ChannelFile.file_id == file_id).all() channel_ids = [row.channel_id for row in channel_file_rows] if not channel_ids: @@ -787,15 +720,13 @@ class ChannelTable: return [] # Preload user's group membership - user_group_ids = [ - g.id for g in Groups.get_groups_by_member_id(user_id, db=db) - ] + user_group_ids = [g.id for g in Groups.get_groups_by_member_id(user_id, db=db)] allowed_channels = [] for channel in channels: # --- Case A: group or dm => user must be an active member --- - if channel.type in ["group", "dm"]: + if channel.type in ['group', 'dm']: membership = ( db.query(ChannelMember) .filter( @@ -815,8 +746,8 @@ class ChannelTable: query = self._has_permission( db, query, - {"user_id": user_id, "group_ids": user_group_ids}, - permission="read", + {'user_id': user_id, 'group_ids': user_group_ids}, + permission='read', ) allowed = query.first() @@ -844,7 +775,7 @@ class ChannelTable: return None # If the channel is a group or dm, read access requires membership (active) - if channel.type in ["group", "dm"]: + if channel.type in ['group', 'dm']: membership = ( db.query(ChannelMember) .filter( @@ -863,24 +794,18 @@ class ChannelTable: query = db.query(Channel).filter(Channel.id == id) # Determine user groups - user_group_ids = [ - group.id for group in Groups.get_groups_by_member_id(user_id, db=db) - ] + user_group_ids = [group.id for group in Groups.get_groups_by_member_id(user_id, db=db)] # Apply ACL rules query = self._has_permission( db, query, - {"user_id": user_id, "group_ids": user_group_ids}, - permission="read", + {'user_id': user_id, 'group_ids': user_group_ids}, + permission='read', ) channel_allowed = query.first() - return ( - self._to_channel_model(channel_allowed, db=db) - if channel_allowed - else None - ) + return self._to_channel_model(channel_allowed, db=db) if channel_allowed else None def update_channel_by_id( self, id: str, form_data: ChannelForm, db: Optional[Session] = None @@ -898,9 +823,7 @@ class ChannelTable: channel.meta = form_data.meta if form_data.access_grants is not None: - AccessGrants.set_access_grants( - "channel", id, form_data.access_grants, db=db - ) + AccessGrants.set_access_grants('channel', id, form_data.access_grants, db=db) channel.updated_at = int(time.time_ns()) db.commit() @@ -912,12 +835,12 @@ class ChannelTable: with get_db_context(db) as db: channel_file = ChannelFileModel( **{ - "id": str(uuid.uuid4()), - "channel_id": channel_id, - "file_id": file_id, - "user_id": user_id, - "created_at": int(time.time()), - "updated_at": int(time.time()), + 'id': str(uuid.uuid4()), + 'channel_id': channel_id, + 'file_id': file_id, + 'user_id': user_id, + 'created_at': int(time.time()), + 'updated_at': int(time.time()), } ) @@ -942,11 +865,7 @@ class ChannelTable: ) -> bool: try: with get_db_context(db) as db: - channel_file = ( - db.query(ChannelFile) - .filter_by(channel_id=channel_id, file_id=file_id) - .first() - ) + channel_file = db.query(ChannelFile).filter_by(channel_id=channel_id, file_id=file_id).first() if not channel_file: return False @@ -958,14 +877,10 @@ class ChannelTable: except Exception: return False - def remove_file_from_channel_by_id( - self, channel_id: str, file_id: str, db: Optional[Session] = None - ) -> bool: + def remove_file_from_channel_by_id(self, channel_id: str, file_id: str, db: Optional[Session] = None) -> bool: try: with get_db_context(db) as db: - db.query(ChannelFile).filter_by( - channel_id=channel_id, file_id=file_id - ).delete() + db.query(ChannelFile).filter_by(channel_id=channel_id, file_id=file_id).delete() db.commit() return True except Exception: @@ -973,7 +888,7 @@ class ChannelTable: def delete_channel_by_id(self, id: str, db: Optional[Session] = None) -> bool: with get_db_context(db) as db: - AccessGrants.revoke_all_access("channel", id, db=db) + AccessGrants.revoke_all_access('channel', id, db=db) db.query(Channel).filter(Channel.id == id).delete() db.commit() return True @@ -1005,24 +920,14 @@ class ChannelTable: db.commit() return webhook - def get_webhooks_by_channel_id( - self, channel_id: str, db: Optional[Session] = None - ) -> list[ChannelWebhookModel]: + def get_webhooks_by_channel_id(self, channel_id: str, db: Optional[Session] = None) -> list[ChannelWebhookModel]: with get_db_context(db) as db: - webhooks = ( - db.query(ChannelWebhook) - .filter(ChannelWebhook.channel_id == channel_id) - .all() - ) + webhooks = db.query(ChannelWebhook).filter(ChannelWebhook.channel_id == channel_id).all() return [ChannelWebhookModel.model_validate(w) for w in webhooks] - def get_webhook_by_id( - self, webhook_id: str, db: Optional[Session] = None - ) -> Optional[ChannelWebhookModel]: + def get_webhook_by_id(self, webhook_id: str, db: Optional[Session] = None) -> Optional[ChannelWebhookModel]: with get_db_context(db) as db: - webhook = ( - db.query(ChannelWebhook).filter(ChannelWebhook.id == webhook_id).first() - ) + webhook = db.query(ChannelWebhook).filter(ChannelWebhook.id == webhook_id).first() return ChannelWebhookModel.model_validate(webhook) if webhook else None def get_webhook_by_id_and_token( @@ -1046,9 +951,7 @@ class ChannelTable: db: Optional[Session] = None, ) -> Optional[ChannelWebhookModel]: with get_db_context(db) as db: - webhook = ( - db.query(ChannelWebhook).filter(ChannelWebhook.id == webhook_id).first() - ) + webhook = db.query(ChannelWebhook).filter(ChannelWebhook.id == webhook_id).first() if not webhook: return None webhook.name = form_data.name @@ -1057,28 +960,18 @@ class ChannelTable: db.commit() return ChannelWebhookModel.model_validate(webhook) - def update_webhook_last_used_at( - self, webhook_id: str, db: Optional[Session] = None - ) -> bool: + def update_webhook_last_used_at(self, webhook_id: str, db: Optional[Session] = None) -> bool: with get_db_context(db) as db: - webhook = ( - db.query(ChannelWebhook).filter(ChannelWebhook.id == webhook_id).first() - ) + webhook = db.query(ChannelWebhook).filter(ChannelWebhook.id == webhook_id).first() if not webhook: return False webhook.last_used_at = int(time.time_ns()) db.commit() return True - def delete_webhook_by_id( - self, webhook_id: str, db: Optional[Session] = None - ) -> bool: + def delete_webhook_by_id(self, webhook_id: str, db: Optional[Session] = None) -> bool: with get_db_context(db) as db: - result = ( - db.query(ChannelWebhook) - .filter(ChannelWebhook.id == webhook_id) - .delete() - ) + result = db.query(ChannelWebhook).filter(ChannelWebhook.id == webhook_id).delete() db.commit() return result > 0 diff --git a/backend/open_webui/models/chat_messages.py b/backend/open_webui/models/chat_messages.py index 00609ce7f0..97490c1602 100644 --- a/backend/open_webui/models/chat_messages.py +++ b/backend/open_webui/models/chat_messages.py @@ -47,13 +47,11 @@ def _normalize_timestamp(timestamp: int) -> float: class ChatMessage(Base): - __tablename__ = "chat_message" + __tablename__ = 'chat_message' # Identity id = Column(Text, primary_key=True) - chat_id = Column( - Text, ForeignKey("chat.id", ondelete="CASCADE"), nullable=False, index=True - ) + chat_id = Column(Text, ForeignKey('chat.id', ondelete='CASCADE'), nullable=False, index=True) user_id = Column(Text, index=True) # Structure @@ -85,9 +83,9 @@ class ChatMessage(Base): updated_at = Column(BigInteger) __table_args__ = ( - Index("chat_message_chat_parent_idx", "chat_id", "parent_id"), - Index("chat_message_model_created_idx", "model_id", "created_at"), - Index("chat_message_user_created_idx", "user_id", "created_at"), + Index('chat_message_chat_parent_idx', 'chat_id', 'parent_id'), + Index('chat_message_model_created_idx', 'model_id', 'created_at'), + Index('chat_message_user_created_idx', 'user_id', 'created_at'), ) @@ -135,43 +133,41 @@ class ChatMessageTable: """Insert or update a chat message.""" with get_db_context(db) as db: now = int(time.time()) - timestamp = data.get("timestamp", now) + timestamp = data.get('timestamp', now) # Use composite ID: {chat_id}-{message_id} - composite_id = f"{chat_id}-{message_id}" + composite_id = f'{chat_id}-{message_id}' existing = db.get(ChatMessage, composite_id) if existing: # Update existing - if "role" in data: - existing.role = data["role"] - if "parent_id" in data: - existing.parent_id = data.get("parent_id") or data.get("parentId") - if "content" in data: - existing.content = data.get("content") - if "output" in data: - existing.output = data.get("output") - if "model_id" in data or "model" in data: - existing.model_id = data.get("model_id") or data.get("model") - if "files" in data: - existing.files = data.get("files") - if "sources" in data: - existing.sources = data.get("sources") - if "embeds" in data: - existing.embeds = data.get("embeds") - if "done" in data: - existing.done = data.get("done", True) - if "status_history" in data or "statusHistory" in data: - existing.status_history = data.get("status_history") or data.get( - "statusHistory" - ) - if "error" in data: - existing.error = data.get("error") + if 'role' in data: + existing.role = data['role'] + if 'parent_id' in data: + existing.parent_id = data.get('parent_id') or data.get('parentId') + if 'content' in data: + existing.content = data.get('content') + if 'output' in data: + existing.output = data.get('output') + if 'model_id' in data or 'model' in data: + existing.model_id = data.get('model_id') or data.get('model') + if 'files' in data: + existing.files = data.get('files') + if 'sources' in data: + existing.sources = data.get('sources') + if 'embeds' in data: + existing.embeds = data.get('embeds') + if 'done' in data: + existing.done = data.get('done', True) + if 'status_history' in data or 'statusHistory' in data: + existing.status_history = data.get('status_history') or data.get('statusHistory') + if 'error' in data: + existing.error = data.get('error') # Extract usage - check direct field first, then info.usage - usage = data.get("usage") + usage = data.get('usage') if not usage: - info = data.get("info", {}) - usage = info.get("usage") if info else None + info = data.get('info', {}) + usage = info.get('usage') if info else None if usage: existing.usage = usage existing.updated_at = now @@ -181,26 +177,25 @@ class ChatMessageTable: else: # Insert new # Extract usage - check direct field first, then info.usage - usage = data.get("usage") + usage = data.get('usage') if not usage: - info = data.get("info", {}) - usage = info.get("usage") if info else None + info = data.get('info', {}) + usage = info.get('usage') if info else None message = ChatMessage( id=composite_id, chat_id=chat_id, user_id=user_id, - role=data.get("role", "user"), - parent_id=data.get("parent_id") or data.get("parentId"), - content=data.get("content"), - output=data.get("output"), - model_id=data.get("model_id") or data.get("model"), - files=data.get("files"), - sources=data.get("sources"), - embeds=data.get("embeds"), - done=data.get("done", True), - status_history=data.get("status_history") - or data.get("statusHistory"), - error=data.get("error"), + role=data.get('role', 'user'), + parent_id=data.get('parent_id') or data.get('parentId'), + content=data.get('content'), + output=data.get('output'), + model_id=data.get('model_id') or data.get('model'), + files=data.get('files'), + sources=data.get('sources'), + embeds=data.get('embeds'), + done=data.get('done', True), + status_history=data.get('status_history') or data.get('statusHistory'), + error=data.get('error'), usage=usage, created_at=timestamp, updated_at=now, @@ -210,23 +205,14 @@ class ChatMessageTable: db.refresh(message) return ChatMessageModel.model_validate(message) - def get_message_by_id( - self, id: str, db: Optional[Session] = None - ) -> Optional[ChatMessageModel]: + def get_message_by_id(self, id: str, db: Optional[Session] = None) -> Optional[ChatMessageModel]: with get_db_context(db) as db: message = db.get(ChatMessage, id) return ChatMessageModel.model_validate(message) if message else None - def get_messages_by_chat_id( - self, chat_id: str, db: Optional[Session] = None - ) -> list[ChatMessageModel]: + def get_messages_by_chat_id(self, chat_id: str, db: Optional[Session] = None) -> list[ChatMessageModel]: with get_db_context(db) as db: - messages = ( - db.query(ChatMessage) - .filter_by(chat_id=chat_id) - .order_by(ChatMessage.created_at.asc()) - .all() - ) + messages = db.query(ChatMessage).filter_by(chat_id=chat_id).order_by(ChatMessage.created_at.asc()).all() return [ChatMessageModel.model_validate(message) for message in messages] def get_messages_by_user_id( @@ -262,12 +248,7 @@ class ChatMessageTable: query = query.filter(ChatMessage.created_at >= start_date) if end_date: query = query.filter(ChatMessage.created_at <= end_date) - messages = ( - query.order_by(ChatMessage.created_at.desc()) - .offset(skip) - .limit(limit) - .all() - ) + messages = query.order_by(ChatMessage.created_at.desc()).offset(skip).limit(limit).all() return [ChatMessageModel.model_validate(message) for message in messages] def get_chat_ids_by_model_id( @@ -284,7 +265,7 @@ class ChatMessageTable: with get_db_context(db) as db: query = db.query( ChatMessage.chat_id, - func.max(ChatMessage.created_at).label("last_message_at"), + func.max(ChatMessage.created_at).label('last_message_at'), ).filter(ChatMessage.model_id == model_id) if start_date: query = query.filter(ChatMessage.created_at >= start_date) @@ -303,9 +284,7 @@ class ChatMessageTable: ) return [chat_id for chat_id, _ in chat_ids] - def delete_messages_by_chat_id( - self, chat_id: str, db: Optional[Session] = None - ) -> bool: + def delete_messages_by_chat_id(self, chat_id: str, db: Optional[Session] = None) -> bool: with get_db_context(db) as db: db.query(ChatMessage).filter_by(chat_id=chat_id).delete() db.commit() @@ -323,12 +302,10 @@ class ChatMessageTable: from sqlalchemy import func from open_webui.models.groups import GroupMember - query = db.query( - ChatMessage.model_id, func.count(ChatMessage.id).label("count") - ).filter( - ChatMessage.role == "assistant", + query = db.query(ChatMessage.model_id, func.count(ChatMessage.id).label('count')).filter( + ChatMessage.role == 'assistant', ChatMessage.model_id.isnot(None), - ~ChatMessage.user_id.like("shared-%"), + ~ChatMessage.user_id.like('shared-%'), ) if start_date: @@ -336,11 +313,7 @@ class ChatMessageTable: if end_date: query = query.filter(ChatMessage.created_at <= end_date) if group_id: - group_users = ( - db.query(GroupMember.user_id) - .filter(GroupMember.group_id == group_id) - .subquery() - ) + group_users = db.query(GroupMember.user_id).filter(GroupMember.group_id == group_id).subquery() query = query.filter(ChatMessage.user_id.in_(group_users)) results = query.group_by(ChatMessage.model_id).all() @@ -360,36 +333,32 @@ class ChatMessageTable: dialect = db.bind.dialect.name - if dialect == "sqlite": - input_tokens = cast( - func.json_extract(ChatMessage.usage, "$.input_tokens"), Integer - ) - output_tokens = cast( - func.json_extract(ChatMessage.usage, "$.output_tokens"), Integer - ) - elif dialect == "postgresql": + if dialect == 'sqlite': + input_tokens = cast(func.json_extract(ChatMessage.usage, '$.input_tokens'), Integer) + output_tokens = cast(func.json_extract(ChatMessage.usage, '$.output_tokens'), Integer) + elif dialect == 'postgresql': # Use json_extract_path_text for PostgreSQL JSON columns input_tokens = cast( - func.json_extract_path_text(ChatMessage.usage, "input_tokens"), + func.json_extract_path_text(ChatMessage.usage, 'input_tokens'), Integer, ) output_tokens = cast( - func.json_extract_path_text(ChatMessage.usage, "output_tokens"), + func.json_extract_path_text(ChatMessage.usage, 'output_tokens'), Integer, ) else: - raise NotImplementedError(f"Unsupported dialect: {dialect}") + raise NotImplementedError(f'Unsupported dialect: {dialect}') query = db.query( ChatMessage.model_id, - func.coalesce(func.sum(input_tokens), 0).label("input_tokens"), - func.coalesce(func.sum(output_tokens), 0).label("output_tokens"), - func.count(ChatMessage.id).label("message_count"), + func.coalesce(func.sum(input_tokens), 0).label('input_tokens'), + func.coalesce(func.sum(output_tokens), 0).label('output_tokens'), + func.count(ChatMessage.id).label('message_count'), ).filter( - ChatMessage.role == "assistant", + ChatMessage.role == 'assistant', ChatMessage.model_id.isnot(None), ChatMessage.usage.isnot(None), - ~ChatMessage.user_id.like("shared-%"), + ~ChatMessage.user_id.like('shared-%'), ) if start_date: @@ -397,21 +366,17 @@ class ChatMessageTable: if end_date: query = query.filter(ChatMessage.created_at <= end_date) if group_id: - group_users = ( - db.query(GroupMember.user_id) - .filter(GroupMember.group_id == group_id) - .subquery() - ) + group_users = db.query(GroupMember.user_id).filter(GroupMember.group_id == group_id).subquery() query = query.filter(ChatMessage.user_id.in_(group_users)) results = query.group_by(ChatMessage.model_id).all() return { row.model_id: { - "input_tokens": row.input_tokens, - "output_tokens": row.output_tokens, - "total_tokens": row.input_tokens + row.output_tokens, - "message_count": row.message_count, + 'input_tokens': row.input_tokens, + 'output_tokens': row.output_tokens, + 'total_tokens': row.input_tokens + row.output_tokens, + 'message_count': row.message_count, } for row in results } @@ -430,36 +395,32 @@ class ChatMessageTable: dialect = db.bind.dialect.name - if dialect == "sqlite": - input_tokens = cast( - func.json_extract(ChatMessage.usage, "$.input_tokens"), Integer - ) - output_tokens = cast( - func.json_extract(ChatMessage.usage, "$.output_tokens"), Integer - ) - elif dialect == "postgresql": + if dialect == 'sqlite': + input_tokens = cast(func.json_extract(ChatMessage.usage, '$.input_tokens'), Integer) + output_tokens = cast(func.json_extract(ChatMessage.usage, '$.output_tokens'), Integer) + elif dialect == 'postgresql': # Use json_extract_path_text for PostgreSQL JSON columns input_tokens = cast( - func.json_extract_path_text(ChatMessage.usage, "input_tokens"), + func.json_extract_path_text(ChatMessage.usage, 'input_tokens'), Integer, ) output_tokens = cast( - func.json_extract_path_text(ChatMessage.usage, "output_tokens"), + func.json_extract_path_text(ChatMessage.usage, 'output_tokens'), Integer, ) else: - raise NotImplementedError(f"Unsupported dialect: {dialect}") + raise NotImplementedError(f'Unsupported dialect: {dialect}') query = db.query( ChatMessage.user_id, - func.coalesce(func.sum(input_tokens), 0).label("input_tokens"), - func.coalesce(func.sum(output_tokens), 0).label("output_tokens"), - func.count(ChatMessage.id).label("message_count"), + func.coalesce(func.sum(input_tokens), 0).label('input_tokens'), + func.coalesce(func.sum(output_tokens), 0).label('output_tokens'), + func.count(ChatMessage.id).label('message_count'), ).filter( - ChatMessage.role == "assistant", + ChatMessage.role == 'assistant', ChatMessage.user_id.isnot(None), ChatMessage.usage.isnot(None), - ~ChatMessage.user_id.like("shared-%"), + ~ChatMessage.user_id.like('shared-%'), ) if start_date: @@ -467,21 +428,17 @@ class ChatMessageTable: if end_date: query = query.filter(ChatMessage.created_at <= end_date) if group_id: - group_users = ( - db.query(GroupMember.user_id) - .filter(GroupMember.group_id == group_id) - .subquery() - ) + group_users = db.query(GroupMember.user_id).filter(GroupMember.group_id == group_id).subquery() query = query.filter(ChatMessage.user_id.in_(group_users)) results = query.group_by(ChatMessage.user_id).all() return { row.user_id: { - "input_tokens": row.input_tokens, - "output_tokens": row.output_tokens, - "total_tokens": row.input_tokens + row.output_tokens, - "message_count": row.message_count, + 'input_tokens': row.input_tokens, + 'output_tokens': row.output_tokens, + 'total_tokens': row.input_tokens + row.output_tokens, + 'message_count': row.message_count, } for row in results } @@ -497,20 +454,16 @@ class ChatMessageTable: from sqlalchemy import func from open_webui.models.groups import GroupMember - query = db.query( - ChatMessage.user_id, func.count(ChatMessage.id).label("count") - ).filter(~ChatMessage.user_id.like("shared-%")) + query = db.query(ChatMessage.user_id, func.count(ChatMessage.id).label('count')).filter( + ~ChatMessage.user_id.like('shared-%') + ) if start_date: query = query.filter(ChatMessage.created_at >= start_date) if end_date: query = query.filter(ChatMessage.created_at <= end_date) if group_id: - group_users = ( - db.query(GroupMember.user_id) - .filter(GroupMember.group_id == group_id) - .subquery() - ) + group_users = db.query(GroupMember.user_id).filter(GroupMember.group_id == group_id).subquery() query = query.filter(ChatMessage.user_id.in_(group_users)) results = query.group_by(ChatMessage.user_id).all() @@ -527,20 +480,16 @@ class ChatMessageTable: from sqlalchemy import func from open_webui.models.groups import GroupMember - query = db.query( - ChatMessage.chat_id, func.count(ChatMessage.id).label("count") - ).filter(~ChatMessage.user_id.like("shared-%")) + query = db.query(ChatMessage.chat_id, func.count(ChatMessage.id).label('count')).filter( + ~ChatMessage.user_id.like('shared-%') + ) if start_date: query = query.filter(ChatMessage.created_at >= start_date) if end_date: query = query.filter(ChatMessage.created_at <= end_date) if group_id: - group_users = ( - db.query(GroupMember.user_id) - .filter(GroupMember.group_id == group_id) - .subquery() - ) + group_users = db.query(GroupMember.user_id).filter(GroupMember.group_id == group_id).subquery() query = query.filter(ChatMessage.user_id.in_(group_users)) results = query.group_by(ChatMessage.chat_id).all() @@ -559,9 +508,9 @@ class ChatMessageTable: from open_webui.models.groups import GroupMember query = db.query(ChatMessage.created_at, ChatMessage.model_id).filter( - ChatMessage.role == "assistant", + ChatMessage.role == 'assistant', ChatMessage.model_id.isnot(None), - ~ChatMessage.user_id.like("shared-%"), + ~ChatMessage.user_id.like('shared-%'), ) if start_date: @@ -569,11 +518,7 @@ class ChatMessageTable: if end_date: query = query.filter(ChatMessage.created_at <= end_date) if group_id: - group_users = ( - db.query(GroupMember.user_id) - .filter(GroupMember.group_id == group_id) - .subquery() - ) + group_users = db.query(GroupMember.user_id).filter(GroupMember.group_id == group_id).subquery() query = query.filter(ChatMessage.user_id.in_(group_users)) results = query.all() @@ -581,21 +526,17 @@ class ChatMessageTable: # Group by date -> model -> count daily_counts: dict[str, dict[str, int]] = {} for timestamp, model_id in results: - date_str = datetime.fromtimestamp( - _normalize_timestamp(timestamp) - ).strftime("%Y-%m-%d") + date_str = datetime.fromtimestamp(_normalize_timestamp(timestamp)).strftime('%Y-%m-%d') if date_str not in daily_counts: daily_counts[date_str] = {} - daily_counts[date_str][model_id] = ( - daily_counts[date_str].get(model_id, 0) + 1 - ) + daily_counts[date_str][model_id] = daily_counts[date_str].get(model_id, 0) + 1 # Fill in missing days if start_date and end_date: current = datetime.fromtimestamp(_normalize_timestamp(start_date)) end_dt = datetime.fromtimestamp(_normalize_timestamp(end_date)) while current <= end_dt: - date_str = current.strftime("%Y-%m-%d") + date_str = current.strftime('%Y-%m-%d') if date_str not in daily_counts: daily_counts[date_str] = {} current += timedelta(days=1) @@ -613,9 +554,9 @@ class ChatMessageTable: from datetime import datetime, timedelta query = db.query(ChatMessage.created_at, ChatMessage.model_id).filter( - ChatMessage.role == "assistant", + ChatMessage.role == 'assistant', ChatMessage.model_id.isnot(None), - ~ChatMessage.user_id.like("shared-%"), + ~ChatMessage.user_id.like('shared-%'), ) if start_date: @@ -628,23 +569,19 @@ class ChatMessageTable: # Group by hour -> model -> count hourly_counts: dict[str, dict[str, int]] = {} for timestamp, model_id in results: - hour_str = datetime.fromtimestamp( - _normalize_timestamp(timestamp) - ).strftime("%Y-%m-%d %H:00") + hour_str = datetime.fromtimestamp(_normalize_timestamp(timestamp)).strftime('%Y-%m-%d %H:00') if hour_str not in hourly_counts: hourly_counts[hour_str] = {} - hourly_counts[hour_str][model_id] = ( - hourly_counts[hour_str].get(model_id, 0) + 1 - ) + hourly_counts[hour_str][model_id] = hourly_counts[hour_str].get(model_id, 0) + 1 # Fill in missing hours if start_date and end_date: - current = datetime.fromtimestamp( - _normalize_timestamp(start_date) - ).replace(minute=0, second=0, microsecond=0) + current = datetime.fromtimestamp(_normalize_timestamp(start_date)).replace( + minute=0, second=0, microsecond=0 + ) end_dt = datetime.fromtimestamp(_normalize_timestamp(end_date)) while current <= end_dt: - hour_str = current.strftime("%Y-%m-%d %H:00") + hour_str = current.strftime('%Y-%m-%d %H:00') if hour_str not in hourly_counts: hourly_counts[hour_str] = {} current += timedelta(hours=1) diff --git a/backend/open_webui/models/chats.py b/backend/open_webui/models/chats.py index 78aa4d3ded..9fe923f004 100644 --- a/backend/open_webui/models/chats.py +++ b/backend/open_webui/models/chats.py @@ -35,7 +35,7 @@ log = logging.getLogger(__name__) class Chat(Base): - __tablename__ = "chat" + __tablename__ = 'chat' id = Column(String, primary_key=True, unique=True) user_id = Column(String) @@ -49,21 +49,21 @@ class Chat(Base): archived = Column(Boolean, default=False) pinned = Column(Boolean, default=False, nullable=True) - meta = Column(JSON, server_default="{}") + meta = Column(JSON, server_default='{}') folder_id = Column(Text, nullable=True) __table_args__ = ( # Performance indexes for common queries # WHERE folder_id = ... - Index("folder_id_idx", "folder_id"), + Index('folder_id_idx', 'folder_id'), # WHERE user_id = ... AND pinned = ... - Index("user_id_pinned_idx", "user_id", "pinned"), + Index('user_id_pinned_idx', 'user_id', 'pinned'), # WHERE user_id = ... AND archived = ... - Index("user_id_archived_idx", "user_id", "archived"), + Index('user_id_archived_idx', 'user_id', 'archived'), # WHERE user_id = ... ORDER BY updated_at DESC - Index("updated_at_user_id_idx", "updated_at", "user_id"), + Index('updated_at_user_id_idx', 'updated_at', 'user_id'), # WHERE folder_id = ... AND user_id = ... - Index("folder_id_user_id_idx", "folder_id", "user_id"), + Index('folder_id_user_id_idx', 'folder_id', 'user_id'), ) @@ -87,21 +87,19 @@ class ChatModel(BaseModel): class ChatFile(Base): - __tablename__ = "chat_file" + __tablename__ = 'chat_file' id = Column(Text, unique=True, primary_key=True) user_id = Column(Text, nullable=False) - chat_id = Column(Text, ForeignKey("chat.id", ondelete="CASCADE"), nullable=False) + chat_id = Column(Text, ForeignKey('chat.id', ondelete='CASCADE'), nullable=False) message_id = Column(Text, nullable=True) - file_id = Column(Text, ForeignKey("file.id", ondelete="CASCADE"), nullable=False) + file_id = Column(Text, ForeignKey('file.id', ondelete='CASCADE'), nullable=False) created_at = Column(BigInteger, nullable=False) updated_at = Column(BigInteger, nullable=False) - __table_args__ = ( - UniqueConstraint("chat_id", "file_id", name="uq_chat_file_chat_file"), - ) + __table_args__ = (UniqueConstraint('chat_id', 'file_id', name='uq_chat_file_chat_file'),) class ChatFileModel(BaseModel): @@ -191,19 +189,11 @@ class ChatUsageStatsResponse(BaseModel): history_models: dict = {} # models used in the chat history with their usage counts history_message_count: int # number of messages in the chat history history_user_message_count: int # number of user messages in the chat history - history_assistant_message_count: ( - int # number of assistant messages in the chat history - ) + history_assistant_message_count: int # number of assistant messages in the chat history - average_response_time: ( - float # average response time of assistant messages in seconds - ) - average_user_message_content_length: ( - float # average length of user message contents - ) - average_assistant_message_content_length: ( - float # average length of assistant message contents - ) + average_response_time: float # average response time of assistant messages in seconds + average_user_message_content_length: float # average length of user message contents + average_assistant_message_content_length: float # average length of assistant message contents tags: list[str] = [] # tags associated with the chat @@ -211,13 +201,13 @@ class ChatUsageStatsResponse(BaseModel): updated_at: int created_at: int - model_config = ConfigDict(extra="allow") + model_config = ConfigDict(extra='allow') class ChatUsageStatsListResponse(BaseModel): items: list[ChatUsageStatsResponse] total: int - model_config = ConfigDict(extra="allow") + model_config = ConfigDict(extra='allow') class MessageStats(BaseModel): @@ -290,24 +280,20 @@ class ChatTable: return changed - def insert_new_chat( - self, user_id: str, form_data: ChatForm, db: Optional[Session] = None - ) -> Optional[ChatModel]: + def insert_new_chat(self, user_id: str, form_data: ChatForm, db: Optional[Session] = None) -> Optional[ChatModel]: with get_db_context(db) as db: id = str(uuid.uuid4()) chat = ChatModel( **{ - "id": id, - "user_id": user_id, - "title": self._clean_null_bytes( - form_data.chat["title"] - if "title" in form_data.chat - else "New Chat" + 'id': id, + 'user_id': user_id, + 'title': self._clean_null_bytes( + form_data.chat['title'] if 'title' in form_data.chat else 'New Chat' ), - "chat": self._clean_null_bytes(form_data.chat), - "folder_id": form_data.folder_id, - "created_at": int(time.time()), - "updated_at": int(time.time()), + 'chat': self._clean_null_bytes(form_data.chat), + 'folder_id': form_data.folder_id, + 'created_at': int(time.time()), + 'updated_at': int(time.time()), } ) @@ -318,10 +304,10 @@ class ChatTable: # Dual-write initial messages to chat_message table try: - history = form_data.chat.get("history", {}) - messages = history.get("messages", {}) + history = form_data.chat.get('history', {}) + messages = history.get('messages', {}) for message_id, message in messages.items(): - if isinstance(message, dict) and message.get("role"): + if isinstance(message, dict) and message.get('role'): ChatMessages.upsert_message( message_id=message_id, chat_id=id, @@ -329,33 +315,23 @@ class ChatTable: data=message, ) except Exception as e: - log.warning( - f"Failed to write initial messages to chat_message table: {e}" - ) + log.warning(f'Failed to write initial messages to chat_message table: {e}') return ChatModel.model_validate(chat_item) if chat_item else None - def _chat_import_form_to_chat_model( - self, user_id: str, form_data: ChatImportForm - ) -> ChatModel: + def _chat_import_form_to_chat_model(self, user_id: str, form_data: ChatImportForm) -> ChatModel: id = str(uuid.uuid4()) chat = ChatModel( **{ - "id": id, - "user_id": user_id, - "title": self._clean_null_bytes( - form_data.chat["title"] if "title" in form_data.chat else "New Chat" - ), - "chat": self._clean_null_bytes(form_data.chat), - "meta": form_data.meta, - "pinned": form_data.pinned, - "folder_id": form_data.folder_id, - "created_at": ( - form_data.created_at if form_data.created_at else int(time.time()) - ), - "updated_at": ( - form_data.updated_at if form_data.updated_at else int(time.time()) - ), + 'id': id, + 'user_id': user_id, + 'title': self._clean_null_bytes(form_data.chat['title'] if 'title' in form_data.chat else 'New Chat'), + 'chat': self._clean_null_bytes(form_data.chat), + 'meta': form_data.meta, + 'pinned': form_data.pinned, + 'folder_id': form_data.folder_id, + 'created_at': (form_data.created_at if form_data.created_at else int(time.time())), + 'updated_at': (form_data.updated_at if form_data.updated_at else int(time.time())), } ) return chat @@ -379,10 +355,10 @@ class ChatTable: # Dual-write messages to chat_message table try: for form_data, chat_obj in zip(chat_import_forms, chats): - history = form_data.chat.get("history", {}) - messages = history.get("messages", {}) + history = form_data.chat.get('history', {}) + messages = history.get('messages', {}) for message_id, message in messages.items(): - if isinstance(message, dict) and message.get("role"): + if isinstance(message, dict) and message.get('role'): ChatMessages.upsert_message( message_id=message_id, chat_id=chat_obj.id, @@ -390,24 +366,16 @@ class ChatTable: data=message, ) except Exception as e: - log.warning( - f"Failed to write imported messages to chat_message table: {e}" - ) + log.warning(f'Failed to write imported messages to chat_message table: {e}') return [ChatModel.model_validate(chat) for chat in chats] - def update_chat_by_id( - self, id: str, chat: dict, db: Optional[Session] = None - ) -> Optional[ChatModel]: + def update_chat_by_id(self, id: str, chat: dict, db: Optional[Session] = None) -> Optional[ChatModel]: try: with get_db_context(db) as db: chat_item = db.get(Chat, id) chat_item.chat = self._clean_null_bytes(chat) - chat_item.title = ( - self._clean_null_bytes(chat["title"]) - if "title" in chat - else "New Chat" - ) + chat_item.title = self._clean_null_bytes(chat['title']) if 'title' in chat else 'New Chat' chat_item.updated_at = int(time.time()) @@ -424,24 +392,22 @@ class ChatTable: return None chat = chat.chat - chat["title"] = title + chat['title'] = title return self.update_chat_by_id(id, chat) - def update_chat_tags_by_id( - self, id: str, tags: list[str], user - ) -> Optional[ChatModel]: + def update_chat_tags_by_id(self, id: str, tags: list[str], user) -> Optional[ChatModel]: with get_db_context() as db: chat = db.get(Chat, id) if chat is None: return None - old_tags = chat.meta.get("tags", []) - new_tags = [t for t in tags if t.replace(" ", "_").lower() != "none"] - new_tag_ids = [t.replace(" ", "_").lower() for t in new_tags] + old_tags = chat.meta.get('tags', []) + new_tags = [t for t in tags if t.replace(' ', '_').lower() != 'none'] + new_tag_ids = [t.replace(' ', '_').lower() for t in new_tags] # Single meta update - chat.meta = {**chat.meta, "tags": new_tag_ids} + chat.meta = {**chat.meta, 'tags': new_tag_ids} db.commit() db.refresh(chat) @@ -460,23 +426,21 @@ class ChatTable: result = db.query(Chat.title).filter_by(id=id).first() if result is None: return None - return result[0] or "New Chat" + return result[0] or 'New Chat' def get_messages_map_by_chat_id(self, id: str) -> Optional[dict]: chat = self.get_chat_by_id(id) if chat is None: return None - return chat.chat.get("history", {}).get("messages", {}) or {} + return chat.chat.get('history', {}).get('messages', {}) or {} - def get_message_by_id_and_message_id( - self, id: str, message_id: str - ) -> Optional[dict]: + def get_message_by_id_and_message_id(self, id: str, message_id: str) -> Optional[dict]: chat = self.get_chat_by_id(id) if chat is None: return None - return chat.chat.get("history", {}).get("messages", {}).get(message_id, {}) + return chat.chat.get('history', {}).get('messages', {}).get(message_id, {}) def upsert_message_to_chat_by_id_and_message_id( self, id: str, message_id: str, message: dict @@ -486,24 +450,24 @@ class ChatTable: return None # Sanitize message content for null characters before upserting - if isinstance(message.get("content"), str): - message["content"] = sanitize_text_for_db(message["content"]) + if isinstance(message.get('content'), str): + message['content'] = sanitize_text_for_db(message['content']) user_id = chat.user_id chat = chat.chat - history = chat.get("history", {}) + history = chat.get('history', {}) - if message_id in history.get("messages", {}): - history["messages"][message_id] = { - **history["messages"][message_id], + if message_id in history.get('messages', {}): + history['messages'][message_id] = { + **history['messages'][message_id], **message, } else: - history["messages"][message_id] = message + history['messages'][message_id] = message - history["currentId"] = message_id + history['currentId'] = message_id - chat["history"] = history + chat['history'] = history # Dual-write to chat_message table try: @@ -511,10 +475,10 @@ class ChatTable: message_id=message_id, chat_id=id, user_id=user_id, - data=history["messages"][message_id], + data=history['messages'][message_id], ) except Exception as e: - log.warning(f"Failed to write to chat_message table: {e}") + log.warning(f'Failed to write to chat_message table: {e}') return self.update_chat_by_id(id, chat) @@ -526,41 +490,37 @@ class ChatTable: return None chat = chat.chat - history = chat.get("history", {}) + history = chat.get('history', {}) - if message_id in history.get("messages", {}): - status_history = history["messages"][message_id].get("statusHistory", []) + if message_id in history.get('messages', {}): + status_history = history['messages'][message_id].get('statusHistory', []) status_history.append(status) - history["messages"][message_id]["statusHistory"] = status_history + history['messages'][message_id]['statusHistory'] = status_history - chat["history"] = history + chat['history'] = history return self.update_chat_by_id(id, chat) - def add_message_files_by_id_and_message_id( - self, id: str, message_id: str, files: list[dict] - ) -> list[dict]: + def add_message_files_by_id_and_message_id(self, id: str, message_id: str, files: list[dict]) -> list[dict]: with get_db_context() as db: chat = self.get_chat_by_id(id, db=db) if chat is None: return None chat = chat.chat - history = chat.get("history", {}) + history = chat.get('history', {}) message_files = [] - if message_id in history.get("messages", {}): - message_files = history["messages"][message_id].get("files", []) + if message_id in history.get('messages', {}): + message_files = history['messages'][message_id].get('files', []) message_files = message_files + files - history["messages"][message_id]["files"] = message_files + history['messages'][message_id]['files'] = message_files - chat["history"] = history + chat['history'] = history self.update_chat_by_id(id, chat, db=db) return message_files - def insert_shared_chat_by_chat_id( - self, chat_id: str, db: Optional[Session] = None - ) -> Optional[ChatModel]: + def insert_shared_chat_by_chat_id(self, chat_id: str, db: Optional[Session] = None) -> Optional[ChatModel]: with get_db_context(db) as db: # Get the existing chat to share chat = db.get(Chat, chat_id) @@ -569,19 +529,19 @@ class ChatTable: return None # Check if the chat is already shared if chat.share_id: - return self.get_chat_by_id_and_user_id(chat.share_id, "shared", db=db) + return self.get_chat_by_id_and_user_id(chat.share_id, 'shared', db=db) # Create a new chat with the same data, but with a new ID shared_chat = ChatModel( **{ - "id": str(uuid.uuid4()), - "user_id": f"shared-{chat_id}", - "title": chat.title, - "chat": chat.chat, - "meta": chat.meta, - "pinned": chat.pinned, - "folder_id": chat.folder_id, - "created_at": chat.created_at, - "updated_at": int(time.time()), + 'id': str(uuid.uuid4()), + 'user_id': f'shared-{chat_id}', + 'title': chat.title, + 'chat': chat.chat, + 'meta': chat.meta, + 'pinned': chat.pinned, + 'folder_id': chat.folder_id, + 'created_at': chat.created_at, + 'updated_at': int(time.time()), } ) shared_result = Chat(**shared_chat.model_dump()) @@ -590,23 +550,15 @@ class ChatTable: db.refresh(shared_result) # Update the original chat with the share_id - result = ( - db.query(Chat) - .filter_by(id=chat_id) - .update({"share_id": shared_chat.id}) - ) + result = db.query(Chat).filter_by(id=chat_id).update({'share_id': shared_chat.id}) db.commit() return shared_chat if (shared_result and result) else None - def update_shared_chat_by_chat_id( - self, chat_id: str, db: Optional[Session] = None - ) -> Optional[ChatModel]: + def update_shared_chat_by_chat_id(self, chat_id: str, db: Optional[Session] = None) -> Optional[ChatModel]: try: with get_db_context(db) as db: chat = db.get(Chat, chat_id) - shared_chat = ( - db.query(Chat).filter_by(user_id=f"shared-{chat_id}").first() - ) + shared_chat = db.query(Chat).filter_by(user_id=f'shared-{chat_id}').first() if shared_chat is None: return self.insert_shared_chat_by_chat_id(chat_id, db=db) @@ -624,33 +576,25 @@ class ChatTable: except Exception: return None - def delete_shared_chat_by_chat_id( - self, chat_id: str, db: Optional[Session] = None - ) -> bool: + def delete_shared_chat_by_chat_id(self, chat_id: str, db: Optional[Session] = None) -> bool: try: with get_db_context(db) as db: # Use subquery to delete chat_messages for shared chats - shared_chat_id_subquery = ( - db.query(Chat.id) - .filter_by(user_id=f"shared-{chat_id}") - .scalar_subquery() + shared_chat_id_subquery = db.query(Chat.id).filter_by(user_id=f'shared-{chat_id}').scalar_subquery() + db.query(ChatMessage).filter(ChatMessage.chat_id.in_(shared_chat_id_subquery)).delete( + synchronize_session=False ) - db.query(ChatMessage).filter( - ChatMessage.chat_id.in_(shared_chat_id_subquery) - ).delete(synchronize_session=False) - db.query(Chat).filter_by(user_id=f"shared-{chat_id}").delete() + db.query(Chat).filter_by(user_id=f'shared-{chat_id}').delete() db.commit() return True except Exception: return False - def unarchive_all_chats_by_user_id( - self, user_id: str, db: Optional[Session] = None - ) -> bool: + def unarchive_all_chats_by_user_id(self, user_id: str, db: Optional[Session] = None) -> bool: try: with get_db_context(db) as db: - db.query(Chat).filter_by(user_id=user_id).update({"archived": False}) + db.query(Chat).filter_by(user_id=user_id).update({'archived': False}) db.commit() return True except Exception: @@ -669,9 +613,7 @@ class ChatTable: except Exception: return None - def toggle_chat_pinned_by_id( - self, id: str, db: Optional[Session] = None - ) -> Optional[ChatModel]: + def toggle_chat_pinned_by_id(self, id: str, db: Optional[Session] = None) -> Optional[ChatModel]: try: with get_db_context(db) as db: chat = db.get(Chat, id) @@ -683,9 +625,7 @@ class ChatTable: except Exception: return None - def toggle_chat_archive_by_id( - self, id: str, db: Optional[Session] = None - ) -> Optional[ChatModel]: + def toggle_chat_archive_by_id(self, id: str, db: Optional[Session] = None) -> Optional[ChatModel]: try: with get_db_context(db) as db: chat = db.get(Chat, id) @@ -698,12 +638,10 @@ class ChatTable: except Exception: return None - def archive_all_chats_by_user_id( - self, user_id: str, db: Optional[Session] = None - ) -> bool: + def archive_all_chats_by_user_id(self, user_id: str, db: Optional[Session] = None) -> bool: try: with get_db_context(db) as db: - db.query(Chat).filter_by(user_id=user_id).update({"archived": True}) + db.query(Chat).filter_by(user_id=user_id).update({'archived': True}) db.commit() return True except Exception: @@ -717,34 +655,31 @@ class ChatTable: limit: int = 50, db: Optional[Session] = None, ) -> list[ChatTitleIdResponse]: - with get_db_context(db) as db: query = db.query(Chat).filter_by(user_id=user_id, archived=True) if filter: - query_key = filter.get("query") + query_key = filter.get('query') if query_key: - query = query.filter(Chat.title.ilike(f"%{query_key}%")) + query = query.filter(Chat.title.ilike(f'%{query_key}%')) - order_by = filter.get("order_by") - direction = filter.get("direction") + order_by = filter.get('order_by') + direction = filter.get('direction') if order_by and direction: if not getattr(Chat, order_by, None): - raise ValueError("Invalid order_by field") + raise ValueError('Invalid order_by field') - if direction.lower() == "asc": + if direction.lower() == 'asc': query = query.order_by(getattr(Chat, order_by).asc(), Chat.id) - elif direction.lower() == "desc": + elif direction.lower() == 'desc': query = query.order_by(getattr(Chat, order_by).desc(), Chat.id) else: - raise ValueError("Invalid direction for ordering") + raise ValueError('Invalid direction for ordering') else: query = query.order_by(Chat.updated_at.desc(), Chat.id) - query = query.with_entities( - Chat.id, Chat.title, Chat.updated_at, Chat.created_at - ) + query = query.with_entities(Chat.id, Chat.title, Chat.updated_at, Chat.created_at) if skip: query = query.offset(skip) @@ -755,10 +690,10 @@ class ChatTable: return [ ChatTitleIdResponse.model_validate( { - "id": chat[0], - "title": chat[1], - "updated_at": chat[2], - "created_at": chat[3], + 'id': chat[0], + 'title': chat[1], + 'updated_at': chat[2], + 'created_at': chat[3], } ) for chat in all_chats @@ -772,32 +707,27 @@ class ChatTable: limit: int = 50, db: Optional[Session] = None, ) -> list[SharedChatResponse]: - with get_db_context(db) as db: - query = ( - db.query(Chat) - .filter_by(user_id=user_id) - .filter(Chat.share_id.isnot(None)) - ) + query = db.query(Chat).filter_by(user_id=user_id).filter(Chat.share_id.isnot(None)) if filter: - query_key = filter.get("query") + query_key = filter.get('query') if query_key: - query = query.filter(Chat.title.ilike(f"%{query_key}%")) + query = query.filter(Chat.title.ilike(f'%{query_key}%')) - order_by = filter.get("order_by") - direction = filter.get("direction") + order_by = filter.get('order_by') + direction = filter.get('direction') if order_by and direction: if not getattr(Chat, order_by, None): - raise ValueError("Invalid order_by field") + raise ValueError('Invalid order_by field') - if direction.lower() == "asc": + if direction.lower() == 'asc': query = query.order_by(getattr(Chat, order_by).asc(), Chat.id) - elif direction.lower() == "desc": + elif direction.lower() == 'desc': query = query.order_by(getattr(Chat, order_by).desc(), Chat.id) else: - raise ValueError("Invalid direction for ordering") + raise ValueError('Invalid direction for ordering') else: query = query.order_by(Chat.updated_at.desc(), Chat.id) @@ -820,11 +750,11 @@ class ChatTable: return [ SharedChatResponse.model_validate( { - "id": chat[0], - "title": chat[1], - "share_id": chat[2], - "updated_at": chat[3], - "created_at": chat[4], + 'id': chat[0], + 'title': chat[1], + 'share_id': chat[2], + 'updated_at': chat[3], + 'created_at': chat[4], } ) for chat in all_chats @@ -845,20 +775,20 @@ class ChatTable: query = query.filter_by(archived=False) if filter: - query_key = filter.get("query") + query_key = filter.get('query') if query_key: - query = query.filter(Chat.title.ilike(f"%{query_key}%")) + query = query.filter(Chat.title.ilike(f'%{query_key}%')) - order_by = filter.get("order_by") - direction = filter.get("direction") + order_by = filter.get('order_by') + direction = filter.get('direction') if order_by and direction and getattr(Chat, order_by): - if direction.lower() == "asc": + if direction.lower() == 'asc': query = query.order_by(getattr(Chat, order_by).asc(), Chat.id) - elif direction.lower() == "desc": + elif direction.lower() == 'desc': query = query.order_by(getattr(Chat, order_by).desc(), Chat.id) else: - raise ValueError("Invalid direction for ordering") + raise ValueError('Invalid direction for ordering') else: query = query.order_by(Chat.updated_at.desc(), Chat.id) @@ -907,10 +837,10 @@ class ChatTable: return [ ChatTitleIdResponse.model_validate( { - "id": chat[0], - "title": chat[1], - "updated_at": chat[2], - "created_at": chat[3], + 'id': chat[0], + 'title': chat[1], + 'updated_at': chat[2], + 'created_at': chat[3], } ) for chat in all_chats @@ -933,9 +863,7 @@ class ChatTable: ) return [ChatModel.model_validate(chat) for chat in all_chats] - def get_chat_by_id( - self, id: str, db: Optional[Session] = None - ) -> Optional[ChatModel]: + def get_chat_by_id(self, id: str, db: Optional[Session] = None) -> Optional[ChatModel]: try: with get_db_context(db) as db: chat_item = db.get(Chat, id) @@ -950,9 +878,7 @@ class ChatTable: except Exception: return None - def get_chat_by_share_id( - self, id: str, db: Optional[Session] = None - ) -> Optional[ChatModel]: + def get_chat_by_share_id(self, id: str, db: Optional[Session] = None) -> Optional[ChatModel]: try: with get_db_context(db) as db: # it is possible that the shared link was deleted. hence, @@ -966,9 +892,7 @@ class ChatTable: except Exception: return None - def get_chat_by_id_and_user_id( - self, id: str, user_id: str, db: Optional[Session] = None - ) -> Optional[ChatModel]: + def get_chat_by_id_and_user_id(self, id: str, user_id: str, db: Optional[Session] = None) -> Optional[ChatModel]: try: with get_db_context(db) as db: chat = db.query(Chat).filter_by(id=id, user_id=user_id).first() @@ -976,40 +900,30 @@ class ChatTable: except Exception: return None - def is_chat_owner( - self, id: str, user_id: str, db: Optional[Session] = None - ) -> bool: + def is_chat_owner(self, id: str, user_id: str, db: Optional[Session] = None) -> bool: """ Lightweight ownership check — uses EXISTS subquery instead of loading the full Chat row (which includes the potentially large JSON blob). """ try: with get_db_context(db) as db: - return db.query( - exists().where(and_(Chat.id == id, Chat.user_id == user_id)) - ).scalar() + return db.query(exists().where(and_(Chat.id == id, Chat.user_id == user_id))).scalar() except Exception: return False - def get_chat_folder_id( - self, id: str, user_id: str, db: Optional[Session] = None - ) -> Optional[str]: + def get_chat_folder_id(self, id: str, user_id: str, db: Optional[Session] = None) -> Optional[str]: """ Fetch only the folder_id column for a chat, without loading the full JSON blob. Returns None if chat doesn't exist or doesn't belong to user. """ try: with get_db_context(db) as db: - result = ( - db.query(Chat.folder_id).filter_by(id=id, user_id=user_id).first() - ) + result = db.query(Chat.folder_id).filter_by(id=id, user_id=user_id).first() return result[0] if result else None except Exception: return None - def get_chats( - self, skip: int = 0, limit: int = 50, db: Optional[Session] = None - ) -> list[ChatModel]: + def get_chats(self, skip: int = 0, limit: int = 50, db: Optional[Session] = None) -> list[ChatModel]: with get_db_context(db) as db: all_chats = ( db.query(Chat) @@ -1030,22 +944,18 @@ class ChatTable: query = db.query(Chat).filter_by(user_id=user_id) if filter: - if filter.get("updated_at"): - query = query.filter(Chat.updated_at > filter.get("updated_at")) + if filter.get('updated_at'): + query = query.filter(Chat.updated_at > filter.get('updated_at')) - order_by = filter.get("order_by") - direction = filter.get("direction") + order_by = filter.get('order_by') + direction = filter.get('direction') if order_by and direction: if hasattr(Chat, order_by): - if direction.lower() == "asc": - query = query.order_by( - getattr(Chat, order_by).asc(), Chat.id - ) - elif direction.lower() == "desc": - query = query.order_by( - getattr(Chat, order_by).desc(), Chat.id - ) + if direction.lower() == 'asc': + query = query.order_by(getattr(Chat, order_by).asc(), Chat.id) + elif direction.lower() == 'desc': + query = query.order_by(getattr(Chat, order_by).desc(), Chat.id) else: query = query.order_by(Chat.updated_at.desc(), Chat.id) @@ -1063,14 +973,12 @@ class ChatTable: return ChatListResponse( **{ - "items": [ChatModel.model_validate(chat) for chat in all_chats], - "total": total, + 'items': [ChatModel.model_validate(chat) for chat in all_chats], + 'total': total, } ) - def get_pinned_chats_by_user_id( - self, user_id: str, db: Optional[Session] = None - ) -> list[ChatTitleIdResponse]: + def get_pinned_chats_by_user_id(self, user_id: str, db: Optional[Session] = None) -> list[ChatTitleIdResponse]: with get_db_context(db) as db: all_chats = ( db.query(Chat) @@ -1081,24 +989,18 @@ class ChatTable: return [ ChatTitleIdResponse.model_validate( { - "id": chat[0], - "title": chat[1], - "updated_at": chat[2], - "created_at": chat[3], + 'id': chat[0], + 'title': chat[1], + 'updated_at': chat[2], + 'created_at': chat[3], } ) for chat in all_chats ] - def get_archived_chats_by_user_id( - self, user_id: str, db: Optional[Session] = None - ) -> list[ChatModel]: + def get_archived_chats_by_user_id(self, user_id: str, db: Optional[Session] = None) -> list[ChatModel]: with get_db_context(db) as db: - all_chats = ( - db.query(Chat) - .filter_by(user_id=user_id, archived=True) - .order_by(Chat.updated_at.desc()) - ) + all_chats = db.query(Chat).filter_by(user_id=user_id, archived=True).order_by(Chat.updated_at.desc()) return [ChatModel.model_validate(chat) for chat in all_chats] def get_chats_by_user_id_and_search_text( @@ -1116,61 +1018,53 @@ class ChatTable: search_text = sanitize_text_for_db(search_text).lower().strip() if not search_text: - return self.get_chat_list_by_user_id( - user_id, include_archived, filter={}, skip=skip, limit=limit, db=db - ) + return self.get_chat_list_by_user_id(user_id, include_archived, filter={}, skip=skip, limit=limit, db=db) - search_text_words = search_text.split(" ") + search_text_words = search_text.split(' ') # search_text might contain 'tag:tag_name' format so we need to extract the tag_name, split the search_text and remove the tags tag_ids = [ - word.replace("tag:", "").replace(" ", "_").lower() - for word in search_text_words - if word.startswith("tag:") + word.replace('tag:', '').replace(' ', '_').lower() for word in search_text_words if word.startswith('tag:') ] # Extract folder names - handle spaces and case insensitivity folders = Folders.search_folders_by_names( user_id, - [ - word.replace("folder:", "") - for word in search_text_words - if word.startswith("folder:") - ], + [word.replace('folder:', '') for word in search_text_words if word.startswith('folder:')], ) folder_ids = [folder.id for folder in folders] is_pinned = None - if "pinned:true" in search_text_words: + if 'pinned:true' in search_text_words: is_pinned = True - elif "pinned:false" in search_text_words: + elif 'pinned:false' in search_text_words: is_pinned = False is_archived = None - if "archived:true" in search_text_words: + if 'archived:true' in search_text_words: is_archived = True - elif "archived:false" in search_text_words: + elif 'archived:false' in search_text_words: is_archived = False is_shared = None - if "shared:true" in search_text_words: + if 'shared:true' in search_text_words: is_shared = True - elif "shared:false" in search_text_words: + elif 'shared:false' in search_text_words: is_shared = False search_text_words = [ word for word in search_text_words if ( - not word.startswith("tag:") - and not word.startswith("folder:") - and not word.startswith("pinned:") - and not word.startswith("archived:") - and not word.startswith("shared:") + not word.startswith('tag:') + and not word.startswith('folder:') + and not word.startswith('pinned:') + and not word.startswith('archived:') + and not word.startswith('shared:') ) ] - search_text = " ".join(search_text_words) + search_text = ' '.join(search_text_words) with get_db_context(db) as db: query = db.query(Chat).filter(Chat.user_id == user_id) @@ -1196,30 +1090,32 @@ class ChatTable: # Check if the database dialect is either 'sqlite' or 'postgresql' dialect_name = db.bind.dialect.name - if dialect_name == "sqlite": + if dialect_name == 'sqlite': # SQLite case: using JSON1 extension for JSON searching sqlite_content_sql = ( - "EXISTS (" - " SELECT 1 " + 'EXISTS (' + ' SELECT 1 ' " FROM json_each(Chat.chat, '$.messages') AS message " " WHERE LOWER(message.value->>'content') LIKE '%' || :content_key || '%'" - ")" + ')' ) sqlite_content_clause = text(sqlite_content_sql) query = query.filter( - or_( - Chat.title.ilike(bindparam("title_key")), sqlite_content_clause - ).params(title_key=f"%{search_text}%", content_key=search_text) + or_(Chat.title.ilike(bindparam('title_key')), sqlite_content_clause).params( + title_key=f'%{search_text}%', content_key=search_text + ) ) # Check if there are any tags to filter, it should have all the tags - if "none" in tag_ids: - query = query.filter(text(""" + if 'none' in tag_ids: + query = query.filter( + text(""" NOT EXISTS ( SELECT 1 FROM json_each(Chat.meta, '$.tags') AS tag ) - """)) + """) + ) elif tag_ids: query = query.filter( and_( @@ -1230,13 +1126,13 @@ class ChatTable: FROM json_each(Chat.meta, '$.tags') AS tag WHERE tag.value = :tag_id_{tag_idx} ) - """).params(**{f"tag_id_{tag_idx}": tag_id}) + """).params(**{f'tag_id_{tag_idx}': tag_id}) for tag_idx, tag_id in enumerate(tag_ids) ] ) ) - elif dialect_name == "postgresql": + elif dialect_name == 'postgresql': # PostgreSQL doesn't allow null bytes in text. We filter those out by checking # the JSON representation for \u0000 before attempting text extraction @@ -1259,19 +1155,21 @@ class ChatTable: query = query.filter( or_( - Chat.title.ilike(bindparam("title_key")), + Chat.title.ilike(bindparam('title_key')), postgres_content_clause, ) - ).params(title_key=f"%{search_text}%", content_key=search_text.lower()) + ).params(title_key=f'%{search_text}%', content_key=search_text.lower()) # Check if there are any tags to filter, it should have all the tags - if "none" in tag_ids: - query = query.filter(text(""" + if 'none' in tag_ids: + query = query.filter( + text(""" NOT EXISTS ( SELECT 1 FROM json_array_elements_text(Chat.meta->'tags') AS tag ) - """)) + """) + ) elif tag_ids: query = query.filter( and_( @@ -1282,20 +1180,18 @@ class ChatTable: FROM json_array_elements_text(Chat.meta->'tags') AS tag WHERE tag = :tag_id_{tag_idx} ) - """).params(**{f"tag_id_{tag_idx}": tag_id}) + """).params(**{f'tag_id_{tag_idx}': tag_id}) for tag_idx, tag_id in enumerate(tag_ids) ] ) ) else: - raise NotImplementedError( - f"Unsupported dialect: {db.bind.dialect.name}" - ) + raise NotImplementedError(f'Unsupported dialect: {db.bind.dialect.name}') # Perform pagination at the SQL level all_chats = query.offset(skip).limit(limit).all() - log.info(f"The number of chats: {len(all_chats)}") + log.info(f'The number of chats: {len(all_chats)}') # Validate and return chats return [ChatModel.model_validate(chat) for chat in all_chats] @@ -1327,9 +1223,7 @@ class ChatTable: self, folder_ids: list[str], user_id: str, db: Optional[Session] = None ) -> list[ChatModel]: with get_db_context(db) as db: - query = db.query(Chat).filter( - Chat.folder_id.in_(folder_ids), Chat.user_id == user_id - ) + query = db.query(Chat).filter(Chat.folder_id.in_(folder_ids), Chat.user_id == user_id) query = query.filter(or_(Chat.pinned == False, Chat.pinned == None)) query = query.filter_by(archived=False) @@ -1353,12 +1247,10 @@ class ChatTable: except Exception: return None - def get_chat_tags_by_id_and_user_id( - self, id: str, user_id: str, db: Optional[Session] = None - ) -> list[TagModel]: + def get_chat_tags_by_id_and_user_id(self, id: str, user_id: str, db: Optional[Session] = None) -> list[TagModel]: with get_db_context(db) as db: chat = db.get(Chat, id) - tag_ids = chat.meta.get("tags", []) + tag_ids = chat.meta.get('tags', []) return Tags.get_tags_by_ids_and_user_id(tag_ids, user_id, db=db) def get_chat_list_by_user_id_and_tag_name( @@ -1371,44 +1263,38 @@ class ChatTable: ) -> list[ChatModel]: with get_db_context(db) as db: query = db.query(Chat).filter_by(user_id=user_id) - tag_id = tag_name.replace(" ", "_").lower() + tag_id = tag_name.replace(' ', '_').lower() - log.info(f"DB dialect name: {db.bind.dialect.name}") - if db.bind.dialect.name == "sqlite": + log.info(f'DB dialect name: {db.bind.dialect.name}') + if db.bind.dialect.name == 'sqlite': # SQLite JSON1 querying for tags within the meta JSON field query = query.filter( - text( - f"EXISTS (SELECT 1 FROM json_each(Chat.meta, '$.tags') WHERE json_each.value = :tag_id)" - ) + text(f"EXISTS (SELECT 1 FROM json_each(Chat.meta, '$.tags') WHERE json_each.value = :tag_id)") ).params(tag_id=tag_id) - elif db.bind.dialect.name == "postgresql": + elif db.bind.dialect.name == 'postgresql': # PostgreSQL JSON query for tags within the meta JSON field (for `json` type) query = query.filter( - text( - "EXISTS (SELECT 1 FROM json_array_elements_text(Chat.meta->'tags') elem WHERE elem = :tag_id)" - ) + text("EXISTS (SELECT 1 FROM json_array_elements_text(Chat.meta->'tags') elem WHERE elem = :tag_id)") ).params(tag_id=tag_id) else: - raise NotImplementedError( - f"Unsupported dialect: {db.bind.dialect.name}" - ) + raise NotImplementedError(f'Unsupported dialect: {db.bind.dialect.name}') all_chats = query.all() - log.debug(f"all_chats: {all_chats}") + log.debug(f'all_chats: {all_chats}') return [ChatModel.model_validate(chat) for chat in all_chats] def add_chat_tag_by_id_and_user_id_and_tag_name( self, id: str, user_id: str, tag_name: str, db: Optional[Session] = None ) -> Optional[ChatModel]: - tag_id = tag_name.replace(" ", "_").lower() + tag_id = tag_name.replace(' ', '_').lower() Tags.ensure_tags_exist([tag_name], user_id, db=db) try: with get_db_context(db) as db: chat = db.get(Chat, id) - if tag_id not in chat.meta.get("tags", []): + if tag_id not in chat.meta.get('tags', []): chat.meta = { **chat.meta, - "tags": list(set(chat.meta.get("tags", []) + [tag_id])), + 'tags': list(set(chat.meta.get('tags', []) + [tag_id])), } db.commit() db.refresh(chat) @@ -1416,29 +1302,21 @@ class ChatTable: except Exception: return None - def count_chats_by_tag_name_and_user_id( - self, tag_name: str, user_id: str, db: Optional[Session] = None - ) -> int: + def count_chats_by_tag_name_and_user_id(self, tag_name: str, user_id: str, db: Optional[Session] = None) -> int: with get_db_context(db) as db: query = db.query(Chat).filter_by(user_id=user_id, archived=False) - tag_id = tag_name.replace(" ", "_").lower() + tag_id = tag_name.replace(' ', '_').lower() - if db.bind.dialect.name == "sqlite": + if db.bind.dialect.name == 'sqlite': query = query.filter( - text( - "EXISTS (SELECT 1 FROM json_each(Chat.meta, '$.tags') WHERE json_each.value = :tag_id)" - ) + text("EXISTS (SELECT 1 FROM json_each(Chat.meta, '$.tags') WHERE json_each.value = :tag_id)") ).params(tag_id=tag_id) - elif db.bind.dialect.name == "postgresql": + elif db.bind.dialect.name == 'postgresql': query = query.filter( - text( - "EXISTS (SELECT 1 FROM json_array_elements_text(Chat.meta->'tags') elem WHERE elem = :tag_id)" - ) + text("EXISTS (SELECT 1 FROM json_array_elements_text(Chat.meta->'tags') elem WHERE elem = :tag_id)") ).params(tag_id=tag_id) else: - raise NotImplementedError( - f"Unsupported dialect: {db.bind.dialect.name}" - ) + raise NotImplementedError(f'Unsupported dialect: {db.bind.dialect.name}') return query.count() @@ -1467,9 +1345,7 @@ class ChatTable: orphans.append(tag_id) Tags.delete_tags_by_ids_and_user_id(orphans, user_id, db=db) - def count_chats_by_folder_id_and_user_id( - self, folder_id: str, user_id: str, db: Optional[Session] = None - ) -> int: + def count_chats_by_folder_id_and_user_id(self, folder_id: str, user_id: str, db: Optional[Session] = None) -> int: with get_db_context(db) as db: query = db.query(Chat).filter_by(user_id=user_id) @@ -1485,28 +1361,26 @@ class ChatTable: try: with get_db_context(db) as db: chat = db.get(Chat, id) - tags = chat.meta.get("tags", []) - tag_id = tag_name.replace(" ", "_").lower() + tags = chat.meta.get('tags', []) + tag_id = tag_name.replace(' ', '_').lower() tags = [tag for tag in tags if tag != tag_id] chat.meta = { **chat.meta, - "tags": list(set(tags)), + 'tags': list(set(tags)), } db.commit() return True except Exception: return False - def delete_all_tags_by_id_and_user_id( - self, id: str, user_id: str, db: Optional[Session] = None - ) -> bool: + def delete_all_tags_by_id_and_user_id(self, id: str, user_id: str, db: Optional[Session] = None) -> bool: try: with get_db_context(db) as db: chat = db.get(Chat, id) chat.meta = { **chat.meta, - "tags": [], + 'tags': [], } db.commit() @@ -1525,9 +1399,7 @@ class ChatTable: except Exception: return False - def delete_chat_by_id_and_user_id( - self, id: str, user_id: str, db: Optional[Session] = None - ) -> bool: + def delete_chat_by_id_and_user_id(self, id: str, user_id: str, db: Optional[Session] = None) -> bool: try: with get_db_context(db) as db: db.query(ChatMessage).filter_by(chat_id=id).delete() @@ -1538,19 +1410,15 @@ class ChatTable: except Exception: return False - def delete_chats_by_user_id( - self, user_id: str, db: Optional[Session] = None - ) -> bool: + def delete_chats_by_user_id(self, user_id: str, db: Optional[Session] = None) -> bool: try: with get_db_context(db) as db: self.delete_shared_chats_by_user_id(user_id, db=db) - chat_id_subquery = ( - db.query(Chat.id).filter_by(user_id=user_id).subquery() + chat_id_subquery = db.query(Chat.id).filter_by(user_id=user_id).subquery() + db.query(ChatMessage).filter(ChatMessage.chat_id.in_(chat_id_subquery)).delete( + synchronize_session=False ) - db.query(ChatMessage).filter( - ChatMessage.chat_id.in_(chat_id_subquery) - ).delete(synchronize_session=False) db.query(Chat).filter_by(user_id=user_id).delete() db.commit() @@ -1558,19 +1426,13 @@ class ChatTable: except Exception: return False - def delete_chats_by_user_id_and_folder_id( - self, user_id: str, folder_id: str, db: Optional[Session] = None - ) -> bool: + def delete_chats_by_user_id_and_folder_id(self, user_id: str, folder_id: str, db: Optional[Session] = None) -> bool: try: with get_db_context(db) as db: - chat_id_subquery = ( - db.query(Chat.id) - .filter_by(user_id=user_id, folder_id=folder_id) - .subquery() + chat_id_subquery = db.query(Chat.id).filter_by(user_id=user_id, folder_id=folder_id).subquery() + db.query(ChatMessage).filter(ChatMessage.chat_id.in_(chat_id_subquery)).delete( + synchronize_session=False ) - db.query(ChatMessage).filter( - ChatMessage.chat_id.in_(chat_id_subquery) - ).delete(synchronize_session=False) db.query(Chat).filter_by(user_id=user_id, folder_id=folder_id).delete() db.commit() @@ -1587,32 +1449,22 @@ class ChatTable: ) -> bool: try: with get_db_context(db) as db: - db.query(Chat).filter_by(user_id=user_id, folder_id=folder_id).update( - {"folder_id": new_folder_id} - ) + db.query(Chat).filter_by(user_id=user_id, folder_id=folder_id).update({'folder_id': new_folder_id}) db.commit() return True except Exception: return False - def delete_shared_chats_by_user_id( - self, user_id: str, db: Optional[Session] = None - ) -> bool: + def delete_shared_chats_by_user_id(self, user_id: str, db: Optional[Session] = None) -> bool: try: with get_db_context(db) as db: chats_by_user = db.query(Chat).filter_by(user_id=user_id).all() - shared_chat_ids = [f"shared-{chat.id}" for chat in chats_by_user] + shared_chat_ids = [f'shared-{chat.id}' for chat in chats_by_user] # Use subquery to delete chat_messages for shared chats - shared_id_subq = ( - db.query(Chat.id) - .filter(Chat.user_id.in_(shared_chat_ids)) - .subquery() - ) - db.query(ChatMessage).filter( - ChatMessage.chat_id.in_(shared_id_subq) - ).delete(synchronize_session=False) + shared_id_subq = db.query(Chat.id).filter(Chat.user_id.in_(shared_chat_ids)).subquery() + db.query(ChatMessage).filter(ChatMessage.chat_id.in_(shared_id_subq)).delete(synchronize_session=False) db.query(Chat).filter(Chat.user_id.in_(shared_chat_ids)).delete() db.commit() @@ -1632,21 +1484,10 @@ class ChatTable: return None chat_message_file_ids = [ - item.id - for item in self.get_chat_files_by_chat_id_and_message_id( - chat_id, message_id, db=db - ) + item.id for item in self.get_chat_files_by_chat_id_and_message_id(chat_id, message_id, db=db) ] # Remove duplicates and existing file_ids - file_ids = list( - set( - [ - file_id - for file_id in file_ids - if file_id and file_id not in chat_message_file_ids - ] - ) - ) + file_ids = list(set([file_id for file_id in file_ids if file_id and file_id not in chat_message_file_ids])) if not file_ids: return None @@ -1667,9 +1508,7 @@ class ChatTable: for file_id in file_ids ] - results = [ - ChatFile(**chat_file.model_dump()) for chat_file in chat_files - ] + results = [ChatFile(**chat_file.model_dump()) for chat_file in chat_files] db.add_all(results) db.commit() @@ -1688,13 +1527,9 @@ class ChatTable: .order_by(ChatFile.created_at.asc()) .all() ) - return [ - ChatFileModel.model_validate(chat_file) for chat_file in all_chat_files - ] + return [ChatFileModel.model_validate(chat_file) for chat_file in all_chat_files] - def delete_chat_file( - self, chat_id: str, file_id: str, db: Optional[Session] = None - ) -> bool: + def delete_chat_file(self, chat_id: str, file_id: str, db: Optional[Session] = None) -> bool: try: with get_db_context(db) as db: db.query(ChatFile).filter_by(chat_id=chat_id, file_id=file_id).delete() @@ -1703,9 +1538,7 @@ class ChatTable: except Exception: return False - def get_shared_chats_by_file_id( - self, file_id: str, db: Optional[Session] = None - ) -> list[ChatModel]: + def get_shared_chats_by_file_id(self, file_id: str, db: Optional[Session] = None) -> list[ChatModel]: with get_db_context(db) as db: # Join Chat and ChatFile tables to get shared chats associated with the file_id all_chats = ( diff --git a/backend/open_webui/models/feedbacks.py b/backend/open_webui/models/feedbacks.py index 406adb2559..aa6c7bdcae 100644 --- a/backend/open_webui/models/feedbacks.py +++ b/backend/open_webui/models/feedbacks.py @@ -19,7 +19,7 @@ log = logging.getLogger(__name__) class Feedback(Base): - __tablename__ = "feedback" + __tablename__ = 'feedback' id = Column(Text, primary_key=True, unique=True) user_id = Column(Text) version = Column(BigInteger, default=0) @@ -81,7 +81,7 @@ class RatingData(BaseModel): sibling_model_ids: Optional[list[str]] = None reason: Optional[str] = None comment: Optional[str] = None - model_config = ConfigDict(extra="allow", protected_namespaces=()) + model_config = ConfigDict(extra='allow', protected_namespaces=()) class MetaData(BaseModel): @@ -89,12 +89,12 @@ class MetaData(BaseModel): chat_id: Optional[str] = None message_id: Optional[str] = None tags: Optional[list[str]] = None - model_config = ConfigDict(extra="allow") + model_config = ConfigDict(extra='allow') class SnapshotData(BaseModel): chat: Optional[dict] = None - model_config = ConfigDict(extra="allow") + model_config = ConfigDict(extra='allow') class FeedbackForm(BaseModel): @@ -102,14 +102,14 @@ class FeedbackForm(BaseModel): data: Optional[RatingData] = None meta: Optional[dict] = None snapshot: Optional[SnapshotData] = None - model_config = ConfigDict(extra="allow") + model_config = ConfigDict(extra='allow') class UserResponse(BaseModel): id: str name: str email: str - role: str = "pending" + role: str = 'pending' last_active_at: int # timestamp in epoch updated_at: int # timestamp in epoch @@ -146,12 +146,12 @@ class FeedbackTable: id = str(uuid.uuid4()) feedback = FeedbackModel( **{ - "id": id, - "user_id": user_id, - "version": 0, + 'id': id, + 'user_id': user_id, + 'version': 0, **form_data.model_dump(), - "created_at": int(time.time()), - "updated_at": int(time.time()), + 'created_at': int(time.time()), + 'updated_at': int(time.time()), } ) try: @@ -164,12 +164,10 @@ class FeedbackTable: else: return None except Exception as e: - log.exception(f"Error creating a new feedback: {e}") + log.exception(f'Error creating a new feedback: {e}') return None - def get_feedback_by_id( - self, id: str, db: Optional[Session] = None - ) -> Optional[FeedbackModel]: + def get_feedback_by_id(self, id: str, db: Optional[Session] = None) -> Optional[FeedbackModel]: try: with get_db_context(db) as db: feedback = db.query(Feedback).filter_by(id=id).first() @@ -191,16 +189,14 @@ class FeedbackTable: except Exception: return None - def get_feedbacks_by_chat_id( - self, chat_id: str, db: Optional[Session] = None - ) -> list[FeedbackModel]: + def get_feedbacks_by_chat_id(self, chat_id: str, db: Optional[Session] = None) -> list[FeedbackModel]: """Get all feedbacks for a specific chat.""" try: with get_db_context(db) as db: # meta.chat_id stores the chat reference feedbacks = ( db.query(Feedback) - .filter(Feedback.meta["chat_id"].as_string() == chat_id) + .filter(Feedback.meta['chat_id'].as_string() == chat_id) .order_by(Feedback.created_at.desc()) .all() ) @@ -219,36 +215,28 @@ class FeedbackTable: query = db.query(Feedback, User).join(User, Feedback.user_id == User.id) if filter: - order_by = filter.get("order_by") - direction = filter.get("direction") + order_by = filter.get('order_by') + direction = filter.get('direction') - if order_by == "username": - if direction == "asc": + if order_by == 'username': + if direction == 'asc': query = query.order_by(User.name.asc()) else: query = query.order_by(User.name.desc()) - elif order_by == "model_id": + elif order_by == 'model_id': # it's stored in feedback.data['model_id'] - if direction == "asc": - query = query.order_by( - Feedback.data["model_id"].as_string().asc() - ) + if direction == 'asc': + query = query.order_by(Feedback.data['model_id'].as_string().asc()) else: - query = query.order_by( - Feedback.data["model_id"].as_string().desc() - ) - elif order_by == "rating": + query = query.order_by(Feedback.data['model_id'].as_string().desc()) + elif order_by == 'rating': # it's stored in feedback.data['rating'] - if direction == "asc": - query = query.order_by( - Feedback.data["rating"].as_string().asc() - ) + if direction == 'asc': + query = query.order_by(Feedback.data['rating'].as_string().asc()) else: - query = query.order_by( - Feedback.data["rating"].as_string().desc() - ) - elif order_by == "updated_at": - if direction == "asc": + query = query.order_by(Feedback.data['rating'].as_string().desc()) + elif order_by == 'updated_at': + if direction == 'asc': query = query.order_by(Feedback.updated_at.asc()) else: query = query.order_by(Feedback.updated_at.desc()) @@ -270,9 +258,7 @@ class FeedbackTable: for feedback, user in items: feedback_model = FeedbackModel.model_validate(feedback) user_model = UserResponse.model_validate(user) - feedbacks.append( - FeedbackUserResponse(**feedback_model.model_dump(), user=user_model) - ) + feedbacks.append(FeedbackUserResponse(**feedback_model.model_dump(), user=user_model)) return FeedbackListResponse(items=feedbacks, total=total) @@ -280,14 +266,10 @@ class FeedbackTable: with get_db_context(db) as db: return [ FeedbackModel.model_validate(feedback) - for feedback in db.query(Feedback) - .order_by(Feedback.updated_at.desc()) - .all() + for feedback in db.query(Feedback).order_by(Feedback.updated_at.desc()).all() ] - def get_all_feedback_ids( - self, db: Optional[Session] = None - ) -> list[FeedbackIdResponse]: + def get_all_feedback_ids(self, db: Optional[Session] = None) -> list[FeedbackIdResponse]: with get_db_context(db) as db: return [ FeedbackIdResponse( @@ -306,14 +288,11 @@ class FeedbackTable: .all() ] - def get_feedbacks_for_leaderboard( - self, db: Optional[Session] = None - ) -> list[LeaderboardFeedbackData]: + def get_feedbacks_for_leaderboard(self, db: Optional[Session] = None) -> list[LeaderboardFeedbackData]: """Fetch only id and data for leaderboard computation (excludes snapshot/meta).""" with get_db_context(db) as db: return [ - LeaderboardFeedbackData(id=row.id, data=row.data) - for row in db.query(Feedback.id, Feedback.data).all() + LeaderboardFeedbackData(id=row.id, data=row.data) for row in db.query(Feedback.id, Feedback.data).all() ] def get_model_evaluation_history( @@ -333,30 +312,26 @@ class FeedbackTable: rows = db.query(Feedback.created_at, Feedback.data).all() else: cutoff = int(time.time()) - (days * 86400) - rows = ( - db.query(Feedback.created_at, Feedback.data) - .filter(Feedback.created_at >= cutoff) - .all() - ) + rows = db.query(Feedback.created_at, Feedback.data).filter(Feedback.created_at >= cutoff).all() - daily_counts = defaultdict(lambda: {"won": 0, "lost": 0}) + daily_counts = defaultdict(lambda: {'won': 0, 'lost': 0}) first_date = None for created_at, data in rows: if not data: continue - if data.get("model_id") != model_id: + if data.get('model_id') != model_id: continue - rating_str = str(data.get("rating", "")) - if rating_str not in ("1", "-1"): + rating_str = str(data.get('rating', '')) + if rating_str not in ('1', '-1'): continue - date_str = datetime.fromtimestamp(created_at).strftime("%Y-%m-%d") - if rating_str == "1": - daily_counts[date_str]["won"] += 1 + date_str = datetime.fromtimestamp(created_at).strftime('%Y-%m-%d') + if rating_str == '1': + daily_counts[date_str]['won'] += 1 else: - daily_counts[date_str]["lost"] += 1 + daily_counts[date_str]['lost'] += 1 # Track first date for this model if first_date is None or date_str < first_date: @@ -368,7 +343,7 @@ class FeedbackTable: if days == 0 and first_date: # All time: start from first feedback date - start_date = datetime.strptime(first_date, "%Y-%m-%d").date() + start_date = datetime.strptime(first_date, '%Y-%m-%d').date() num_days = (today - start_date).days + 1 else: # Fixed range @@ -377,36 +352,24 @@ class FeedbackTable: for i in range(num_days): d = start_date + timedelta(days=i) - date_str = d.strftime("%Y-%m-%d") - counts = daily_counts.get(date_str, {"won": 0, "lost": 0}) - result.append( - ModelHistoryEntry(date=date_str, won=counts["won"], lost=counts["lost"]) - ) + date_str = d.strftime('%Y-%m-%d') + counts = daily_counts.get(date_str, {'won': 0, 'lost': 0}) + result.append(ModelHistoryEntry(date=date_str, won=counts['won'], lost=counts['lost'])) return result - def get_feedbacks_by_type( - self, type: str, db: Optional[Session] = None - ) -> list[FeedbackModel]: + def get_feedbacks_by_type(self, type: str, db: Optional[Session] = None) -> list[FeedbackModel]: with get_db_context(db) as db: return [ FeedbackModel.model_validate(feedback) - for feedback in db.query(Feedback) - .filter_by(type=type) - .order_by(Feedback.updated_at.desc()) - .all() + for feedback in db.query(Feedback).filter_by(type=type).order_by(Feedback.updated_at.desc()).all() ] - def get_feedbacks_by_user_id( - self, user_id: str, db: Optional[Session] = None - ) -> list[FeedbackModel]: + def get_feedbacks_by_user_id(self, user_id: str, db: Optional[Session] = None) -> list[FeedbackModel]: with get_db_context(db) as db: return [ FeedbackModel.model_validate(feedback) - for feedback in db.query(Feedback) - .filter_by(user_id=user_id) - .order_by(Feedback.updated_at.desc()) - .all() + for feedback in db.query(Feedback).filter_by(user_id=user_id).order_by(Feedback.updated_at.desc()).all() ] def update_feedback_by_id( @@ -462,9 +425,7 @@ class FeedbackTable: db.commit() return True - def delete_feedback_by_id_and_user_id( - self, id: str, user_id: str, db: Optional[Session] = None - ) -> bool: + def delete_feedback_by_id_and_user_id(self, id: str, user_id: str, db: Optional[Session] = None) -> bool: with get_db_context(db) as db: feedback = db.query(Feedback).filter_by(id=id, user_id=user_id).first() if not feedback: @@ -473,9 +434,7 @@ class FeedbackTable: db.commit() return True - def delete_feedbacks_by_user_id( - self, user_id: str, db: Optional[Session] = None - ) -> bool: + def delete_feedbacks_by_user_id(self, user_id: str, db: Optional[Session] = None) -> bool: with get_db_context(db) as db: result = db.query(Feedback).filter_by(user_id=user_id).delete() db.commit() diff --git a/backend/open_webui/models/files.py b/backend/open_webui/models/files.py index d1200e4f4c..c02752f130 100644 --- a/backend/open_webui/models/files.py +++ b/backend/open_webui/models/files.py @@ -16,7 +16,7 @@ log = logging.getLogger(__name__) class File(Base): - __tablename__ = "file" + __tablename__ = 'file' id = Column(String, primary_key=True, unique=True) user_id = Column(String) hash = Column(Text, nullable=True) @@ -58,9 +58,9 @@ class FileMeta(BaseModel): content_type: Optional[str] = None size: Optional[int] = None - model_config = ConfigDict(extra="allow") + model_config = ConfigDict(extra='allow') - @model_validator(mode="before") + @model_validator(mode='before') @classmethod def sanitize_meta(cls, data): """Sanitize metadata fields to handle malformed legacy data.""" @@ -68,14 +68,12 @@ class FileMeta(BaseModel): return data # Handle content_type that may be a list like ['application/pdf', None] - content_type = data.get("content_type") + content_type = data.get('content_type') if isinstance(content_type, list): # Extract first non-None string value - data["content_type"] = next( - (item for item in content_type if isinstance(item, str)), None - ) + data['content_type'] = next((item for item in content_type if isinstance(item, str)), None) elif content_type is not None and not isinstance(content_type, str): - data["content_type"] = None + data['content_type'] = None return data @@ -92,7 +90,7 @@ class FileModelResponse(BaseModel): created_at: int # timestamp in epoch updated_at: Optional[int] = None # timestamp in epoch, optional for legacy files - model_config = ConfigDict(extra="allow") + model_config = ConfigDict(extra='allow') class FileMetadataResponse(BaseModel): @@ -123,25 +121,22 @@ class FileUpdateForm(BaseModel): meta: Optional[dict] = None - class FilesTable: - def insert_new_file( - self, user_id: str, form_data: FileForm, db: Optional[Session] = None - ) -> Optional[FileModel]: + def insert_new_file(self, user_id: str, form_data: FileForm, db: Optional[Session] = None) -> Optional[FileModel]: with get_db_context(db) as db: file_data = form_data.model_dump() # Sanitize meta to remove non-JSON-serializable objects # (e.g. callable tool functions, MCP client instances from middleware) - if file_data.get("meta"): - file_data["meta"] = sanitize_metadata(file_data["meta"]) + if file_data.get('meta'): + file_data['meta'] = sanitize_metadata(file_data['meta']) file = FileModel( **{ **file_data, - "user_id": user_id, - "created_at": int(time.time()), - "updated_at": int(time.time()), + 'user_id': user_id, + 'created_at': int(time.time()), + 'updated_at': int(time.time()), } ) @@ -155,12 +150,10 @@ class FilesTable: else: return None except Exception as e: - log.exception(f"Error inserting a new file: {e}") + log.exception(f'Error inserting a new file: {e}') return None - def get_file_by_id( - self, id: str, db: Optional[Session] = None - ) -> Optional[FileModel]: + def get_file_by_id(self, id: str, db: Optional[Session] = None) -> Optional[FileModel]: try: with get_db_context(db) as db: try: @@ -171,9 +164,7 @@ class FilesTable: except Exception: return None - def get_file_by_id_and_user_id( - self, id: str, user_id: str, db: Optional[Session] = None - ) -> Optional[FileModel]: + def get_file_by_id_and_user_id(self, id: str, user_id: str, db: Optional[Session] = None) -> Optional[FileModel]: with get_db_context(db) as db: try: file = db.query(File).filter_by(id=id, user_id=user_id).first() @@ -184,9 +175,7 @@ class FilesTable: except Exception: return None - def get_file_metadata_by_id( - self, id: str, db: Optional[Session] = None - ) -> Optional[FileMetadataResponse]: + def get_file_metadata_by_id(self, id: str, db: Optional[Session] = None) -> Optional[FileMetadataResponse]: with get_db_context(db) as db: try: file = db.get(File, id) @@ -204,9 +193,7 @@ class FilesTable: with get_db_context(db) as db: return [FileModel.model_validate(file) for file in db.query(File).all()] - def check_access_by_user_id( - self, id, user_id, permission="write", db: Optional[Session] = None - ) -> bool: + def check_access_by_user_id(self, id, user_id, permission='write', db: Optional[Session] = None) -> bool: file = self.get_file_by_id(id, db=db) if not file: return False @@ -215,21 +202,14 @@ class FilesTable: # Implement additional access control logic here as needed return False - def get_files_by_ids( - self, ids: list[str], db: Optional[Session] = None - ) -> list[FileModel]: + def get_files_by_ids(self, ids: list[str], db: Optional[Session] = None) -> list[FileModel]: with get_db_context(db) as db: return [ FileModel.model_validate(file) - for file in db.query(File) - .filter(File.id.in_(ids)) - .order_by(File.updated_at.desc()) - .all() + for file in db.query(File).filter(File.id.in_(ids)).order_by(File.updated_at.desc()).all() ] - def get_file_metadatas_by_ids( - self, ids: list[str], db: Optional[Session] = None - ) -> list[FileMetadataResponse]: + def get_file_metadatas_by_ids(self, ids: list[str], db: Optional[Session] = None) -> list[FileMetadataResponse]: with get_db_context(db) as db: return [ FileMetadataResponse( @@ -239,22 +219,15 @@ class FilesTable: created_at=file.created_at, updated_at=file.updated_at, ) - for file in db.query( - File.id, File.hash, File.meta, File.created_at, File.updated_at - ) + for file in db.query(File.id, File.hash, File.meta, File.created_at, File.updated_at) .filter(File.id.in_(ids)) .order_by(File.updated_at.desc()) .all() ] - def get_files_by_user_id( - self, user_id: str, db: Optional[Session] = None - ) -> list[FileModel]: + def get_files_by_user_id(self, user_id: str, db: Optional[Session] = None) -> list[FileModel]: with get_db_context(db) as db: - return [ - FileModel.model_validate(file) - for file in db.query(File).filter_by(user_id=user_id).all() - ] + return [FileModel.model_validate(file) for file in db.query(File).filter_by(user_id=user_id).all()] def get_file_list( self, @@ -262,7 +235,7 @@ class FilesTable: skip: int = 0, limit: int = 50, db: Optional[Session] = None, - ) -> "FileListResponse": + ) -> 'FileListResponse': with get_db_context(db) as db: query = db.query(File) if user_id: @@ -272,10 +245,7 @@ class FilesTable: items = [ FileModel.model_validate(file) - for file in query.order_by(File.updated_at.desc(), File.id.desc()) - .offset(skip) - .limit(limit) - .all() + for file in query.order_by(File.updated_at.desc(), File.id.desc()).offset(skip).limit(limit).all() ] return FileListResponse(items=items, total=total) @@ -296,17 +266,17 @@ class FilesTable: A SQL LIKE compatible pattern with proper escaping. """ # Escape SQL special characters first, then convert glob wildcards - pattern = glob.replace("\\", "\\\\") - pattern = pattern.replace("%", "\\%") - pattern = pattern.replace("_", "\\_") - pattern = pattern.replace("*", "%") - pattern = pattern.replace("?", "_") + pattern = glob.replace('\\', '\\\\') + pattern = pattern.replace('%', '\\%') + pattern = pattern.replace('_', '\\_') + pattern = pattern.replace('*', '%') + pattern = pattern.replace('?', '_') return pattern def search_files( self, user_id: Optional[str] = None, - filename: str = "*", + filename: str = '*', skip: int = 0, limit: int = 100, db: Optional[Session] = None, @@ -331,15 +301,12 @@ class FilesTable: query = query.filter_by(user_id=user_id) pattern = self._glob_to_like_pattern(filename) - if pattern != "%": - query = query.filter(File.filename.ilike(pattern, escape="\\")) + if pattern != '%': + query = query.filter(File.filename.ilike(pattern, escape='\\')) return [ FileModel.model_validate(file) - for file in query.order_by(File.created_at.desc(), File.id.desc()) - .offset(skip) - .limit(limit) - .all() + for file in query.order_by(File.created_at.desc(), File.id.desc()).offset(skip).limit(limit).all() ] def update_file_by_id( @@ -362,12 +329,10 @@ class FilesTable: db.commit() return FileModel.model_validate(file) except Exception as e: - log.exception(f"Error updating file completely by id: {e}") + log.exception(f'Error updating file completely by id: {e}') return None - def update_file_hash_by_id( - self, id: str, hash: Optional[str], db: Optional[Session] = None - ) -> Optional[FileModel]: + def update_file_hash_by_id(self, id: str, hash: Optional[str], db: Optional[Session] = None) -> Optional[FileModel]: with get_db_context(db) as db: try: file = db.query(File).filter_by(id=id).first() @@ -379,9 +344,7 @@ class FilesTable: except Exception: return None - def update_file_data_by_id( - self, id: str, data: dict, db: Optional[Session] = None - ) -> Optional[FileModel]: + def update_file_data_by_id(self, id: str, data: dict, db: Optional[Session] = None) -> Optional[FileModel]: with get_db_context(db) as db: try: file = db.query(File).filter_by(id=id).first() @@ -390,12 +353,9 @@ class FilesTable: db.commit() return FileModel.model_validate(file) except Exception as e: - return None - def update_file_metadata_by_id( - self, id: str, meta: dict, db: Optional[Session] = None - ) -> Optional[FileModel]: + def update_file_metadata_by_id(self, id: str, meta: dict, db: Optional[Session] = None) -> Optional[FileModel]: with get_db_context(db) as db: try: file = db.query(File).filter_by(id=id).first() diff --git a/backend/open_webui/models/folders.py b/backend/open_webui/models/folders.py index b491b831f2..5311f922e2 100644 --- a/backend/open_webui/models/folders.py +++ b/backend/open_webui/models/folders.py @@ -20,7 +20,7 @@ log = logging.getLogger(__name__) class Folder(Base): - __tablename__ = "folder" + __tablename__ = 'folder' id = Column(Text, primary_key=True, unique=True) parent_id = Column(Text, nullable=True) user_id = Column(Text) @@ -72,14 +72,14 @@ class FolderForm(BaseModel): data: Optional[dict] = None meta: Optional[dict] = None parent_id: Optional[str] = None - model_config = ConfigDict(extra="allow") + model_config = ConfigDict(extra='allow') class FolderUpdateForm(BaseModel): name: Optional[str] = None data: Optional[dict] = None meta: Optional[dict] = None - model_config = ConfigDict(extra="allow") + model_config = ConfigDict(extra='allow') class FolderTable: @@ -94,12 +94,12 @@ class FolderTable: id = str(uuid.uuid4()) folder = FolderModel( **{ - "id": id, - "user_id": user_id, + 'id': id, + 'user_id': user_id, **(form_data.model_dump(exclude_unset=True) or {}), - "parent_id": parent_id, - "created_at": int(time.time()), - "updated_at": int(time.time()), + 'parent_id': parent_id, + 'created_at': int(time.time()), + 'updated_at': int(time.time()), } ) try: @@ -112,7 +112,7 @@ class FolderTable: else: return None except Exception as e: - log.exception(f"Error inserting a new folder: {e}") + log.exception(f'Error inserting a new folder: {e}') return None def get_folder_by_id_and_user_id( @@ -137,9 +137,7 @@ class FolderTable: folders = [] def get_children(folder): - children = self.get_folders_by_parent_id_and_user_id( - folder.id, user_id, db=db - ) + children = self.get_folders_by_parent_id_and_user_id(folder.id, user_id, db=db) for child in children: get_children(child) folders.append(child) @@ -153,14 +151,9 @@ class FolderTable: except Exception: return None - def get_folders_by_user_id( - self, user_id: str, db: Optional[Session] = None - ) -> list[FolderModel]: + def get_folders_by_user_id(self, user_id: str, db: Optional[Session] = None) -> list[FolderModel]: with get_db_context(db) as db: - return [ - FolderModel.model_validate(folder) - for folder in db.query(Folder).filter_by(user_id=user_id).all() - ] + return [FolderModel.model_validate(folder) for folder in db.query(Folder).filter_by(user_id=user_id).all()] def get_folder_by_parent_id_and_user_id_and_name( self, @@ -184,7 +177,7 @@ class FolderTable: return FolderModel.model_validate(folder) except Exception as e: - log.error(f"get_folder_by_parent_id_and_user_id_and_name: {e}") + log.error(f'get_folder_by_parent_id_and_user_id_and_name: {e}') return None def get_folders_by_parent_id_and_user_id( @@ -193,9 +186,7 @@ class FolderTable: with get_db_context(db) as db: return [ FolderModel.model_validate(folder) - for folder in db.query(Folder) - .filter_by(parent_id=parent_id, user_id=user_id) - .all() + for folder in db.query(Folder).filter_by(parent_id=parent_id, user_id=user_id).all() ] def update_folder_parent_id_by_id_and_user_id( @@ -219,7 +210,7 @@ class FolderTable: return FolderModel.model_validate(folder) except Exception as e: - log.error(f"update_folder: {e}") + log.error(f'update_folder: {e}') return def update_folder_by_id_and_user_id( @@ -241,7 +232,7 @@ class FolderTable: existing_folder = ( db.query(Folder) .filter_by( - name=form_data.get("name"), + name=form_data.get('name'), parent_id=folder.parent_id, user_id=user_id, ) @@ -251,17 +242,17 @@ class FolderTable: if existing_folder and existing_folder.id != id: return None - folder.name = form_data.get("name", folder.name) - if "data" in form_data: + folder.name = form_data.get('name', folder.name) + if 'data' in form_data: folder.data = { **(folder.data or {}), - **form_data["data"], + **form_data['data'], } - if "meta" in form_data: + if 'meta' in form_data: folder.meta = { **(folder.meta or {}), - **form_data["meta"], + **form_data['meta'], } folder.updated_at = int(time.time()) @@ -269,7 +260,7 @@ class FolderTable: return FolderModel.model_validate(folder) except Exception as e: - log.error(f"update_folder: {e}") + log.error(f'update_folder: {e}') return def update_folder_is_expanded_by_id_and_user_id( @@ -289,12 +280,10 @@ class FolderTable: return FolderModel.model_validate(folder) except Exception as e: - log.error(f"update_folder: {e}") + log.error(f'update_folder: {e}') return - def delete_folder_by_id_and_user_id( - self, id: str, user_id: str, db: Optional[Session] = None - ) -> list[str]: + def delete_folder_by_id_and_user_id(self, id: str, user_id: str, db: Optional[Session] = None) -> list[str]: try: folder_ids = [] with get_db_context(db) as db: @@ -306,11 +295,8 @@ class FolderTable: # Delete all children folders def delete_children(folder): - folder_children = self.get_folders_by_parent_id_and_user_id( - folder.id, user_id, db=db - ) + folder_children = self.get_folders_by_parent_id_and_user_id(folder.id, user_id, db=db) for folder_child in folder_children: - delete_children(folder_child) folder_ids.append(folder_child.id) @@ -323,12 +309,12 @@ class FolderTable: db.commit() return folder_ids except Exception as e: - log.error(f"delete_folder: {e}") + log.error(f'delete_folder: {e}') return [] def normalize_folder_name(self, name: str) -> str: # Replace _ and space with a single space, lower case, collapse multiple spaces - name = re.sub(r"[\s_]+", " ", name) + name = re.sub(r'[\s_]+', ' ', name) return name.strip().lower() def search_folders_by_names( @@ -349,9 +335,7 @@ class FolderTable: results[folder.id] = FolderModel.model_validate(folder) # get children folders - children = self.get_children_folders_by_id_and_user_id( - folder.id, user_id, db=db - ) + children = self.get_children_folders_by_id_and_user_id(folder.id, user_id, db=db) for child in children: results[child.id] = child diff --git a/backend/open_webui/models/functions.py b/backend/open_webui/models/functions.py index 18916315e6..3c29b4fa93 100644 --- a/backend/open_webui/models/functions.py +++ b/backend/open_webui/models/functions.py @@ -16,7 +16,7 @@ log = logging.getLogger(__name__) class Function(Base): - __tablename__ = "function" + __tablename__ = 'function' id = Column(String, primary_key=True, unique=True) user_id = Column(String) @@ -30,13 +30,13 @@ class Function(Base): updated_at = Column(BigInteger) created_at = Column(BigInteger) - __table_args__ = (Index("is_global_idx", "is_global"),) + __table_args__ = (Index('is_global_idx', 'is_global'),) class FunctionMeta(BaseModel): description: Optional[str] = None manifest: Optional[dict] = {} - model_config = ConfigDict(extra="allow") + model_config = ConfigDict(extra='allow') class FunctionModel(BaseModel): @@ -113,10 +113,10 @@ class FunctionsTable: function = FunctionModel( **{ **form_data.model_dump(), - "user_id": user_id, - "type": type, - "updated_at": int(time.time()), - "created_at": int(time.time()), + 'user_id': user_id, + 'type': type, + 'updated_at': int(time.time()), + 'created_at': int(time.time()), } ) @@ -131,7 +131,7 @@ class FunctionsTable: else: return None except Exception as e: - log.exception(f"Error creating a new function: {e}") + log.exception(f'Error creating a new function: {e}') return None def sync_functions( @@ -156,16 +156,16 @@ class FunctionsTable: db.query(Function).filter_by(id=func.id).update( { **func.model_dump(), - "user_id": user_id, - "updated_at": int(time.time()), + 'user_id': user_id, + 'updated_at': int(time.time()), } ) else: new_func = Function( **{ **func.model_dump(), - "user_id": user_id, - "updated_at": int(time.time()), + 'user_id': user_id, + 'updated_at': int(time.time()), } ) db.add(new_func) @@ -177,17 +177,12 @@ class FunctionsTable: db.commit() - return [ - FunctionModel.model_validate(func) - for func in db.query(Function).all() - ] + return [FunctionModel.model_validate(func) for func in db.query(Function).all()] except Exception as e: - log.exception(f"Error syncing functions for user {user_id}: {e}") + log.exception(f'Error syncing functions for user {user_id}: {e}') return [] - def get_function_by_id( - self, id: str, db: Optional[Session] = None - ) -> Optional[FunctionModel]: + def get_function_by_id(self, id: str, db: Optional[Session] = None) -> Optional[FunctionModel]: try: with get_db_context(db) as db: function = db.get(Function, id) @@ -195,9 +190,7 @@ class FunctionsTable: except Exception: return None - def get_functions_by_ids( - self, ids: list[str], db: Optional[Session] = None - ) -> list[FunctionModel]: + def get_functions_by_ids(self, ids: list[str], db: Optional[Session] = None) -> list[FunctionModel]: """ Batch fetch multiple functions by their IDs in a single query. Returns functions in the same order as the input IDs (None entries filtered out). @@ -225,18 +218,11 @@ class FunctionsTable: functions = db.query(Function).all() if include_valves: - return [ - FunctionWithValvesModel.model_validate(function) - for function in functions - ] + return [FunctionWithValvesModel.model_validate(function) for function in functions] else: - return [ - FunctionModel.model_validate(function) for function in functions - ] + return [FunctionModel.model_validate(function) for function in functions] - def get_function_list( - self, db: Optional[Session] = None - ) -> list[FunctionUserResponse]: + def get_function_list(self, db: Optional[Session] = None) -> list[FunctionUserResponse]: with get_db_context(db) as db: functions = db.query(Function).order_by(Function.updated_at.desc()).all() user_ids = list(set(func.user_id for func in functions)) @@ -248,69 +234,48 @@ class FunctionsTable: FunctionUserResponse.model_validate( { **FunctionModel.model_validate(func).model_dump(), - "user": ( - users_dict.get(func.user_id).model_dump() - if func.user_id in users_dict - else None - ), + 'user': (users_dict.get(func.user_id).model_dump() if func.user_id in users_dict else None), } ) for func in functions ] - def get_functions_by_type( - self, type: str, active_only=False, db: Optional[Session] = None - ) -> list[FunctionModel]: + def get_functions_by_type(self, type: str, active_only=False, db: Optional[Session] = None) -> list[FunctionModel]: with get_db_context(db) as db: if active_only: return [ FunctionModel.model_validate(function) - for function in db.query(Function) - .filter_by(type=type, is_active=True) - .all() + for function in db.query(Function).filter_by(type=type, is_active=True).all() ] else: return [ - FunctionModel.model_validate(function) - for function in db.query(Function).filter_by(type=type).all() + FunctionModel.model_validate(function) for function in db.query(Function).filter_by(type=type).all() ] - def get_global_filter_functions( - self, db: Optional[Session] = None - ) -> list[FunctionModel]: + def get_global_filter_functions(self, db: Optional[Session] = None) -> list[FunctionModel]: with get_db_context(db) as db: return [ FunctionModel.model_validate(function) - for function in db.query(Function) - .filter_by(type="filter", is_active=True, is_global=True) - .all() + for function in db.query(Function).filter_by(type='filter', is_active=True, is_global=True).all() ] - def get_global_action_functions( - self, db: Optional[Session] = None - ) -> list[FunctionModel]: + def get_global_action_functions(self, db: Optional[Session] = None) -> list[FunctionModel]: with get_db_context(db) as db: return [ FunctionModel.model_validate(function) - for function in db.query(Function) - .filter_by(type="action", is_active=True, is_global=True) - .all() + for function in db.query(Function).filter_by(type='action', is_active=True, is_global=True).all() ] - def get_function_valves_by_id( - self, id: str, db: Optional[Session] = None - ) -> Optional[dict]: + def get_function_valves_by_id(self, id: str, db: Optional[Session] = None) -> Optional[dict]: with get_db_context(db) as db: try: function = db.get(Function, id) return function.valves if function.valves else {} except Exception as e: - log.exception(f"Error getting function valves by id {id}: {e}") + log.exception(f'Error getting function valves by id {id}: {e}') return None - def get_function_valves_by_ids( - self, ids: list[str], db: Optional[Session] = None - ) -> dict[str, dict]: + def get_function_valves_by_ids(self, ids: list[str], db: Optional[Session] = None) -> dict[str, dict]: """ Batch fetch valves for multiple functions in a single query. Returns a dict mapping function_id -> valves dict. @@ -320,14 +285,10 @@ class FunctionsTable: return {} try: with get_db_context(db) as db: - functions = ( - db.query(Function.id, Function.valves) - .filter(Function.id.in_(ids)) - .all() - ) + functions = db.query(Function.id, Function.valves).filter(Function.id.in_(ids)).all() return {f.id: (f.valves if f.valves else {}) for f in functions} except Exception as e: - log.exception(f"Error batch-fetching function valves: {e}") + log.exception(f'Error batch-fetching function valves: {e}') return {} def update_function_valves_by_id( @@ -364,25 +325,23 @@ class FunctionsTable: else: return None except Exception as e: - log.exception(f"Error updating function metadata by id {id}: {e}") + log.exception(f'Error updating function metadata by id {id}: {e}') return None - def get_user_valves_by_id_and_user_id( - self, id: str, user_id: str, db: Optional[Session] = None - ) -> Optional[dict]: + def get_user_valves_by_id_and_user_id(self, id: str, user_id: str, db: Optional[Session] = None) -> Optional[dict]: try: user = Users.get_user_by_id(user_id, db=db) user_settings = user.settings.model_dump() if user.settings else {} # Check if user has "functions" and "valves" settings - if "functions" not in user_settings: - user_settings["functions"] = {} - if "valves" not in user_settings["functions"]: - user_settings["functions"]["valves"] = {} + if 'functions' not in user_settings: + user_settings['functions'] = {} + if 'valves' not in user_settings['functions']: + user_settings['functions']['valves'] = {} - return user_settings["functions"]["valves"].get(id, {}) + return user_settings['functions']['valves'].get(id, {}) except Exception as e: - log.exception(f"Error getting user values by id {id} and user id {user_id}") + log.exception(f'Error getting user values by id {id} and user id {user_id}') return None def update_user_valves_by_id_and_user_id( @@ -393,32 +352,28 @@ class FunctionsTable: user_settings = user.settings.model_dump() if user.settings else {} # Check if user has "functions" and "valves" settings - if "functions" not in user_settings: - user_settings["functions"] = {} - if "valves" not in user_settings["functions"]: - user_settings["functions"]["valves"] = {} + if 'functions' not in user_settings: + user_settings['functions'] = {} + if 'valves' not in user_settings['functions']: + user_settings['functions']['valves'] = {} - user_settings["functions"]["valves"][id] = valves + user_settings['functions']['valves'][id] = valves # Update the user settings in the database - Users.update_user_by_id(user_id, {"settings": user_settings}, db=db) + Users.update_user_by_id(user_id, {'settings': user_settings}, db=db) - return user_settings["functions"]["valves"][id] + return user_settings['functions']['valves'][id] except Exception as e: - log.exception( - f"Error updating user valves by id {id} and user_id {user_id}: {e}" - ) + log.exception(f'Error updating user valves by id {id} and user_id {user_id}: {e}') return None - def update_function_by_id( - self, id: str, updated: dict, db: Optional[Session] = None - ) -> Optional[FunctionModel]: + def update_function_by_id(self, id: str, updated: dict, db: Optional[Session] = None) -> Optional[FunctionModel]: with get_db_context(db) as db: try: db.query(Function).filter_by(id=id).update( { **updated, - "updated_at": int(time.time()), + 'updated_at': int(time.time()), } ) db.commit() @@ -432,8 +387,8 @@ class FunctionsTable: try: db.query(Function).update( { - "is_active": False, - "updated_at": int(time.time()), + 'is_active': False, + 'updated_at': int(time.time()), } ) db.commit() diff --git a/backend/open_webui/models/groups.py b/backend/open_webui/models/groups.py index 4c7f456e59..d6c2fc9450 100644 --- a/backend/open_webui/models/groups.py +++ b/backend/open_webui/models/groups.py @@ -34,7 +34,7 @@ log = logging.getLogger(__name__) class Group(Base): - __tablename__ = "group" + __tablename__ = 'group' id = Column(Text, unique=True, primary_key=True) user_id = Column(Text) @@ -70,12 +70,12 @@ class GroupModel(BaseModel): class GroupMember(Base): - __tablename__ = "group_member" + __tablename__ = 'group_member' id = Column(Text, unique=True, primary_key=True) group_id = Column( Text, - ForeignKey("group.id", ondelete="CASCADE"), + ForeignKey('group.id', ondelete='CASCADE'), nullable=False, ) user_id = Column(Text, nullable=False) @@ -133,28 +133,26 @@ class GroupListResponse(BaseModel): class GroupTable: def _ensure_default_share_config(self, group_data: dict) -> dict: """Ensure the group data dict has a default share config if not already set.""" - if "data" not in group_data or group_data["data"] is None: - group_data["data"] = {} - if "config" not in group_data["data"]: - group_data["data"]["config"] = {} - if "share" not in group_data["data"]["config"]: - group_data["data"]["config"]["share"] = DEFAULT_GROUP_SHARE_PERMISSION + if 'data' not in group_data or group_data['data'] is None: + group_data['data'] = {} + if 'config' not in group_data['data']: + group_data['data']['config'] = {} + if 'share' not in group_data['data']['config']: + group_data['data']['config']['share'] = DEFAULT_GROUP_SHARE_PERMISSION return group_data def insert_new_group( self, user_id: str, form_data: GroupForm, db: Optional[Session] = None ) -> Optional[GroupModel]: with get_db_context(db) as db: - group_data = self._ensure_default_share_config( - form_data.model_dump(exclude_none=True) - ) + group_data = self._ensure_default_share_config(form_data.model_dump(exclude_none=True)) group = GroupModel( **{ **group_data, - "id": str(uuid.uuid4()), - "user_id": user_id, - "created_at": int(time.time()), - "updated_at": int(time.time()), + 'id': str(uuid.uuid4()), + 'user_id': user_id, + 'created_at': int(time.time()), + 'updated_at': int(time.time()), } ) @@ -183,19 +181,19 @@ class GroupTable: .where(GroupMember.group_id == Group.id) .correlate(Group) .scalar_subquery() - .label("member_count") + .label('member_count') ) query = db.query(Group, member_count) if filter: - if "query" in filter: - query = query.filter(Group.name.ilike(f"%{filter['query']}%")) + if 'query' in filter: + query = query.filter(Group.name.ilike(f'%{filter["query"]}%')) # When share filter is present, member check is handled in the share logic - if "share" in filter: - share_value = filter["share"] - member_id = filter.get("member_id") - json_share = Group.data["config"]["share"] + if 'share' in filter: + share_value = filter['share'] + member_id = filter.get('member_id') + json_share = Group.data['config']['share'] json_share_str = json_share.as_string() json_share_lower = func.lower(json_share_str) @@ -203,37 +201,27 @@ class GroupTable: anyone_can_share = or_( Group.data.is_(None), json_share_str.is_(None), - json_share_lower == "true", - json_share_lower == "1", # Handle SQLite boolean true + json_share_lower == 'true', + json_share_lower == '1', # Handle SQLite boolean true ) if member_id: - member_groups_select = select(GroupMember.group_id).where( - GroupMember.user_id == member_id - ) + member_groups_select = select(GroupMember.group_id).where(GroupMember.user_id == member_id) members_only_and_is_member = and_( - json_share_lower == "members", + json_share_lower == 'members', Group.id.in_(member_groups_select), ) - query = query.filter( - or_(anyone_can_share, members_only_and_is_member) - ) + query = query.filter(or_(anyone_can_share, members_only_and_is_member)) else: query = query.filter(anyone_can_share) else: - query = query.filter( - and_(Group.data.isnot(None), json_share_lower == "false") - ) + query = query.filter(and_(Group.data.isnot(None), json_share_lower == 'false')) else: # Only apply member_id filter when share filter is NOT present - if "member_id" in filter: + if 'member_id' in filter: query = query.filter( - Group.id.in_( - select(GroupMember.group_id).where( - GroupMember.user_id == filter["member_id"] - ) - ) + Group.id.in_(select(GroupMember.group_id).where(GroupMember.user_id == filter['member_id'])) ) results = query.order_by(Group.updated_at.desc()).all() @@ -242,7 +230,7 @@ class GroupTable: GroupResponse.model_validate( { **GroupModel.model_validate(group).model_dump(), - "member_count": count or 0, + 'member_count': count or 0, } ) for group, count in results @@ -259,22 +247,16 @@ class GroupTable: query = db.query(Group) if filter: - if "query" in filter: - query = query.filter(Group.name.ilike(f"%{filter['query']}%")) - if "member_id" in filter: + if 'query' in filter: + query = query.filter(Group.name.ilike(f'%{filter["query"]}%')) + if 'member_id' in filter: query = query.filter( - Group.id.in_( - select(GroupMember.group_id).where( - GroupMember.user_id == filter["member_id"] - ) - ) + Group.id.in_(select(GroupMember.group_id).where(GroupMember.user_id == filter['member_id'])) ) - if "share" in filter: - share_value = filter["share"] - query = query.filter( - Group.data.op("->>")("share") == str(share_value) - ) + if 'share' in filter: + share_value = filter['share'] + query = query.filter(Group.data.op('->>')('share') == str(share_value)) total = query.count() @@ -283,32 +265,24 @@ class GroupTable: .where(GroupMember.group_id == Group.id) .correlate(Group) .scalar_subquery() - .label("member_count") - ) - results = ( - query.add_columns(member_count) - .order_by(Group.updated_at.desc()) - .offset(skip) - .limit(limit) - .all() + .label('member_count') ) + results = query.add_columns(member_count).order_by(Group.updated_at.desc()).offset(skip).limit(limit).all() return { - "items": [ + 'items': [ GroupResponse.model_validate( { **GroupModel.model_validate(group).model_dump(), - "member_count": count or 0, + 'member_count': count or 0, } ) for group, count in results ], - "total": total, + 'total': total, } - def get_groups_by_member_id( - self, user_id: str, db: Optional[Session] = None - ) -> list[GroupModel]: + def get_groups_by_member_id(self, user_id: str, db: Optional[Session] = None) -> list[GroupModel]: with get_db_context(db) as db: return [ GroupModel.model_validate(group) @@ -340,9 +314,7 @@ class GroupTable: return user_groups - def get_group_by_id( - self, id: str, db: Optional[Session] = None - ) -> Optional[GroupModel]: + def get_group_by_id(self, id: str, db: Optional[Session] = None) -> Optional[GroupModel]: try: with get_db_context(db) as db: group = db.query(Group).filter_by(id=id).first() @@ -350,41 +322,29 @@ class GroupTable: except Exception: return None - def get_group_user_ids_by_id( - self, id: str, db: Optional[Session] = None - ) -> list[str]: + def get_group_user_ids_by_id(self, id: str, db: Optional[Session] = None) -> list[str]: with get_db_context(db) as db: - members = ( - db.query(GroupMember.user_id).filter(GroupMember.group_id == id).all() - ) + members = db.query(GroupMember.user_id).filter(GroupMember.group_id == id).all() if not members: return [] return [m[0] for m in members] - def get_group_user_ids_by_ids( - self, group_ids: list[str], db: Optional[Session] = None - ) -> dict[str, list[str]]: + def get_group_user_ids_by_ids(self, group_ids: list[str], db: Optional[Session] = None) -> dict[str, list[str]]: with get_db_context(db) as db: members = ( - db.query(GroupMember.group_id, GroupMember.user_id) - .filter(GroupMember.group_id.in_(group_ids)) - .all() + db.query(GroupMember.group_id, GroupMember.user_id).filter(GroupMember.group_id.in_(group_ids)).all() ) - group_user_ids: dict[str, list[str]] = { - group_id: [] for group_id in group_ids - } + group_user_ids: dict[str, list[str]] = {group_id: [] for group_id in group_ids} for group_id, user_id in members: group_user_ids[group_id].append(user_id) return group_user_ids - def set_group_user_ids_by_id( - self, group_id: str, user_ids: list[str], db: Optional[Session] = None - ) -> None: + def set_group_user_ids_by_id(self, group_id: str, user_ids: list[str], db: Optional[Session] = None) -> None: with get_db_context(db) as db: # Delete existing members db.query(GroupMember).filter(GroupMember.group_id == group_id).delete() @@ -405,20 +365,12 @@ class GroupTable: db.add_all(new_members) db.commit() - def get_group_member_count_by_id( - self, id: str, db: Optional[Session] = None - ) -> int: + def get_group_member_count_by_id(self, id: str, db: Optional[Session] = None) -> int: with get_db_context(db) as db: - count = ( - db.query(func.count(GroupMember.user_id)) - .filter(GroupMember.group_id == id) - .scalar() - ) + count = db.query(func.count(GroupMember.user_id)).filter(GroupMember.group_id == id).scalar() return count if count else 0 - def get_group_member_counts_by_ids( - self, ids: list[str], db: Optional[Session] = None - ) -> dict[str, int]: + def get_group_member_counts_by_ids(self, ids: list[str], db: Optional[Session] = None) -> dict[str, int]: if not ids: return {} with get_db_context(db) as db: @@ -442,7 +394,7 @@ class GroupTable: db.query(Group).filter_by(id=id).update( { **form_data.model_dump(exclude_none=True), - "updated_at": int(time.time()), + 'updated_at': int(time.time()), } ) db.commit() @@ -470,9 +422,7 @@ class GroupTable: except Exception: return False - def remove_user_from_all_groups( - self, user_id: str, db: Optional[Session] = None - ) -> bool: + def remove_user_from_all_groups(self, user_id: str, db: Optional[Session] = None) -> bool: with get_db_context(db) as db: try: # Find all groups the user belongs to @@ -489,9 +439,7 @@ class GroupTable: GroupMember.group_id == group.id, GroupMember.user_id == user_id ).delete() - db.query(Group).filter_by(id=group.id).update( - {"updated_at": int(time.time())} - ) + db.query(Group).filter_by(id=group.id).update({'updated_at': int(time.time())}) db.commit() return True @@ -503,7 +451,6 @@ class GroupTable: def create_groups_by_group_names( self, user_id: str, group_names: list[str], db: Optional[Session] = None ) -> list[GroupModel]: - # check for existing groups existing_groups = self.get_all_groups(db=db) existing_group_names = {group.name for group in existing_groups} @@ -517,10 +464,10 @@ class GroupTable: id=str(uuid.uuid4()), user_id=user_id, name=group_name, - description="", + description='', data={ - "config": { - "share": DEFAULT_GROUP_SHARE_PERMISSION, + 'config': { + 'share': DEFAULT_GROUP_SHARE_PERMISSION, } }, created_at=int(time.time()), @@ -537,17 +484,13 @@ class GroupTable: continue return new_groups - def sync_groups_by_group_names( - self, user_id: str, group_names: list[str], db: Optional[Session] = None - ) -> bool: + def sync_groups_by_group_names(self, user_id: str, group_names: list[str], db: Optional[Session] = None) -> bool: with get_db_context(db) as db: try: now = int(time.time()) # 1. Groups that SHOULD contain the user - target_groups = ( - db.query(Group).filter(Group.name.in_(group_names)).all() - ) + target_groups = db.query(Group).filter(Group.name.in_(group_names)).all() target_group_ids = {g.id for g in target_groups} # 2. Groups the user is CURRENTLY in @@ -571,7 +514,7 @@ class GroupTable: ).delete(synchronize_session=False) db.query(Group).filter(Group.id.in_(groups_to_remove)).update( - {"updated_at": now}, synchronize_session=False + {'updated_at': now}, synchronize_session=False ) # 5. Bulk insert missing memberships @@ -588,7 +531,7 @@ class GroupTable: if groups_to_add: db.query(Group).filter(Group.id.in_(groups_to_add)).update( - {"updated_at": now}, synchronize_session=False + {'updated_at': now}, synchronize_session=False ) db.commit() @@ -656,9 +599,9 @@ class GroupTable: return GroupModel.model_validate(group) # Remove users from group_member in batch - db.query(GroupMember).filter( - GroupMember.group_id == id, GroupMember.user_id.in_(user_ids) - ).delete(synchronize_session=False) + db.query(GroupMember).filter(GroupMember.group_id == id, GroupMember.user_id.in_(user_ids)).delete( + synchronize_session=False + ) # Update group timestamp group.updated_at = int(time.time()) diff --git a/backend/open_webui/models/knowledge.py b/backend/open_webui/models/knowledge.py index 4212cf0fd5..0495abfb39 100644 --- a/backend/open_webui/models/knowledge.py +++ b/backend/open_webui/models/knowledge.py @@ -38,7 +38,7 @@ log = logging.getLogger(__name__) class Knowledge(Base): - __tablename__ = "knowledge" + __tablename__ = 'knowledge' id = Column(Text, unique=True, primary_key=True) user_id = Column(Text) @@ -70,24 +70,18 @@ class KnowledgeModel(BaseModel): class KnowledgeFile(Base): - __tablename__ = "knowledge_file" + __tablename__ = 'knowledge_file' id = Column(Text, unique=True, primary_key=True) - knowledge_id = Column( - Text, ForeignKey("knowledge.id", ondelete="CASCADE"), nullable=False - ) - file_id = Column(Text, ForeignKey("file.id", ondelete="CASCADE"), nullable=False) + knowledge_id = Column(Text, ForeignKey('knowledge.id', ondelete='CASCADE'), nullable=False) + file_id = Column(Text, ForeignKey('file.id', ondelete='CASCADE'), nullable=False) user_id = Column(Text, nullable=False) created_at = Column(BigInteger, nullable=False) updated_at = Column(BigInteger, nullable=False) - __table_args__ = ( - UniqueConstraint( - "knowledge_id", "file_id", name="uq_knowledge_file_knowledge_file" - ), - ) + __table_args__ = (UniqueConstraint('knowledge_id', 'file_id', name='uq_knowledge_file_knowledge_file'),) class KnowledgeFileModel(BaseModel): @@ -138,10 +132,8 @@ class KnowledgeFileListResponse(BaseModel): class KnowledgeTable: - def _get_access_grants( - self, knowledge_id: str, db: Optional[Session] = None - ) -> list[AccessGrantModel]: - return AccessGrants.get_grants_by_resource("knowledge", knowledge_id, db=db) + def _get_access_grants(self, knowledge_id: str, db: Optional[Session] = None) -> list[AccessGrantModel]: + return AccessGrants.get_grants_by_resource('knowledge', knowledge_id, db=db) def _to_knowledge_model( self, @@ -149,13 +141,9 @@ class KnowledgeTable: access_grants: Optional[list[AccessGrantModel]] = None, db: Optional[Session] = None, ) -> KnowledgeModel: - knowledge_data = KnowledgeModel.model_validate(knowledge).model_dump( - exclude={"access_grants"} - ) - knowledge_data["access_grants"] = ( - access_grants - if access_grants is not None - else self._get_access_grants(knowledge_data["id"], db=db) + knowledge_data = KnowledgeModel.model_validate(knowledge).model_dump(exclude={'access_grants'}) + knowledge_data['access_grants'] = ( + access_grants if access_grants is not None else self._get_access_grants(knowledge_data['id'], db=db) ) return KnowledgeModel.model_validate(knowledge_data) @@ -165,23 +153,21 @@ class KnowledgeTable: with get_db_context(db) as db: knowledge = KnowledgeModel( **{ - **form_data.model_dump(exclude={"access_grants"}), - "id": str(uuid.uuid4()), - "user_id": user_id, - "created_at": int(time.time()), - "updated_at": int(time.time()), - "access_grants": [], + **form_data.model_dump(exclude={'access_grants'}), + 'id': str(uuid.uuid4()), + 'user_id': user_id, + 'created_at': int(time.time()), + 'updated_at': int(time.time()), + 'access_grants': [], } ) try: - result = Knowledge(**knowledge.model_dump(exclude={"access_grants"})) + result = Knowledge(**knowledge.model_dump(exclude={'access_grants'})) db.add(result) db.commit() db.refresh(result) - AccessGrants.set_access_grants( - "knowledge", result.id, form_data.access_grants, db=db - ) + AccessGrants.set_access_grants('knowledge', result.id, form_data.access_grants, db=db) if result: return self._to_knowledge_model(result, db=db) else: @@ -193,17 +179,13 @@ class KnowledgeTable: self, skip: int = 0, limit: int = 30, db: Optional[Session] = None ) -> list[KnowledgeUserModel]: with get_db_context(db) as db: - all_knowledge = ( - db.query(Knowledge).order_by(Knowledge.updated_at.desc()).all() - ) + all_knowledge = db.query(Knowledge).order_by(Knowledge.updated_at.desc()).all() user_ids = list(set(knowledge.user_id for knowledge in all_knowledge)) knowledge_ids = [knowledge.id for knowledge in all_knowledge] users = Users.get_users_by_user_ids(user_ids, db=db) if user_ids else [] users_dict = {user.id: user for user in users} - grants_map = AccessGrants.get_grants_by_resources( - "knowledge", knowledge_ids, db=db - ) + grants_map = AccessGrants.get_grants_by_resources('knowledge', knowledge_ids, db=db) knowledge_bases = [] for knowledge in all_knowledge: @@ -216,7 +198,7 @@ class KnowledgeTable: access_grants=grants_map.get(knowledge.id, []), db=db, ).model_dump(), - "user": user.model_dump() if user else None, + 'user': user.model_dump() if user else None, } ) ) @@ -232,27 +214,25 @@ class KnowledgeTable: ) -> KnowledgeListResponse: try: with get_db_context(db) as db: - query = db.query(Knowledge, User).outerjoin( - User, User.id == Knowledge.user_id - ) + query = db.query(Knowledge, User).outerjoin(User, User.id == Knowledge.user_id) if filter: - query_key = filter.get("query") + query_key = filter.get('query') if query_key: query = query.filter( or_( - Knowledge.name.ilike(f"%{query_key}%"), - Knowledge.description.ilike(f"%{query_key}%"), - User.name.ilike(f"%{query_key}%"), - User.email.ilike(f"%{query_key}%"), - User.username.ilike(f"%{query_key}%"), + Knowledge.name.ilike(f'%{query_key}%'), + Knowledge.description.ilike(f'%{query_key}%'), + User.name.ilike(f'%{query_key}%'), + User.email.ilike(f'%{query_key}%'), + User.username.ilike(f'%{query_key}%'), ) ) - view_option = filter.get("view_option") - if view_option == "created": + view_option = filter.get('view_option') + if view_option == 'created': query = query.filter(Knowledge.user_id == user_id) - elif view_option == "shared": + elif view_option == 'shared': query = query.filter(Knowledge.user_id != user_id) query = AccessGrants.has_permission_filter( @@ -260,8 +240,8 @@ class KnowledgeTable: query=query, DocumentModel=Knowledge, filter=filter, - resource_type="knowledge", - permission="read", + resource_type='knowledge', + permission='read', ) query = query.order_by(Knowledge.updated_at.desc(), Knowledge.id.asc()) @@ -275,9 +255,7 @@ class KnowledgeTable: items = query.all() knowledge_ids = [kb.id for kb, _ in items] - grants_map = AccessGrants.get_grants_by_resources( - "knowledge", knowledge_ids, db=db - ) + grants_map = AccessGrants.get_grants_by_resources('knowledge', knowledge_ids, db=db) knowledge_bases = [] for knowledge_base, user in items: @@ -289,11 +267,7 @@ class KnowledgeTable: access_grants=grants_map.get(knowledge_base.id, []), db=db, ).model_dump(), - "user": ( - UserModel.model_validate(user).model_dump() - if user - else None - ), + 'user': (UserModel.model_validate(user).model_dump() if user else None), } ) ) @@ -327,15 +301,15 @@ class KnowledgeTable: query=query, DocumentModel=Knowledge, filter=filter, - resource_type="knowledge", - permission="read", + resource_type='knowledge', + permission='read', ) # Apply filename search if filter: - q = filter.get("query") + q = filter.get('query') if q: - query = query.filter(File.filename.ilike(f"%{q}%")) + query = query.filter(File.filename.ilike(f'%{q}%')) # Order by file changes query = query.order_by(File.updated_at.desc(), File.id.asc()) @@ -355,39 +329,27 @@ class KnowledgeTable: items.append( FileUserResponse( **FileModel.model_validate(file).model_dump(), - user=( - UserResponse( - **UserModel.model_validate(user).model_dump() - ) - if user - else None - ), - collection=self._to_knowledge_model( - knowledge, db=db - ).model_dump(), + user=(UserResponse(**UserModel.model_validate(user).model_dump()) if user else None), + collection=self._to_knowledge_model(knowledge, db=db).model_dump(), ) ) return KnowledgeFileListResponse(items=items, total=total) except Exception as e: - print("search_knowledge_files error:", e) + print('search_knowledge_files error:', e) return KnowledgeFileListResponse(items=[], total=0) - def check_access_by_user_id( - self, id, user_id, permission="write", db: Optional[Session] = None - ) -> bool: + def check_access_by_user_id(self, id, user_id, permission='write', db: Optional[Session] = None) -> bool: knowledge = self.get_knowledge_by_id(id, db=db) if not knowledge: return False if knowledge.user_id == user_id: return True - user_group_ids = { - group.id for group in Groups.get_groups_by_member_id(user_id, db=db) - } + user_group_ids = {group.id for group in Groups.get_groups_by_member_id(user_id, db=db)} return AccessGrants.has_access( user_id=user_id, - resource_type="knowledge", + resource_type='knowledge', resource_id=knowledge.id, permission=permission, user_group_ids=user_group_ids, @@ -395,19 +357,17 @@ class KnowledgeTable: ) def get_knowledge_bases_by_user_id( - self, user_id: str, permission: str = "write", db: Optional[Session] = None + self, user_id: str, permission: str = 'write', db: Optional[Session] = None ) -> list[KnowledgeUserModel]: knowledge_bases = self.get_knowledge_bases(db=db) - user_group_ids = { - group.id for group in Groups.get_groups_by_member_id(user_id, db=db) - } + user_group_ids = {group.id for group in Groups.get_groups_by_member_id(user_id, db=db)} return [ knowledge_base for knowledge_base in knowledge_bases if knowledge_base.user_id == user_id or AccessGrants.has_access( user_id=user_id, - resource_type="knowledge", + resource_type='knowledge', resource_id=knowledge_base.id, permission=permission, user_group_ids=user_group_ids, @@ -415,9 +375,7 @@ class KnowledgeTable: ) ] - def get_knowledge_by_id( - self, id: str, db: Optional[Session] = None - ) -> Optional[KnowledgeModel]: + def get_knowledge_by_id(self, id: str, db: Optional[Session] = None) -> Optional[KnowledgeModel]: try: with get_db_context(db) as db: knowledge = db.query(Knowledge).filter_by(id=id).first() @@ -435,23 +393,19 @@ class KnowledgeTable: if knowledge.user_id == user_id: return knowledge - user_group_ids = { - group.id for group in Groups.get_groups_by_member_id(user_id, db=db) - } + user_group_ids = {group.id for group in Groups.get_groups_by_member_id(user_id, db=db)} if AccessGrants.has_access( user_id=user_id, - resource_type="knowledge", + resource_type='knowledge', resource_id=knowledge.id, - permission="write", + permission='write', user_group_ids=user_group_ids, db=db, ): return knowledge return None - def get_knowledges_by_file_id( - self, file_id: str, db: Optional[Session] = None - ) -> list[KnowledgeModel]: + def get_knowledges_by_file_id(self, file_id: str, db: Optional[Session] = None) -> list[KnowledgeModel]: try: with get_db_context(db) as db: knowledges = ( @@ -461,9 +415,7 @@ class KnowledgeTable: .all() ) knowledge_ids = [k.id for k in knowledges] - grants_map = AccessGrants.get_grants_by_resources( - "knowledge", knowledge_ids, db=db - ) + grants_map = AccessGrants.get_grants_by_resources('knowledge', knowledge_ids, db=db) return [ self._to_knowledge_model( knowledge, @@ -497,32 +449,26 @@ class KnowledgeTable: primary_sort = File.updated_at.desc() if filter: - query_key = filter.get("query") + query_key = filter.get('query') if query_key: - query = query.filter(or_(File.filename.ilike(f"%{query_key}%"))) + query = query.filter(or_(File.filename.ilike(f'%{query_key}%'))) - view_option = filter.get("view_option") - if view_option == "created": + view_option = filter.get('view_option') + if view_option == 'created': query = query.filter(KnowledgeFile.user_id == user_id) - elif view_option == "shared": + elif view_option == 'shared': query = query.filter(KnowledgeFile.user_id != user_id) - order_by = filter.get("order_by") - direction = filter.get("direction") - is_asc = direction == "asc" + order_by = filter.get('order_by') + direction = filter.get('direction') + is_asc = direction == 'asc' - if order_by == "name": - primary_sort = ( - File.filename.asc() if is_asc else File.filename.desc() - ) - elif order_by == "created_at": - primary_sort = ( - File.created_at.asc() if is_asc else File.created_at.desc() - ) - elif order_by == "updated_at": - primary_sort = ( - File.updated_at.asc() if is_asc else File.updated_at.desc() - ) + if order_by == 'name': + primary_sort = File.filename.asc() if is_asc else File.filename.desc() + elif order_by == 'created_at': + primary_sort = File.created_at.asc() if is_asc else File.created_at.desc() + elif order_by == 'updated_at': + primary_sort = File.updated_at.asc() if is_asc else File.updated_at.desc() # Apply sort with secondary key for deterministic pagination query = query.order_by(primary_sort, File.id.asc()) @@ -542,13 +488,7 @@ class KnowledgeTable: files.append( FileUserResponse( **FileModel.model_validate(file).model_dump(), - user=( - UserResponse( - **UserModel.model_validate(user).model_dump() - ) - if user - else None - ), + user=(UserResponse(**UserModel.model_validate(user).model_dump()) if user else None), ) ) @@ -557,9 +497,7 @@ class KnowledgeTable: print(e) return KnowledgeFileListResponse(items=[], total=0) - def get_files_by_id( - self, knowledge_id: str, db: Optional[Session] = None - ) -> list[FileModel]: + def get_files_by_id(self, knowledge_id: str, db: Optional[Session] = None) -> list[FileModel]: try: with get_db_context(db) as db: files = ( @@ -572,9 +510,7 @@ class KnowledgeTable: except Exception: return [] - def get_file_metadatas_by_id( - self, knowledge_id: str, db: Optional[Session] = None - ) -> list[FileMetadataResponse]: + def get_file_metadatas_by_id(self, knowledge_id: str, db: Optional[Session] = None) -> list[FileMetadataResponse]: try: with get_db_context(db) as db: files = self.get_files_by_id(knowledge_id, db=db) @@ -592,12 +528,12 @@ class KnowledgeTable: with get_db_context(db) as db: knowledge_file = KnowledgeFileModel( **{ - "id": str(uuid.uuid4()), - "knowledge_id": knowledge_id, - "file_id": file_id, - "user_id": user_id, - "created_at": int(time.time()), - "updated_at": int(time.time()), + 'id': str(uuid.uuid4()), + 'knowledge_id': knowledge_id, + 'file_id': file_id, + 'user_id': user_id, + 'created_at': int(time.time()), + 'updated_at': int(time.time()), } ) @@ -613,37 +549,24 @@ class KnowledgeTable: except Exception: return None - def has_file( - self, knowledge_id: str, file_id: str, db: Optional[Session] = None - ) -> bool: + def has_file(self, knowledge_id: str, file_id: str, db: Optional[Session] = None) -> bool: """Check whether a file belongs to a knowledge base.""" try: with get_db_context(db) as db: - return ( - db.query(KnowledgeFile) - .filter_by(knowledge_id=knowledge_id, file_id=file_id) - .first() - is not None - ) + return db.query(KnowledgeFile).filter_by(knowledge_id=knowledge_id, file_id=file_id).first() is not None except Exception: return False - def remove_file_from_knowledge_by_id( - self, knowledge_id: str, file_id: str, db: Optional[Session] = None - ) -> bool: + def remove_file_from_knowledge_by_id(self, knowledge_id: str, file_id: str, db: Optional[Session] = None) -> bool: try: with get_db_context(db) as db: - db.query(KnowledgeFile).filter_by( - knowledge_id=knowledge_id, file_id=file_id - ).delete() + db.query(KnowledgeFile).filter_by(knowledge_id=knowledge_id, file_id=file_id).delete() db.commit() return True except Exception: return False - def reset_knowledge_by_id( - self, id: str, db: Optional[Session] = None - ) -> Optional[KnowledgeModel]: + def reset_knowledge_by_id(self, id: str, db: Optional[Session] = None) -> Optional[KnowledgeModel]: try: with get_db_context(db) as db: # Delete all knowledge_file entries for this knowledge_id @@ -653,7 +576,7 @@ class KnowledgeTable: # Update the knowledge entry's updated_at timestamp db.query(Knowledge).filter_by(id=id).update( { - "updated_at": int(time.time()), + 'updated_at': int(time.time()), } ) db.commit() @@ -675,15 +598,13 @@ class KnowledgeTable: knowledge = self.get_knowledge_by_id(id=id, db=db) db.query(Knowledge).filter_by(id=id).update( { - **form_data.model_dump(exclude={"access_grants"}), - "updated_at": int(time.time()), + **form_data.model_dump(exclude={'access_grants'}), + 'updated_at': int(time.time()), } ) db.commit() if form_data.access_grants is not None: - AccessGrants.set_access_grants( - "knowledge", id, form_data.access_grants, db=db - ) + AccessGrants.set_access_grants('knowledge', id, form_data.access_grants, db=db) return self.get_knowledge_by_id(id=id, db=db) except Exception as e: log.exception(e) @@ -697,8 +618,8 @@ class KnowledgeTable: knowledge = self.get_knowledge_by_id(id=id, db=db) db.query(Knowledge).filter_by(id=id).update( { - "data": data, - "updated_at": int(time.time()), + 'data': data, + 'updated_at': int(time.time()), } ) db.commit() @@ -710,7 +631,7 @@ class KnowledgeTable: def delete_knowledge_by_id(self, id: str, db: Optional[Session] = None) -> bool: try: with get_db_context(db) as db: - AccessGrants.revoke_all_access("knowledge", id, db=db) + AccessGrants.revoke_all_access('knowledge', id, db=db) db.query(Knowledge).filter_by(id=id).delete() db.commit() return True @@ -722,7 +643,7 @@ class KnowledgeTable: try: knowledge_ids = [row[0] for row in db.query(Knowledge.id).all()] for knowledge_id in knowledge_ids: - AccessGrants.revoke_all_access("knowledge", knowledge_id, db=db) + AccessGrants.revoke_all_access('knowledge', knowledge_id, db=db) db.query(Knowledge).delete() db.commit() diff --git a/backend/open_webui/models/memories.py b/backend/open_webui/models/memories.py index e6b70a3020..17d96adc0e 100644 --- a/backend/open_webui/models/memories.py +++ b/backend/open_webui/models/memories.py @@ -13,7 +13,7 @@ from sqlalchemy import BigInteger, Column, String, Text class Memory(Base): - __tablename__ = "memory" + __tablename__ = 'memory' id = Column(String, primary_key=True, unique=True) user_id = Column(String) @@ -49,11 +49,11 @@ class MemoriesTable: memory = MemoryModel( **{ - "id": id, - "user_id": user_id, - "content": content, - "created_at": int(time.time()), - "updated_at": int(time.time()), + 'id': id, + 'user_id': user_id, + 'content': content, + 'created_at': int(time.time()), + 'updated_at': int(time.time()), } ) result = Memory(**memory.model_dump()) @@ -95,9 +95,7 @@ class MemoriesTable: except Exception: return None - def get_memories_by_user_id( - self, user_id: str, db: Optional[Session] = None - ) -> list[MemoryModel]: + def get_memories_by_user_id(self, user_id: str, db: Optional[Session] = None) -> list[MemoryModel]: with get_db_context(db) as db: try: memories = db.query(Memory).filter_by(user_id=user_id).all() @@ -105,9 +103,7 @@ class MemoriesTable: except Exception: return None - def get_memory_by_id( - self, id: str, db: Optional[Session] = None - ) -> Optional[MemoryModel]: + def get_memory_by_id(self, id: str, db: Optional[Session] = None) -> Optional[MemoryModel]: with get_db_context(db) as db: try: memory = db.get(Memory, id) @@ -126,9 +122,7 @@ class MemoriesTable: except Exception: return False - def delete_memories_by_user_id( - self, user_id: str, db: Optional[Session] = None - ) -> bool: + def delete_memories_by_user_id(self, user_id: str, db: Optional[Session] = None) -> bool: with get_db_context(db) as db: try: db.query(Memory).filter_by(user_id=user_id).delete() @@ -138,9 +132,7 @@ class MemoriesTable: except Exception: return False - def delete_memory_by_id_and_user_id( - self, id: str, user_id: str, db: Optional[Session] = None - ) -> bool: + def delete_memory_by_id_and_user_id(self, id: str, user_id: str, db: Optional[Session] = None) -> bool: with get_db_context(db) as db: try: memory = db.get(Memory, id) diff --git a/backend/open_webui/models/messages.py b/backend/open_webui/models/messages.py index 0851107b0b..034eaac160 100644 --- a/backend/open_webui/models/messages.py +++ b/backend/open_webui/models/messages.py @@ -21,7 +21,7 @@ from sqlalchemy.sql import exists class MessageReaction(Base): - __tablename__ = "message_reaction" + __tablename__ = 'message_reaction' id = Column(Text, primary_key=True, unique=True) user_id = Column(Text) message_id = Column(Text) @@ -40,7 +40,7 @@ class MessageReactionModel(BaseModel): class Message(Base): - __tablename__ = "message" + __tablename__ = 'message' id = Column(Text, primary_key=True, unique=True) user_id = Column(Text) @@ -112,7 +112,7 @@ class MessageUserResponse(MessageModel): class MessageUserSlimResponse(MessageUserResponse): data: bool | None = None - @field_validator("data", mode="before") + @field_validator('data', mode='before') def convert_data_to_bool(cls, v): # No data or not a dict → False if not isinstance(v, dict): @@ -152,19 +152,19 @@ class MessageTable: message = MessageModel( **{ - "id": id, - "user_id": user_id, - "channel_id": channel_id, - "reply_to_id": form_data.reply_to_id, - "parent_id": form_data.parent_id, - "is_pinned": False, - "pinned_at": None, - "pinned_by": None, - "content": form_data.content, - "data": form_data.data, - "meta": form_data.meta, - "created_at": ts, - "updated_at": ts, + 'id': id, + 'user_id': user_id, + 'channel_id': channel_id, + 'reply_to_id': form_data.reply_to_id, + 'parent_id': form_data.parent_id, + 'is_pinned': False, + 'pinned_at': None, + 'pinned_by': None, + 'content': form_data.content, + 'data': form_data.data, + 'meta': form_data.meta, + 'created_at': ts, + 'updated_at': ts, } ) result = Message(**message.model_dump()) @@ -186,9 +186,7 @@ class MessageTable: return None reply_to_message = ( - self.get_message_by_id( - message.reply_to_id, include_thread_replies=False, db=db - ) + self.get_message_by_id(message.reply_to_id, include_thread_replies=False, db=db) if message.reply_to_id else None ) @@ -200,22 +198,22 @@ class MessageTable: thread_replies = self.get_thread_replies_by_message_id(id, db=db) # Check if message was sent by webhook (webhook info in meta takes precedence) - webhook_info = message.meta.get("webhook") if message.meta else None - if webhook_info and webhook_info.get("id"): + webhook_info = message.meta.get('webhook') if message.meta else None + if webhook_info and webhook_info.get('id'): # Look up webhook by ID to get current name - webhook = Channels.get_webhook_by_id(webhook_info.get("id"), db=db) + webhook = Channels.get_webhook_by_id(webhook_info.get('id'), db=db) if webhook: user_info = { - "id": webhook.id, - "name": webhook.name, - "role": "webhook", + 'id': webhook.id, + 'name': webhook.name, + 'role': 'webhook', } else: # Webhook was deleted, use placeholder user_info = { - "id": webhook_info.get("id"), - "name": "Deleted Webhook", - "role": "webhook", + 'id': webhook_info.get('id'), + 'name': 'Deleted Webhook', + 'role': 'webhook', } else: user = Users.get_user_by_id(message.user_id, db=db) @@ -224,79 +222,57 @@ class MessageTable: return MessageResponse.model_validate( { **MessageModel.model_validate(message).model_dump(), - "user": user_info, - "reply_to_message": ( - reply_to_message.model_dump() if reply_to_message else None - ), - "latest_reply_at": ( - thread_replies[0].created_at if thread_replies else None - ), - "reply_count": len(thread_replies), - "reactions": reactions, + 'user': user_info, + 'reply_to_message': (reply_to_message.model_dump() if reply_to_message else None), + 'latest_reply_at': (thread_replies[0].created_at if thread_replies else None), + 'reply_count': len(thread_replies), + 'reactions': reactions, } ) - def get_thread_replies_by_message_id( - self, id: str, db: Optional[Session] = None - ) -> list[MessageReplyToResponse]: + def get_thread_replies_by_message_id(self, id: str, db: Optional[Session] = None) -> list[MessageReplyToResponse]: with get_db_context(db) as db: - all_messages = ( - db.query(Message) - .filter_by(parent_id=id) - .order_by(Message.created_at.desc()) - .all() - ) + all_messages = db.query(Message).filter_by(parent_id=id).order_by(Message.created_at.desc()).all() messages = [] for message in all_messages: reply_to_message = ( - self.get_message_by_id( - message.reply_to_id, include_thread_replies=False, db=db - ) + self.get_message_by_id(message.reply_to_id, include_thread_replies=False, db=db) if message.reply_to_id else None ) - webhook_info = message.meta.get("webhook") if message.meta else None + webhook_info = message.meta.get('webhook') if message.meta else None user_info = None - if webhook_info and webhook_info.get("id"): - webhook = Channels.get_webhook_by_id(webhook_info.get("id"), db=db) + if webhook_info and webhook_info.get('id'): + webhook = Channels.get_webhook_by_id(webhook_info.get('id'), db=db) if webhook: user_info = { - "id": webhook.id, - "name": webhook.name, - "role": "webhook", + 'id': webhook.id, + 'name': webhook.name, + 'role': 'webhook', } else: user_info = { - "id": webhook_info.get("id"), - "name": "Deleted Webhook", - "role": "webhook", + 'id': webhook_info.get('id'), + 'name': 'Deleted Webhook', + 'role': 'webhook', } messages.append( MessageReplyToResponse.model_validate( { **MessageModel.model_validate(message).model_dump(), - "user": user_info, - "reply_to_message": ( - reply_to_message.model_dump() - if reply_to_message - else None - ), + 'user': user_info, + 'reply_to_message': (reply_to_message.model_dump() if reply_to_message else None), } ) ) return messages - def get_reply_user_ids_by_message_id( - self, id: str, db: Optional[Session] = None - ) -> list[str]: + def get_reply_user_ids_by_message_id(self, id: str, db: Optional[Session] = None) -> list[str]: with get_db_context(db) as db: - return [ - message.user_id - for message in db.query(Message).filter_by(parent_id=id).all() - ] + return [message.user_id for message in db.query(Message).filter_by(parent_id=id).all()] def get_messages_by_channel_id( self, @@ -318,40 +294,34 @@ class MessageTable: messages = [] for message in all_messages: reply_to_message = ( - self.get_message_by_id( - message.reply_to_id, include_thread_replies=False, db=db - ) + self.get_message_by_id(message.reply_to_id, include_thread_replies=False, db=db) if message.reply_to_id else None ) - webhook_info = message.meta.get("webhook") if message.meta else None + webhook_info = message.meta.get('webhook') if message.meta else None user_info = None - if webhook_info and webhook_info.get("id"): - webhook = Channels.get_webhook_by_id(webhook_info.get("id"), db=db) + if webhook_info and webhook_info.get('id'): + webhook = Channels.get_webhook_by_id(webhook_info.get('id'), db=db) if webhook: user_info = { - "id": webhook.id, - "name": webhook.name, - "role": "webhook", + 'id': webhook.id, + 'name': webhook.name, + 'role': 'webhook', } else: user_info = { - "id": webhook_info.get("id"), - "name": "Deleted Webhook", - "role": "webhook", + 'id': webhook_info.get('id'), + 'name': 'Deleted Webhook', + 'role': 'webhook', } messages.append( MessageReplyToResponse.model_validate( { **MessageModel.model_validate(message).model_dump(), - "user": user_info, - "reply_to_message": ( - reply_to_message.model_dump() - if reply_to_message - else None - ), + 'user': user_info, + 'reply_to_message': (reply_to_message.model_dump() if reply_to_message else None), } ) ) @@ -387,55 +357,42 @@ class MessageTable: messages = [] for message in all_messages: reply_to_message = ( - self.get_message_by_id( - message.reply_to_id, include_thread_replies=False, db=db - ) + self.get_message_by_id(message.reply_to_id, include_thread_replies=False, db=db) if message.reply_to_id else None ) - webhook_info = message.meta.get("webhook") if message.meta else None + webhook_info = message.meta.get('webhook') if message.meta else None user_info = None - if webhook_info and webhook_info.get("id"): - webhook = Channels.get_webhook_by_id(webhook_info.get("id"), db=db) + if webhook_info and webhook_info.get('id'): + webhook = Channels.get_webhook_by_id(webhook_info.get('id'), db=db) if webhook: user_info = { - "id": webhook.id, - "name": webhook.name, - "role": "webhook", + 'id': webhook.id, + 'name': webhook.name, + 'role': 'webhook', } else: user_info = { - "id": webhook_info.get("id"), - "name": "Deleted Webhook", - "role": "webhook", + 'id': webhook_info.get('id'), + 'name': 'Deleted Webhook', + 'role': 'webhook', } messages.append( MessageReplyToResponse.model_validate( { **MessageModel.model_validate(message).model_dump(), - "user": user_info, - "reply_to_message": ( - reply_to_message.model_dump() - if reply_to_message - else None - ), + 'user': user_info, + 'reply_to_message': (reply_to_message.model_dump() if reply_to_message else None), } ) ) return messages - def get_last_message_by_channel_id( - self, channel_id: str, db: Optional[Session] = None - ) -> Optional[MessageModel]: + def get_last_message_by_channel_id(self, channel_id: str, db: Optional[Session] = None) -> Optional[MessageModel]: with get_db_context(db) as db: - message = ( - db.query(Message) - .filter_by(channel_id=channel_id) - .order_by(Message.created_at.desc()) - .first() - ) + message = db.query(Message).filter_by(channel_id=channel_id).order_by(Message.created_at.desc()).first() return MessageModel.model_validate(message) if message else None def get_pinned_messages_by_channel_id( @@ -513,11 +470,7 @@ class MessageTable: ) -> Optional[MessageReactionModel]: with get_db_context(db) as db: # check for existing reaction - existing_reaction = ( - db.query(MessageReaction) - .filter_by(message_id=id, user_id=user_id, name=name) - .first() - ) + existing_reaction = db.query(MessageReaction).filter_by(message_id=id, user_id=user_id, name=name).first() if existing_reaction: return MessageReactionModel.model_validate(existing_reaction) @@ -535,9 +488,7 @@ class MessageTable: db.refresh(result) return MessageReactionModel.model_validate(result) if result else None - def get_reactions_by_message_id( - self, id: str, db: Optional[Session] = None - ) -> list[Reactions]: + def get_reactions_by_message_id(self, id: str, db: Optional[Session] = None) -> list[Reactions]: with get_db_context(db) as db: # JOIN User so all user info is fetched in one query results = ( @@ -552,18 +503,18 @@ class MessageTable: for reaction, user in results: if reaction.name not in reactions: reactions[reaction.name] = { - "name": reaction.name, - "users": [], - "count": 0, + 'name': reaction.name, + 'users': [], + 'count': 0, } - reactions[reaction.name]["users"].append( + reactions[reaction.name]['users'].append( { - "id": user.id, - "name": user.name, + 'id': user.id, + 'name': user.name, } ) - reactions[reaction.name]["count"] += 1 + reactions[reaction.name]['count'] += 1 return [Reactions(**reaction) for reaction in reactions.values()] @@ -571,9 +522,7 @@ class MessageTable: self, id: str, user_id: str, name: str, db: Optional[Session] = None ) -> bool: with get_db_context(db) as db: - db.query(MessageReaction).filter_by( - message_id=id, user_id=user_id, name=name - ).delete() + db.query(MessageReaction).filter_by(message_id=id, user_id=user_id, name=name).delete() db.commit() return True @@ -612,21 +561,15 @@ class MessageTable: with get_db_context(db) as db: query_builder = db.query(Message).filter( Message.channel_id.in_(channel_ids), - Message.content.ilike(f"%{query}%"), + Message.content.ilike(f'%{query}%'), ) if start_timestamp: - query_builder = query_builder.filter( - Message.created_at >= start_timestamp - ) + query_builder = query_builder.filter(Message.created_at >= start_timestamp) if end_timestamp: - query_builder = query_builder.filter( - Message.created_at <= end_timestamp - ) + query_builder = query_builder.filter(Message.created_at <= end_timestamp) - messages = ( - query_builder.order_by(Message.created_at.desc()).limit(limit).all() - ) + messages = query_builder.order_by(Message.created_at.desc()).limit(limit).all() return [MessageModel.model_validate(msg) for msg in messages] diff --git a/backend/open_webui/models/models.py b/backend/open_webui/models/models.py index 3bffe4ddcf..c48847b702 100755 --- a/backend/open_webui/models/models.py +++ b/backend/open_webui/models/models.py @@ -28,13 +28,13 @@ log = logging.getLogger(__name__) # ModelParams is a model for the data stored in the params field of the Model table class ModelParams(BaseModel): - model_config = ConfigDict(extra="allow") + model_config = ConfigDict(extra='allow') pass # ModelMeta is a model for the data stored in the meta field of the Model table class ModelMeta(BaseModel): - profile_image_url: Optional[str] = "/static/favicon.png" + profile_image_url: Optional[str] = '/static/favicon.png' description: Optional[str] = None """ @@ -43,13 +43,13 @@ class ModelMeta(BaseModel): capabilities: Optional[dict] = None - model_config = ConfigDict(extra="allow") + model_config = ConfigDict(extra='allow') pass class Model(Base): - __tablename__ = "model" + __tablename__ = 'model' id = Column(Text, primary_key=True, unique=True) """ @@ -139,10 +139,8 @@ class ModelForm(BaseModel): class ModelsTable: - def _get_access_grants( - self, model_id: str, db: Optional[Session] = None - ) -> list[AccessGrantModel]: - return AccessGrants.get_grants_by_resource("model", model_id, db=db) + def _get_access_grants(self, model_id: str, db: Optional[Session] = None) -> list[AccessGrantModel]: + return AccessGrants.get_grants_by_resource('model', model_id, db=db) def _to_model_model( self, @@ -150,13 +148,9 @@ class ModelsTable: access_grants: Optional[list[AccessGrantModel]] = None, db: Optional[Session] = None, ) -> ModelModel: - model_data = ModelModel.model_validate(model).model_dump( - exclude={"access_grants"} - ) - model_data["access_grants"] = ( - access_grants - if access_grants is not None - else self._get_access_grants(model_data["id"], db=db) + model_data = ModelModel.model_validate(model).model_dump(exclude={'access_grants'}) + model_data['access_grants'] = ( + access_grants if access_grants is not None else self._get_access_grants(model_data['id'], db=db) ) return ModelModel.model_validate(model_data) @@ -167,37 +161,32 @@ class ModelsTable: with get_db_context(db) as db: result = Model( **{ - **form_data.model_dump(exclude={"access_grants"}), - "user_id": user_id, - "created_at": int(time.time()), - "updated_at": int(time.time()), + **form_data.model_dump(exclude={'access_grants'}), + 'user_id': user_id, + 'created_at': int(time.time()), + 'updated_at': int(time.time()), } ) db.add(result) db.commit() db.refresh(result) - AccessGrants.set_access_grants( - "model", result.id, form_data.access_grants, db=db - ) + AccessGrants.set_access_grants('model', result.id, form_data.access_grants, db=db) if result: return self._to_model_model(result, db=db) else: return None except Exception as e: - log.exception(f"Failed to insert a new model: {e}") + log.exception(f'Failed to insert a new model: {e}') return None def get_all_models(self, db: Optional[Session] = None) -> list[ModelModel]: with get_db_context(db) as db: all_models = db.query(Model).all() model_ids = [model.id for model in all_models] - grants_map = AccessGrants.get_grants_by_resources("model", model_ids, db=db) + grants_map = AccessGrants.get_grants_by_resources('model', model_ids, db=db) return [ - self._to_model_model( - model, access_grants=grants_map.get(model.id, []), db=db - ) - for model in all_models + self._to_model_model(model, access_grants=grants_map.get(model.id, []), db=db) for model in all_models ] def get_models(self, db: Optional[Session] = None) -> list[ModelUserResponse]: @@ -209,7 +198,7 @@ class ModelsTable: users = Users.get_users_by_user_ids(user_ids, db=db) if user_ids else [] users_dict = {user.id: user for user in users} - grants_map = AccessGrants.get_grants_by_resources("model", model_ids, db=db) + grants_map = AccessGrants.get_grants_by_resources('model', model_ids, db=db) models = [] for model in all_models: @@ -222,7 +211,7 @@ class ModelsTable: access_grants=grants_map.get(model.id, []), db=db, ).model_dump(), - "user": user.model_dump() if user else None, + 'user': user.model_dump() if user else None, } ) ) @@ -232,28 +221,23 @@ class ModelsTable: with get_db_context(db) as db: all_models = db.query(Model).filter(Model.base_model_id == None).all() model_ids = [model.id for model in all_models] - grants_map = AccessGrants.get_grants_by_resources("model", model_ids, db=db) + grants_map = AccessGrants.get_grants_by_resources('model', model_ids, db=db) return [ - self._to_model_model( - model, access_grants=grants_map.get(model.id, []), db=db - ) - for model in all_models + self._to_model_model(model, access_grants=grants_map.get(model.id, []), db=db) for model in all_models ] def get_models_by_user_id( - self, user_id: str, permission: str = "write", db: Optional[Session] = None + self, user_id: str, permission: str = 'write', db: Optional[Session] = None ) -> list[ModelUserResponse]: models = self.get_models(db=db) - user_group_ids = { - group.id for group in Groups.get_groups_by_member_id(user_id, db=db) - } + user_group_ids = {group.id for group in Groups.get_groups_by_member_id(user_id, db=db)} return [ model for model in models if model.user_id == user_id or AccessGrants.has_access( user_id=user_id, - resource_type="model", + resource_type='model', resource_id=model.id, permission=permission, user_group_ids=user_group_ids, @@ -261,13 +245,13 @@ class ModelsTable: ) ] - def _has_permission(self, db, query, filter: dict, permission: str = "read"): + def _has_permission(self, db, query, filter: dict, permission: str = 'read'): return AccessGrants.has_permission_filter( db=db, query=query, DocumentModel=Model, filter=filter, - resource_type="model", + resource_type='model', permission=permission, ) @@ -285,22 +269,22 @@ class ModelsTable: query = query.filter(Model.base_model_id != None) if filter: - query_key = filter.get("query") + query_key = filter.get('query') if query_key: query = query.filter( or_( - Model.name.ilike(f"%{query_key}%"), - Model.base_model_id.ilike(f"%{query_key}%"), - User.name.ilike(f"%{query_key}%"), - User.email.ilike(f"%{query_key}%"), - User.username.ilike(f"%{query_key}%"), + Model.name.ilike(f'%{query_key}%'), + Model.base_model_id.ilike(f'%{query_key}%'), + User.name.ilike(f'%{query_key}%'), + User.email.ilike(f'%{query_key}%'), + User.username.ilike(f'%{query_key}%'), ) ) - view_option = filter.get("view_option") - if view_option == "created": + view_option = filter.get('view_option') + if view_option == 'created': query = query.filter(Model.user_id == user_id) - elif view_option == "shared": + elif view_option == 'shared': query = query.filter(Model.user_id != user_id) # Apply access control filtering @@ -308,10 +292,10 @@ class ModelsTable: db, query, filter, - permission="read", + permission='read', ) - tag = filter.get("tag") + tag = filter.get('tag') if tag: # TODO: This is a simple implementation and should be improved for performance like_pattern = f'%"{tag.lower()}"%' # `"tag"` inside JSON array @@ -319,21 +303,21 @@ class ModelsTable: query = query.filter(meta_text.like(like_pattern)) - order_by = filter.get("order_by") - direction = filter.get("direction") + order_by = filter.get('order_by') + direction = filter.get('direction') - if order_by == "name": - if direction == "asc": + if order_by == 'name': + if direction == 'asc': query = query.order_by(Model.name.asc()) else: query = query.order_by(Model.name.desc()) - elif order_by == "created_at": - if direction == "asc": + elif order_by == 'created_at': + if direction == 'asc': query = query.order_by(Model.created_at.asc()) else: query = query.order_by(Model.created_at.desc()) - elif order_by == "updated_at": - if direction == "asc": + elif order_by == 'updated_at': + if direction == 'asc': query = query.order_by(Model.updated_at.asc()) else: query = query.order_by(Model.updated_at.desc()) @@ -352,7 +336,7 @@ class ModelsTable: items = query.all() model_ids = [model.id for model, _ in items] - grants_map = AccessGrants.get_grants_by_resources("model", model_ids, db=db) + grants_map = AccessGrants.get_grants_by_resources('model', model_ids, db=db) models = [] for model, user in items: @@ -363,19 +347,13 @@ class ModelsTable: access_grants=grants_map.get(model.id, []), db=db, ).model_dump(), - user=( - UserResponse(**UserModel.model_validate(user).model_dump()) - if user - else None - ), + user=(UserResponse(**UserModel.model_validate(user).model_dump()) if user else None), ) ) return ModelListResponse(items=models, total=total) - def get_model_by_id( - self, id: str, db: Optional[Session] = None - ) -> Optional[ModelModel]: + def get_model_by_id(self, id: str, db: Optional[Session] = None) -> Optional[ModelModel]: try: with get_db_context(db) as db: model = db.get(Model, id) @@ -383,16 +361,12 @@ class ModelsTable: except Exception: return None - def get_models_by_ids( - self, ids: list[str], db: Optional[Session] = None - ) -> list[ModelModel]: + def get_models_by_ids(self, ids: list[str], db: Optional[Session] = None) -> list[ModelModel]: try: with get_db_context(db) as db: models = db.query(Model).filter(Model.id.in_(ids)).all() model_ids = [model.id for model in models] - grants_map = AccessGrants.get_grants_by_resources( - "model", model_ids, db=db - ) + grants_map = AccessGrants.get_grants_by_resources('model', model_ids, db=db) return [ self._to_model_model( model, @@ -404,9 +378,7 @@ class ModelsTable: except Exception: return [] - def toggle_model_by_id( - self, id: str, db: Optional[Session] = None - ) -> Optional[ModelModel]: + def toggle_model_by_id(self, id: str, db: Optional[Session] = None) -> Optional[ModelModel]: with get_db_context(db) as db: try: model = db.query(Model).filter_by(id=id).first() @@ -422,30 +394,26 @@ class ModelsTable: except Exception: return None - def update_model_by_id( - self, id: str, model: ModelForm, db: Optional[Session] = None - ) -> Optional[ModelModel]: + def update_model_by_id(self, id: str, model: ModelForm, db: Optional[Session] = None) -> Optional[ModelModel]: try: with get_db_context(db) as db: # update only the fields that are present in the model - data = model.model_dump(exclude={"id", "access_grants"}) + data = model.model_dump(exclude={'id', 'access_grants'}) result = db.query(Model).filter_by(id=id).update(data) db.commit() if model.access_grants is not None: - AccessGrants.set_access_grants( - "model", id, model.access_grants, db=db - ) + AccessGrants.set_access_grants('model', id, model.access_grants, db=db) return self.get_model_by_id(id, db=db) except Exception as e: - log.exception(f"Failed to update the model by id {id}: {e}") + log.exception(f'Failed to update the model by id {id}: {e}') return None def delete_model_by_id(self, id: str, db: Optional[Session] = None) -> bool: try: with get_db_context(db) as db: - AccessGrants.revoke_all_access("model", id, db=db) + AccessGrants.revoke_all_access('model', id, db=db) db.query(Model).filter_by(id=id).delete() db.commit() @@ -458,7 +426,7 @@ class ModelsTable: with get_db_context(db) as db: model_ids = [row[0] for row in db.query(Model.id).all()] for model_id in model_ids: - AccessGrants.revoke_all_access("model", model_id, db=db) + AccessGrants.revoke_all_access('model', model_id, db=db) db.query(Model).delete() db.commit() @@ -466,9 +434,7 @@ class ModelsTable: except Exception: return False - def sync_models( - self, user_id: str, models: list[ModelModel], db: Optional[Session] = None - ) -> list[ModelModel]: + def sync_models(self, user_id: str, models: list[ModelModel], db: Optional[Session] = None) -> list[ModelModel]: try: with get_db_context(db) as db: # Get existing models @@ -483,37 +449,33 @@ class ModelsTable: if model.id in existing_ids: db.query(Model).filter_by(id=model.id).update( { - **model.model_dump(exclude={"access_grants"}), - "user_id": user_id, - "updated_at": int(time.time()), + **model.model_dump(exclude={'access_grants'}), + 'user_id': user_id, + 'updated_at': int(time.time()), } ) else: new_model = Model( **{ - **model.model_dump(exclude={"access_grants"}), - "user_id": user_id, - "updated_at": int(time.time()), + **model.model_dump(exclude={'access_grants'}), + 'user_id': user_id, + 'updated_at': int(time.time()), } ) db.add(new_model) - AccessGrants.set_access_grants( - "model", model.id, model.access_grants, db=db - ) + AccessGrants.set_access_grants('model', model.id, model.access_grants, db=db) # Remove models that are no longer present for model in existing_models: if model.id not in new_model_ids: - AccessGrants.revoke_all_access("model", model.id, db=db) + AccessGrants.revoke_all_access('model', model.id, db=db) db.delete(model) db.commit() all_models = db.query(Model).all() model_ids = [model.id for model in all_models] - grants_map = AccessGrants.get_grants_by_resources( - "model", model_ids, db=db - ) + grants_map = AccessGrants.get_grants_by_resources('model', model_ids, db=db) return [ self._to_model_model( model, @@ -523,7 +485,7 @@ class ModelsTable: for model in all_models ] except Exception as e: - log.exception(f"Error syncing models for user {user_id}: {e}") + log.exception(f'Error syncing models for user {user_id}: {e}') return [] diff --git a/backend/open_webui/models/notes.py b/backend/open_webui/models/notes.py index ff8a3ac635..34749f5f6c 100644 --- a/backend/open_webui/models/notes.py +++ b/backend/open_webui/models/notes.py @@ -21,7 +21,7 @@ from sqlalchemy import or_, func, cast class Note(Base): - __tablename__ = "note" + __tablename__ = 'note' id = Column(Text, primary_key=True, unique=True) user_id = Column(Text) @@ -88,10 +88,8 @@ class NoteListResponse(BaseModel): class NoteTable: - def _get_access_grants( - self, note_id: str, db: Optional[Session] = None - ) -> list[AccessGrantModel]: - return AccessGrants.get_grants_by_resource("note", note_id, db=db) + def _get_access_grants(self, note_id: str, db: Optional[Session] = None) -> list[AccessGrantModel]: + return AccessGrants.get_grants_by_resource('note', note_id, db=db) def _to_note_model( self, @@ -99,51 +97,43 @@ class NoteTable: access_grants: Optional[list[AccessGrantModel]] = None, db: Optional[Session] = None, ) -> NoteModel: - note_data = NoteModel.model_validate(note).model_dump(exclude={"access_grants"}) - note_data["access_grants"] = ( - access_grants - if access_grants is not None - else self._get_access_grants(note_data["id"], db=db) + note_data = NoteModel.model_validate(note).model_dump(exclude={'access_grants'}) + note_data['access_grants'] = ( + access_grants if access_grants is not None else self._get_access_grants(note_data['id'], db=db) ) return NoteModel.model_validate(note_data) - def _has_permission(self, db, query, filter: dict, permission: str = "read"): + def _has_permission(self, db, query, filter: dict, permission: str = 'read'): return AccessGrants.has_permission_filter( db=db, query=query, DocumentModel=Note, filter=filter, - resource_type="note", + resource_type='note', permission=permission, ) - def insert_new_note( - self, user_id: str, form_data: NoteForm, db: Optional[Session] = None - ) -> Optional[NoteModel]: + def insert_new_note(self, user_id: str, form_data: NoteForm, db: Optional[Session] = None) -> Optional[NoteModel]: with get_db_context(db) as db: note = NoteModel( **{ - "id": str(uuid.uuid4()), - "user_id": user_id, - **form_data.model_dump(exclude={"access_grants"}), - "created_at": int(time.time_ns()), - "updated_at": int(time.time_ns()), - "access_grants": [], + 'id': str(uuid.uuid4()), + 'user_id': user_id, + **form_data.model_dump(exclude={'access_grants'}), + 'created_at': int(time.time_ns()), + 'updated_at': int(time.time_ns()), + 'access_grants': [], } ) - new_note = Note(**note.model_dump(exclude={"access_grants"})) + new_note = Note(**note.model_dump(exclude={'access_grants'})) db.add(new_note) db.commit() - AccessGrants.set_access_grants( - "note", note.id, form_data.access_grants, db=db - ) + AccessGrants.set_access_grants('note', note.id, form_data.access_grants, db=db) return self._to_note_model(new_note, db=db) - def get_notes( - self, skip: int = 0, limit: int = 50, db: Optional[Session] = None - ) -> list[NoteModel]: + def get_notes(self, skip: int = 0, limit: int = 50, db: Optional[Session] = None) -> list[NoteModel]: with get_db_context(db) as db: query = db.query(Note).order_by(Note.updated_at.desc()) if skip is not None: @@ -152,13 +142,8 @@ class NoteTable: query = query.limit(limit) notes = query.all() note_ids = [note.id for note in notes] - grants_map = AccessGrants.get_grants_by_resources("note", note_ids, db=db) - return [ - self._to_note_model( - note, access_grants=grants_map.get(note.id, []), db=db - ) - for note in notes - ] + grants_map = AccessGrants.get_grants_by_resources('note', note_ids, db=db) + return [self._to_note_model(note, access_grants=grants_map.get(note.id, []), db=db) for note in notes] def search_notes( self, @@ -171,36 +156,32 @@ class NoteTable: with get_db_context(db) as db: query = db.query(Note, User).outerjoin(User, User.id == Note.user_id) if filter: - query_key = filter.get("query") + query_key = filter.get('query') if query_key: # Normalize search by removing hyphens and spaces (e.g., "todo" matches "to-do" and "to do") - normalized_query = query_key.replace("-", "").replace(" ", "") + normalized_query = query_key.replace('-', '').replace(' ', '') query = query.filter( or_( + func.replace(func.replace(Note.title, '-', ''), ' ', '').ilike(f'%{normalized_query}%'), func.replace( - func.replace(Note.title, "-", ""), " ", "" - ).ilike(f"%{normalized_query}%"), - func.replace( - func.replace( - cast(Note.data["content"]["md"], Text), "-", "" - ), - " ", - "", - ).ilike(f"%{normalized_query}%"), + func.replace(cast(Note.data['content']['md'], Text), '-', ''), + ' ', + '', + ).ilike(f'%{normalized_query}%'), ) ) - view_option = filter.get("view_option") - if view_option == "created": + view_option = filter.get('view_option') + if view_option == 'created': query = query.filter(Note.user_id == user_id) - elif view_option == "shared": + elif view_option == 'shared': query = query.filter(Note.user_id != user_id) # Apply access control filtering - if "permission" in filter: - permission = filter["permission"] + if 'permission' in filter: + permission = filter['permission'] else: - permission = "write" + permission = 'write' query = self._has_permission( db, @@ -209,21 +190,21 @@ class NoteTable: permission=permission, ) - order_by = filter.get("order_by") - direction = filter.get("direction") + order_by = filter.get('order_by') + direction = filter.get('direction') - if order_by == "name": - if direction == "asc": + if order_by == 'name': + if direction == 'asc': query = query.order_by(Note.title.asc()) else: query = query.order_by(Note.title.desc()) - elif order_by == "created_at": - if direction == "asc": + elif order_by == 'created_at': + if direction == 'asc': query = query.order_by(Note.created_at.asc()) else: query = query.order_by(Note.created_at.desc()) - elif order_by == "updated_at": - if direction == "asc": + elif order_by == 'updated_at': + if direction == 'asc': query = query.order_by(Note.updated_at.asc()) else: query = query.order_by(Note.updated_at.desc()) @@ -244,7 +225,7 @@ class NoteTable: items = query.all() note_ids = [note.id for note, _ in items] - grants_map = AccessGrants.get_grants_by_resources("note", note_ids, db=db) + grants_map = AccessGrants.get_grants_by_resources('note', note_ids, db=db) notes = [] for note, user in items: @@ -255,11 +236,7 @@ class NoteTable: access_grants=grants_map.get(note.id, []), db=db, ).model_dump(), - user=( - UserResponse(**UserModel.model_validate(user).model_dump()) - if user - else None - ), + user=(UserResponse(**UserModel.model_validate(user).model_dump()) if user else None), ) ) @@ -268,20 +245,16 @@ class NoteTable: def get_notes_by_user_id( self, user_id: str, - permission: str = "read", + permission: str = 'read', skip: int = 0, limit: int = 50, db: Optional[Session] = None, ) -> list[NoteModel]: with get_db_context(db) as db: - user_group_ids = [ - group.id for group in Groups.get_groups_by_member_id(user_id, db=db) - ] + user_group_ids = [group.id for group in Groups.get_groups_by_member_id(user_id, db=db)] query = db.query(Note).order_by(Note.updated_at.desc()) - query = self._has_permission( - db, query, {"user_id": user_id, "group_ids": user_group_ids}, permission - ) + query = self._has_permission(db, query, {'user_id': user_id, 'group_ids': user_group_ids}, permission) if skip is not None: query = query.offset(skip) @@ -290,17 +263,10 @@ class NoteTable: notes = query.all() note_ids = [note.id for note in notes] - grants_map = AccessGrants.get_grants_by_resources("note", note_ids, db=db) - return [ - self._to_note_model( - note, access_grants=grants_map.get(note.id, []), db=db - ) - for note in notes - ] + grants_map = AccessGrants.get_grants_by_resources('note', note_ids, db=db) + return [self._to_note_model(note, access_grants=grants_map.get(note.id, []), db=db) for note in notes] - def get_note_by_id( - self, id: str, db: Optional[Session] = None - ) -> Optional[NoteModel]: + def get_note_by_id(self, id: str, db: Optional[Session] = None) -> Optional[NoteModel]: with get_db_context(db) as db: note = db.query(Note).filter(Note.id == id).first() return self._to_note_model(note, db=db) if note else None @@ -315,17 +281,15 @@ class NoteTable: form_data = form_data.model_dump(exclude_unset=True) - if "title" in form_data: - note.title = form_data["title"] - if "data" in form_data: - note.data = {**note.data, **form_data["data"]} - if "meta" in form_data: - note.meta = {**note.meta, **form_data["meta"]} + if 'title' in form_data: + note.title = form_data['title'] + if 'data' in form_data: + note.data = {**note.data, **form_data['data']} + if 'meta' in form_data: + note.meta = {**note.meta, **form_data['meta']} - if "access_grants" in form_data: - AccessGrants.set_access_grants( - "note", id, form_data["access_grants"], db=db - ) + if 'access_grants' in form_data: + AccessGrants.set_access_grants('note', id, form_data['access_grants'], db=db) note.updated_at = int(time.time_ns()) @@ -335,7 +299,7 @@ class NoteTable: def delete_note_by_id(self, id: str, db: Optional[Session] = None) -> bool: try: with get_db_context(db) as db: - AccessGrants.revoke_all_access("note", id, db=db) + AccessGrants.revoke_all_access('note', id, db=db) db.query(Note).filter(Note.id == id).delete() db.commit() return True diff --git a/backend/open_webui/models/oauth_sessions.py b/backend/open_webui/models/oauth_sessions.py index c9110d3267..868216164a 100644 --- a/backend/open_webui/models/oauth_sessions.py +++ b/backend/open_webui/models/oauth_sessions.py @@ -23,23 +23,21 @@ log = logging.getLogger(__name__) class OAuthSession(Base): - __tablename__ = "oauth_session" + __tablename__ = 'oauth_session' id = Column(Text, primary_key=True, unique=True) user_id = Column(Text, nullable=False) provider = Column(Text, nullable=False) - token = Column( - Text, nullable=False - ) # JSON with access_token, id_token, refresh_token + token = Column(Text, nullable=False) # JSON with access_token, id_token, refresh_token expires_at = Column(BigInteger, nullable=False) created_at = Column(BigInteger, nullable=False) updated_at = Column(BigInteger, nullable=False) # Add indexes for better performance __table_args__ = ( - Index("idx_oauth_session_user_id", "user_id"), - Index("idx_oauth_session_expires_at", "expires_at"), - Index("idx_oauth_session_user_provider", "user_id", "provider"), + Index('idx_oauth_session_user_id', 'user_id'), + Index('idx_oauth_session_expires_at', 'expires_at'), + Index('idx_oauth_session_user_provider', 'user_id', 'provider'), ) @@ -71,7 +69,7 @@ class OAuthSessionTable: def __init__(self): self.encryption_key = OAUTH_SESSION_TOKEN_ENCRYPTION_KEY if not self.encryption_key: - raise Exception("OAUTH_SESSION_TOKEN_ENCRYPTION_KEY is not set") + raise Exception('OAUTH_SESSION_TOKEN_ENCRYPTION_KEY is not set') # check if encryption key is in the right format for Fernet (32 url-safe base64-encoded bytes) if len(self.encryption_key) != 44: @@ -83,7 +81,7 @@ class OAuthSessionTable: try: self.fernet = Fernet(self.encryption_key) except Exception as e: - log.error(f"Error initializing Fernet with provided key: {e}") + log.error(f'Error initializing Fernet with provided key: {e}') raise def _encrypt_token(self, token) -> str: @@ -93,7 +91,7 @@ class OAuthSessionTable: encrypted = self.fernet.encrypt(token_json.encode()).decode() return encrypted except Exception as e: - log.error(f"Error encrypting tokens: {e}") + log.error(f'Error encrypting tokens: {e}') raise def _decrypt_token(self, token: str): @@ -102,7 +100,7 @@ class OAuthSessionTable: decrypted = self.fernet.decrypt(token.encode()).decode() return json.loads(decrypted) except Exception as e: - log.error(f"Error decrypting tokens: {type(e).__name__}: {e}") + log.error(f'Error decrypting tokens: {type(e).__name__}: {e}') raise def create_session( @@ -120,13 +118,13 @@ class OAuthSessionTable: result = OAuthSession( **{ - "id": id, - "user_id": user_id, - "provider": provider, - "token": self._encrypt_token(token), - "expires_at": token.get("expires_at"), - "created_at": current_time, - "updated_at": current_time, + 'id': id, + 'user_id': user_id, + 'provider': provider, + 'token': self._encrypt_token(token), + 'expires_at': token.get('expires_at'), + 'created_at': current_time, + 'updated_at': current_time, } ) @@ -141,12 +139,10 @@ class OAuthSessionTable: else: return None except Exception as e: - log.error(f"Error creating OAuth session: {e}") + log.error(f'Error creating OAuth session: {e}') return None - def get_session_by_id( - self, session_id: str, db: Optional[Session] = None - ) -> Optional[OAuthSessionModel]: + def get_session_by_id(self, session_id: str, db: Optional[Session] = None) -> Optional[OAuthSessionModel]: """Get OAuth session by ID""" try: with get_db_context(db) as db: @@ -158,7 +154,7 @@ class OAuthSessionTable: return None except Exception as e: - log.error(f"Error getting OAuth session by ID: {e}") + log.error(f'Error getting OAuth session by ID: {e}') return None def get_session_by_id_and_user_id( @@ -167,11 +163,7 @@ class OAuthSessionTable: """Get OAuth session by ID and user ID""" try: with get_db_context(db) as db: - session = ( - db.query(OAuthSession) - .filter_by(id=session_id, user_id=user_id) - .first() - ) + session = db.query(OAuthSession).filter_by(id=session_id, user_id=user_id).first() if session: db.expunge(session) session.token = self._decrypt_token(session.token) @@ -179,7 +171,7 @@ class OAuthSessionTable: return None except Exception as e: - log.error(f"Error getting OAuth session by ID: {e}") + log.error(f'Error getting OAuth session by ID: {e}') return None def get_session_by_provider_and_user_id( @@ -201,12 +193,10 @@ class OAuthSessionTable: return None except Exception as e: - log.error(f"Error getting OAuth session by provider and user ID: {e}") + log.error(f'Error getting OAuth session by provider and user ID: {e}') return None - def get_sessions_by_user_id( - self, user_id: str, db: Optional[Session] = None - ) -> List[OAuthSessionModel]: + def get_sessions_by_user_id(self, user_id: str, db: Optional[Session] = None) -> List[OAuthSessionModel]: """Get all OAuth sessions for a user""" try: with get_db_context(db) as db: @@ -220,7 +210,7 @@ class OAuthSessionTable: results.append(OAuthSessionModel.model_validate(session)) except Exception as e: log.warning( - f"Skipping OAuth session {session.id} due to decryption failure, deleting corrupted session: {type(e).__name__}: {e}" + f'Skipping OAuth session {session.id} due to decryption failure, deleting corrupted session: {type(e).__name__}: {e}' ) db.query(OAuthSession).filter_by(id=session.id).delete() db.commit() @@ -228,7 +218,7 @@ class OAuthSessionTable: return results except Exception as e: - log.error(f"Error getting OAuth sessions by user ID: {e}") + log.error(f'Error getting OAuth sessions by user ID: {e}') return [] def update_session_by_id( @@ -241,9 +231,9 @@ class OAuthSessionTable: db.query(OAuthSession).filter_by(id=session_id).update( { - "token": self._encrypt_token(token), - "expires_at": token.get("expires_at"), - "updated_at": current_time, + 'token': self._encrypt_token(token), + 'expires_at': token.get('expires_at'), + 'updated_at': current_time, } ) db.commit() @@ -256,12 +246,10 @@ class OAuthSessionTable: return None except Exception as e: - log.error(f"Error updating OAuth session tokens: {e}") + log.error(f'Error updating OAuth session tokens: {e}') return None - def delete_session_by_id( - self, session_id: str, db: Optional[Session] = None - ) -> bool: + def delete_session_by_id(self, session_id: str, db: Optional[Session] = None) -> bool: """Delete an OAuth session""" try: with get_db_context(db) as db: @@ -269,12 +257,10 @@ class OAuthSessionTable: db.commit() return result > 0 except Exception as e: - log.error(f"Error deleting OAuth session: {e}") + log.error(f'Error deleting OAuth session: {e}') return False - def delete_sessions_by_user_id( - self, user_id: str, db: Optional[Session] = None - ) -> bool: + def delete_sessions_by_user_id(self, user_id: str, db: Optional[Session] = None) -> bool: """Delete all OAuth sessions for a user""" try: with get_db_context(db) as db: @@ -282,12 +268,10 @@ class OAuthSessionTable: db.commit() return True except Exception as e: - log.error(f"Error deleting OAuth sessions by user ID: {e}") + log.error(f'Error deleting OAuth sessions by user ID: {e}') return False - def delete_sessions_by_provider( - self, provider: str, db: Optional[Session] = None - ) -> bool: + def delete_sessions_by_provider(self, provider: str, db: Optional[Session] = None) -> bool: """Delete all OAuth sessions for a provider""" try: with get_db_context(db) as db: @@ -295,7 +279,7 @@ class OAuthSessionTable: db.commit() return True except Exception as e: - log.error(f"Error deleting OAuth sessions by provider {provider}: {e}") + log.error(f'Error deleting OAuth sessions by provider {provider}: {e}') return False diff --git a/backend/open_webui/models/prompt_history.py b/backend/open_webui/models/prompt_history.py index 91ca4cb445..d42b4bfa24 100644 --- a/backend/open_webui/models/prompt_history.py +++ b/backend/open_webui/models/prompt_history.py @@ -19,7 +19,7 @@ from sqlalchemy import BigInteger, Column, Text, JSON, Index class PromptHistory(Base): - __tablename__ = "prompt_history" + __tablename__ = 'prompt_history' id = Column(Text, primary_key=True) prompt_id = Column(Text, nullable=False, index=True) @@ -100,11 +100,7 @@ class PromptHistoryTable: return [ PromptHistoryResponse( **PromptHistoryModel.model_validate(entry).model_dump(), - user=( - users_dict.get(entry.user_id).model_dump() - if users_dict.get(entry.user_id) - else None - ), + user=(users_dict.get(entry.user_id).model_dump() if users_dict.get(entry.user_id) else None), ) for entry in entries ] @@ -116,9 +112,7 @@ class PromptHistoryTable: ) -> Optional[PromptHistoryModel]: """Get a specific history entry by ID.""" with get_db_context(db) as db: - entry = ( - db.query(PromptHistory).filter(PromptHistory.id == history_id).first() - ) + entry = db.query(PromptHistory).filter(PromptHistory.id == history_id).first() if entry: return PromptHistoryModel.model_validate(entry) return None @@ -147,11 +141,7 @@ class PromptHistoryTable: ) -> int: """Get the number of history entries for a prompt.""" with get_db_context(db) as db: - return ( - db.query(PromptHistory) - .filter(PromptHistory.prompt_id == prompt_id) - .count() - ) + return db.query(PromptHistory).filter(PromptHistory.prompt_id == prompt_id).count() def compute_diff( self, @@ -161,9 +151,7 @@ class PromptHistoryTable: ) -> Optional[dict]: """Compute diff between two history entries.""" with get_db_context(db) as db: - from_entry = ( - db.query(PromptHistory).filter(PromptHistory.id == from_id).first() - ) + from_entry = db.query(PromptHistory).filter(PromptHistory.id == from_id).first() to_entry = db.query(PromptHistory).filter(PromptHistory.id == to_id).first() if not from_entry or not to_entry: @@ -173,26 +161,26 @@ class PromptHistoryTable: to_snapshot = to_entry.snapshot # Compute diff for content field - from_content = from_snapshot.get("content", "") - to_content = to_snapshot.get("content", "") + from_content = from_snapshot.get('content', '') + to_content = to_snapshot.get('content', '') diff_lines = list( difflib.unified_diff( from_content.splitlines(keepends=True), to_content.splitlines(keepends=True), - fromfile=f"v{from_id[:8]}", - tofile=f"v{to_id[:8]}", - lineterm="", + fromfile=f'v{from_id[:8]}', + tofile=f'v{to_id[:8]}', + lineterm='', ) ) return { - "from_id": from_id, - "to_id": to_id, - "from_snapshot": from_snapshot, - "to_snapshot": to_snapshot, - "content_diff": diff_lines, - "name_changed": from_snapshot.get("name") != to_snapshot.get("name"), + 'from_id': from_id, + 'to_id': to_id, + 'from_snapshot': from_snapshot, + 'to_snapshot': to_snapshot, + 'content_diff': diff_lines, + 'name_changed': from_snapshot.get('name') != to_snapshot.get('name'), } def delete_history_by_prompt_id( @@ -202,9 +190,7 @@ class PromptHistoryTable: ) -> bool: """Delete all history entries for a prompt.""" with get_db_context(db) as db: - db.query(PromptHistory).filter( - PromptHistory.prompt_id == prompt_id - ).delete() + db.query(PromptHistory).filter(PromptHistory.prompt_id == prompt_id).delete() db.commit() return True diff --git a/backend/open_webui/models/prompts.py b/backend/open_webui/models/prompts.py index e32621f4e5..028b7a1bc7 100644 --- a/backend/open_webui/models/prompts.py +++ b/backend/open_webui/models/prompts.py @@ -19,7 +19,7 @@ from sqlalchemy import BigInteger, Boolean, Column, String, Text, JSON, or_, fun class Prompt(Base): - __tablename__ = "prompt" + __tablename__ = 'prompt' id = Column(Text, primary_key=True) command = Column(String, unique=True, index=True) @@ -77,7 +77,6 @@ class PromptAccessListResponse(BaseModel): class PromptForm(BaseModel): - command: str name: str # Changed from title content: str @@ -91,10 +90,8 @@ class PromptForm(BaseModel): class PromptsTable: - def _get_access_grants( - self, prompt_id: str, db: Optional[Session] = None - ) -> list[AccessGrantModel]: - return AccessGrants.get_grants_by_resource("prompt", prompt_id, db=db) + def _get_access_grants(self, prompt_id: str, db: Optional[Session] = None) -> list[AccessGrantModel]: + return AccessGrants.get_grants_by_resource('prompt', prompt_id, db=db) def _to_prompt_model( self, @@ -102,13 +99,9 @@ class PromptsTable: access_grants: Optional[list[AccessGrantModel]] = None, db: Optional[Session] = None, ) -> PromptModel: - prompt_data = PromptModel.model_validate(prompt).model_dump( - exclude={"access_grants"} - ) - prompt_data["access_grants"] = ( - access_grants - if access_grants is not None - else self._get_access_grants(prompt_data["id"], db=db) + prompt_data = PromptModel.model_validate(prompt).model_dump(exclude={'access_grants'}) + prompt_data['access_grants'] = ( + access_grants if access_grants is not None else self._get_access_grants(prompt_data['id'], db=db) ) return PromptModel.model_validate(prompt_data) @@ -135,26 +128,22 @@ class PromptsTable: try: with get_db_context(db) as db: - result = Prompt(**prompt.model_dump(exclude={"access_grants"})) + result = Prompt(**prompt.model_dump(exclude={'access_grants'})) db.add(result) db.commit() db.refresh(result) - AccessGrants.set_access_grants( - "prompt", prompt_id, form_data.access_grants, db=db - ) + AccessGrants.set_access_grants('prompt', prompt_id, form_data.access_grants, db=db) if result: current_access_grants = self._get_access_grants(prompt_id, db=db) snapshot = { - "name": form_data.name, - "content": form_data.content, - "command": form_data.command, - "data": form_data.data or {}, - "meta": form_data.meta or {}, - "tags": form_data.tags or [], - "access_grants": [ - grant.model_dump() for grant in current_access_grants - ], + 'name': form_data.name, + 'content': form_data.content, + 'command': form_data.command, + 'data': form_data.data or {}, + 'meta': form_data.meta or {}, + 'tags': form_data.tags or [], + 'access_grants': [grant.model_dump() for grant in current_access_grants], } history_entry = PromptHistories.create_history_entry( @@ -162,7 +151,7 @@ class PromptsTable: snapshot=snapshot, user_id=user_id, parent_id=None, # Initial commit has no parent - commit_message=form_data.commit_message or "Initial version", + commit_message=form_data.commit_message or 'Initial version', db=db, ) @@ -178,9 +167,7 @@ class PromptsTable: except Exception: return None - def get_prompt_by_id( - self, prompt_id: str, db: Optional[Session] = None - ) -> Optional[PromptModel]: + def get_prompt_by_id(self, prompt_id: str, db: Optional[Session] = None) -> Optional[PromptModel]: """Get prompt by UUID.""" try: with get_db_context(db) as db: @@ -191,9 +178,7 @@ class PromptsTable: except Exception: return None - def get_prompt_by_command( - self, command: str, db: Optional[Session] = None - ) -> Optional[PromptModel]: + def get_prompt_by_command(self, command: str, db: Optional[Session] = None) -> Optional[PromptModel]: try: with get_db_context(db) as db: prompt = db.query(Prompt).filter_by(command=command).first() @@ -205,21 +190,14 @@ class PromptsTable: def get_prompts(self, db: Optional[Session] = None) -> list[PromptUserResponse]: with get_db_context(db) as db: - all_prompts = ( - db.query(Prompt) - .filter(Prompt.is_active == True) - .order_by(Prompt.updated_at.desc()) - .all() - ) + all_prompts = db.query(Prompt).filter(Prompt.is_active == True).order_by(Prompt.updated_at.desc()).all() user_ids = list(set(prompt.user_id for prompt in all_prompts)) prompt_ids = [prompt.id for prompt in all_prompts] users = Users.get_users_by_user_ids(user_ids, db=db) if user_ids else [] users_dict = {user.id: user for user in users} - grants_map = AccessGrants.get_grants_by_resources( - "prompt", prompt_ids, db=db - ) + grants_map = AccessGrants.get_grants_by_resources('prompt', prompt_ids, db=db) prompts = [] for prompt in all_prompts: @@ -232,7 +210,7 @@ class PromptsTable: access_grants=grants_map.get(prompt.id, []), db=db, ).model_dump(), - "user": user.model_dump() if user else None, + 'user': user.model_dump() if user else None, } ) ) @@ -240,12 +218,10 @@ class PromptsTable: return prompts def get_prompts_by_user_id( - self, user_id: str, permission: str = "write", db: Optional[Session] = None + self, user_id: str, permission: str = 'write', db: Optional[Session] = None ) -> list[PromptUserResponse]: prompts = self.get_prompts(db=db) - user_group_ids = { - group.id for group in Groups.get_groups_by_member_id(user_id, db=db) - } + user_group_ids = {group.id for group in Groups.get_groups_by_member_id(user_id, db=db)} return [ prompt @@ -253,7 +229,7 @@ class PromptsTable: if prompt.user_id == user_id or AccessGrants.has_access( user_id=user_id, - resource_type="prompt", + resource_type='prompt', resource_id=prompt.id, permission=permission, user_group_ids=user_group_ids, @@ -276,22 +252,22 @@ class PromptsTable: query = db.query(Prompt, User).outerjoin(User, User.id == Prompt.user_id) if filter: - query_key = filter.get("query") + query_key = filter.get('query') if query_key: query = query.filter( or_( - Prompt.name.ilike(f"%{query_key}%"), - Prompt.command.ilike(f"%{query_key}%"), - Prompt.content.ilike(f"%{query_key}%"), - User.name.ilike(f"%{query_key}%"), - User.email.ilike(f"%{query_key}%"), + Prompt.name.ilike(f'%{query_key}%'), + Prompt.command.ilike(f'%{query_key}%'), + Prompt.content.ilike(f'%{query_key}%'), + User.name.ilike(f'%{query_key}%'), + User.email.ilike(f'%{query_key}%'), ) ) - view_option = filter.get("view_option") - if view_option == "created": + view_option = filter.get('view_option') + if view_option == 'created': query = query.filter(Prompt.user_id == user_id) - elif view_option == "shared": + elif view_option == 'shared': query = query.filter(Prompt.user_id != user_id) # Apply access grant filtering @@ -300,32 +276,32 @@ class PromptsTable: query=query, DocumentModel=Prompt, filter=filter, - resource_type="prompt", - permission="read", + resource_type='prompt', + permission='read', ) - tag = filter.get("tag") + tag = filter.get('tag') if tag: # Search for tag in JSON array field like_pattern = f'%"{tag.lower()}"%' tags_text = func.lower(cast(Prompt.tags, String)) query = query.filter(tags_text.like(like_pattern)) - order_by = filter.get("order_by") - direction = filter.get("direction") + order_by = filter.get('order_by') + direction = filter.get('direction') - if order_by == "name": - if direction == "asc": + if order_by == 'name': + if direction == 'asc': query = query.order_by(Prompt.name.asc()) else: query = query.order_by(Prompt.name.desc()) - elif order_by == "created_at": - if direction == "asc": + elif order_by == 'created_at': + if direction == 'asc': query = query.order_by(Prompt.created_at.asc()) else: query = query.order_by(Prompt.created_at.desc()) - elif order_by == "updated_at": - if direction == "asc": + elif order_by == 'updated_at': + if direction == 'asc': query = query.order_by(Prompt.updated_at.asc()) else: query = query.order_by(Prompt.updated_at.desc()) @@ -345,9 +321,7 @@ class PromptsTable: items = query.all() prompt_ids = [prompt.id for prompt, _ in items] - grants_map = AccessGrants.get_grants_by_resources( - "prompt", prompt_ids, db=db - ) + grants_map = AccessGrants.get_grants_by_resources('prompt', prompt_ids, db=db) prompts = [] for prompt, user in items: @@ -358,11 +332,7 @@ class PromptsTable: access_grants=grants_map.get(prompt.id, []), db=db, ).model_dump(), - user=( - UserResponse(**UserModel.model_validate(user).model_dump()) - if user - else None - ), + user=(UserResponse(**UserModel.model_validate(user).model_dump()) if user else None), ) ) @@ -381,9 +351,7 @@ class PromptsTable: if not prompt: return None - latest_history = PromptHistories.get_latest_history_entry( - prompt.id, db=db - ) + latest_history = PromptHistories.get_latest_history_entry(prompt.id, db=db) parent_id = latest_history.id if latest_history else None current_access_grants = self._get_access_grants(prompt.id, db=db) @@ -401,9 +369,7 @@ class PromptsTable: prompt.meta = form_data.meta or prompt.meta prompt.updated_at = int(time.time()) if form_data.access_grants is not None: - AccessGrants.set_access_grants( - "prompt", prompt.id, form_data.access_grants, db=db - ) + AccessGrants.set_access_grants('prompt', prompt.id, form_data.access_grants, db=db) current_access_grants = self._get_access_grants(prompt.id, db=db) db.commit() @@ -411,14 +377,12 @@ class PromptsTable: # Create history entry only if content changed if content_changed: snapshot = { - "name": form_data.name, - "content": form_data.content, - "command": command, - "data": form_data.data or {}, - "meta": form_data.meta or {}, - "access_grants": [ - grant.model_dump() for grant in current_access_grants - ], + 'name': form_data.name, + 'content': form_data.content, + 'command': command, + 'data': form_data.data or {}, + 'meta': form_data.meta or {}, + 'access_grants': [grant.model_dump() for grant in current_access_grants], } history_entry = PromptHistories.create_history_entry( @@ -452,9 +416,7 @@ class PromptsTable: if not prompt: return None - latest_history = PromptHistories.get_latest_history_entry( - prompt.id, db=db - ) + latest_history = PromptHistories.get_latest_history_entry(prompt.id, db=db) parent_id = latest_history.id if latest_history else None current_access_grants = self._get_access_grants(prompt.id, db=db) @@ -478,9 +440,7 @@ class PromptsTable: prompt.tags = form_data.tags if form_data.access_grants is not None: - AccessGrants.set_access_grants( - "prompt", prompt.id, form_data.access_grants, db=db - ) + AccessGrants.set_access_grants('prompt', prompt.id, form_data.access_grants, db=db) current_access_grants = self._get_access_grants(prompt.id, db=db) prompt.updated_at = int(time.time()) @@ -490,15 +450,13 @@ class PromptsTable: # Create history entry only if content changed if content_changed: snapshot = { - "name": form_data.name, - "content": form_data.content, - "command": prompt.command, - "data": form_data.data or {}, - "meta": form_data.meta or {}, - "tags": prompt.tags or [], - "access_grants": [ - grant.model_dump() for grant in current_access_grants - ], + 'name': form_data.name, + 'content': form_data.content, + 'command': prompt.command, + 'data': form_data.data or {}, + 'meta': form_data.meta or {}, + 'tags': prompt.tags or [], + 'access_grants': [grant.model_dump() for grant in current_access_grants], } history_entry = PromptHistories.create_history_entry( @@ -560,9 +518,7 @@ class PromptsTable: if not prompt: return None - history_entry = PromptHistories.get_history_entry_by_id( - version_id, db=db - ) + history_entry = PromptHistories.get_history_entry_by_id(version_id, db=db) if not history_entry: return None @@ -570,11 +526,11 @@ class PromptsTable: # Restore prompt content from the snapshot snapshot = history_entry.snapshot if snapshot: - prompt.name = snapshot.get("name", prompt.name) - prompt.content = snapshot.get("content", prompt.content) - prompt.data = snapshot.get("data", prompt.data) - prompt.meta = snapshot.get("meta", prompt.meta) - prompt.tags = snapshot.get("tags", prompt.tags) + prompt.name = snapshot.get('name', prompt.name) + prompt.content = snapshot.get('content', prompt.content) + prompt.data = snapshot.get('data', prompt.data) + prompt.meta = snapshot.get('meta', prompt.meta) + prompt.tags = snapshot.get('tags', prompt.tags) # Note: command and access_grants are not restored from snapshot prompt.version_id = version_id @@ -585,9 +541,7 @@ class PromptsTable: except Exception: return None - def toggle_prompt_active( - self, prompt_id: str, db: Optional[Session] = None - ) -> Optional[PromptModel]: + def toggle_prompt_active(self, prompt_id: str, db: Optional[Session] = None) -> Optional[PromptModel]: """Toggle the is_active flag on a prompt.""" try: with get_db_context(db) as db: @@ -602,16 +556,14 @@ class PromptsTable: except Exception: return None - def delete_prompt_by_command( - self, command: str, db: Optional[Session] = None - ) -> bool: + def delete_prompt_by_command(self, command: str, db: Optional[Session] = None) -> bool: """Permanently delete a prompt and its history.""" try: with get_db_context(db) as db: prompt = db.query(Prompt).filter_by(command=command).first() if prompt: PromptHistories.delete_history_by_prompt_id(prompt.id, db=db) - AccessGrants.revoke_all_access("prompt", prompt.id, db=db) + AccessGrants.revoke_all_access('prompt', prompt.id, db=db) db.delete(prompt) db.commit() @@ -627,7 +579,7 @@ class PromptsTable: prompt = db.query(Prompt).filter_by(id=prompt_id).first() if prompt: PromptHistories.delete_history_by_prompt_id(prompt.id, db=db) - AccessGrants.revoke_all_access("prompt", prompt.id, db=db) + AccessGrants.revoke_all_access('prompt', prompt.id, db=db) db.delete(prompt) db.commit() diff --git a/backend/open_webui/models/skills.py b/backend/open_webui/models/skills.py index 6bd5affce8..cdf8ecaea4 100644 --- a/backend/open_webui/models/skills.py +++ b/backend/open_webui/models/skills.py @@ -19,7 +19,7 @@ log = logging.getLogger(__name__) class Skill(Base): - __tablename__ = "skill" + __tablename__ = 'skill' id = Column(String, primary_key=True, unique=True) user_id = Column(String) @@ -77,7 +77,7 @@ class SkillResponse(BaseModel): class SkillUserResponse(SkillResponse): user: Optional[UserResponse] = None - model_config = ConfigDict(extra="allow") + model_config = ConfigDict(extra='allow') class SkillAccessResponse(SkillUserResponse): @@ -105,10 +105,8 @@ class SkillAccessListResponse(BaseModel): class SkillsTable: - def _get_access_grants( - self, skill_id: str, db: Optional[Session] = None - ) -> list[AccessGrantModel]: - return AccessGrants.get_grants_by_resource("skill", skill_id, db=db) + def _get_access_grants(self, skill_id: str, db: Optional[Session] = None) -> list[AccessGrantModel]: + return AccessGrants.get_grants_by_resource('skill', skill_id, db=db) def _to_skill_model( self, @@ -116,13 +114,9 @@ class SkillsTable: access_grants: Optional[list[AccessGrantModel]] = None, db: Optional[Session] = None, ) -> SkillModel: - skill_data = SkillModel.model_validate(skill).model_dump( - exclude={"access_grants"} - ) - skill_data["access_grants"] = ( - access_grants - if access_grants is not None - else self._get_access_grants(skill_data["id"], db=db) + skill_data = SkillModel.model_validate(skill).model_dump(exclude={'access_grants'}) + skill_data['access_grants'] = ( + access_grants if access_grants is not None else self._get_access_grants(skill_data['id'], db=db) ) return SkillModel.model_validate(skill_data) @@ -136,29 +130,25 @@ class SkillsTable: try: result = Skill( **{ - **form_data.model_dump(exclude={"access_grants"}), - "user_id": user_id, - "updated_at": int(time.time()), - "created_at": int(time.time()), + **form_data.model_dump(exclude={'access_grants'}), + 'user_id': user_id, + 'updated_at': int(time.time()), + 'created_at': int(time.time()), } ) db.add(result) db.commit() db.refresh(result) - AccessGrants.set_access_grants( - "skill", result.id, form_data.access_grants, db=db - ) + AccessGrants.set_access_grants('skill', result.id, form_data.access_grants, db=db) if result: return self._to_skill_model(result, db=db) else: return None except Exception as e: - log.exception(f"Error creating a new skill: {e}") + log.exception(f'Error creating a new skill: {e}') return None - def get_skill_by_id( - self, id: str, db: Optional[Session] = None - ) -> Optional[SkillModel]: + def get_skill_by_id(self, id: str, db: Optional[Session] = None) -> Optional[SkillModel]: try: with get_db_context(db) as db: skill = db.get(Skill, id) @@ -166,9 +156,7 @@ class SkillsTable: except Exception: return None - def get_skill_by_name( - self, name: str, db: Optional[Session] = None - ) -> Optional[SkillModel]: + def get_skill_by_name(self, name: str, db: Optional[Session] = None) -> Optional[SkillModel]: try: with get_db_context(db) as db: skill = db.query(Skill).filter_by(name=name).first() @@ -185,7 +173,7 @@ class SkillsTable: users = Users.get_users_by_user_ids(user_ids, db=db) if user_ids else [] users_dict = {user.id: user for user in users} - grants_map = AccessGrants.get_grants_by_resources("skill", skill_ids, db=db) + grants_map = AccessGrants.get_grants_by_resources('skill', skill_ids, db=db) skills = [] for skill in all_skills: @@ -198,19 +186,17 @@ class SkillsTable: access_grants=grants_map.get(skill.id, []), db=db, ).model_dump(), - "user": user.model_dump() if user else None, + 'user': user.model_dump() if user else None, } ) ) return skills def get_skills_by_user_id( - self, user_id: str, permission: str = "write", db: Optional[Session] = None + self, user_id: str, permission: str = 'write', db: Optional[Session] = None ) -> list[SkillUserModel]: skills = self.get_skills(db=db) - user_group_ids = { - group.id for group in Groups.get_groups_by_member_id(user_id, db=db) - } + user_group_ids = {group.id for group in Groups.get_groups_by_member_id(user_id, db=db)} return [ skill @@ -218,7 +204,7 @@ class SkillsTable: if skill.user_id == user_id or AccessGrants.has_access( user_id=user_id, - resource_type="skill", + resource_type='skill', resource_id=skill.id, permission=permission, user_group_ids=user_group_ids, @@ -242,22 +228,22 @@ class SkillsTable: query = db.query(Skill, User).outerjoin(User, User.id == Skill.user_id) if filter: - query_key = filter.get("query") + query_key = filter.get('query') if query_key: query = query.filter( or_( - Skill.name.ilike(f"%{query_key}%"), - Skill.description.ilike(f"%{query_key}%"), - Skill.id.ilike(f"%{query_key}%"), - User.name.ilike(f"%{query_key}%"), - User.email.ilike(f"%{query_key}%"), + Skill.name.ilike(f'%{query_key}%'), + Skill.description.ilike(f'%{query_key}%'), + Skill.id.ilike(f'%{query_key}%'), + User.name.ilike(f'%{query_key}%'), + User.email.ilike(f'%{query_key}%'), ) ) - view_option = filter.get("view_option") - if view_option == "created": + view_option = filter.get('view_option') + if view_option == 'created': query = query.filter(Skill.user_id == user_id) - elif view_option == "shared": + elif view_option == 'shared': query = query.filter(Skill.user_id != user_id) # Apply access grant filtering @@ -266,8 +252,8 @@ class SkillsTable: query=query, DocumentModel=Skill, filter=filter, - resource_type="skill", - permission="read", + resource_type='skill', + permission='read', ) query = query.order_by(Skill.updated_at.desc()) @@ -283,9 +269,7 @@ class SkillsTable: items = query.all() skill_ids = [skill.id for skill, _ in items] - grants_map = AccessGrants.get_grants_by_resources( - "skill", skill_ids, db=db - ) + grants_map = AccessGrants.get_grants_by_resources('skill', skill_ids, db=db) skills = [] for skill, user in items: @@ -296,33 +280,23 @@ class SkillsTable: access_grants=grants_map.get(skill.id, []), db=db, ).model_dump(), - user=( - UserResponse( - **UserModel.model_validate(user).model_dump() - ) - if user - else None - ), + user=(UserResponse(**UserModel.model_validate(user).model_dump()) if user else None), ) ) return SkillListResponse(items=skills, total=total) except Exception as e: - log.exception(f"Error searching skills: {e}") + log.exception(f'Error searching skills: {e}') return SkillListResponse(items=[], total=0) - def update_skill_by_id( - self, id: str, updated: dict, db: Optional[Session] = None - ) -> Optional[SkillModel]: + def update_skill_by_id(self, id: str, updated: dict, db: Optional[Session] = None) -> Optional[SkillModel]: try: with get_db_context(db) as db: - access_grants = updated.pop("access_grants", None) - db.query(Skill).filter_by(id=id).update( - {**updated, "updated_at": int(time.time())} - ) + access_grants = updated.pop('access_grants', None) + db.query(Skill).filter_by(id=id).update({**updated, 'updated_at': int(time.time())}) db.commit() if access_grants is not None: - AccessGrants.set_access_grants("skill", id, access_grants, db=db) + AccessGrants.set_access_grants('skill', id, access_grants, db=db) skill = db.query(Skill).get(id) db.refresh(skill) @@ -330,9 +304,7 @@ class SkillsTable: except Exception: return None - def toggle_skill_by_id( - self, id: str, db: Optional[Session] = None - ) -> Optional[SkillModel]: + def toggle_skill_by_id(self, id: str, db: Optional[Session] = None) -> Optional[SkillModel]: with get_db_context(db) as db: try: skill = db.query(Skill).filter_by(id=id).first() @@ -351,7 +323,7 @@ class SkillsTable: def delete_skill_by_id(self, id: str, db: Optional[Session] = None) -> bool: try: with get_db_context(db) as db: - AccessGrants.revoke_all_access("skill", id, db=db) + AccessGrants.revoke_all_access('skill', id, db=db) db.query(Skill).filter_by(id=id).delete() db.commit() diff --git a/backend/open_webui/models/tags.py b/backend/open_webui/models/tags.py index 147bb394d5..8e401f3010 100644 --- a/backend/open_webui/models/tags.py +++ b/backend/open_webui/models/tags.py @@ -17,19 +17,19 @@ log = logging.getLogger(__name__) # Tag DB Schema #################### class Tag(Base): - __tablename__ = "tag" + __tablename__ = 'tag' id = Column(String) name = Column(String) user_id = Column(String) meta = Column(JSON, nullable=True) __table_args__ = ( - PrimaryKeyConstraint("id", "user_id", name="pk_id_user_id"), - Index("user_id_idx", "user_id"), + PrimaryKeyConstraint('id', 'user_id', name='pk_id_user_id'), + Index('user_id_idx', 'user_id'), ) # Unique constraint ensuring (id, user_id) is unique, not just the `id` column - __table_args__ = (PrimaryKeyConstraint("id", "user_id", name="pk_id_user_id"),) + __table_args__ = (PrimaryKeyConstraint('id', 'user_id', name='pk_id_user_id'),) class TagModel(BaseModel): @@ -51,12 +51,10 @@ class TagChatIdForm(BaseModel): class TagTable: - def insert_new_tag( - self, name: str, user_id: str, db: Optional[Session] = None - ) -> Optional[TagModel]: + def insert_new_tag(self, name: str, user_id: str, db: Optional[Session] = None) -> Optional[TagModel]: with get_db_context(db) as db: - id = name.replace(" ", "_").lower() - tag = TagModel(**{"id": id, "user_id": user_id, "name": name}) + id = name.replace(' ', '_').lower() + tag = TagModel(**{'id': id, 'user_id': user_id, 'name': name}) try: result = Tag(**tag.model_dump()) db.add(result) @@ -67,89 +65,63 @@ class TagTable: else: return None except Exception as e: - log.exception(f"Error inserting a new tag: {e}") + log.exception(f'Error inserting a new tag: {e}') return None - def get_tag_by_name_and_user_id( - self, name: str, user_id: str, db: Optional[Session] = None - ) -> Optional[TagModel]: + def get_tag_by_name_and_user_id(self, name: str, user_id: str, db: Optional[Session] = None) -> Optional[TagModel]: try: - id = name.replace(" ", "_").lower() + id = name.replace(' ', '_').lower() with get_db_context(db) as db: tag = db.query(Tag).filter_by(id=id, user_id=user_id).first() return TagModel.model_validate(tag) except Exception: return None - def get_tags_by_user_id( - self, user_id: str, db: Optional[Session] = None - ) -> list[TagModel]: + def get_tags_by_user_id(self, user_id: str, db: Optional[Session] = None) -> list[TagModel]: + with get_db_context(db) as db: + return [TagModel.model_validate(tag) for tag in (db.query(Tag).filter_by(user_id=user_id).all())] + + def get_tags_by_ids_and_user_id(self, ids: list[str], user_id: str, db: Optional[Session] = None) -> list[TagModel]: with get_db_context(db) as db: return [ TagModel.model_validate(tag) - for tag in (db.query(Tag).filter_by(user_id=user_id).all()) + for tag in (db.query(Tag).filter(Tag.id.in_(ids), Tag.user_id == user_id).all()) ] - def get_tags_by_ids_and_user_id( - self, ids: list[str], user_id: str, db: Optional[Session] = None - ) -> list[TagModel]: - with get_db_context(db) as db: - return [ - TagModel.model_validate(tag) - for tag in ( - db.query(Tag).filter(Tag.id.in_(ids), Tag.user_id == user_id).all() - ) - ] - - def delete_tag_by_name_and_user_id( - self, name: str, user_id: str, db: Optional[Session] = None - ) -> bool: + def delete_tag_by_name_and_user_id(self, name: str, user_id: str, db: Optional[Session] = None) -> bool: try: with get_db_context(db) as db: - id = name.replace(" ", "_").lower() + id = name.replace(' ', '_').lower() res = db.query(Tag).filter_by(id=id, user_id=user_id).delete() - log.debug(f"res: {res}") + log.debug(f'res: {res}') db.commit() return True except Exception as e: - log.error(f"delete_tag: {e}") + log.error(f'delete_tag: {e}') return False - def delete_tags_by_ids_and_user_id( - self, ids: list[str], user_id: str, db: Optional[Session] = None - ) -> bool: + def delete_tags_by_ids_and_user_id(self, ids: list[str], user_id: str, db: Optional[Session] = None) -> bool: """Delete all tags whose id is in *ids* for the given user, in one query.""" if not ids: return True try: with get_db_context(db) as db: - db.query(Tag).filter(Tag.id.in_(ids), Tag.user_id == user_id).delete( - synchronize_session=False - ) + db.query(Tag).filter(Tag.id.in_(ids), Tag.user_id == user_id).delete(synchronize_session=False) db.commit() return True except Exception as e: - log.error(f"delete_tags_by_ids: {e}") + log.error(f'delete_tags_by_ids: {e}') return False - def ensure_tags_exist( - self, names: list[str], user_id: str, db: Optional[Session] = None - ) -> None: + def ensure_tags_exist(self, names: list[str], user_id: str, db: Optional[Session] = None) -> None: """Create tag rows for any *names* that don't already exist for *user_id*.""" if not names: return - ids = [n.replace(" ", "_").lower() for n in names] + ids = [n.replace(' ', '_').lower() for n in names] with get_db_context(db) as db: - existing = { - t.id - for t in db.query(Tag.id) - .filter(Tag.id.in_(ids), Tag.user_id == user_id) - .all() - } + existing = {t.id for t in db.query(Tag.id).filter(Tag.id.in_(ids), Tag.user_id == user_id).all()} new_tags = [ - Tag(id=tag_id, name=name, user_id=user_id) - for tag_id, name in zip(ids, names) - if tag_id not in existing + Tag(id=tag_id, name=name, user_id=user_id) for tag_id, name in zip(ids, names) if tag_id not in existing ] if new_tags: db.add_all(new_tags) diff --git a/backend/open_webui/models/tools.py b/backend/open_webui/models/tools.py index f813ce21cd..02dacaa80c 100644 --- a/backend/open_webui/models/tools.py +++ b/backend/open_webui/models/tools.py @@ -19,7 +19,7 @@ log = logging.getLogger(__name__) class Tool(Base): - __tablename__ = "tool" + __tablename__ = 'tool' id = Column(String, primary_key=True, unique=True) user_id = Column(String) @@ -75,7 +75,7 @@ class ToolResponse(BaseModel): class ToolUserResponse(ToolResponse): user: Optional[UserResponse] = None - model_config = ConfigDict(extra="allow") + model_config = ConfigDict(extra='allow') class ToolAccessResponse(ToolUserResponse): @@ -95,10 +95,8 @@ class ToolValves(BaseModel): class ToolsTable: - def _get_access_grants( - self, tool_id: str, db: Optional[Session] = None - ) -> list[AccessGrantModel]: - return AccessGrants.get_grants_by_resource("tool", tool_id, db=db) + def _get_access_grants(self, tool_id: str, db: Optional[Session] = None) -> list[AccessGrantModel]: + return AccessGrants.get_grants_by_resource('tool', tool_id, db=db) def _to_tool_model( self, @@ -106,11 +104,9 @@ class ToolsTable: access_grants: Optional[list[AccessGrantModel]] = None, db: Optional[Session] = None, ) -> ToolModel: - tool_data = ToolModel.model_validate(tool).model_dump(exclude={"access_grants"}) - tool_data["access_grants"] = ( - access_grants - if access_grants is not None - else self._get_access_grants(tool_data["id"], db=db) + tool_data = ToolModel.model_validate(tool).model_dump(exclude={'access_grants'}) + tool_data['access_grants'] = ( + access_grants if access_grants is not None else self._get_access_grants(tool_data['id'], db=db) ) return ToolModel.model_validate(tool_data) @@ -125,30 +121,26 @@ class ToolsTable: try: result = Tool( **{ - **form_data.model_dump(exclude={"access_grants"}), - "specs": specs, - "user_id": user_id, - "updated_at": int(time.time()), - "created_at": int(time.time()), + **form_data.model_dump(exclude={'access_grants'}), + 'specs': specs, + 'user_id': user_id, + 'updated_at': int(time.time()), + 'created_at': int(time.time()), } ) db.add(result) db.commit() db.refresh(result) - AccessGrants.set_access_grants( - "tool", result.id, form_data.access_grants, db=db - ) + AccessGrants.set_access_grants('tool', result.id, form_data.access_grants, db=db) if result: return self._to_tool_model(result, db=db) else: return None except Exception as e: - log.exception(f"Error creating a new tool: {e}") + log.exception(f'Error creating a new tool: {e}') return None - def get_tool_by_id( - self, id: str, db: Optional[Session] = None - ) -> Optional[ToolModel]: + def get_tool_by_id(self, id: str, db: Optional[Session] = None) -> Optional[ToolModel]: try: with get_db_context(db) as db: tool = db.get(Tool, id) @@ -156,9 +148,7 @@ class ToolsTable: except Exception: return None - def get_tools( - self, defer_content: bool = False, db: Optional[Session] = None - ) -> list[ToolUserModel]: + def get_tools(self, defer_content: bool = False, db: Optional[Session] = None) -> list[ToolUserModel]: with get_db_context(db) as db: query = db.query(Tool).order_by(Tool.updated_at.desc()) if defer_content: @@ -170,7 +160,7 @@ class ToolsTable: users = Users.get_users_by_user_ids(user_ids, db=db) if user_ids else [] users_dict = {user.id: user for user in users} - grants_map = AccessGrants.get_grants_by_resources("tool", tool_ids, db=db) + grants_map = AccessGrants.get_grants_by_resources('tool', tool_ids, db=db) tools = [] for tool in all_tools: @@ -183,7 +173,7 @@ class ToolsTable: access_grants=grants_map.get(tool.id, []), db=db, ).model_dump(), - "user": user.model_dump() if user else None, + 'user': user.model_dump() if user else None, } ) ) @@ -192,14 +182,12 @@ class ToolsTable: def get_tools_by_user_id( self, user_id: str, - permission: str = "write", + permission: str = 'write', defer_content: bool = False, db: Optional[Session] = None, ) -> list[ToolUserModel]: tools = self.get_tools(defer_content=defer_content, db=db) - user_group_ids = { - group.id for group in Groups.get_groups_by_member_id(user_id, db=db) - } + user_group_ids = {group.id for group in Groups.get_groups_by_member_id(user_id, db=db)} return [ tool @@ -207,7 +195,7 @@ class ToolsTable: if tool.user_id == user_id or AccessGrants.has_access( user_id=user_id, - resource_type="tool", + resource_type='tool', resource_id=tool.id, permission=permission, user_group_ids=user_group_ids, @@ -215,48 +203,38 @@ class ToolsTable: ) ] - def get_tool_valves_by_id( - self, id: str, db: Optional[Session] = None - ) -> Optional[dict]: + def get_tool_valves_by_id(self, id: str, db: Optional[Session] = None) -> Optional[dict]: try: with get_db_context(db) as db: tool = db.get(Tool, id) return tool.valves if tool.valves else {} except Exception as e: - log.exception(f"Error getting tool valves by id {id}") + log.exception(f'Error getting tool valves by id {id}') return None - def update_tool_valves_by_id( - self, id: str, valves: dict, db: Optional[Session] = None - ) -> Optional[ToolValves]: + def update_tool_valves_by_id(self, id: str, valves: dict, db: Optional[Session] = None) -> Optional[ToolValves]: try: with get_db_context(db) as db: - db.query(Tool).filter_by(id=id).update( - {"valves": valves, "updated_at": int(time.time())} - ) + db.query(Tool).filter_by(id=id).update({'valves': valves, 'updated_at': int(time.time())}) db.commit() return self.get_tool_by_id(id, db=db) except Exception: return None - def get_user_valves_by_id_and_user_id( - self, id: str, user_id: str, db: Optional[Session] = None - ) -> Optional[dict]: + def get_user_valves_by_id_and_user_id(self, id: str, user_id: str, db: Optional[Session] = None) -> Optional[dict]: try: user = Users.get_user_by_id(user_id, db=db) user_settings = user.settings.model_dump() if user.settings else {} # Check if user has "tools" and "valves" settings - if "tools" not in user_settings: - user_settings["tools"] = {} - if "valves" not in user_settings["tools"]: - user_settings["tools"]["valves"] = {} + if 'tools' not in user_settings: + user_settings['tools'] = {} + if 'valves' not in user_settings['tools']: + user_settings['tools']['valves'] = {} - return user_settings["tools"]["valves"].get(id, {}) + return user_settings['tools']['valves'].get(id, {}) except Exception as e: - log.exception( - f"Error getting user values by id {id} and user_id {user_id}: {e}" - ) + log.exception(f'Error getting user values by id {id} and user_id {user_id}: {e}') return None def update_user_valves_by_id_and_user_id( @@ -267,35 +245,29 @@ class ToolsTable: user_settings = user.settings.model_dump() if user.settings else {} # Check if user has "tools" and "valves" settings - if "tools" not in user_settings: - user_settings["tools"] = {} - if "valves" not in user_settings["tools"]: - user_settings["tools"]["valves"] = {} + if 'tools' not in user_settings: + user_settings['tools'] = {} + if 'valves' not in user_settings['tools']: + user_settings['tools']['valves'] = {} - user_settings["tools"]["valves"][id] = valves + user_settings['tools']['valves'][id] = valves # Update the user settings in the database - Users.update_user_by_id(user_id, {"settings": user_settings}, db=db) + Users.update_user_by_id(user_id, {'settings': user_settings}, db=db) - return user_settings["tools"]["valves"][id] + return user_settings['tools']['valves'][id] except Exception as e: - log.exception( - f"Error updating user valves by id {id} and user_id {user_id}: {e}" - ) + log.exception(f'Error updating user valves by id {id} and user_id {user_id}: {e}') return None - def update_tool_by_id( - self, id: str, updated: dict, db: Optional[Session] = None - ) -> Optional[ToolModel]: + def update_tool_by_id(self, id: str, updated: dict, db: Optional[Session] = None) -> Optional[ToolModel]: try: with get_db_context(db) as db: - access_grants = updated.pop("access_grants", None) - db.query(Tool).filter_by(id=id).update( - {**updated, "updated_at": int(time.time())} - ) + access_grants = updated.pop('access_grants', None) + db.query(Tool).filter_by(id=id).update({**updated, 'updated_at': int(time.time())}) db.commit() if access_grants is not None: - AccessGrants.set_access_grants("tool", id, access_grants, db=db) + AccessGrants.set_access_grants('tool', id, access_grants, db=db) tool = db.query(Tool).get(id) db.refresh(tool) @@ -306,7 +278,7 @@ class ToolsTable: def delete_tool_by_id(self, id: str, db: Optional[Session] = None) -> bool: try: with get_db_context(db) as db: - AccessGrants.revoke_all_access("tool", id, db=db) + AccessGrants.revoke_all_access('tool', id, db=db) db.query(Tool).filter_by(id=id).delete() db.commit() diff --git a/backend/open_webui/models/users.py b/backend/open_webui/models/users.py index e5da9231df..9015646444 100644 --- a/backend/open_webui/models/users.py +++ b/backend/open_webui/models/users.py @@ -40,12 +40,12 @@ import datetime class UserSettings(BaseModel): ui: Optional[dict] = {} - model_config = ConfigDict(extra="allow") + model_config = ConfigDict(extra='allow') pass class User(Base): - __tablename__ = "user" + __tablename__ = 'user' id = Column(String, primary_key=True, unique=True) email = Column(String) @@ -83,7 +83,7 @@ class UserModel(BaseModel): email: str username: Optional[str] = None - role: str = "pending" + role: str = 'pending' name: str @@ -112,10 +112,10 @@ class UserModel(BaseModel): model_config = ConfigDict(from_attributes=True) - @model_validator(mode="after") + @model_validator(mode='after') def set_profile_image_url(self): if not self.profile_image_url: - self.profile_image_url = f"/api/v1/users/{self.id}/profile/image" + self.profile_image_url = f'/api/v1/users/{self.id}/profile/image' return self @@ -126,7 +126,7 @@ class UserStatusModel(UserModel): class ApiKey(Base): - __tablename__ = "api_key" + __tablename__ = 'api_key' id = Column(Text, primary_key=True, unique=True) user_id = Column(Text, nullable=False) @@ -163,7 +163,7 @@ class UpdateProfileForm(BaseModel): gender: Optional[str] = None date_of_birth: Optional[datetime.date] = None - @field_validator("profile_image_url") + @field_validator('profile_image_url') @classmethod def check_profile_image_url(cls, v: str) -> str: return validate_profile_image_url(v) @@ -174,7 +174,7 @@ class UserGroupIdsModel(UserModel): class UserModelResponse(UserModel): - model_config = ConfigDict(extra="allow") + model_config = ConfigDict(extra='allow') class UserListResponse(BaseModel): @@ -251,7 +251,7 @@ class UserUpdateForm(BaseModel): profile_image_url: str password: Optional[str] = None - @field_validator("profile_image_url") + @field_validator('profile_image_url') @classmethod def check_profile_image_url(cls, v: str) -> str: return validate_profile_image_url(v) @@ -263,8 +263,8 @@ class UsersTable: id: str, name: str, email: str, - profile_image_url: str = "/user.png", - role: str = "pending", + profile_image_url: str = '/user.png', + role: str = 'pending', username: Optional[str] = None, oauth: Optional[dict] = None, db: Optional[Session] = None, @@ -272,16 +272,16 @@ class UsersTable: with get_db_context(db) as db: user = UserModel( **{ - "id": id, - "email": email, - "name": name, - "role": role, - "profile_image_url": profile_image_url, - "last_active_at": int(time.time()), - "created_at": int(time.time()), - "updated_at": int(time.time()), - "username": username, - "oauth": oauth, + 'id': id, + 'email': email, + 'name': name, + 'role': role, + 'profile_image_url': profile_image_url, + 'last_active_at': int(time.time()), + 'created_at': int(time.time()), + 'updated_at': int(time.time()), + 'username': username, + 'oauth': oauth, } ) result = User(**user.model_dump()) @@ -293,9 +293,7 @@ class UsersTable: else: return None - def get_user_by_id( - self, id: str, db: Optional[Session] = None - ) -> Optional[UserModel]: + def get_user_by_id(self, id: str, db: Optional[Session] = None) -> Optional[UserModel]: try: with get_db_context(db) as db: user = db.query(User).filter_by(id=id).first() @@ -303,49 +301,32 @@ class UsersTable: except Exception: return None - def get_user_by_api_key( - self, api_key: str, db: Optional[Session] = None - ) -> Optional[UserModel]: + def get_user_by_api_key(self, api_key: str, db: Optional[Session] = None) -> Optional[UserModel]: try: with get_db_context(db) as db: - user = ( - db.query(User) - .join(ApiKey, User.id == ApiKey.user_id) - .filter(ApiKey.key == api_key) - .first() - ) + user = db.query(User).join(ApiKey, User.id == ApiKey.user_id).filter(ApiKey.key == api_key).first() return UserModel.model_validate(user) if user else None except Exception: return None - def get_user_by_email( - self, email: str, db: Optional[Session] = None - ) -> Optional[UserModel]: + def get_user_by_email(self, email: str, db: Optional[Session] = None) -> Optional[UserModel]: try: with get_db_context(db) as db: - user = ( - db.query(User) - .filter(func.lower(User.email) == email.lower()) - .first() - ) + user = db.query(User).filter(func.lower(User.email) == email.lower()).first() return UserModel.model_validate(user) if user else None except Exception: return None - def get_user_by_oauth_sub( - self, provider: str, sub: str, db: Optional[Session] = None - ) -> Optional[UserModel]: + def get_user_by_oauth_sub(self, provider: str, sub: str, db: Optional[Session] = None) -> Optional[UserModel]: try: with get_db_context(db) as db: # type: Session dialect_name = db.bind.dialect.name query = db.query(User) - if dialect_name == "sqlite": - query = query.filter(User.oauth.contains({provider: {"sub": sub}})) - elif dialect_name == "postgresql": - query = query.filter( - User.oauth[provider].cast(JSONB)["sub"].astext == sub - ) + if dialect_name == 'sqlite': + query = query.filter(User.oauth.contains({provider: {'sub': sub}})) + elif dialect_name == 'postgresql': + query = query.filter(User.oauth[provider].cast(JSONB)['sub'].astext == sub) user = query.first() return UserModel.model_validate(user) if user else None @@ -361,15 +342,10 @@ class UsersTable: dialect_name = db.bind.dialect.name query = db.query(User) - if dialect_name == "sqlite": - query = query.filter( - User.scim.contains({provider: {"external_id": external_id}}) - ) - elif dialect_name == "postgresql": - query = query.filter( - User.scim[provider].cast(JSONB)["external_id"].astext - == external_id - ) + if dialect_name == 'sqlite': + query = query.filter(User.scim.contains({provider: {'external_id': external_id}})) + elif dialect_name == 'postgresql': + query = query.filter(User.scim[provider].cast(JSONB)['external_id'].astext == external_id) user = query.first() return UserModel.model_validate(user) if user else None @@ -388,16 +364,16 @@ class UsersTable: query = db.query(User).options(defer(User.profile_image_url)) if filter: - query_key = filter.get("query") + query_key = filter.get('query') if query_key: query = query.filter( or_( - User.name.ilike(f"%{query_key}%"), - User.email.ilike(f"%{query_key}%"), + User.name.ilike(f'%{query_key}%'), + User.email.ilike(f'%{query_key}%'), ) ) - channel_id = filter.get("channel_id") + channel_id = filter.get('channel_id') if channel_id: query = query.filter( exists( @@ -408,13 +384,13 @@ class UsersTable: ) ) - user_ids = filter.get("user_ids") - group_ids = filter.get("group_ids") + user_ids = filter.get('user_ids') + group_ids = filter.get('group_ids') if isinstance(user_ids, list) and isinstance(group_ids, list): # If both are empty lists, return no users if not user_ids and not group_ids: - return {"users": [], "total": 0} + return {'users': [], 'total': 0} if user_ids: query = query.filter(User.id.in_(user_ids)) @@ -429,21 +405,21 @@ class UsersTable: ) ) - roles = filter.get("roles") + roles = filter.get('roles') if roles: - include_roles = [role for role in roles if not role.startswith("!")] - exclude_roles = [role[1:] for role in roles if role.startswith("!")] + include_roles = [role for role in roles if not role.startswith('!')] + exclude_roles = [role[1:] for role in roles if role.startswith('!')] if include_roles: query = query.filter(User.role.in_(include_roles)) if exclude_roles: query = query.filter(~User.role.in_(exclude_roles)) - order_by = filter.get("order_by") - direction = filter.get("direction") + order_by = filter.get('order_by') + direction = filter.get('direction') - if order_by and order_by.startswith("group_id:"): - group_id = order_by.split(":", 1)[1] + if order_by and order_by.startswith('group_id:'): + group_id = order_by.split(':', 1)[1] # Subquery that checks if the user belongs to the group membership_exists = exists( @@ -456,42 +432,42 @@ class UsersTable: # CASE: user in group → 1, user not in group → 0 group_sort = case((membership_exists, 1), else_=0) - if direction == "asc": + if direction == 'asc': query = query.order_by(group_sort.asc(), User.name.asc()) else: query = query.order_by(group_sort.desc(), User.name.asc()) - elif order_by == "name": - if direction == "asc": + elif order_by == 'name': + if direction == 'asc': query = query.order_by(User.name.asc()) else: query = query.order_by(User.name.desc()) - elif order_by == "email": - if direction == "asc": + elif order_by == 'email': + if direction == 'asc': query = query.order_by(User.email.asc()) else: query = query.order_by(User.email.desc()) - elif order_by == "created_at": - if direction == "asc": + elif order_by == 'created_at': + if direction == 'asc': query = query.order_by(User.created_at.asc()) else: query = query.order_by(User.created_at.desc()) - elif order_by == "last_active_at": - if direction == "asc": + elif order_by == 'last_active_at': + if direction == 'asc': query = query.order_by(User.last_active_at.asc()) else: query = query.order_by(User.last_active_at.desc()) - elif order_by == "updated_at": - if direction == "asc": + elif order_by == 'updated_at': + if direction == 'asc': query = query.order_by(User.updated_at.asc()) else: query = query.order_by(User.updated_at.desc()) - elif order_by == "role": - if direction == "asc": + elif order_by == 'role': + if direction == 'asc': query = query.order_by(User.role.asc()) else: query = query.order_by(User.role.desc()) @@ -510,13 +486,11 @@ class UsersTable: users = query.all() return { - "users": [UserModel.model_validate(user) for user in users], - "total": total, + 'users': [UserModel.model_validate(user) for user in users], + 'total': total, } - def get_users_by_group_id( - self, group_id: str, db: Optional[Session] = None - ) -> list[UserModel]: + def get_users_by_group_id(self, group_id: str, db: Optional[Session] = None) -> list[UserModel]: with get_db_context(db) as db: users = ( db.query(User) @@ -527,16 +501,9 @@ class UsersTable: ) return [UserModel.model_validate(user) for user in users] - def get_users_by_user_ids( - self, user_ids: list[str], db: Optional[Session] = None - ) -> list[UserStatusModel]: + def get_users_by_user_ids(self, user_ids: list[str], db: Optional[Session] = None) -> list[UserStatusModel]: with get_db_context(db) as db: - users = ( - db.query(User) - .options(defer(User.profile_image_url)) - .filter(User.id.in_(user_ids)) - .all() - ) + users = db.query(User).options(defer(User.profile_image_url)).filter(User.id.in_(user_ids)).all() return [UserModel.model_validate(user) for user in users] def get_num_users(self, db: Optional[Session] = None) -> Optional[int]: @@ -555,9 +522,7 @@ class UsersTable: except Exception: return None - def get_user_webhook_url_by_id( - self, id: str, db: Optional[Session] = None - ) -> Optional[str]: + def get_user_webhook_url_by_id(self, id: str, db: Optional[Session] = None) -> Optional[str]: try: with get_db_context(db) as db: user = db.query(User).filter_by(id=id).first() @@ -565,11 +530,7 @@ class UsersTable: if user.settings is None: return None else: - return ( - user.settings.get("ui", {}) - .get("notifications", {}) - .get("webhook_url", None) - ) + return user.settings.get('ui', {}).get('notifications', {}).get('webhook_url', None) except Exception: return None @@ -577,14 +538,10 @@ class UsersTable: with get_db_context(db) as db: current_timestamp = int(datetime.datetime.now().timestamp()) today_midnight_timestamp = current_timestamp - (current_timestamp % 86400) - query = db.query(User).filter( - User.last_active_at > today_midnight_timestamp - ) + query = db.query(User).filter(User.last_active_at > today_midnight_timestamp) return query.count() - def update_user_role_by_id( - self, id: str, role: str, db: Optional[Session] = None - ) -> Optional[UserModel]: + def update_user_role_by_id(self, id: str, role: str, db: Optional[Session] = None) -> Optional[UserModel]: try: with get_db_context(db) as db: user = db.query(User).filter_by(id=id).first() @@ -629,9 +586,7 @@ class UsersTable: return None @throttle(DATABASE_USER_ACTIVE_STATUS_UPDATE_INTERVAL) - def update_last_active_by_id( - self, id: str, db: Optional[Session] = None - ) -> Optional[UserModel]: + def update_last_active_by_id(self, id: str, db: Optional[Session] = None) -> Optional[UserModel]: try: with get_db_context(db) as db: user = db.query(User).filter_by(id=id).first() @@ -665,10 +620,10 @@ class UsersTable: oauth = user.oauth or {} # Update or insert provider entry - oauth[provider] = {"sub": sub} + oauth[provider] = {'sub': sub} # Persist updated JSON - db.query(User).filter_by(id=id).update({"oauth": oauth}) + db.query(User).filter_by(id=id).update({'oauth': oauth}) db.commit() return UserModel.model_validate(user) @@ -698,9 +653,9 @@ class UsersTable: return None scim = user.scim or {} - scim[provider] = {"external_id": external_id} + scim[provider] = {'external_id': external_id} - db.query(User).filter_by(id=id).update({"scim": scim}) + db.query(User).filter_by(id=id).update({'scim': scim}) db.commit() return UserModel.model_validate(user) @@ -708,9 +663,7 @@ class UsersTable: except Exception: return None - def update_user_by_id( - self, id: str, updated: dict, db: Optional[Session] = None - ) -> Optional[UserModel]: + def update_user_by_id(self, id: str, updated: dict, db: Optional[Session] = None) -> Optional[UserModel]: try: with get_db_context(db) as db: user = db.query(User).filter_by(id=id).first() @@ -725,9 +678,7 @@ class UsersTable: print(e) return None - def update_user_settings_by_id( - self, id: str, updated: dict, db: Optional[Session] = None - ) -> Optional[UserModel]: + def update_user_settings_by_id(self, id: str, updated: dict, db: Optional[Session] = None) -> Optional[UserModel]: try: with get_db_context(db) as db: user = db.query(User).filter_by(id=id).first() @@ -741,7 +692,7 @@ class UsersTable: user_settings.update(updated) - db.query(User).filter_by(id=id).update({"settings": user_settings}) + db.query(User).filter_by(id=id).update({'settings': user_settings}) db.commit() user = db.query(User).filter_by(id=id).first() @@ -768,9 +719,7 @@ class UsersTable: except Exception: return False - def get_user_api_key_by_id( - self, id: str, db: Optional[Session] = None - ) -> Optional[str]: + def get_user_api_key_by_id(self, id: str, db: Optional[Session] = None) -> Optional[str]: try: with get_db_context(db) as db: api_key = db.query(ApiKey).filter_by(user_id=id).first() @@ -778,9 +727,7 @@ class UsersTable: except Exception: return None - def update_user_api_key_by_id( - self, id: str, api_key: str, db: Optional[Session] = None - ) -> bool: + def update_user_api_key_by_id(self, id: str, api_key: str, db: Optional[Session] = None) -> bool: try: with get_db_context(db) as db: db.query(ApiKey).filter_by(user_id=id).delete() @@ -788,7 +735,7 @@ class UsersTable: now = int(time.time()) new_api_key = ApiKey( - id=f"key_{id}", + id=f'key_{id}', user_id=id, key=api_key, created_at=now, @@ -811,16 +758,14 @@ class UsersTable: except Exception: return False - def get_valid_user_ids( - self, user_ids: list[str], db: Optional[Session] = None - ) -> list[str]: + def get_valid_user_ids(self, user_ids: list[str], db: Optional[Session] = None) -> list[str]: with get_db_context(db) as db: users = db.query(User).filter(User.id.in_(user_ids)).all() return [user.id for user in users] def get_super_admin_user(self, db: Optional[Session] = None) -> Optional[UserModel]: with get_db_context(db) as db: - user = db.query(User).filter_by(role="admin").first() + user = db.query(User).filter_by(role='admin').first() if user: return UserModel.model_validate(user) else: @@ -830,9 +775,7 @@ class UsersTable: with get_db_context(db) as db: # Consider user active if last_active_at within the last 3 minutes three_minutes_ago = int(time.time()) - 180 - count = ( - db.query(User).filter(User.last_active_at >= three_minutes_ago).count() - ) + count = db.query(User).filter(User.last_active_at >= three_minutes_ago).count() return count @staticmethod diff --git a/backend/open_webui/retrieval/loaders/datalab_marker.py b/backend/open_webui/retrieval/loaders/datalab_marker.py index 8d14be0a40..dd4a763b70 100644 --- a/backend/open_webui/retrieval/loaders/datalab_marker.py +++ b/backend/open_webui/retrieval/loaders/datalab_marker.py @@ -40,78 +40,76 @@ class DatalabMarkerLoader: self.output_format = output_format def _get_mime_type(self, filename: str) -> str: - ext = filename.rsplit(".", 1)[-1].lower() + ext = filename.rsplit('.', 1)[-1].lower() mime_map = { - "pdf": "application/pdf", - "xls": "application/vnd.ms-excel", - "xlsx": "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet", - "ods": "application/vnd.oasis.opendocument.spreadsheet", - "doc": "application/msword", - "docx": "application/vnd.openxmlformats-officedocument.wordprocessingml.document", - "odt": "application/vnd.oasis.opendocument.text", - "ppt": "application/vnd.ms-powerpoint", - "pptx": "application/vnd.openxmlformats-officedocument.presentationml.presentation", - "odp": "application/vnd.oasis.opendocument.presentation", - "html": "text/html", - "epub": "application/epub+zip", - "png": "image/png", - "jpeg": "image/jpeg", - "jpg": "image/jpeg", - "webp": "image/webp", - "gif": "image/gif", - "tiff": "image/tiff", + 'pdf': 'application/pdf', + 'xls': 'application/vnd.ms-excel', + 'xlsx': 'application/vnd.openxmlformats-officedocument.spreadsheetml.sheet', + 'ods': 'application/vnd.oasis.opendocument.spreadsheet', + 'doc': 'application/msword', + 'docx': 'application/vnd.openxmlformats-officedocument.wordprocessingml.document', + 'odt': 'application/vnd.oasis.opendocument.text', + 'ppt': 'application/vnd.ms-powerpoint', + 'pptx': 'application/vnd.openxmlformats-officedocument.presentationml.presentation', + 'odp': 'application/vnd.oasis.opendocument.presentation', + 'html': 'text/html', + 'epub': 'application/epub+zip', + 'png': 'image/png', + 'jpeg': 'image/jpeg', + 'jpg': 'image/jpeg', + 'webp': 'image/webp', + 'gif': 'image/gif', + 'tiff': 'image/tiff', } - return mime_map.get(ext, "application/octet-stream") + return mime_map.get(ext, 'application/octet-stream') def check_marker_request_status(self, request_id: str) -> dict: - url = f"{self.api_base_url}/{request_id}" - headers = {"X-Api-Key": self.api_key} + url = f'{self.api_base_url}/{request_id}' + headers = {'X-Api-Key': self.api_key} try: response = requests.get(url, headers=headers) response.raise_for_status() result = response.json() - log.info(f"Marker API status check for request {request_id}: {result}") + log.info(f'Marker API status check for request {request_id}: {result}') return result except requests.HTTPError as e: - log.error(f"Error checking Marker request status: {e}") + log.error(f'Error checking Marker request status: {e}') raise HTTPException( status.HTTP_502_BAD_GATEWAY, - detail=f"Failed to check Marker request: {e}", + detail=f'Failed to check Marker request: {e}', ) except ValueError as e: - log.error(f"Invalid JSON checking Marker request: {e}") - raise HTTPException( - status.HTTP_502_BAD_GATEWAY, detail=f"Invalid JSON: {e}" - ) + log.error(f'Invalid JSON checking Marker request: {e}') + raise HTTPException(status.HTTP_502_BAD_GATEWAY, detail=f'Invalid JSON: {e}') def load(self) -> List[Document]: filename = os.path.basename(self.file_path) mime_type = self._get_mime_type(filename) - headers = {"X-Api-Key": self.api_key} + headers = {'X-Api-Key': self.api_key} form_data = { - "use_llm": str(self.use_llm).lower(), - "skip_cache": str(self.skip_cache).lower(), - "force_ocr": str(self.force_ocr).lower(), - "paginate": str(self.paginate).lower(), - "strip_existing_ocr": str(self.strip_existing_ocr).lower(), - "disable_image_extraction": str(self.disable_image_extraction).lower(), - "format_lines": str(self.format_lines).lower(), - "output_format": self.output_format, + 'use_llm': str(self.use_llm).lower(), + 'skip_cache': str(self.skip_cache).lower(), + 'force_ocr': str(self.force_ocr).lower(), + 'paginate': str(self.paginate).lower(), + 'strip_existing_ocr': str(self.strip_existing_ocr).lower(), + 'disable_image_extraction': str(self.disable_image_extraction).lower(), + 'format_lines': str(self.format_lines).lower(), + 'output_format': self.output_format, } if self.additional_config and self.additional_config.strip(): - form_data["additional_config"] = self.additional_config + form_data['additional_config'] = self.additional_config log.info( f"Datalab Marker POST request parameters: {{'filename': '{filename}', 'mime_type': '{mime_type}', **{form_data}}}" ) try: - with open(self.file_path, "rb") as f: - files = {"file": (filename, f, mime_type)} + with open(self.file_path, 'rb') as f: + files = {'file': (filename, f, mime_type)} response = requests.post( - f"{self.api_base_url}", + f'{self.api_base_url}', data=form_data, files=files, headers=headers, @@ -119,29 +117,25 @@ class DatalabMarkerLoader: response.raise_for_status() result = response.json() except FileNotFoundError: - raise HTTPException( - status.HTTP_404_NOT_FOUND, detail=f"File not found: {self.file_path}" - ) + raise HTTPException(status.HTTP_404_NOT_FOUND, detail=f'File not found: {self.file_path}') except requests.HTTPError as e: raise HTTPException( status.HTTP_400_BAD_REQUEST, - detail=f"Datalab Marker request failed: {e}", + detail=f'Datalab Marker request failed: {e}', ) except ValueError as e: - raise HTTPException( - status.HTTP_502_BAD_GATEWAY, detail=f"Invalid JSON response: {e}" - ) + raise HTTPException(status.HTTP_502_BAD_GATEWAY, detail=f'Invalid JSON response: {e}') except Exception as e: raise HTTPException(status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e)) - if not result.get("success"): + if not result.get('success'): raise HTTPException( status.HTTP_400_BAD_REQUEST, - detail=f"Datalab Marker request failed: {result.get('error', 'Unknown error')}", + detail=f'Datalab Marker request failed: {result.get("error", "Unknown error")}', ) - check_url = result.get("request_check_url") - request_id = result.get("request_id") + check_url = result.get('request_check_url') + request_id = result.get('request_id') # Check if this is a direct response (self-hosted) or polling response (DataLab) if check_url: @@ -154,54 +148,45 @@ class DatalabMarkerLoader: poll_result = poll_response.json() except (requests.HTTPError, ValueError) as e: raw_body = poll_response.text - log.error(f"Polling error: {e}, response body: {raw_body}") - raise HTTPException( - status.HTTP_502_BAD_GATEWAY, detail=f"Polling failed: {e}" - ) + log.error(f'Polling error: {e}, response body: {raw_body}') + raise HTTPException(status.HTTP_502_BAD_GATEWAY, detail=f'Polling failed: {e}') - status_val = poll_result.get("status") - success_val = poll_result.get("success") + status_val = poll_result.get('status') + success_val = poll_result.get('success') - if status_val == "complete": + if status_val == 'complete': summary = { k: poll_result.get(k) for k in ( - "status", - "output_format", - "success", - "error", - "page_count", - "total_cost", + 'status', + 'output_format', + 'success', + 'error', + 'page_count', + 'total_cost', ) } - log.info( - f"Marker processing completed successfully: {json.dumps(summary, indent=2)}" - ) + log.info(f'Marker processing completed successfully: {json.dumps(summary, indent=2)}') break - if status_val == "failed" or success_val is False: - log.error( - f"Marker poll failed full response: {json.dumps(poll_result, indent=2)}" - ) - error_msg = ( - poll_result.get("error") - or "Marker returned failure without error message" - ) + if status_val == 'failed' or success_val is False: + log.error(f'Marker poll failed full response: {json.dumps(poll_result, indent=2)}') + error_msg = poll_result.get('error') or 'Marker returned failure without error message' raise HTTPException( status.HTTP_400_BAD_REQUEST, - detail=f"Marker processing failed: {error_msg}", + detail=f'Marker processing failed: {error_msg}', ) else: raise HTTPException( status.HTTP_504_GATEWAY_TIMEOUT, - detail="Marker processing timed out", + detail='Marker processing timed out', ) - if not poll_result.get("success", False): - error_msg = poll_result.get("error") or "Unknown processing error" + if not poll_result.get('success', False): + error_msg = poll_result.get('error') or 'Unknown processing error' raise HTTPException( status.HTTP_400_BAD_REQUEST, - detail=f"Final processing failed: {error_msg}", + detail=f'Final processing failed: {error_msg}', ) # DataLab format - content in format-specific fields @@ -210,69 +195,65 @@ class DatalabMarkerLoader: final_result = poll_result else: # Self-hosted direct response - content in "output" field - if "output" in result: - log.info("Self-hosted Marker returned direct response without polling") - raw_content = result.get("output") + if 'output' in result: + log.info('Self-hosted Marker returned direct response without polling') + raw_content = result.get('output') final_result = result else: - available_fields = ( - list(result.keys()) - if isinstance(result, dict) - else "non-dict response" - ) + available_fields = list(result.keys()) if isinstance(result, dict) else 'non-dict response' raise HTTPException( status.HTTP_502_BAD_GATEWAY, detail=f"Custom Marker endpoint returned success but no 'output' field found. Available fields: {available_fields}. Expected either 'request_check_url' for polling or 'output' field for direct response.", ) - if self.output_format.lower() == "json": + if self.output_format.lower() == 'json': full_text = json.dumps(raw_content, indent=2) - elif self.output_format.lower() in {"markdown", "html"}: + elif self.output_format.lower() in {'markdown', 'html'}: full_text = str(raw_content).strip() else: raise HTTPException( status.HTTP_400_BAD_REQUEST, - detail=f"Unsupported output format: {self.output_format}", + detail=f'Unsupported output format: {self.output_format}', ) if not full_text: raise HTTPException( status.HTTP_400_BAD_REQUEST, - detail="Marker returned empty content", + detail='Marker returned empty content', ) - marker_output_dir = os.path.join("/app/backend/data/uploads", "marker_output") + marker_output_dir = os.path.join('/app/backend/data/uploads', 'marker_output') os.makedirs(marker_output_dir, exist_ok=True) - file_ext_map = {"markdown": "md", "json": "json", "html": "html"} - file_ext = file_ext_map.get(self.output_format.lower(), "txt") - output_filename = f"{os.path.splitext(filename)[0]}.{file_ext}" + file_ext_map = {'markdown': 'md', 'json': 'json', 'html': 'html'} + file_ext = file_ext_map.get(self.output_format.lower(), 'txt') + output_filename = f'{os.path.splitext(filename)[0]}.{file_ext}' output_path = os.path.join(marker_output_dir, output_filename) try: - with open(output_path, "w", encoding="utf-8") as f: + with open(output_path, 'w', encoding='utf-8') as f: f.write(full_text) - log.info(f"Saved Marker output to: {output_path}") + log.info(f'Saved Marker output to: {output_path}') except Exception as e: - log.warning(f"Failed to write marker output to disk: {e}") + log.warning(f'Failed to write marker output to disk: {e}') metadata = { - "source": filename, - "output_format": final_result.get("output_format", self.output_format), - "page_count": final_result.get("page_count", 0), - "processed_with_llm": self.use_llm, - "request_id": request_id or "", + 'source': filename, + 'output_format': final_result.get('output_format', self.output_format), + 'page_count': final_result.get('page_count', 0), + 'processed_with_llm': self.use_llm, + 'request_id': request_id or '', } - images = final_result.get("images", {}) + images = final_result.get('images', {}) if images: - metadata["image_count"] = len(images) - metadata["images"] = json.dumps(list(images.keys())) + metadata['image_count'] = len(images) + metadata['images'] = json.dumps(list(images.keys())) for k, v in metadata.items(): if isinstance(v, (dict, list)): metadata[k] = json.dumps(v) elif v is None: - metadata[k] = "" + metadata[k] = '' return [Document(page_content=full_text, metadata=metadata)] diff --git a/backend/open_webui/retrieval/loaders/external_document.py b/backend/open_webui/retrieval/loaders/external_document.py index 90fe70f879..77b1abfcd8 100644 --- a/backend/open_webui/retrieval/loaders/external_document.py +++ b/backend/open_webui/retrieval/loaders/external_document.py @@ -29,18 +29,18 @@ class ExternalDocumentLoader(BaseLoader): self.user = user def load(self) -> List[Document]: - with open(self.file_path, "rb") as f: + with open(self.file_path, 'rb') as f: data = f.read() headers = {} if self.mime_type is not None: - headers["Content-Type"] = self.mime_type + headers['Content-Type'] = self.mime_type if self.api_key is not None: - headers["Authorization"] = f"Bearer {self.api_key}" + headers['Authorization'] = f'Bearer {self.api_key}' try: - headers["X-Filename"] = quote(os.path.basename(self.file_path)) + headers['X-Filename'] = quote(os.path.basename(self.file_path)) except Exception: pass @@ -48,24 +48,23 @@ class ExternalDocumentLoader(BaseLoader): headers = include_user_info_headers(headers, self.user) url = self.url - if url.endswith("/"): + if url.endswith('/'): url = url[:-1] try: - response = requests.put(f"{url}/process", data=data, headers=headers) + response = requests.put(f'{url}/process', data=data, headers=headers) except Exception as e: - log.error(f"Error connecting to endpoint: {e}") - raise Exception(f"Error connecting to endpoint: {e}") + log.error(f'Error connecting to endpoint: {e}') + raise Exception(f'Error connecting to endpoint: {e}') if response.ok: - response_data = response.json() if response_data: if isinstance(response_data, dict): return [ Document( - page_content=response_data.get("page_content"), - metadata=response_data.get("metadata"), + page_content=response_data.get('page_content'), + metadata=response_data.get('metadata'), ) ] elif isinstance(response_data, list): @@ -73,17 +72,15 @@ class ExternalDocumentLoader(BaseLoader): for document in response_data: documents.append( Document( - page_content=document.get("page_content"), - metadata=document.get("metadata"), + page_content=document.get('page_content'), + metadata=document.get('metadata'), ) ) return documents else: - raise Exception("Error loading document: Unable to parse content") + raise Exception('Error loading document: Unable to parse content') else: - raise Exception("Error loading document: No content returned") + raise Exception('Error loading document: No content returned') else: - raise Exception( - f"Error loading document: {response.status_code} {response.text}" - ) + raise Exception(f'Error loading document: {response.status_code} {response.text}') diff --git a/backend/open_webui/retrieval/loaders/external_web.py b/backend/open_webui/retrieval/loaders/external_web.py index 39644caddb..64248427b3 100644 --- a/backend/open_webui/retrieval/loaders/external_web.py +++ b/backend/open_webui/retrieval/loaders/external_web.py @@ -30,22 +30,22 @@ class ExternalWebLoader(BaseLoader): response = requests.post( self.external_url, headers={ - "User-Agent": "Open WebUI (https://github.com/open-webui/open-webui) External Web Loader", - "Authorization": f"Bearer {self.external_api_key}", + 'User-Agent': 'Open WebUI (https://github.com/open-webui/open-webui) External Web Loader', + 'Authorization': f'Bearer {self.external_api_key}', }, json={ - "urls": urls, + 'urls': urls, }, ) response.raise_for_status() results = response.json() for result in results: yield Document( - page_content=result.get("page_content", ""), - metadata=result.get("metadata", {}), + page_content=result.get('page_content', ''), + metadata=result.get('metadata', {}), ) except Exception as e: if self.continue_on_failure: - log.error(f"Error extracting content from batch {urls}: {e}") + log.error(f'Error extracting content from batch {urls}: {e}') else: raise e diff --git a/backend/open_webui/retrieval/loaders/main.py b/backend/open_webui/retrieval/loaders/main.py index e817ae100f..57867d78f5 100644 --- a/backend/open_webui/retrieval/loaders/main.py +++ b/backend/open_webui/retrieval/loaders/main.py @@ -30,59 +30,59 @@ logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL) log = logging.getLogger(__name__) known_source_ext = [ - "go", - "py", - "java", - "sh", - "bat", - "ps1", - "cmd", - "js", - "ts", - "css", - "cpp", - "hpp", - "h", - "c", - "cs", - "sql", - "log", - "ini", - "pl", - "pm", - "r", - "dart", - "dockerfile", - "env", - "php", - "hs", - "hsc", - "lua", - "nginxconf", - "conf", - "m", - "mm", - "plsql", - "perl", - "rb", - "rs", - "db2", - "scala", - "bash", - "swift", - "vue", - "svelte", - "ex", - "exs", - "erl", - "tsx", - "jsx", - "hs", - "lhs", - "json", - "yaml", - "yml", - "toml", + 'go', + 'py', + 'java', + 'sh', + 'bat', + 'ps1', + 'cmd', + 'js', + 'ts', + 'css', + 'cpp', + 'hpp', + 'h', + 'c', + 'cs', + 'sql', + 'log', + 'ini', + 'pl', + 'pm', + 'r', + 'dart', + 'dockerfile', + 'env', + 'php', + 'hs', + 'hsc', + 'lua', + 'nginxconf', + 'conf', + 'm', + 'mm', + 'plsql', + 'perl', + 'rb', + 'rs', + 'db2', + 'scala', + 'bash', + 'swift', + 'vue', + 'svelte', + 'ex', + 'exs', + 'erl', + 'tsx', + 'jsx', + 'hs', + 'lhs', + 'json', + 'yaml', + 'yml', + 'toml', ] @@ -99,11 +99,11 @@ class ExcelLoader: xls = pd.ExcelFile(self.file_path) for sheet_name in xls.sheet_names: df = pd.read_excel(xls, sheet_name=sheet_name) - text_parts.append(f"Sheet: {sheet_name}\n{df.to_string(index=False)}") + text_parts.append(f'Sheet: {sheet_name}\n{df.to_string(index=False)}') return [ Document( - page_content="\n\n".join(text_parts), - metadata={"source": self.file_path}, + page_content='\n\n'.join(text_parts), + metadata={'source': self.file_path}, ) ] @@ -125,11 +125,11 @@ class PptxLoader: if shape.has_text_frame: slide_texts.append(shape.text_frame.text) if slide_texts: - text_parts.append(f"Slide {i}:\n" + "\n".join(slide_texts)) + text_parts.append(f'Slide {i}:\n' + '\n'.join(slide_texts)) return [ Document( - page_content="\n\n".join(text_parts), - metadata={"source": self.file_path}, + page_content='\n\n'.join(text_parts), + metadata={'source': self.file_path}, ) ] @@ -143,41 +143,41 @@ class TikaLoader: self.extract_images = extract_images def load(self) -> list[Document]: - with open(self.file_path, "rb") as f: + with open(self.file_path, 'rb') as f: data = f.read() if self.mime_type is not None: - headers = {"Content-Type": self.mime_type} + headers = {'Content-Type': self.mime_type} else: headers = {} if self.extract_images == True: - headers["X-Tika-PDFextractInlineImages"] = "true" + headers['X-Tika-PDFextractInlineImages'] = 'true' endpoint = self.url - if not endpoint.endswith("/"): - endpoint += "/" - endpoint += "tika/text" + if not endpoint.endswith('/'): + endpoint += '/' + endpoint += 'tika/text' r = requests.put(endpoint, data=data, headers=headers, verify=REQUESTS_VERIFY) if r.ok: raw_metadata = r.json() - text = raw_metadata.get("X-TIKA:content", "").strip() + text = raw_metadata.get('X-TIKA:content', '').strip() - if "Content-Type" in raw_metadata: - headers["Content-Type"] = raw_metadata["Content-Type"] + if 'Content-Type' in raw_metadata: + headers['Content-Type'] = raw_metadata['Content-Type'] - log.debug("Tika extracted text: %s", text) + log.debug('Tika extracted text: %s', text) return [Document(page_content=text, metadata=headers)] else: - raise Exception(f"Error calling Tika: {r.reason}") + raise Exception(f'Error calling Tika: {r.reason}') class DoclingLoader: def __init__(self, url, api_key=None, file_path=None, mime_type=None, params=None): - self.url = url.rstrip("/") + self.url = url.rstrip('/') self.api_key = api_key self.file_path = file_path self.mime_type = mime_type @@ -185,199 +185,183 @@ class DoclingLoader: self.params = params or {} def load(self) -> list[Document]: - with open(self.file_path, "rb") as f: + with open(self.file_path, 'rb') as f: headers = {} if self.api_key: - headers["X-Api-Key"] = f"{self.api_key}" + headers['X-Api-Key'] = f'{self.api_key}' r = requests.post( - f"{self.url}/v1/convert/file", + f'{self.url}/v1/convert/file', files={ - "files": ( + 'files': ( self.file_path, f, - self.mime_type or "application/octet-stream", + self.mime_type or 'application/octet-stream', ) }, data={ - "image_export_mode": "placeholder", + 'image_export_mode': 'placeholder', **self.params, }, headers=headers, ) if r.ok: result = r.json() - document_data = result.get("document", {}) - text = document_data.get("md_content", "") + document_data = result.get('document', {}) + text = document_data.get('md_content', '') - metadata = {"Content-Type": self.mime_type} if self.mime_type else {} + metadata = {'Content-Type': self.mime_type} if self.mime_type else {} - log.debug("Docling extracted text: %s", text) + log.debug('Docling extracted text: %s', text) return [Document(page_content=text, metadata=metadata)] else: - error_msg = f"Error calling Docling API: {r.reason}" + error_msg = f'Error calling Docling API: {r.reason}' if r.text: try: error_data = r.json() - if "detail" in error_data: - error_msg += f" - {error_data['detail']}" + if 'detail' in error_data: + error_msg += f' - {error_data["detail"]}' except Exception: - error_msg += f" - {r.text}" - raise Exception(f"Error calling Docling: {error_msg}") + error_msg += f' - {r.text}' + raise Exception(f'Error calling Docling: {error_msg}') class Loader: - def __init__(self, engine: str = "", **kwargs): + def __init__(self, engine: str = '', **kwargs): self.engine = engine - self.user = kwargs.get("user", None) + self.user = kwargs.get('user', None) self.kwargs = kwargs - def load( - self, filename: str, file_content_type: str, file_path: str - ) -> list[Document]: + def load(self, filename: str, file_content_type: str, file_path: str) -> list[Document]: loader = self._get_loader(filename, file_content_type, file_path) docs = loader.load() - return [ - Document( - page_content=ftfy.fix_text(doc.page_content), metadata=doc.metadata - ) - for doc in docs - ] + return [Document(page_content=ftfy.fix_text(doc.page_content), metadata=doc.metadata) for doc in docs] def _is_text_file(self, file_ext: str, file_content_type: str) -> bool: return file_ext in known_source_ext or ( file_content_type - and file_content_type.find("text/") >= 0 + and file_content_type.find('text/') >= 0 # Avoid text/html files being detected as text - and not file_content_type.find("html") >= 0 + and not file_content_type.find('html') >= 0 ) def _get_loader(self, filename: str, file_content_type: str, file_path: str): - file_ext = filename.split(".")[-1].lower() + file_ext = filename.split('.')[-1].lower() if ( - self.engine == "external" - and self.kwargs.get("EXTERNAL_DOCUMENT_LOADER_URL") - and self.kwargs.get("EXTERNAL_DOCUMENT_LOADER_API_KEY") + self.engine == 'external' + and self.kwargs.get('EXTERNAL_DOCUMENT_LOADER_URL') + and self.kwargs.get('EXTERNAL_DOCUMENT_LOADER_API_KEY') ): loader = ExternalDocumentLoader( file_path=file_path, - url=self.kwargs.get("EXTERNAL_DOCUMENT_LOADER_URL"), - api_key=self.kwargs.get("EXTERNAL_DOCUMENT_LOADER_API_KEY"), + url=self.kwargs.get('EXTERNAL_DOCUMENT_LOADER_URL'), + api_key=self.kwargs.get('EXTERNAL_DOCUMENT_LOADER_API_KEY'), mime_type=file_content_type, user=self.user, ) - elif self.engine == "tika" and self.kwargs.get("TIKA_SERVER_URL"): + elif self.engine == 'tika' and self.kwargs.get('TIKA_SERVER_URL'): if self._is_text_file(file_ext, file_content_type): loader = TextLoader(file_path, autodetect_encoding=True) else: loader = TikaLoader( - url=self.kwargs.get("TIKA_SERVER_URL"), + url=self.kwargs.get('TIKA_SERVER_URL'), file_path=file_path, - extract_images=self.kwargs.get("PDF_EXTRACT_IMAGES"), + extract_images=self.kwargs.get('PDF_EXTRACT_IMAGES'), ) elif ( - self.engine == "datalab_marker" - and self.kwargs.get("DATALAB_MARKER_API_KEY") + self.engine == 'datalab_marker' + and self.kwargs.get('DATALAB_MARKER_API_KEY') and file_ext in [ - "pdf", - "xls", - "xlsx", - "ods", - "doc", - "docx", - "odt", - "ppt", - "pptx", - "odp", - "html", - "epub", - "png", - "jpeg", - "jpg", - "webp", - "gif", - "tiff", + 'pdf', + 'xls', + 'xlsx', + 'ods', + 'doc', + 'docx', + 'odt', + 'ppt', + 'pptx', + 'odp', + 'html', + 'epub', + 'png', + 'jpeg', + 'jpg', + 'webp', + 'gif', + 'tiff', ] ): - api_base_url = self.kwargs.get("DATALAB_MARKER_API_BASE_URL", "") - if not api_base_url or api_base_url.strip() == "": - api_base_url = "https://www.datalab.to/api/v1/marker" # https://github.com/open-webui/open-webui/pull/16867#issuecomment-3218424349 + api_base_url = self.kwargs.get('DATALAB_MARKER_API_BASE_URL', '') + if not api_base_url or api_base_url.strip() == '': + api_base_url = 'https://www.datalab.to/api/v1/marker' # https://github.com/open-webui/open-webui/pull/16867#issuecomment-3218424349 loader = DatalabMarkerLoader( file_path=file_path, - api_key=self.kwargs["DATALAB_MARKER_API_KEY"], + api_key=self.kwargs['DATALAB_MARKER_API_KEY'], api_base_url=api_base_url, - additional_config=self.kwargs.get("DATALAB_MARKER_ADDITIONAL_CONFIG"), - use_llm=self.kwargs.get("DATALAB_MARKER_USE_LLM", False), - skip_cache=self.kwargs.get("DATALAB_MARKER_SKIP_CACHE", False), - force_ocr=self.kwargs.get("DATALAB_MARKER_FORCE_OCR", False), - paginate=self.kwargs.get("DATALAB_MARKER_PAGINATE", False), - strip_existing_ocr=self.kwargs.get( - "DATALAB_MARKER_STRIP_EXISTING_OCR", False - ), - disable_image_extraction=self.kwargs.get( - "DATALAB_MARKER_DISABLE_IMAGE_EXTRACTION", False - ), - format_lines=self.kwargs.get("DATALAB_MARKER_FORMAT_LINES", False), - output_format=self.kwargs.get( - "DATALAB_MARKER_OUTPUT_FORMAT", "markdown" - ), + additional_config=self.kwargs.get('DATALAB_MARKER_ADDITIONAL_CONFIG'), + use_llm=self.kwargs.get('DATALAB_MARKER_USE_LLM', False), + skip_cache=self.kwargs.get('DATALAB_MARKER_SKIP_CACHE', False), + force_ocr=self.kwargs.get('DATALAB_MARKER_FORCE_OCR', False), + paginate=self.kwargs.get('DATALAB_MARKER_PAGINATE', False), + strip_existing_ocr=self.kwargs.get('DATALAB_MARKER_STRIP_EXISTING_OCR', False), + disable_image_extraction=self.kwargs.get('DATALAB_MARKER_DISABLE_IMAGE_EXTRACTION', False), + format_lines=self.kwargs.get('DATALAB_MARKER_FORMAT_LINES', False), + output_format=self.kwargs.get('DATALAB_MARKER_OUTPUT_FORMAT', 'markdown'), ) - elif self.engine == "docling" and self.kwargs.get("DOCLING_SERVER_URL"): + elif self.engine == 'docling' and self.kwargs.get('DOCLING_SERVER_URL'): if self._is_text_file(file_ext, file_content_type): loader = TextLoader(file_path, autodetect_encoding=True) else: # Build params for DoclingLoader - params = self.kwargs.get("DOCLING_PARAMS", {}) + params = self.kwargs.get('DOCLING_PARAMS', {}) if not isinstance(params, dict): try: params = json.loads(params) except json.JSONDecodeError: - log.error("Invalid DOCLING_PARAMS format, expected JSON object") + log.error('Invalid DOCLING_PARAMS format, expected JSON object') params = {} loader = DoclingLoader( - url=self.kwargs.get("DOCLING_SERVER_URL"), - api_key=self.kwargs.get("DOCLING_API_KEY", None), + url=self.kwargs.get('DOCLING_SERVER_URL'), + api_key=self.kwargs.get('DOCLING_API_KEY', None), file_path=file_path, mime_type=file_content_type, params=params, ) elif ( - self.engine == "document_intelligence" - and self.kwargs.get("DOCUMENT_INTELLIGENCE_ENDPOINT") != "" + self.engine == 'document_intelligence' + and self.kwargs.get('DOCUMENT_INTELLIGENCE_ENDPOINT') != '' and ( - file_ext in ["pdf", "docx", "ppt", "pptx"] + file_ext in ['pdf', 'docx', 'ppt', 'pptx'] or file_content_type in [ - "application/vnd.openxmlformats-officedocument.wordprocessingml.document", - "application/vnd.ms-powerpoint", - "application/vnd.openxmlformats-officedocument.presentationml.presentation", + 'application/vnd.openxmlformats-officedocument.wordprocessingml.document', + 'application/vnd.ms-powerpoint', + 'application/vnd.openxmlformats-officedocument.presentationml.presentation', ] ) ): - if self.kwargs.get("DOCUMENT_INTELLIGENCE_KEY") != "": + if self.kwargs.get('DOCUMENT_INTELLIGENCE_KEY') != '': loader = AzureAIDocumentIntelligenceLoader( file_path=file_path, - api_endpoint=self.kwargs.get("DOCUMENT_INTELLIGENCE_ENDPOINT"), - api_key=self.kwargs.get("DOCUMENT_INTELLIGENCE_KEY"), - api_model=self.kwargs.get("DOCUMENT_INTELLIGENCE_MODEL"), + api_endpoint=self.kwargs.get('DOCUMENT_INTELLIGENCE_ENDPOINT'), + api_key=self.kwargs.get('DOCUMENT_INTELLIGENCE_KEY'), + api_model=self.kwargs.get('DOCUMENT_INTELLIGENCE_MODEL'), ) else: loader = AzureAIDocumentIntelligenceLoader( file_path=file_path, - api_endpoint=self.kwargs.get("DOCUMENT_INTELLIGENCE_ENDPOINT"), + api_endpoint=self.kwargs.get('DOCUMENT_INTELLIGENCE_ENDPOINT'), azure_credential=DefaultAzureCredential(), - api_model=self.kwargs.get("DOCUMENT_INTELLIGENCE_MODEL"), + api_model=self.kwargs.get('DOCUMENT_INTELLIGENCE_MODEL'), ) - elif self.engine == "mineru" and file_ext in [ - "pdf" - ]: # MinerU currently only supports PDF - - mineru_timeout = self.kwargs.get("MINERU_API_TIMEOUT", 300) + elif self.engine == 'mineru' and file_ext in ['pdf']: # MinerU currently only supports PDF + mineru_timeout = self.kwargs.get('MINERU_API_TIMEOUT', 300) if mineru_timeout: try: mineru_timeout = int(mineru_timeout) @@ -386,111 +370,115 @@ class Loader: loader = MinerULoader( file_path=file_path, - api_mode=self.kwargs.get("MINERU_API_MODE", "local"), - api_url=self.kwargs.get("MINERU_API_URL", "http://localhost:8000"), - api_key=self.kwargs.get("MINERU_API_KEY", ""), - params=self.kwargs.get("MINERU_PARAMS", {}), + api_mode=self.kwargs.get('MINERU_API_MODE', 'local'), + api_url=self.kwargs.get('MINERU_API_URL', 'http://localhost:8000'), + api_key=self.kwargs.get('MINERU_API_KEY', ''), + params=self.kwargs.get('MINERU_PARAMS', {}), timeout=mineru_timeout, ) elif ( - self.engine == "mistral_ocr" - and self.kwargs.get("MISTRAL_OCR_API_KEY") != "" - and file_ext - in ["pdf"] # Mistral OCR currently only supports PDF and images + self.engine == 'mistral_ocr' + and self.kwargs.get('MISTRAL_OCR_API_KEY') != '' + and file_ext in ['pdf'] # Mistral OCR currently only supports PDF and images ): loader = MistralLoader( - base_url=self.kwargs.get("MISTRAL_OCR_API_BASE_URL"), - api_key=self.kwargs.get("MISTRAL_OCR_API_KEY"), + base_url=self.kwargs.get('MISTRAL_OCR_API_BASE_URL'), + api_key=self.kwargs.get('MISTRAL_OCR_API_KEY'), file_path=file_path, ) else: - if file_ext == "pdf": + if file_ext == 'pdf': loader = PyPDFLoader( file_path, - extract_images=self.kwargs.get("PDF_EXTRACT_IMAGES"), - mode=self.kwargs.get("PDF_LOADER_MODE", "page"), + extract_images=self.kwargs.get('PDF_EXTRACT_IMAGES'), + mode=self.kwargs.get('PDF_LOADER_MODE', 'page'), ) - elif file_ext == "csv": + elif file_ext == 'csv': loader = CSVLoader(file_path, autodetect_encoding=True) - elif file_ext == "rst": + elif file_ext == 'rst': try: from langchain_community.document_loaders import UnstructuredRSTLoader - loader = UnstructuredRSTLoader(file_path, mode="elements") + + loader = UnstructuredRSTLoader(file_path, mode='elements') except ImportError: log.warning( "The 'unstructured' package is not installed. " - "Falling back to plain text loading for .rst file. " - "Install it with: pip install unstructured" + 'Falling back to plain text loading for .rst file. ' + 'Install it with: pip install unstructured' ) loader = TextLoader(file_path, autodetect_encoding=True) - elif file_ext == "xml": + elif file_ext == 'xml': try: from langchain_community.document_loaders import UnstructuredXMLLoader + loader = UnstructuredXMLLoader(file_path) except ImportError: log.warning( "The 'unstructured' package is not installed. " - "Falling back to plain text loading for .xml file. " - "Install it with: pip install unstructured" + 'Falling back to plain text loading for .xml file. ' + 'Install it with: pip install unstructured' ) loader = TextLoader(file_path, autodetect_encoding=True) - elif file_ext in ["htm", "html"]: - loader = BSHTMLLoader(file_path, open_encoding="unicode_escape") - elif file_ext == "md": + elif file_ext in ['htm', 'html']: + loader = BSHTMLLoader(file_path, open_encoding='unicode_escape') + elif file_ext == 'md': loader = TextLoader(file_path, autodetect_encoding=True) - elif file_content_type == "application/epub+zip": + elif file_content_type == 'application/epub+zip': try: from langchain_community.document_loaders import UnstructuredEPubLoader + loader = UnstructuredEPubLoader(file_path) except ImportError: raise ValueError( "Processing .epub files requires the 'unstructured' package. " - "Install it with: pip install unstructured" + 'Install it with: pip install unstructured' ) elif ( - file_content_type - == "application/vnd.openxmlformats-officedocument.wordprocessingml.document" - or file_ext == "docx" + file_content_type == 'application/vnd.openxmlformats-officedocument.wordprocessingml.document' + or file_ext == 'docx' ): loader = Docx2txtLoader(file_path) elif file_content_type in [ - "application/vnd.ms-excel", - "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet", - ] or file_ext in ["xls", "xlsx"]: + 'application/vnd.ms-excel', + 'application/vnd.openxmlformats-officedocument.spreadsheetml.sheet', + ] or file_ext in ['xls', 'xlsx']: try: from langchain_community.document_loaders import UnstructuredExcelLoader + loader = UnstructuredExcelLoader(file_path) except ImportError: log.warning( "The 'unstructured' package is not installed. " - "Falling back to pandas for Excel file loading. " - "Install unstructured for better results: pip install unstructured" + 'Falling back to pandas for Excel file loading. ' + 'Install unstructured for better results: pip install unstructured' ) loader = ExcelLoader(file_path) elif file_content_type in [ - "application/vnd.ms-powerpoint", - "application/vnd.openxmlformats-officedocument.presentationml.presentation", - ] or file_ext in ["ppt", "pptx"]: + 'application/vnd.ms-powerpoint', + 'application/vnd.openxmlformats-officedocument.presentationml.presentation', + ] or file_ext in ['ppt', 'pptx']: try: from langchain_community.document_loaders import UnstructuredPowerPointLoader + loader = UnstructuredPowerPointLoader(file_path) except ImportError: log.warning( "The 'unstructured' package is not installed. " - "Falling back to python-pptx for PowerPoint file loading. " - "Install unstructured for better results: pip install unstructured" + 'Falling back to python-pptx for PowerPoint file loading. ' + 'Install unstructured for better results: pip install unstructured' ) loader = PptxLoader(file_path) - elif file_ext == "msg": + elif file_ext == 'msg': loader = OutlookMessageLoader(file_path) - elif file_ext == "odt": + elif file_ext == 'odt': try: from langchain_community.document_loaders import UnstructuredODTLoader + loader = UnstructuredODTLoader(file_path) except ImportError: raise ValueError( "Processing .odt files requires the 'unstructured' package. " - "Install it with: pip install unstructured" + 'Install it with: pip install unstructured' ) elif self._is_text_file(file_ext, file_content_type): loader = TextLoader(file_path, autodetect_encoding=True) @@ -498,4 +486,3 @@ class Loader: loader = TextLoader(file_path, autodetect_encoding=True) return loader - diff --git a/backend/open_webui/retrieval/loaders/mineru.py b/backend/open_webui/retrieval/loaders/mineru.py index 108abd8b9a..1f0848a613 100644 --- a/backend/open_webui/retrieval/loaders/mineru.py +++ b/backend/open_webui/retrieval/loaders/mineru.py @@ -22,37 +22,35 @@ class MinerULoader: def __init__( self, file_path: str, - api_mode: str = "local", - api_url: str = "http://localhost:8000", - api_key: str = "", + api_mode: str = 'local', + api_url: str = 'http://localhost:8000', + api_key: str = '', params: dict = None, timeout: Optional[int] = 300, ): self.file_path = file_path self.api_mode = api_mode.lower() - self.api_url = api_url.rstrip("/") + self.api_url = api_url.rstrip('/') self.api_key = api_key self.timeout = timeout # Parse params dict with defaults self.params = params or {} - self.enable_ocr = params.get("enable_ocr", False) - self.enable_formula = params.get("enable_formula", True) - self.enable_table = params.get("enable_table", True) - self.language = params.get("language", "en") - self.model_version = params.get("model_version", "pipeline") + self.enable_ocr = params.get('enable_ocr', False) + self.enable_formula = params.get('enable_formula', True) + self.enable_table = params.get('enable_table', True) + self.language = params.get('language', 'en') + self.model_version = params.get('model_version', 'pipeline') - self.page_ranges = self.params.pop("page_ranges", "") + self.page_ranges = self.params.pop('page_ranges', '') # Validate API mode - if self.api_mode not in ["local", "cloud"]: - raise ValueError( - f"Invalid API mode: {self.api_mode}. Must be 'local' or 'cloud'" - ) + if self.api_mode not in ['local', 'cloud']: + raise ValueError(f"Invalid API mode: {self.api_mode}. Must be 'local' or 'cloud'") # Validate Cloud API requirements - if self.api_mode == "cloud" and not self.api_key: - raise ValueError("API key is required for Cloud API mode") + if self.api_mode == 'cloud' and not self.api_key: + raise ValueError('API key is required for Cloud API mode') def load(self) -> List[Document]: """ @@ -60,12 +58,12 @@ class MinerULoader: Routes to Cloud or Local API based on api_mode. """ try: - if self.api_mode == "cloud": + if self.api_mode == 'cloud': return self._load_cloud_api() else: return self._load_local_api() except Exception as e: - log.error(f"Error loading document with MinerU: {e}") + log.error(f'Error loading document with MinerU: {e}') raise def _load_local_api(self) -> List[Document]: @@ -73,14 +71,14 @@ class MinerULoader: Load document using Local API (synchronous). Posts file to /file_parse endpoint and gets immediate response. """ - log.info(f"Using MinerU Local API at {self.api_url}") + log.info(f'Using MinerU Local API at {self.api_url}') filename = os.path.basename(self.file_path) # Build form data for Local API form_data = { **self.params, - "return_md": "true", + 'return_md': 'true', } # Page ranges (Local API uses start_page_id and end_page_id) @@ -89,18 +87,18 @@ class MinerULoader: # Full page range parsing would require parsing the string log.warning( f"Page ranges '{self.page_ranges}' specified but Local API uses different format. " - "Consider using start_page_id/end_page_id parameters if needed." + 'Consider using start_page_id/end_page_id parameters if needed.' ) try: - with open(self.file_path, "rb") as f: - files = {"files": (filename, f, "application/octet-stream")} + with open(self.file_path, 'rb') as f: + files = {'files': (filename, f, 'application/octet-stream')} - log.info(f"Sending file to MinerU Local API: {filename}") - log.debug(f"Local API parameters: {form_data}") + log.info(f'Sending file to MinerU Local API: {filename}') + log.debug(f'Local API parameters: {form_data}') response = requests.post( - f"{self.api_url}/file_parse", + f'{self.api_url}/file_parse', data=form_data, files=files, timeout=self.timeout, @@ -108,27 +106,25 @@ class MinerULoader: response.raise_for_status() except FileNotFoundError: - raise HTTPException( - status.HTTP_404_NOT_FOUND, detail=f"File not found: {self.file_path}" - ) + raise HTTPException(status.HTTP_404_NOT_FOUND, detail=f'File not found: {self.file_path}') except requests.Timeout: raise HTTPException( status.HTTP_504_GATEWAY_TIMEOUT, - detail="MinerU Local API request timed out", + detail='MinerU Local API request timed out', ) except requests.HTTPError as e: - error_detail = f"MinerU Local API request failed: {e}" + error_detail = f'MinerU Local API request failed: {e}' if e.response is not None: try: error_data = e.response.json() - error_detail += f" - {error_data}" + error_detail += f' - {error_data}' except Exception: - error_detail += f" - {e.response.text}" + error_detail += f' - {e.response.text}' raise HTTPException(status.HTTP_400_BAD_REQUEST, detail=error_detail) except Exception as e: raise HTTPException( status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=f"Error calling MinerU Local API: {str(e)}", + detail=f'Error calling MinerU Local API: {str(e)}', ) # Parse response @@ -137,41 +133,41 @@ class MinerULoader: except ValueError as e: raise HTTPException( status.HTTP_502_BAD_GATEWAY, - detail=f"Invalid JSON response from MinerU Local API: {e}", + detail=f'Invalid JSON response from MinerU Local API: {e}', ) # Extract markdown content from response - if "results" not in result: + if 'results' not in result: raise HTTPException( status.HTTP_502_BAD_GATEWAY, detail="MinerU Local API response missing 'results' field", ) - results = result["results"] + results = result['results'] if not results: raise HTTPException( status.HTTP_400_BAD_REQUEST, - detail="MinerU returned empty results", + detail='MinerU returned empty results', ) # Get the first (and typically only) result file_result = list(results.values())[0] - markdown_content = file_result.get("md_content", "") + markdown_content = file_result.get('md_content', '') if not markdown_content: raise HTTPException( status.HTTP_400_BAD_REQUEST, - detail="MinerU returned empty markdown content", + detail='MinerU returned empty markdown content', ) - log.info(f"Successfully parsed document with MinerU Local API: {filename}") + log.info(f'Successfully parsed document with MinerU Local API: {filename}') # Create metadata metadata = { - "source": filename, - "api_mode": "local", - "backend": result.get("backend", "unknown"), - "version": result.get("version", "unknown"), + 'source': filename, + 'api_mode': 'local', + 'backend': result.get('backend', 'unknown'), + 'version': result.get('version', 'unknown'), } return [Document(page_content=markdown_content, metadata=metadata)] @@ -181,7 +177,7 @@ class MinerULoader: Load document using Cloud API (asynchronous). Uses batch upload endpoint to avoid need for public file URLs. """ - log.info(f"Using MinerU Cloud API at {self.api_url}") + log.info(f'Using MinerU Cloud API at {self.api_url}') filename = os.path.basename(self.file_path) @@ -195,17 +191,15 @@ class MinerULoader: result = self._poll_batch_status(batch_id, filename) # Step 4: Download and extract markdown from ZIP - markdown_content = self._download_and_extract_zip( - result["full_zip_url"], filename - ) + markdown_content = self._download_and_extract_zip(result['full_zip_url'], filename) - log.info(f"Successfully parsed document with MinerU Cloud API: {filename}") + log.info(f'Successfully parsed document with MinerU Cloud API: {filename}') # Create metadata metadata = { - "source": filename, - "api_mode": "cloud", - "batch_id": batch_id, + 'source': filename, + 'api_mode': 'cloud', + 'batch_id': batch_id, } return [Document(page_content=markdown_content, metadata=metadata)] @@ -216,49 +210,49 @@ class MinerULoader: Returns (batch_id, upload_url). """ headers = { - "Authorization": f"Bearer {self.api_key}", - "Content-Type": "application/json", + 'Authorization': f'Bearer {self.api_key}', + 'Content-Type': 'application/json', } # Build request body request_body = { **self.params, - "files": [ + 'files': [ { - "name": filename, - "is_ocr": self.enable_ocr, + 'name': filename, + 'is_ocr': self.enable_ocr, } ], } # Add page ranges if specified if self.page_ranges: - request_body["files"][0]["page_ranges"] = self.page_ranges + request_body['files'][0]['page_ranges'] = self.page_ranges - log.info(f"Requesting upload URL for: {filename}") - log.debug(f"Cloud API request body: {request_body}") + log.info(f'Requesting upload URL for: {filename}') + log.debug(f'Cloud API request body: {request_body}') try: response = requests.post( - f"{self.api_url}/file-urls/batch", + f'{self.api_url}/file-urls/batch', headers=headers, json=request_body, timeout=30, ) response.raise_for_status() except requests.HTTPError as e: - error_detail = f"Failed to request upload URL: {e}" + error_detail = f'Failed to request upload URL: {e}' if e.response is not None: try: error_data = e.response.json() - error_detail += f" - {error_data.get('msg', error_data)}" + error_detail += f' - {error_data.get("msg", error_data)}' except Exception: - error_detail += f" - {e.response.text}" + error_detail += f' - {e.response.text}' raise HTTPException(status.HTTP_400_BAD_REQUEST, detail=error_detail) except Exception as e: raise HTTPException( status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=f"Error requesting upload URL: {str(e)}", + detail=f'Error requesting upload URL: {str(e)}', ) try: @@ -266,28 +260,28 @@ class MinerULoader: except ValueError as e: raise HTTPException( status.HTTP_502_BAD_GATEWAY, - detail=f"Invalid JSON response: {e}", + detail=f'Invalid JSON response: {e}', ) # Check for API error response - if result.get("code") != 0: + if result.get('code') != 0: raise HTTPException( status.HTTP_400_BAD_REQUEST, - detail=f"MinerU Cloud API error: {result.get('msg', 'Unknown error')}", + detail=f'MinerU Cloud API error: {result.get("msg", "Unknown error")}', ) - data = result.get("data", {}) - batch_id = data.get("batch_id") - file_urls = data.get("file_urls", []) + data = result.get('data', {}) + batch_id = data.get('batch_id') + file_urls = data.get('file_urls', []) if not batch_id or not file_urls: raise HTTPException( status.HTTP_502_BAD_GATEWAY, - detail="MinerU Cloud API response missing batch_id or file_urls", + detail='MinerU Cloud API response missing batch_id or file_urls', ) upload_url = file_urls[0] - log.info(f"Received upload URL for batch: {batch_id}") + log.info(f'Received upload URL for batch: {batch_id}') return batch_id, upload_url @@ -295,10 +289,10 @@ class MinerULoader: """ Upload file to presigned URL (no authentication needed). """ - log.info(f"Uploading file to presigned URL") + log.info(f'Uploading file to presigned URL') try: - with open(self.file_path, "rb") as f: + with open(self.file_path, 'rb') as f: response = requests.put( upload_url, data=f, @@ -306,26 +300,24 @@ class MinerULoader: ) response.raise_for_status() except FileNotFoundError: - raise HTTPException( - status.HTTP_404_NOT_FOUND, detail=f"File not found: {self.file_path}" - ) + raise HTTPException(status.HTTP_404_NOT_FOUND, detail=f'File not found: {self.file_path}') except requests.Timeout: raise HTTPException( status.HTTP_504_GATEWAY_TIMEOUT, - detail="File upload to presigned URL timed out", + detail='File upload to presigned URL timed out', ) except requests.HTTPError as e: raise HTTPException( status.HTTP_400_BAD_REQUEST, - detail=f"Failed to upload file to presigned URL: {e}", + detail=f'Failed to upload file to presigned URL: {e}', ) except Exception as e: raise HTTPException( status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=f"Error uploading file: {str(e)}", + detail=f'Error uploading file: {str(e)}', ) - log.info("File uploaded successfully") + log.info('File uploaded successfully') def _poll_batch_status(self, batch_id: str, filename: str) -> dict: """ @@ -333,35 +325,35 @@ class MinerULoader: Returns the result dict for the file. """ headers = { - "Authorization": f"Bearer {self.api_key}", + 'Authorization': f'Bearer {self.api_key}', } max_iterations = 300 # 10 minutes max (2 seconds per iteration) poll_interval = 2 # seconds - log.info(f"Polling batch status: {batch_id}") + log.info(f'Polling batch status: {batch_id}') for iteration in range(max_iterations): try: response = requests.get( - f"{self.api_url}/extract-results/batch/{batch_id}", + f'{self.api_url}/extract-results/batch/{batch_id}', headers=headers, timeout=30, ) response.raise_for_status() except requests.HTTPError as e: - error_detail = f"Failed to poll batch status: {e}" + error_detail = f'Failed to poll batch status: {e}' if e.response is not None: try: error_data = e.response.json() - error_detail += f" - {error_data.get('msg', error_data)}" + error_detail += f' - {error_data.get("msg", error_data)}' except Exception: - error_detail += f" - {e.response.text}" + error_detail += f' - {e.response.text}' raise HTTPException(status.HTTP_400_BAD_REQUEST, detail=error_detail) except Exception as e: raise HTTPException( status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=f"Error polling batch status: {str(e)}", + detail=f'Error polling batch status: {str(e)}', ) try: @@ -369,58 +361,56 @@ class MinerULoader: except ValueError as e: raise HTTPException( status.HTTP_502_BAD_GATEWAY, - detail=f"Invalid JSON response while polling: {e}", + detail=f'Invalid JSON response while polling: {e}', ) # Check for API error response - if result.get("code") != 0: + if result.get('code') != 0: raise HTTPException( status.HTTP_400_BAD_REQUEST, - detail=f"MinerU Cloud API error: {result.get('msg', 'Unknown error')}", + detail=f'MinerU Cloud API error: {result.get("msg", "Unknown error")}', ) - data = result.get("data", {}) - extract_result = data.get("extract_result", []) + data = result.get('data', {}) + extract_result = data.get('extract_result', []) # Find our file in the batch results file_result = None for item in extract_result: - if item.get("file_name") == filename: + if item.get('file_name') == filename: file_result = item break if not file_result: raise HTTPException( status.HTTP_502_BAD_GATEWAY, - detail=f"File {filename} not found in batch results", + detail=f'File {filename} not found in batch results', ) - state = file_result.get("state") + state = file_result.get('state') - if state == "done": - log.info(f"Processing complete for {filename}") + if state == 'done': + log.info(f'Processing complete for {filename}') return file_result - elif state == "failed": - error_msg = file_result.get("err_msg", "Unknown error") + elif state == 'failed': + error_msg = file_result.get('err_msg', 'Unknown error') raise HTTPException( status.HTTP_400_BAD_REQUEST, - detail=f"MinerU processing failed: {error_msg}", + detail=f'MinerU processing failed: {error_msg}', ) - elif state in ["waiting-file", "pending", "running", "converting"]: + elif state in ['waiting-file', 'pending', 'running', 'converting']: # Still processing if iteration % 10 == 0: # Log every 20 seconds - log.info( - f"Processing status: {state} (iteration {iteration + 1}/{max_iterations})" - ) + log.info(f'Processing status: {state} (iteration {iteration + 1}/{max_iterations})') time.sleep(poll_interval) else: - log.warning(f"Unknown state: {state}") + log.warning(f'Unknown state: {state}') time.sleep(poll_interval) # Timeout raise HTTPException( status.HTTP_504_GATEWAY_TIMEOUT, - detail="MinerU processing timed out after 10 minutes", + detail='MinerU processing timed out after 10 minutes', ) def _download_and_extract_zip(self, zip_url: str, filename: str) -> str: @@ -428,7 +418,7 @@ class MinerULoader: Download ZIP file from CDN and extract markdown content. Returns the markdown content as a string. """ - log.info(f"Downloading results from: {zip_url}") + log.info(f'Downloading results from: {zip_url}') try: response = requests.get(zip_url, timeout=60) @@ -436,23 +426,23 @@ class MinerULoader: except requests.HTTPError as e: raise HTTPException( status.HTTP_400_BAD_REQUEST, - detail=f"Failed to download results ZIP: {e}", + detail=f'Failed to download results ZIP: {e}', ) except Exception as e: raise HTTPException( status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=f"Error downloading results: {str(e)}", + detail=f'Error downloading results: {str(e)}', ) # Save ZIP to temporary file and extract try: - with tempfile.NamedTemporaryFile(delete=False, suffix=".zip") as tmp_zip: + with tempfile.NamedTemporaryFile(delete=False, suffix='.zip') as tmp_zip: tmp_zip.write(response.content) tmp_zip_path = tmp_zip.name with tempfile.TemporaryDirectory() as tmp_dir: # Extract ZIP - with zipfile.ZipFile(tmp_zip_path, "r") as zip_ref: + with zipfile.ZipFile(tmp_zip_path, 'r') as zip_ref: zip_ref.extractall(tmp_dir) # Find markdown file - search recursively for any .md file @@ -466,33 +456,27 @@ class MinerULoader: full_path = os.path.join(root, file) all_files.append(full_path) # Look for any .md file - if file.endswith(".md"): + if file.endswith('.md'): found_md_path = full_path - log.info(f"Found markdown file at: {full_path}") + log.info(f'Found markdown file at: {full_path}') try: - with open(full_path, "r", encoding="utf-8") as f: + with open(full_path, 'r', encoding='utf-8') as f: markdown_content = f.read() - if ( - markdown_content - ): # Use the first non-empty markdown file + if markdown_content: # Use the first non-empty markdown file break except Exception as e: - log.warning(f"Failed to read {full_path}: {e}") + log.warning(f'Failed to read {full_path}: {e}') if markdown_content: break if markdown_content is None: - log.error(f"Available files in ZIP: {all_files}") + log.error(f'Available files in ZIP: {all_files}') # Try to provide more helpful error message - md_files = [f for f in all_files if f.endswith(".md")] + md_files = [f for f in all_files if f.endswith('.md')] if md_files: - error_msg = ( - f"Found .md files but couldn't read them: {md_files}" - ) + error_msg = f"Found .md files but couldn't read them: {md_files}" else: - error_msg = ( - f"No .md files found in ZIP. Available files: {all_files}" - ) + error_msg = f'No .md files found in ZIP. Available files: {all_files}' raise HTTPException( status.HTTP_502_BAD_GATEWAY, detail=error_msg, @@ -504,21 +488,19 @@ class MinerULoader: except zipfile.BadZipFile as e: raise HTTPException( status.HTTP_502_BAD_GATEWAY, - detail=f"Invalid ZIP file received: {e}", + detail=f'Invalid ZIP file received: {e}', ) except Exception as e: raise HTTPException( status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=f"Error extracting ZIP: {str(e)}", + detail=f'Error extracting ZIP: {str(e)}', ) if not markdown_content: raise HTTPException( status.HTTP_400_BAD_REQUEST, - detail="Extracted markdown content is empty", + detail='Extracted markdown content is empty', ) - log.info( - f"Successfully extracted markdown content ({len(markdown_content)} characters)" - ) + log.info(f'Successfully extracted markdown content ({len(markdown_content)} characters)') return markdown_content diff --git a/backend/open_webui/retrieval/loaders/mistral.py b/backend/open_webui/retrieval/loaders/mistral.py index 68570757c8..e46863a96a 100644 --- a/backend/open_webui/retrieval/loaders/mistral.py +++ b/backend/open_webui/retrieval/loaders/mistral.py @@ -49,13 +49,11 @@ class MistralLoader: enable_debug_logging: Enable detailed debug logs. """ if not api_key: - raise ValueError("API key cannot be empty.") + raise ValueError('API key cannot be empty.') if not os.path.exists(file_path): - raise FileNotFoundError(f"File not found at {file_path}") + raise FileNotFoundError(f'File not found at {file_path}') - self.base_url = ( - base_url.rstrip("/") if base_url else "https://api.mistral.ai/v1" - ) + self.base_url = base_url.rstrip('/') if base_url else 'https://api.mistral.ai/v1' self.api_key = api_key self.file_path = file_path self.timeout = timeout @@ -65,18 +63,10 @@ class MistralLoader: # PERFORMANCE OPTIMIZATION: Differentiated timeouts for different operations # This prevents long-running OCR operations from affecting quick operations # and improves user experience by failing fast on operations that should be quick - self.upload_timeout = min( - timeout, 120 - ) # Cap upload at 2 minutes - prevents hanging on large files - self.url_timeout = ( - 30 # URL requests should be fast - fail quickly if API is slow - ) - self.ocr_timeout = ( - timeout # OCR can take the full timeout - this is the heavy operation - ) - self.cleanup_timeout = ( - 30 # Cleanup should be quick - don't hang on file deletion - ) + self.upload_timeout = min(timeout, 120) # Cap upload at 2 minutes - prevents hanging on large files + self.url_timeout = 30 # URL requests should be fast - fail quickly if API is slow + self.ocr_timeout = timeout # OCR can take the full timeout - this is the heavy operation + self.cleanup_timeout = 30 # Cleanup should be quick - don't hang on file deletion # PERFORMANCE OPTIMIZATION: Pre-compute file info to avoid repeated filesystem calls # This avoids multiple os.path.basename() and os.path.getsize() calls during processing @@ -85,8 +75,8 @@ class MistralLoader: # ENHANCEMENT: Added User-Agent for better API tracking and debugging self.headers = { - "Authorization": f"Bearer {self.api_key}", - "User-Agent": "OpenWebUI-MistralLoader/2.0", # Helps API provider track usage + 'Authorization': f'Bearer {self.api_key}', + 'User-Agent': 'OpenWebUI-MistralLoader/2.0', # Helps API provider track usage } def _debug_log(self, message: str, *args) -> None: @@ -108,43 +98,39 @@ class MistralLoader: return {} # Return empty dict if no content return response.json() except requests.exceptions.HTTPError as http_err: - log.error(f"HTTP error occurred: {http_err} - Response: {response.text}") + log.error(f'HTTP error occurred: {http_err} - Response: {response.text}') raise except requests.exceptions.RequestException as req_err: - log.error(f"Request exception occurred: {req_err}") + log.error(f'Request exception occurred: {req_err}') raise except ValueError as json_err: # Includes JSONDecodeError - log.error(f"JSON decode error: {json_err} - Response: {response.text}") + log.error(f'JSON decode error: {json_err} - Response: {response.text}') raise # Re-raise after logging - async def _handle_response_async( - self, response: aiohttp.ClientResponse - ) -> Dict[str, Any]: + async def _handle_response_async(self, response: aiohttp.ClientResponse) -> Dict[str, Any]: """Async version of response handling with better error info.""" try: response.raise_for_status() # Check content type - content_type = response.headers.get("content-type", "") - if "application/json" not in content_type: + content_type = response.headers.get('content-type', '') + if 'application/json' not in content_type: if response.status == 204: return {} text = await response.text() - raise ValueError( - f"Unexpected content type: {content_type}, body: {text[:200]}..." - ) + raise ValueError(f'Unexpected content type: {content_type}, body: {text[:200]}...') return await response.json() except aiohttp.ClientResponseError as e: - error_text = await response.text() if response else "No response" - log.error(f"HTTP {e.status}: {e.message} - Response: {error_text[:500]}") + error_text = await response.text() if response else 'No response' + log.error(f'HTTP {e.status}: {e.message} - Response: {error_text[:500]}') raise except aiohttp.ClientError as e: - log.error(f"Client error: {e}") + log.error(f'Client error: {e}') raise except Exception as e: - log.error(f"Unexpected error processing response: {e}") + log.error(f'Unexpected error processing response: {e}') raise def _is_retryable_error(self, error: Exception) -> bool: @@ -172,13 +158,11 @@ class MistralLoader: return True # Timeouts might resolve on retry if isinstance(error, requests.exceptions.HTTPError): # Only retry on server errors (5xx) or rate limits (429) - if hasattr(error, "response") and error.response is not None: + if hasattr(error, 'response') and error.response is not None: status_code = error.response.status_code return status_code >= 500 or status_code == 429 return False - if isinstance( - error, (aiohttp.ClientConnectionError, aiohttp.ServerTimeoutError) - ): + if isinstance(error, (aiohttp.ClientConnectionError, aiohttp.ServerTimeoutError)): return True # Async network/timeout errors are retryable if isinstance(error, aiohttp.ClientResponseError): return error.status >= 500 or error.status == 429 @@ -204,8 +188,7 @@ class MistralLoader: # Prevents overwhelming the server while ensuring reasonable retry delays wait_time = min((2**attempt) + 0.5, 30) # Cap at 30 seconds log.warning( - f"Retryable error (attempt {attempt + 1}/{self.max_retries}): {e}. " - f"Retrying in {wait_time}s..." + f'Retryable error (attempt {attempt + 1}/{self.max_retries}): {e}. Retrying in {wait_time}s...' ) time.sleep(wait_time) @@ -226,8 +209,7 @@ class MistralLoader: # PERFORMANCE OPTIMIZATION: Non-blocking exponential backoff wait_time = min((2**attempt) + 0.5, 30) # Cap at 30 seconds log.warning( - f"Retryable error (attempt {attempt + 1}/{self.max_retries}): {e}. " - f"Retrying in {wait_time}s..." + f'Retryable error (attempt {attempt + 1}/{self.max_retries}): {e}. Retrying in {wait_time}s...' ) await asyncio.sleep(wait_time) # Non-blocking wait @@ -240,15 +222,15 @@ class MistralLoader: Although streaming is not enabled for this endpoint, the file is opened in a context manager to minimize memory usage duration. """ - log.info("Uploading file to Mistral API") - url = f"{self.base_url}/files" + log.info('Uploading file to Mistral API') + url = f'{self.base_url}/files' def upload_request(): # MEMORY OPTIMIZATION: Use context manager to minimize file handle lifetime # This ensures the file is closed immediately after reading, reducing memory usage - with open(self.file_path, "rb") as f: - files = {"file": (self.file_name, f, "application/pdf")} - data = {"purpose": "ocr"} + with open(self.file_path, 'rb') as f: + files = {'file': (self.file_name, f, 'application/pdf')} + data = {'purpose': 'ocr'} # NOTE: stream=False is required for this endpoint # The Mistral API doesn't support chunked uploads for this endpoint @@ -265,42 +247,38 @@ class MistralLoader: try: response_data = self._retry_request_sync(upload_request) - file_id = response_data.get("id") + file_id = response_data.get('id') if not file_id: - raise ValueError("File ID not found in upload response.") - log.info(f"File uploaded successfully. File ID: {file_id}") + raise ValueError('File ID not found in upload response.') + log.info(f'File uploaded successfully. File ID: {file_id}') return file_id except Exception as e: - log.error(f"Failed to upload file: {e}") + log.error(f'Failed to upload file: {e}') raise async def _upload_file_async(self, session: aiohttp.ClientSession) -> str: """Async file upload with streaming for better memory efficiency.""" - url = f"{self.base_url}/files" + url = f'{self.base_url}/files' async def upload_request(): # Create multipart writer for streaming upload - writer = aiohttp.MultipartWriter("form-data") + writer = aiohttp.MultipartWriter('form-data') # Add purpose field - purpose_part = writer.append("ocr") - purpose_part.set_content_disposition("form-data", name="purpose") + purpose_part = writer.append('ocr') + purpose_part.set_content_disposition('form-data', name='purpose') # Add file part with streaming file_part = writer.append_payload( aiohttp.streams.FilePayload( self.file_path, filename=self.file_name, - content_type="application/pdf", + content_type='application/pdf', ) ) - file_part.set_content_disposition( - "form-data", name="file", filename=self.file_name - ) + file_part.set_content_disposition('form-data', name='file', filename=self.file_name) - self._debug_log( - f"Uploading file: {self.file_name} ({self.file_size:,} bytes)" - ) + self._debug_log(f'Uploading file: {self.file_name} ({self.file_size:,} bytes)') async with session.post( url, @@ -312,48 +290,44 @@ class MistralLoader: response_data = await self._retry_request_async(upload_request) - file_id = response_data.get("id") + file_id = response_data.get('id') if not file_id: - raise ValueError("File ID not found in upload response.") + raise ValueError('File ID not found in upload response.') - log.info(f"File uploaded successfully. File ID: {file_id}") + log.info(f'File uploaded successfully. File ID: {file_id}') return file_id def _get_signed_url(self, file_id: str) -> str: """Retrieves a temporary signed URL for the uploaded file (sync version).""" - log.info(f"Getting signed URL for file ID: {file_id}") - url = f"{self.base_url}/files/{file_id}/url" - params = {"expiry": 1} - signed_url_headers = {**self.headers, "Accept": "application/json"} + log.info(f'Getting signed URL for file ID: {file_id}') + url = f'{self.base_url}/files/{file_id}/url' + params = {'expiry': 1} + signed_url_headers = {**self.headers, 'Accept': 'application/json'} def url_request(): - response = requests.get( - url, headers=signed_url_headers, params=params, timeout=self.url_timeout - ) + response = requests.get(url, headers=signed_url_headers, params=params, timeout=self.url_timeout) return self._handle_response(response) try: response_data = self._retry_request_sync(url_request) - signed_url = response_data.get("url") + signed_url = response_data.get('url') if not signed_url: - raise ValueError("Signed URL not found in response.") - log.info("Signed URL received.") + raise ValueError('Signed URL not found in response.') + log.info('Signed URL received.') return signed_url except Exception as e: - log.error(f"Failed to get signed URL: {e}") + log.error(f'Failed to get signed URL: {e}') raise - async def _get_signed_url_async( - self, session: aiohttp.ClientSession, file_id: str - ) -> str: + async def _get_signed_url_async(self, session: aiohttp.ClientSession, file_id: str) -> str: """Async signed URL retrieval.""" - url = f"{self.base_url}/files/{file_id}/url" - params = {"expiry": 1} + url = f'{self.base_url}/files/{file_id}/url' + params = {'expiry': 1} - headers = {**self.headers, "Accept": "application/json"} + headers = {**self.headers, 'Accept': 'application/json'} async def url_request(): - self._debug_log(f"Getting signed URL for file ID: {file_id}") + self._debug_log(f'Getting signed URL for file ID: {file_id}') async with session.get( url, headers=headers, @@ -364,69 +338,65 @@ class MistralLoader: response_data = await self._retry_request_async(url_request) - signed_url = response_data.get("url") + signed_url = response_data.get('url') if not signed_url: - raise ValueError("Signed URL not found in response.") + raise ValueError('Signed URL not found in response.') - self._debug_log("Signed URL received successfully") + self._debug_log('Signed URL received successfully') return signed_url def _process_ocr(self, signed_url: str) -> Dict[str, Any]: """Sends the signed URL to the OCR endpoint for processing (sync version).""" - log.info("Processing OCR via Mistral API") - url = f"{self.base_url}/ocr" + log.info('Processing OCR via Mistral API') + url = f'{self.base_url}/ocr' ocr_headers = { **self.headers, - "Content-Type": "application/json", - "Accept": "application/json", + 'Content-Type': 'application/json', + 'Accept': 'application/json', } payload = { - "model": "mistral-ocr-latest", - "document": { - "type": "document_url", - "document_url": signed_url, + 'model': 'mistral-ocr-latest', + 'document': { + 'type': 'document_url', + 'document_url': signed_url, }, - "include_image_base64": False, + 'include_image_base64': False, } def ocr_request(): - response = requests.post( - url, headers=ocr_headers, json=payload, timeout=self.ocr_timeout - ) + response = requests.post(url, headers=ocr_headers, json=payload, timeout=self.ocr_timeout) return self._handle_response(response) try: ocr_response = self._retry_request_sync(ocr_request) - log.info("OCR processing done.") - self._debug_log("OCR response: %s", ocr_response) + log.info('OCR processing done.') + self._debug_log('OCR response: %s', ocr_response) return ocr_response except Exception as e: - log.error(f"Failed during OCR processing: {e}") + log.error(f'Failed during OCR processing: {e}') raise - async def _process_ocr_async( - self, session: aiohttp.ClientSession, signed_url: str - ) -> Dict[str, Any]: + async def _process_ocr_async(self, session: aiohttp.ClientSession, signed_url: str) -> Dict[str, Any]: """Async OCR processing with timing metrics.""" - url = f"{self.base_url}/ocr" + url = f'{self.base_url}/ocr' headers = { **self.headers, - "Content-Type": "application/json", - "Accept": "application/json", + 'Content-Type': 'application/json', + 'Accept': 'application/json', } payload = { - "model": "mistral-ocr-latest", - "document": { - "type": "document_url", - "document_url": signed_url, + 'model': 'mistral-ocr-latest', + 'document': { + 'type': 'document_url', + 'document_url': signed_url, }, - "include_image_base64": False, + 'include_image_base64': False, } async def ocr_request(): - log.info("Starting OCR processing via Mistral API") + log.info('Starting OCR processing via Mistral API') start_time = time.time() async with session.post( @@ -438,7 +408,7 @@ class MistralLoader: ocr_response = await self._handle_response_async(response) processing_time = time.time() - start_time - log.info(f"OCR processing completed in {processing_time:.2f}s") + log.info(f'OCR processing completed in {processing_time:.2f}s') return ocr_response @@ -446,42 +416,36 @@ class MistralLoader: def _delete_file(self, file_id: str) -> None: """Deletes the file from Mistral storage (sync version).""" - log.info(f"Deleting uploaded file ID: {file_id}") - url = f"{self.base_url}/files/{file_id}" + log.info(f'Deleting uploaded file ID: {file_id}') + url = f'{self.base_url}/files/{file_id}' try: - response = requests.delete( - url, headers=self.headers, timeout=self.cleanup_timeout - ) + response = requests.delete(url, headers=self.headers, timeout=self.cleanup_timeout) delete_response = self._handle_response(response) - log.info(f"File deleted successfully: {delete_response}") + log.info(f'File deleted successfully: {delete_response}') except Exception as e: # Log error but don't necessarily halt execution if deletion fails - log.error(f"Failed to delete file ID {file_id}: {e}") + log.error(f'Failed to delete file ID {file_id}: {e}') - async def _delete_file_async( - self, session: aiohttp.ClientSession, file_id: str - ) -> None: + async def _delete_file_async(self, session: aiohttp.ClientSession, file_id: str) -> None: """Async file deletion with error tolerance.""" try: async def delete_request(): - self._debug_log(f"Deleting file ID: {file_id}") + self._debug_log(f'Deleting file ID: {file_id}') async with session.delete( - url=f"{self.base_url}/files/{file_id}", + url=f'{self.base_url}/files/{file_id}', headers=self.headers, - timeout=aiohttp.ClientTimeout( - total=self.cleanup_timeout - ), # Shorter timeout for cleanup + timeout=aiohttp.ClientTimeout(total=self.cleanup_timeout), # Shorter timeout for cleanup ) as response: return await self._handle_response_async(response) await self._retry_request_async(delete_request) - self._debug_log(f"File {file_id} deleted successfully") + self._debug_log(f'File {file_id} deleted successfully') except Exception as e: # Don't fail the entire process if cleanup fails - log.warning(f"Failed to delete file ID {file_id}: {e}") + log.warning(f'Failed to delete file ID {file_id}: {e}') @asynccontextmanager async def _get_session(self): @@ -506,7 +470,7 @@ class MistralLoader: async with aiohttp.ClientSession( connector=connector, timeout=timeout, - headers={"User-Agent": "OpenWebUI-MistralLoader/2.0"}, + headers={'User-Agent': 'OpenWebUI-MistralLoader/2.0'}, raise_for_status=False, # We handle status codes manually trust_env=True, ) as session: @@ -514,13 +478,13 @@ class MistralLoader: def _process_results(self, ocr_response: Dict[str, Any]) -> List[Document]: """Process OCR results into Document objects with enhanced metadata and memory efficiency.""" - pages_data = ocr_response.get("pages") + pages_data = ocr_response.get('pages') if not pages_data: - log.warning("No pages found in OCR response.") + log.warning('No pages found in OCR response.') return [ Document( - page_content="No text content found", - metadata={"error": "no_pages", "file_name": self.file_name}, + page_content='No text content found', + metadata={'error': 'no_pages', 'file_name': self.file_name}, ) ] @@ -530,8 +494,8 @@ class MistralLoader: # Process pages in a memory-efficient way for page_data in pages_data: - page_content = page_data.get("markdown") - page_index = page_data.get("index") # API uses 0-based index + page_content = page_data.get('markdown') + page_index = page_data.get('index') # API uses 0-based index if page_content is None or page_index is None: skipped_pages += 1 @@ -548,7 +512,7 @@ class MistralLoader: if not cleaned_content: skipped_pages += 1 - self._debug_log(f"Skipping empty page {page_index}") + self._debug_log(f'Skipping empty page {page_index}') continue # Create document with optimized metadata @@ -556,34 +520,30 @@ class MistralLoader: Document( page_content=cleaned_content, metadata={ - "page": page_index, # 0-based index from API - "page_label": page_index + 1, # 1-based label for convenience - "total_pages": total_pages, - "file_name": self.file_name, - "file_size": self.file_size, - "processing_engine": "mistral-ocr", - "content_length": len(cleaned_content), + 'page': page_index, # 0-based index from API + 'page_label': page_index + 1, # 1-based label for convenience + 'total_pages': total_pages, + 'file_name': self.file_name, + 'file_size': self.file_size, + 'processing_engine': 'mistral-ocr', + 'content_length': len(cleaned_content), }, ) ) if skipped_pages > 0: - log.info( - f"Processed {len(documents)} pages, skipped {skipped_pages} empty/invalid pages" - ) + log.info(f'Processed {len(documents)} pages, skipped {skipped_pages} empty/invalid pages') if not documents: # Case where pages existed but none had valid markdown/index - log.warning( - "OCR response contained pages, but none had valid content/index." - ) + log.warning('OCR response contained pages, but none had valid content/index.') return [ Document( - page_content="No valid text content found in document", + page_content='No valid text content found in document', metadata={ - "error": "no_valid_pages", - "total_pages": total_pages, - "file_name": self.file_name, + 'error': 'no_valid_pages', + 'total_pages': total_pages, + 'file_name': self.file_name, }, ) ] @@ -615,24 +575,20 @@ class MistralLoader: documents = self._process_results(ocr_response) total_time = time.time() - start_time - log.info( - f"Sync OCR workflow completed in {total_time:.2f}s, produced {len(documents)} documents" - ) + log.info(f'Sync OCR workflow completed in {total_time:.2f}s, produced {len(documents)} documents') return documents except Exception as e: total_time = time.time() - start_time - log.error( - f"An error occurred during the loading process after {total_time:.2f}s: {e}" - ) + log.error(f'An error occurred during the loading process after {total_time:.2f}s: {e}') # Return an error document on failure return [ Document( - page_content=f"Error during processing: {e}", + page_content=f'Error during processing: {e}', metadata={ - "error": "processing_failed", - "file_name": self.file_name, + 'error': 'processing_failed', + 'file_name': self.file_name, }, ) ] @@ -643,9 +599,7 @@ class MistralLoader: self._delete_file(file_id) except Exception as del_e: # Log deletion error, but don't overwrite original error if one occurred - log.error( - f"Cleanup error: Could not delete file ID {file_id}. Reason: {del_e}" - ) + log.error(f'Cleanup error: Could not delete file ID {file_id}. Reason: {del_e}') async def load_async(self) -> List[Document]: """ @@ -672,21 +626,19 @@ class MistralLoader: documents = self._process_results(ocr_response) total_time = time.time() - start_time - log.info( - f"Async OCR workflow completed in {total_time:.2f}s, produced {len(documents)} documents" - ) + log.info(f'Async OCR workflow completed in {total_time:.2f}s, produced {len(documents)} documents') return documents except Exception as e: total_time = time.time() - start_time - log.error(f"Async OCR workflow failed after {total_time:.2f}s: {e}") + log.error(f'Async OCR workflow failed after {total_time:.2f}s: {e}') return [ Document( - page_content=f"Error during OCR processing: {e}", + page_content=f'Error during OCR processing: {e}', metadata={ - "error": "processing_failed", - "file_name": self.file_name, + 'error': 'processing_failed', + 'file_name': self.file_name, }, ) ] @@ -697,11 +649,11 @@ class MistralLoader: async with self._get_session() as session: await self._delete_file_async(session, file_id) except Exception as cleanup_error: - log.error(f"Cleanup failed for file ID {file_id}: {cleanup_error}") + log.error(f'Cleanup failed for file ID {file_id}: {cleanup_error}') @staticmethod async def load_multiple_async( - loaders: List["MistralLoader"], + loaders: List['MistralLoader'], max_concurrent: int = 5, # Limit concurrent requests ) -> List[List[Document]]: """ @@ -717,15 +669,13 @@ class MistralLoader: if not loaders: return [] - log.info( - f"Starting concurrent processing of {len(loaders)} files with max {max_concurrent} concurrent" - ) + log.info(f'Starting concurrent processing of {len(loaders)} files with max {max_concurrent} concurrent') start_time = time.time() # Use semaphore to control concurrency semaphore = asyncio.Semaphore(max_concurrent) - async def process_with_semaphore(loader: "MistralLoader") -> List[Document]: + async def process_with_semaphore(loader: 'MistralLoader') -> List[Document]: async with semaphore: return await loader.load_async() @@ -737,14 +687,14 @@ class MistralLoader: processed_results = [] for i, result in enumerate(results): if isinstance(result, Exception): - log.error(f"File {i} failed: {result}") + log.error(f'File {i} failed: {result}') processed_results.append( [ Document( - page_content=f"Error processing file: {result}", + page_content=f'Error processing file: {result}', metadata={ - "error": "batch_processing_failed", - "file_index": i, + 'error': 'batch_processing_failed', + 'file_index': i, }, ) ] @@ -755,15 +705,13 @@ class MistralLoader: # MONITORING: Log comprehensive batch processing statistics total_time = time.time() - start_time total_docs = sum(len(docs) for docs in processed_results) - success_count = sum( - 1 for result in results if not isinstance(result, Exception) - ) + success_count = sum(1 for result in results if not isinstance(result, Exception)) failure_count = len(results) - success_count log.info( - f"Batch processing completed in {total_time:.2f}s: " - f"{success_count} files succeeded, {failure_count} files failed, " - f"produced {total_docs} total documents" + f'Batch processing completed in {total_time:.2f}s: ' + f'{success_count} files succeeded, {failure_count} files failed, ' + f'produced {total_docs} total documents' ) return processed_results diff --git a/backend/open_webui/retrieval/loaders/tavily.py b/backend/open_webui/retrieval/loaders/tavily.py index f298de80b4..742ac499cf 100644 --- a/backend/open_webui/retrieval/loaders/tavily.py +++ b/backend/open_webui/retrieval/loaders/tavily.py @@ -25,7 +25,7 @@ class TavilyLoader(BaseLoader): self, urls: Union[str, List[str]], api_key: str, - extract_depth: Literal["basic", "advanced"] = "basic", + extract_depth: Literal['basic', 'advanced'] = 'basic', continue_on_failure: bool = True, ) -> None: """Initialize Tavily Extract client. @@ -42,13 +42,13 @@ class TavilyLoader(BaseLoader): continue_on_failure: Whether to continue if extraction of a URL fails. """ if not urls: - raise ValueError("At least one URL must be provided.") + raise ValueError('At least one URL must be provided.') self.api_key = api_key self.urls = urls if isinstance(urls, list) else [urls] self.extract_depth = extract_depth self.continue_on_failure = continue_on_failure - self.api_url = "https://api.tavily.com/extract" + self.api_url = 'https://api.tavily.com/extract' def lazy_load(self) -> Iterator[Document]: """Extract and yield documents from the URLs using Tavily Extract API.""" @@ -57,35 +57,35 @@ class TavilyLoader(BaseLoader): batch_urls = self.urls[i : i + batch_size] try: headers = { - "Content-Type": "application/json", - "Authorization": f"Bearer {self.api_key}", + 'Content-Type': 'application/json', + 'Authorization': f'Bearer {self.api_key}', } # Use string for single URL, array for multiple URLs urls_param = batch_urls[0] if len(batch_urls) == 1 else batch_urls - payload = {"urls": urls_param, "extract_depth": self.extract_depth} + payload = {'urls': urls_param, 'extract_depth': self.extract_depth} # Make the API call response = requests.post(self.api_url, headers=headers, json=payload) response.raise_for_status() response_data = response.json() # Process successful results - for result in response_data.get("results", []): - url = result.get("url", "") - content = result.get("raw_content", "") + for result in response_data.get('results', []): + url = result.get('url', '') + content = result.get('raw_content', '') if not content: - log.warning(f"No content extracted from {url}") + log.warning(f'No content extracted from {url}') continue # Add URLs as metadata - metadata = {"source": url} + metadata = {'source': url} yield Document( page_content=content, metadata=metadata, ) - for failed in response_data.get("failed_results", []): - url = failed.get("url", "") - error = failed.get("error", "Unknown error") - log.error(f"Failed to extract content from {url}: {error}") + for failed in response_data.get('failed_results', []): + url = failed.get('url', '') + error = failed.get('error', 'Unknown error') + log.error(f'Failed to extract content from {url}: {error}') except Exception as e: if self.continue_on_failure: - log.error(f"Error extracting content from batch {batch_urls}: {e}") + log.error(f'Error extracting content from batch {batch_urls}: {e}') else: raise e diff --git a/backend/open_webui/retrieval/loaders/youtube.py b/backend/open_webui/retrieval/loaders/youtube.py index faf7b4452e..34a1d20740 100644 --- a/backend/open_webui/retrieval/loaders/youtube.py +++ b/backend/open_webui/retrieval/loaders/youtube.py @@ -7,14 +7,14 @@ from langchain_core.documents import Document log = logging.getLogger(__name__) -ALLOWED_SCHEMES = {"http", "https"} +ALLOWED_SCHEMES = {'http', 'https'} ALLOWED_NETLOCS = { - "youtu.be", - "m.youtube.com", - "youtube.com", - "www.youtube.com", - "www.youtube-nocookie.com", - "vid.plus", + 'youtu.be', + 'm.youtube.com', + 'youtube.com', + 'www.youtube.com', + 'www.youtube-nocookie.com', + 'vid.plus', } @@ -30,17 +30,17 @@ def _parse_video_id(url: str) -> Optional[str]: path = parsed_url.path - if path.endswith("/watch"): + if path.endswith('/watch'): query = parsed_url.query parsed_query = parse_qs(query) - if "v" in parsed_query: - ids = parsed_query["v"] + if 'v' in parsed_query: + ids = parsed_query['v'] video_id = ids if isinstance(ids, str) else ids[0] else: return None else: - path = parsed_url.path.lstrip("/") - video_id = path.split("/")[-1] + path = parsed_url.path.lstrip('/') + video_id = path.split('/')[-1] if len(video_id) != 11: # Video IDs are 11 characters long return None @@ -54,13 +54,13 @@ class YoutubeLoader: def __init__( self, video_id: str, - language: Union[str, Sequence[str]] = "en", + language: Union[str, Sequence[str]] = 'en', proxy_url: Optional[str] = None, ): """Initialize with YouTube video ID.""" _video_id = _parse_video_id(video_id) self.video_id = _video_id if _video_id is not None else video_id - self._metadata = {"source": video_id} + self._metadata = {'source': video_id} self.proxy_url = proxy_url # Ensure language is a list @@ -70,8 +70,8 @@ class YoutubeLoader: self.language = list(language) # Add English as fallback if not already in the list - if "en" not in self.language: - self.language.append("en") + if 'en' not in self.language: + self.language.append('en') def load(self) -> List[Document]: """Load YouTube transcripts into `Document` objects.""" @@ -85,14 +85,12 @@ class YoutubeLoader: except ImportError: raise ImportError( 'Could not import "youtube_transcript_api" Python package. ' - "Please install it with `pip install youtube-transcript-api`." + 'Please install it with `pip install youtube-transcript-api`.' ) if self.proxy_url: - youtube_proxies = GenericProxyConfig( - http_url=self.proxy_url, https_url=self.proxy_url - ) - log.debug(f"Using proxy URL: {self.proxy_url[:14]}...") + youtube_proxies = GenericProxyConfig(http_url=self.proxy_url, https_url=self.proxy_url) + log.debug(f'Using proxy URL: {self.proxy_url[:14]}...') else: youtube_proxies = None @@ -100,7 +98,7 @@ class YoutubeLoader: try: transcript_list = transcript_api.list(self.video_id) except Exception as e: - log.exception("Loading YouTube transcript failed") + log.exception('Loading YouTube transcript failed') return [] # Try each language in order of priority @@ -110,14 +108,10 @@ class YoutubeLoader: if transcript.is_generated: log.debug(f"Found generated transcript for language '{lang}'") try: - transcript = transcript_list.find_manually_created_transcript( - [lang] - ) + transcript = transcript_list.find_manually_created_transcript([lang]) log.debug(f"Found manual transcript for language '{lang}'") except NoTranscriptFound: - log.debug( - f"No manual transcript found for language '{lang}', using generated" - ) + log.debug(f"No manual transcript found for language '{lang}', using generated") pass log.debug(f"Found transcript for language '{lang}'") @@ -131,12 +125,10 @@ class YoutubeLoader: log.debug(f"Empty transcript for language '{lang}'") continue - transcript_text = " ".join( + transcript_text = ' '.join( map( lambda transcript_piece: ( - transcript_piece.text.strip(" ") - if hasattr(transcript_piece, "text") - else "" + transcript_piece.text.strip(' ') if hasattr(transcript_piece, 'text') else '' ), transcript_pieces, ) @@ -150,9 +142,9 @@ class YoutubeLoader: raise e # If we get here, all languages failed - languages_tried = ", ".join(self.language) + languages_tried = ', '.join(self.language) log.warning( - f"No transcript found for any of the specified languages: {languages_tried}. Verify if the video has transcripts, add more languages if needed." + f'No transcript found for any of the specified languages: {languages_tried}. Verify if the video has transcripts, add more languages if needed.' ) raise NoTranscriptFound(self.video_id, self.language, list(transcript_list)) diff --git a/backend/open_webui/retrieval/models/colbert.py b/backend/open_webui/retrieval/models/colbert.py index 2a8c0329d7..d122291ec5 100644 --- a/backend/open_webui/retrieval/models/colbert.py +++ b/backend/open_webui/retrieval/models/colbert.py @@ -13,19 +13,17 @@ log = logging.getLogger(__name__) class ColBERT(BaseReranker): def __init__(self, name, **kwargs) -> None: - log.info("ColBERT: Loading model", name) - self.device = "cuda" if torch.cuda.is_available() else "cpu" + log.info('ColBERT: Loading model', name) + self.device = 'cuda' if torch.cuda.is_available() else 'cpu' - DOCKER = kwargs.get("env") == "docker" + DOCKER = kwargs.get('env') == 'docker' if DOCKER: # This is a workaround for the issue with the docker container # where the torch extension is not loaded properly # and the following error is thrown: # /root/.cache/torch_extensions/py311_cpu/segmented_maxsim_cpp/segmented_maxsim_cpp.so: cannot open shared object file: No such file or directory - lock_file = ( - "/root/.cache/torch_extensions/py311_cpu/segmented_maxsim_cpp/lock" - ) + lock_file = '/root/.cache/torch_extensions/py311_cpu/segmented_maxsim_cpp/lock' if os.path.exists(lock_file): os.remove(lock_file) @@ -36,23 +34,16 @@ class ColBERT(BaseReranker): pass def calculate_similarity_scores(self, query_embeddings, document_embeddings): - query_embeddings = query_embeddings.to(self.device) document_embeddings = document_embeddings.to(self.device) # Validate dimensions to ensure compatibility if query_embeddings.dim() != 3: - raise ValueError( - f"Expected query embeddings to have 3 dimensions, but got {query_embeddings.dim()}." - ) + raise ValueError(f'Expected query embeddings to have 3 dimensions, but got {query_embeddings.dim()}.') if document_embeddings.dim() != 3: - raise ValueError( - f"Expected document embeddings to have 3 dimensions, but got {document_embeddings.dim()}." - ) + raise ValueError(f'Expected document embeddings to have 3 dimensions, but got {document_embeddings.dim()}.') if query_embeddings.size(0) not in [1, document_embeddings.size(0)]: - raise ValueError( - "There should be either one query or queries equal to the number of documents." - ) + raise ValueError('There should be either one query or queries equal to the number of documents.') # Transpose the query embeddings to align for matrix multiplication transposed_query_embeddings = query_embeddings.permute(0, 2, 1) @@ -69,7 +60,6 @@ class ColBERT(BaseReranker): return normalized_scores.detach().cpu().numpy().astype(np.float32) def predict(self, sentences): - query = sentences[0][0] docs = [i[1] for i in sentences] @@ -80,8 +70,6 @@ class ColBERT(BaseReranker): embedded_query = embedded_queries[0] # Calculate retrieval scores for the query against all documents - scores = self.calculate_similarity_scores( - embedded_query.unsqueeze(0), embedded_docs - ) + scores = self.calculate_similarity_scores(embedded_query.unsqueeze(0), embedded_docs) return scores diff --git a/backend/open_webui/retrieval/models/external.py b/backend/open_webui/retrieval/models/external.py index cd24dc6af2..f04583b965 100644 --- a/backend/open_webui/retrieval/models/external.py +++ b/backend/open_webui/retrieval/models/external.py @@ -15,8 +15,8 @@ class ExternalReranker(BaseReranker): def __init__( self, api_key: str, - url: str = "http://localhost:8080/v1/rerank", - model: str = "reranker", + url: str = 'http://localhost:8080/v1/rerank', + model: str = 'reranker', timeout: Optional[int] = None, ): self.api_key = api_key @@ -24,33 +24,31 @@ class ExternalReranker(BaseReranker): self.model = model self.timeout = timeout - def predict( - self, sentences: List[Tuple[str, str]], user=None - ) -> Optional[List[float]]: + def predict(self, sentences: List[Tuple[str, str]], user=None) -> Optional[List[float]]: query = sentences[0][0] docs = [i[1] for i in sentences] payload = { - "model": self.model, - "query": query, - "documents": docs, - "top_n": len(docs), + 'model': self.model, + 'query': query, + 'documents': docs, + 'top_n': len(docs), } try: - log.info(f"ExternalReranker:predict:model {self.model}") - log.info(f"ExternalReranker:predict:query {query}") + log.info(f'ExternalReranker:predict:model {self.model}') + log.info(f'ExternalReranker:predict:query {query}') headers = { - "Content-Type": "application/json", - "Authorization": f"Bearer {self.api_key}", + 'Content-Type': 'application/json', + 'Authorization': f'Bearer {self.api_key}', } if ENABLE_FORWARD_USER_INFO_HEADERS and user: headers = include_user_info_headers(headers, user) r = requests.post( - f"{self.url}", + f'{self.url}', headers=headers, json=payload, timeout=self.timeout, @@ -60,13 +58,13 @@ class ExternalReranker(BaseReranker): r.raise_for_status() data = r.json() - if "results" in data: - sorted_results = sorted(data["results"], key=lambda x: x["index"]) - return [result["relevance_score"] for result in sorted_results] + if 'results' in data: + sorted_results = sorted(data['results'], key=lambda x: x['index']) + return [result['relevance_score'] for result in sorted_results] else: - log.error("No results found in external reranking response") + log.error('No results found in external reranking response') return None except Exception as e: - log.exception(f"Error in external reranking: {e}") + log.exception(f'Error in external reranking: {e}') return None diff --git a/backend/open_webui/retrieval/utils.py b/backend/open_webui/retrieval/utils.py index 5030a6eecb..dfb38a9659 100644 --- a/backend/open_webui/retrieval/utils.py +++ b/backend/open_webui/retrieval/utils.py @@ -61,7 +61,7 @@ from langchain_core.retrievers import BaseRetriever def is_youtube_url(url: str) -> bool: - youtube_regex = r"^(https?://)?(www\.)?(youtube\.com|youtu\.be)/.+$" + youtube_regex = r'^(https?://)?(www\.)?(youtube\.com|youtu\.be)/.+$' return re.match(youtube_regex, url) is not None @@ -84,11 +84,11 @@ def get_loader(request, url: str): def get_content_from_url(request, url: str) -> str: loader = get_loader(request, url) docs = loader.load() - content = " ".join([doc.page_content for doc in docs]) + content = ' '.join([doc.page_content for doc in docs]) return content, docs -CHUNK_HASH_KEY = "_chunk_hash" +CHUNK_HASH_KEY = '_chunk_hash' def _content_hash(text: str) -> str: @@ -101,9 +101,7 @@ class VectorSearchRetriever(BaseRetriever): embedding_function: Any top_k: int - def _get_relevant_documents( - self, query: str, *, run_manager: CallbackManagerForRetrieverRun - ) -> list[Document]: + def _get_relevant_documents(self, query: str, *, run_manager: CallbackManagerForRetrieverRun) -> list[Document]: """Get documents relevant to a query. Args: @@ -145,11 +143,9 @@ class VectorSearchRetriever(BaseRetriever): return results -def query_doc( - collection_name: str, query_embedding: list[float], k: int, user: UserModel = None -): +def query_doc(collection_name: str, query_embedding: list[float], k: int, user: UserModel = None): try: - log.debug(f"query_doc:doc {collection_name}") + log.debug(f'query_doc:doc {collection_name}') result = VECTOR_DB_CLIENT.search( collection_name=collection_name, vectors=[query_embedding], @@ -157,25 +153,25 @@ def query_doc( ) if result: - log.info(f"query_doc:result {result.ids} {result.metadatas}") + log.info(f'query_doc:result {result.ids} {result.metadatas}') return result except Exception as e: - log.exception(f"Error querying doc {collection_name} with limit {k}: {e}") + log.exception(f'Error querying doc {collection_name} with limit {k}: {e}') raise e def get_doc(collection_name: str, user: UserModel = None): try: - log.debug(f"get_doc:doc {collection_name}") + log.debug(f'get_doc:doc {collection_name}') result = VECTOR_DB_CLIENT.get(collection_name=collection_name) if result: - log.info(f"query_doc:result {result.ids} {result.metadatas}") + log.info(f'query_doc:result {result.ids} {result.metadatas}') return result except Exception as e: - log.exception(f"Error getting doc {collection_name}: {e}") + log.exception(f'Error getting doc {collection_name}: {e}') raise e @@ -186,33 +182,29 @@ def get_enriched_texts(collection_result: GetResult) -> list[str]: metadata_parts = [text] # Add filename (repeat twice for extra weight in BM25 scoring) - if metadata.get("name"): - filename = metadata["name"] - filename_tokens = ( - filename.replace("_", " ").replace("-", " ").replace(".", " ") - ) - metadata_parts.append( - f"Filename: {filename} {filename_tokens} {filename_tokens}" - ) + if metadata.get('name'): + filename = metadata['name'] + filename_tokens = filename.replace('_', ' ').replace('-', ' ').replace('.', ' ') + metadata_parts.append(f'Filename: {filename} {filename_tokens} {filename_tokens}') # Add title if available - if metadata.get("title"): - metadata_parts.append(f"Title: {metadata['title']}") + if metadata.get('title'): + metadata_parts.append(f'Title: {metadata["title"]}') # Add document section headings if available (from markdown splitter) - if metadata.get("headings") and isinstance(metadata["headings"], list): - headings = " > ".join(str(h) for h in metadata["headings"]) - metadata_parts.append(f"Section: {headings}") + if metadata.get('headings') and isinstance(metadata['headings'], list): + headings = ' > '.join(str(h) for h in metadata['headings']) + metadata_parts.append(f'Section: {headings}') # Add source URL/path if available - if metadata.get("source"): - metadata_parts.append(f"Source: {metadata['source']}") + if metadata.get('source'): + metadata_parts.append(f'Source: {metadata["source"]}') # Add snippet for web search results - if metadata.get("snippet"): - metadata_parts.append(f"Snippet: {metadata['snippet']}") + if metadata.get('snippet'): + metadata_parts.append(f'Snippet: {metadata["snippet"]}') - enriched_texts.append(" ".join(metadata_parts)) + enriched_texts.append(' '.join(metadata_parts)) return enriched_texts @@ -233,11 +225,11 @@ async def query_doc_with_hybrid_search( # First check if collection_result has the required attributes if ( not collection_result - or not hasattr(collection_result, "documents") - or not hasattr(collection_result, "metadatas") + or not hasattr(collection_result, 'documents') + or not hasattr(collection_result, 'metadatas') ): - log.warning(f"query_doc_with_hybrid_search:no_docs {collection_name}") - return {"documents": [], "metadatas": [], "distances": []} + log.warning(f'query_doc_with_hybrid_search:no_docs {collection_name}') + return {'documents': [], 'metadatas': [], 'distances': []} # Now safely check the documents content after confirming attributes exist if ( @@ -245,10 +237,10 @@ async def query_doc_with_hybrid_search( or len(collection_result.documents) == 0 or not collection_result.documents[0] ): - log.warning(f"query_doc_with_hybrid_search:no_docs {collection_name}") - return {"documents": [], "metadatas": [], "distances": []} + log.warning(f'query_doc_with_hybrid_search:no_docs {collection_name}') + return {'documents': [], 'metadatas': [], 'distances': []} - log.debug(f"query_doc_with_hybrid_search:doc {collection_name}") + log.debug(f'query_doc_with_hybrid_search:doc {collection_name}') original_texts = collection_result.documents[0] bm25_metadatas = [ @@ -256,11 +248,7 @@ async def query_doc_with_hybrid_search( for idx, meta in enumerate(collection_result.metadatas[0]) ] - bm25_texts = ( - get_enriched_texts(collection_result) - if enable_enriched_texts - else original_texts - ) + bm25_texts = get_enriched_texts(collection_result) if enable_enriched_texts else original_texts bm25_retriever = BM25Retriever.from_texts( texts=bm25_texts, @@ -307,15 +295,13 @@ async def query_doc_with_hybrid_search( result = await compression_retriever.ainvoke(query) - distances = [d.metadata.get("score") for d in result] + distances = [d.metadata.get('score') for d in result] documents = [d.page_content for d in result] metadatas = [d.metadata for d in result] # retrieve only min(k, k_reranker) items, sort and cut by distance if k < k_reranker if k < k_reranker: - sorted_items = sorted( - zip(distances, documents, metadatas), key=lambda x: x[0], reverse=True - ) + sorted_items = sorted(zip(distances, documents, metadatas), key=lambda x: x[0], reverse=True) sorted_items = sorted_items[:k] if sorted_items: @@ -324,18 +310,15 @@ async def query_doc_with_hybrid_search( distances, documents, metadatas = [], [], [] result = { - "distances": [distances], - "documents": [documents], - "metadatas": [metadatas], + 'distances': [distances], + 'documents': [documents], + 'metadatas': [metadatas], } - log.info( - "query_doc_with_hybrid_search:result " - + f'{result["metadatas"]} {result["distances"]}' - ) + log.info('query_doc_with_hybrid_search:result ' + f'{result["metadatas"]} {result["distances"]}') return result except Exception as e: - log.exception(f"Error querying doc {collection_name} with hybrid search: {e}") + log.exception(f'Error querying doc {collection_name} with hybrid search: {e}') raise e @@ -346,15 +329,15 @@ def merge_get_results(get_results: list[dict]) -> dict: combined_ids = [] for data in get_results: - combined_documents.extend(data["documents"][0]) - combined_metadatas.extend(data["metadatas"][0]) - combined_ids.extend(data["ids"][0]) + combined_documents.extend(data['documents'][0]) + combined_metadatas.extend(data['metadatas'][0]) + combined_ids.extend(data['ids'][0]) # Create the output dictionary result = { - "documents": [combined_documents], - "metadatas": [combined_metadatas], - "ids": [combined_ids], + 'documents': [combined_documents], + 'metadatas': [combined_metadatas], + 'ids': [combined_ids], } return result @@ -366,21 +349,19 @@ def merge_and_sort_query_results(query_results: list[dict], k: int) -> dict: for data in query_results: if ( - len(data.get("distances", [])) == 0 - or len(data.get("documents", [])) == 0 - or len(data.get("metadatas", [])) == 0 + len(data.get('distances', [])) == 0 + or len(data.get('documents', [])) == 0 + or len(data.get('metadatas', [])) == 0 ): continue - distances = data["distances"][0] - documents = data["documents"][0] - metadatas = data["metadatas"][0] + distances = data['distances'][0] + documents = data['documents'][0] + metadatas = data['metadatas'][0] for distance, document, metadata in zip(distances, documents, metadatas): if isinstance(document, str): - doc_hash = hashlib.sha256( - document.encode() - ).hexdigest() # Compute a hash for uniqueness + doc_hash = hashlib.sha256(document.encode()).hexdigest() # Compute a hash for uniqueness if doc_hash not in combined.keys(): combined[doc_hash] = (distance, document, metadata) @@ -395,15 +376,13 @@ def merge_and_sort_query_results(query_results: list[dict], k: int) -> dict: combined.sort(key=lambda x: x[0], reverse=True) # Slice to keep only the top k elements - sorted_distances, sorted_documents, sorted_metadatas = ( - zip(*combined[:k]) if combined else ([], [], []) - ) + sorted_distances, sorted_documents, sorted_metadatas = zip(*combined[:k]) if combined else ([], [], []) # Create and return the output dictionary return { - "distances": [list(sorted_distances)], - "documents": [list(sorted_documents)], - "metadatas": [list(sorted_metadatas)], + 'distances': [list(sorted_distances)], + 'documents': [list(sorted_documents)], + 'metadatas': [list(sorted_metadatas)], } @@ -417,7 +396,7 @@ def get_all_items_from_collections(collection_names: list[str]) -> dict: if result is not None: results.append(result.model_dump()) except Exception as e: - log.exception(f"Error when querying the collection: {e}") + log.exception(f'Error when querying the collection: {e}') else: pass @@ -445,24 +424,18 @@ async def query_collection( return result.model_dump(), None return None, None except Exception as e: - log.exception(f"Error when querying the collection: {e}") + log.exception(f'Error when querying the collection: {e}') return None, e # Generate all query embeddings (in one call) - query_embeddings = await embedding_function( - queries, prefix=RAG_EMBEDDING_QUERY_PREFIX - ) - log.debug( - f"query_collection: processing {len(queries)} queries across {len(collection_names)} collections" - ) + query_embeddings = await embedding_function(queries, prefix=RAG_EMBEDDING_QUERY_PREFIX) + log.debug(f'query_collection: processing {len(queries)} queries across {len(collection_names)} collections') with ThreadPoolExecutor() as executor: future_results = [] for query_embedding in query_embeddings: for collection_name in collection_names: - result = executor.submit( - process_query_collection, collection_name, query_embedding - ) + result = executor.submit(process_query_collection, collection_name, query_embedding) future_results.append(result) task_results = [future.result() for future in future_results] @@ -473,7 +446,7 @@ async def query_collection( results.append(result) if error and not results: - log.warning("All collection queries failed. No results returned.") + log.warning('All collection queries failed. No results returned.') return merge_and_sort_query_results(results, k=k) @@ -496,19 +469,13 @@ async def query_collection_with_hybrid_search( collection_results = {} for collection_name in collection_names: try: - log.debug( - f"query_collection_with_hybrid_search:VECTOR_DB_CLIENT.get:collection {collection_name}" - ) - collection_results[collection_name] = VECTOR_DB_CLIENT.get( - collection_name=collection_name - ) + log.debug(f'query_collection_with_hybrid_search:VECTOR_DB_CLIENT.get:collection {collection_name}') + collection_results[collection_name] = VECTOR_DB_CLIENT.get(collection_name=collection_name) except Exception as e: - log.exception(f"Failed to fetch collection {collection_name}: {e}") + log.exception(f'Failed to fetch collection {collection_name}: {e}') collection_results[collection_name] = None - log.info( - f"Starting hybrid search for {len(queries)} queries in {len(collection_names)} collections..." - ) + log.info(f'Starting hybrid search for {len(queries)} queries in {len(collection_names)} collections...') async def process_query(collection_name, query): try: @@ -526,7 +493,7 @@ async def query_collection_with_hybrid_search( ) return result, None except Exception as e: - log.exception(f"Error when querying the collection with hybrid_search: {e}") + log.exception(f'Error when querying the collection with hybrid_search: {e}') return None, e # Prepare tasks for all collections and queries @@ -539,9 +506,7 @@ async def query_collection_with_hybrid_search( ] # Run all queries in parallel using asyncio.gather - task_results = await asyncio.gather( - *[process_query(collection_name, query) for collection_name, query in tasks] - ) + task_results = await asyncio.gather(*[process_query(collection_name, query) for collection_name, query in tasks]) for result, err in task_results: if err is not None: @@ -550,9 +515,7 @@ async def query_collection_with_hybrid_search( results.append(result) if error and not results: - raise Exception( - "Hybrid search failed for all collections. Using Non-hybrid search as fallback." - ) + raise Exception('Hybrid search failed for all collections. Using Non-hybrid search as fallback.') return merge_and_sort_query_results(results, k=k) @@ -560,63 +523,57 @@ async def query_collection_with_hybrid_search( def generate_openai_batch_embeddings( model: str, texts: list[str], - url: str = "https://api.openai.com/v1", - key: str = "", + url: str = 'https://api.openai.com/v1', + key: str = '', prefix: str = None, user: UserModel = None, ) -> Optional[list[list[float]]]: try: - log.debug( - f"generate_openai_batch_embeddings:model {model} batch size: {len(texts)}" - ) - json_data = {"input": texts, "model": model} + log.debug(f'generate_openai_batch_embeddings:model {model} batch size: {len(texts)}') + json_data = {'input': texts, 'model': model} if isinstance(RAG_EMBEDDING_PREFIX_FIELD_NAME, str) and isinstance(prefix, str): json_data[RAG_EMBEDDING_PREFIX_FIELD_NAME] = prefix headers = { - "Content-Type": "application/json", - "Authorization": f"Bearer {key}", + 'Content-Type': 'application/json', + 'Authorization': f'Bearer {key}', } if ENABLE_FORWARD_USER_INFO_HEADERS and user: headers = include_user_info_headers(headers, user) r = requests.post( - f"{url}/embeddings", + f'{url}/embeddings', headers=headers, json=json_data, ) r.raise_for_status() data = r.json() - if "data" in data: - return [elem["embedding"] for elem in data["data"]] + if 'data' in data: + return [elem['embedding'] for elem in data['data']] else: - raise ValueError( - "Unexpected OpenAI embeddings response: missing 'data' key" - ) + raise ValueError("Unexpected OpenAI embeddings response: missing 'data' key") except Exception as e: - log.exception(f"Error generating openai batch embeddings: {e}") + log.exception(f'Error generating openai batch embeddings: {e}') return None async def agenerate_openai_batch_embeddings( model: str, texts: list[str], - url: str = "https://api.openai.com/v1", - key: str = "", + url: str = 'https://api.openai.com/v1', + key: str = '', prefix: str = None, user: UserModel = None, ) -> Optional[list[list[float]]]: try: - log.debug( - f"agenerate_openai_batch_embeddings:model {model} batch size: {len(texts)}" - ) - form_data = {"input": texts, "model": model} + log.debug(f'agenerate_openai_batch_embeddings:model {model} batch size: {len(texts)}') + form_data = {'input': texts, 'model': model} if isinstance(RAG_EMBEDDING_PREFIX_FIELD_NAME, str) and isinstance(prefix, str): form_data[RAG_EMBEDDING_PREFIX_FIELD_NAME] = prefix headers = { - "Content-Type": "application/json", - "Authorization": f"Bearer {key}", + 'Content-Type': 'application/json', + 'Authorization': f'Bearer {key}', } if ENABLE_FORWARD_USER_INFO_HEADERS and user: headers = include_user_info_headers(headers, user) @@ -625,19 +582,19 @@ async def agenerate_openai_batch_embeddings( trust_env=True, timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT) ) as session: async with session.post( - f"{url}/embeddings", + f'{url}/embeddings', headers=headers, json=form_data, ssl=AIOHTTP_CLIENT_SESSION_SSL, ) as r: r.raise_for_status() data = await r.json() - if "data" in data: - return [item["embedding"] for item in data["data"]] + if 'data' in data: + return [item['embedding'] for item in data['data']] else: - raise Exception("Something went wrong :/") + raise Exception('Something went wrong :/') except Exception as e: - log.exception(f"Error generating openai batch embeddings: {e}") + log.exception(f'Error generating openai batch embeddings: {e}') return None @@ -645,25 +602,23 @@ def generate_azure_openai_batch_embeddings( model: str, texts: list[str], url: str, - key: str = "", - version: str = "", + key: str = '', + version: str = '', prefix: str = None, user: UserModel = None, ) -> Optional[list[list[float]]]: try: - log.debug( - f"generate_azure_openai_batch_embeddings:deployment {model} batch size: {len(texts)}" - ) - json_data = {"input": texts} + log.debug(f'generate_azure_openai_batch_embeddings:deployment {model} batch size: {len(texts)}') + json_data = {'input': texts} if isinstance(RAG_EMBEDDING_PREFIX_FIELD_NAME, str) and isinstance(prefix, str): json_data[RAG_EMBEDDING_PREFIX_FIELD_NAME] = prefix - url = f"{url}/openai/deployments/{model}/embeddings?api-version={version}" + url = f'{url}/openai/deployments/{model}/embeddings?api-version={version}' for _ in range(5): headers = { - "Content-Type": "application/json", - "api-key": key, + 'Content-Type': 'application/json', + 'api-key': key, } if ENABLE_FORWARD_USER_INFO_HEADERS and user: headers = include_user_info_headers(headers, user) @@ -674,18 +629,18 @@ def generate_azure_openai_batch_embeddings( json=json_data, ) if r.status_code == 429: - retry = float(r.headers.get("Retry-After", "1")) + retry = float(r.headers.get('Retry-After', '1')) time.sleep(retry) continue r.raise_for_status() data = r.json() - if "data" in data: - return [elem["embedding"] for elem in data["data"]] + if 'data' in data: + return [elem['embedding'] for elem in data['data']] else: - raise Exception("Something went wrong :/") + raise Exception('Something went wrong :/') return None except Exception as e: - log.exception(f"Error generating azure openai batch embeddings: {e}") + log.exception(f'Error generating azure openai batch embeddings: {e}') return None @@ -693,24 +648,22 @@ async def agenerate_azure_openai_batch_embeddings( model: str, texts: list[str], url: str, - key: str = "", - version: str = "", + key: str = '', + version: str = '', prefix: str = None, user: UserModel = None, ) -> Optional[list[list[float]]]: try: - log.debug( - f"agenerate_azure_openai_batch_embeddings:deployment {model} batch size: {len(texts)}" - ) - form_data = {"input": texts} + log.debug(f'agenerate_azure_openai_batch_embeddings:deployment {model} batch size: {len(texts)}') + form_data = {'input': texts} if isinstance(RAG_EMBEDDING_PREFIX_FIELD_NAME, str) and isinstance(prefix, str): form_data[RAG_EMBEDDING_PREFIX_FIELD_NAME] = prefix - full_url = f"{url}/openai/deployments/{model}/embeddings?api-version={version}" + full_url = f'{url}/openai/deployments/{model}/embeddings?api-version={version}' headers = { - "Content-Type": "application/json", - "api-key": key, + 'Content-Type': 'application/json', + 'api-key': key, } if ENABLE_FORWARD_USER_INFO_HEADERS and user: headers = include_user_info_headers(headers, user) @@ -726,12 +679,12 @@ async def agenerate_azure_openai_batch_embeddings( ) as r: r.raise_for_status() data = await r.json() - if "data" in data: - return [item["embedding"] for item in data["data"]] + if 'data' in data: + return [item['embedding'] for item in data['data']] else: - raise Exception("Something went wrong :/") + raise Exception('Something went wrong :/') except Exception as e: - log.exception(f"Error generating azure openai batch embeddings: {e}") + log.exception(f'Error generating azure openai batch embeddings: {e}') return None @@ -739,41 +692,37 @@ def generate_ollama_batch_embeddings( model: str, texts: list[str], url: str, - key: str = "", + key: str = '', prefix: str = None, user: UserModel = None, ) -> Optional[list[list[float]]]: try: - log.debug( - f"generate_ollama_batch_embeddings:model {model} batch size: {len(texts)}" - ) - json_data = {"input": texts, "model": model} + log.debug(f'generate_ollama_batch_embeddings:model {model} batch size: {len(texts)}') + json_data = {'input': texts, 'model': model} if isinstance(RAG_EMBEDDING_PREFIX_FIELD_NAME, str) and isinstance(prefix, str): json_data[RAG_EMBEDDING_PREFIX_FIELD_NAME] = prefix headers = { - "Content-Type": "application/json", - "Authorization": f"Bearer {key}", + 'Content-Type': 'application/json', + 'Authorization': f'Bearer {key}', } if ENABLE_FORWARD_USER_INFO_HEADERS and user: headers = include_user_info_headers(headers, user) r = requests.post( - f"{url}/api/embed", + f'{url}/api/embed', headers=headers, json=json_data, ) r.raise_for_status() data = r.json() - if "embeddings" in data: - return data["embeddings"] + if 'embeddings' in data: + return data['embeddings'] else: - raise ValueError( - "Unexpected Ollama embeddings response: missing 'embeddings' key" - ) + raise ValueError("Unexpected Ollama embeddings response: missing 'embeddings' key") except Exception as e: - log.exception(f"Error generating ollama batch embeddings: {e}") + log.exception(f'Error generating ollama batch embeddings: {e}') return None @@ -781,21 +730,19 @@ async def agenerate_ollama_batch_embeddings( model: str, texts: list[str], url: str, - key: str = "", + key: str = '', prefix: str = None, user: UserModel = None, ) -> Optional[list[list[float]]]: try: - log.debug( - f"agenerate_ollama_batch_embeddings:model {model} batch size: {len(texts)}" - ) - form_data = {"input": texts, "model": model} + log.debug(f'agenerate_ollama_batch_embeddings:model {model} batch size: {len(texts)}') + form_data = {'input': texts, 'model': model} if isinstance(RAG_EMBEDDING_PREFIX_FIELD_NAME, str) and isinstance(prefix, str): form_data[RAG_EMBEDDING_PREFIX_FIELD_NAME] = prefix headers = { - "Content-Type": "application/json", - "Authorization": f"Bearer {key}", + 'Content-Type': 'application/json', + 'Authorization': f'Bearer {key}', } if ENABLE_FORWARD_USER_INFO_HEADERS and user: headers = include_user_info_headers(headers, user) @@ -804,19 +751,19 @@ async def agenerate_ollama_batch_embeddings( trust_env=True, timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT) ) as session: async with session.post( - f"{url}/api/embed", + f'{url}/api/embed', headers=headers, json=form_data, ssl=AIOHTTP_CLIENT_SESSION_SSL, ) as r: r.raise_for_status() data = await r.json() - if "embeddings" in data: - return data["embeddings"] + if 'embeddings' in data: + return data['embeddings'] else: - raise Exception("Something went wrong :/") + raise Exception('Something went wrong :/') except Exception as e: - log.exception(f"Error generating ollama batch embeddings: {e}") + log.exception(f'Error generating ollama batch embeddings: {e}') return None @@ -831,7 +778,7 @@ def get_embedding_function( enable_async=True, concurrent_requests=0, ) -> Awaitable: - if embedding_engine == "": + if embedding_engine == '': # Sentence transformers: CPU-bound sync operation async def async_embedding_function(query, prefix=None, user=None): return await asyncio.to_thread( @@ -839,7 +786,7 @@ def get_embedding_function( lambda query, prefix=None: embedding_function.encode( query, batch_size=int(embedding_batch_size), - **({"prompt": prefix} if prefix else {}), + **({'prompt': prefix} if prefix else {}), ).tolist() ), query, @@ -847,7 +794,7 @@ def get_embedding_function( ) return async_embedding_function - elif embedding_engine in ["ollama", "openai", "azure_openai"]: + elif embedding_engine in ['ollama', 'openai', 'azure_openai']: embedding_function = lambda query, prefix=None, user=None: generate_embeddings( engine=embedding_engine, model=embedding_model, @@ -862,15 +809,10 @@ def get_embedding_function( async def async_embedding_function(query, prefix=None, user=None): if isinstance(query, list): # Create batches - batches = [ - query[i : i + embedding_batch_size] - for i in range(0, len(query), embedding_batch_size) - ] + batches = [query[i : i + embedding_batch_size] for i in range(0, len(query), embedding_batch_size)] if enable_async: - log.debug( - f"generate_multiple_async: Processing {len(batches)} batches in parallel" - ) + log.debug(f'generate_multiple_async: Processing {len(batches)} batches in parallel') # Use semaphore to limit concurrent embedding API requests # 0 = unlimited (no semaphore) if concurrent_requests: @@ -878,28 +820,17 @@ def get_embedding_function( async def generate_batch_with_semaphore(batch): async with semaphore: - return await embedding_function( - batch, prefix=prefix, user=user - ) + return await embedding_function(batch, prefix=prefix, user=user) - tasks = [ - generate_batch_with_semaphore(batch) for batch in batches - ] + tasks = [generate_batch_with_semaphore(batch) for batch in batches] else: - tasks = [ - embedding_function(batch, prefix=prefix, user=user) - for batch in batches - ] + tasks = [embedding_function(batch, prefix=prefix, user=user) for batch in batches] batch_results = await asyncio.gather(*tasks) else: - log.debug( - f"generate_multiple_async: Processing {len(batches)} batches sequentially" - ) + log.debug(f'generate_multiple_async: Processing {len(batches)} batches sequentially') batch_results = [] for batch in batches: - batch_results.append( - await embedding_function(batch, prefix=prefix, user=user) - ) + batch_results.append(await embedding_function(batch, prefix=prefix, user=user)) # Flatten results embeddings = [] @@ -908,7 +839,7 @@ def get_embedding_function( embeddings.extend(batch_embeddings) log.debug( - f"generate_multiple_async: Generated {len(embeddings)} embeddings from {len(batches)} parallel batches" + f'generate_multiple_async: Generated {len(embeddings)} embeddings from {len(batches)} parallel batches' ) return embeddings else: @@ -916,7 +847,7 @@ def get_embedding_function( return async_embedding_function else: - raise ValueError(f"Unknown embedding engine: {embedding_engine}") + raise ValueError(f'Unknown embedding engine: {embedding_engine}') async def generate_embeddings( @@ -926,35 +857,35 @@ async def generate_embeddings( prefix: Union[str, None] = None, **kwargs, ): - url = kwargs.get("url", "") - key = kwargs.get("key", "") - user = kwargs.get("user") + url = kwargs.get('url', '') + key = kwargs.get('key', '') + user = kwargs.get('user') if prefix is not None and RAG_EMBEDDING_PREFIX_FIELD_NAME is None: if isinstance(text, list): - text = [f"{prefix}{text_element}" for text_element in text] + text = [f'{prefix}{text_element}' for text_element in text] else: - text = f"{prefix}{text}" + text = f'{prefix}{text}' - if engine == "ollama": + if engine == 'ollama': embeddings = await agenerate_ollama_batch_embeddings( **{ - "model": model, - "texts": text if isinstance(text, list) else [text], - "url": url, - "key": key, - "prefix": prefix, - "user": user, + 'model': model, + 'texts': text if isinstance(text, list) else [text], + 'url': url, + 'key': key, + 'prefix': prefix, + 'user': user, } ) return embeddings[0] if isinstance(text, str) else embeddings - elif engine == "openai": + elif engine == 'openai': embeddings = await agenerate_openai_batch_embeddings( model, text if isinstance(text, list) else [text], url, key, prefix, user ) return embeddings[0] if isinstance(text, str) else embeddings - elif engine == "azure_openai": - azure_api_version = kwargs.get("azure_api_version", "") + elif engine == 'azure_openai': + azure_api_version = kwargs.get('azure_api_version', '') embeddings = await agenerate_azure_openai_batch_embeddings( model, text if isinstance(text, list) else [text], @@ -970,7 +901,7 @@ async def generate_embeddings( def get_reranking_function(reranking_engine, reranking_model, reranking_function): if reranking_function is None: return None - if reranking_engine == "external": + if reranking_engine == 'external': return lambda query, documents, user=None: reranking_function.predict( [(query, doc.page_content) for doc in documents], user=user ) @@ -994,9 +925,7 @@ async def get_sources_from_items( full_context=False, user: Optional[UserModel] = None, ): - log.debug( - f"items: {items} {queries} {embedding_function} {reranking_function} {full_context}" - ) + log.debug(f'items: {items} {queries} {embedding_function} {reranking_function} {full_context}') extracted_collections = [] query_results = [] @@ -1005,165 +934,146 @@ async def get_sources_from_items( query_result = None collection_names = [] - if item.get("type") == "text": + if item.get('type') == 'text': # Raw Text # Used during temporary chat file uploads or web page & youtube attachements - if item.get("context") == "full": - if item.get("file"): + if item.get('context') == 'full': + if item.get('file'): # if item has file data, use it query_result = { - "documents": [ - [item.get("file", {}).get("data", {}).get("content")] - ], - "metadatas": [[item.get("file", {}).get("meta", {})]], + 'documents': [[item.get('file', {}).get('data', {}).get('content')]], + 'metadatas': [[item.get('file', {}).get('meta', {})]], } if query_result is None: # Fallback - if item.get("collection_name"): + if item.get('collection_name'): # If item has a collection name, use it - collection_names.append(item.get("collection_name")) - elif item.get("file"): + collection_names.append(item.get('collection_name')) + elif item.get('file'): # If item has file data, use it query_result = { - "documents": [ - [item.get("file", {}).get("data", {}).get("content")] - ], - "metadatas": [[item.get("file", {}).get("meta", {})]], + 'documents': [[item.get('file', {}).get('data', {}).get('content')]], + 'metadatas': [[item.get('file', {}).get('meta', {})]], } else: # Fallback to item content query_result = { - "documents": [[item.get("content")]], - "metadatas": [ - [{"file_id": item.get("id"), "name": item.get("name")}] - ], + 'documents': [[item.get('content')]], + 'metadatas': [[{'file_id': item.get('id'), 'name': item.get('name')}]], } - elif item.get("type") == "note": + elif item.get('type') == 'note': # Note Attached - note = Notes.get_note_by_id(item.get("id")) + note = Notes.get_note_by_id(item.get('id')) if note and ( - user.role == "admin" + user.role == 'admin' or note.user_id == user.id or AccessGrants.has_access( user_id=user.id, - resource_type="note", + resource_type='note', resource_id=note.id, - permission="read", + permission='read', ) ): # User has access to the note query_result = { - "documents": [[note.data.get("content", {}).get("md", "")]], - "metadatas": [[{"file_id": note.id, "name": note.title}]], + 'documents': [[note.data.get('content', {}).get('md', '')]], + 'metadatas': [[{'file_id': note.id, 'name': note.title}]], } - elif item.get("type") == "chat": + elif item.get('type') == 'chat': # Chat Attached - chat = Chats.get_chat_by_id(item.get("id")) + chat = Chats.get_chat_by_id(item.get('id')) - if chat and (user.role == "admin" or chat.user_id == user.id): - messages_map = chat.chat.get("history", {}).get("messages", {}) - message_id = chat.chat.get("history", {}).get("currentId") + if chat and (user.role == 'admin' or chat.user_id == user.id): + messages_map = chat.chat.get('history', {}).get('messages', {}) + message_id = chat.chat.get('history', {}).get('currentId') if messages_map and message_id: # Reconstruct the message list in order message_list = get_message_list(messages_map, message_id) - message_history = "\n".join( - [ - f"#### {m.get('role', 'user').capitalize()}\n{m.get('content')}\n" - for m in message_list - ] + message_history = '\n'.join( + [f'#### {m.get("role", "user").capitalize()}\n{m.get("content")}\n' for m in message_list] ) # User has access to the chat query_result = { - "documents": [[message_history]], - "metadatas": [[{"file_id": chat.id, "name": chat.title}]], + 'documents': [[message_history]], + 'metadatas': [[{'file_id': chat.id, 'name': chat.title}]], } - elif item.get("type") == "url": - content, docs = get_content_from_url(request, item.get("url")) + elif item.get('type') == 'url': + content, docs = get_content_from_url(request, item.get('url')) if docs: query_result = { - "documents": [[content]], - "metadatas": [[{"url": item.get("url"), "name": item.get("url")}]], + 'documents': [[content]], + 'metadatas': [[{'url': item.get('url'), 'name': item.get('url')}]], } - elif item.get("type") == "file": - if ( - item.get("context") == "full" - or request.app.state.config.BYPASS_EMBEDDING_AND_RETRIEVAL - ): - if item.get("file", {}).get("data", {}).get("content", ""): + elif item.get('type') == 'file': + if item.get('context') == 'full' or request.app.state.config.BYPASS_EMBEDDING_AND_RETRIEVAL: + if item.get('file', {}).get('data', {}).get('content', ''): # Manual Full Mode Toggle # Used from chat file modal, we can assume that the file content will be available from item.get("file").get("data", {}).get("content") query_result = { - "documents": [ - [item.get("file", {}).get("data", {}).get("content", "")] - ], - "metadatas": [ + 'documents': [[item.get('file', {}).get('data', {}).get('content', '')]], + 'metadatas': [ [ { - "file_id": item.get("id"), - "name": item.get("name"), - **item.get("file") - .get("data", {}) - .get("metadata", {}), + 'file_id': item.get('id'), + 'name': item.get('name'), + **item.get('file').get('data', {}).get('metadata', {}), } ] ], } - elif item.get("id"): - file_object = Files.get_file_by_id(item.get("id")) + elif item.get('id'): + file_object = Files.get_file_by_id(item.get('id')) if file_object: query_result = { - "documents": [[file_object.data.get("content", "")]], - "metadatas": [ + 'documents': [[file_object.data.get('content', '')]], + 'metadatas': [ [ { - "file_id": item.get("id"), - "name": file_object.filename, - "source": file_object.filename, + 'file_id': item.get('id'), + 'name': file_object.filename, + 'source': file_object.filename, } ] ], } else: # Fallback to collection names - if item.get("legacy"): - collection_names.append(f"{item['id']}") + if item.get('legacy'): + collection_names.append(f'{item["id"]}') else: - collection_names.append(f"file-{item['id']}") + collection_names.append(f'file-{item["id"]}') - elif item.get("type") == "collection": + elif item.get('type') == 'collection': # Manual Full Mode Toggle for Collection - knowledge_base = Knowledges.get_knowledge_by_id(item.get("id")) + knowledge_base = Knowledges.get_knowledge_by_id(item.get('id')) if knowledge_base and ( - user.role == "admin" + user.role == 'admin' or knowledge_base.user_id == user.id or AccessGrants.has_access( user_id=user.id, - resource_type="knowledge", + resource_type='knowledge', resource_id=knowledge_base.id, - permission="read", + permission='read', ) ): - if ( - item.get("context") == "full" - or request.app.state.config.BYPASS_EMBEDDING_AND_RETRIEVAL - ): + if item.get('context') == 'full' or request.app.state.config.BYPASS_EMBEDDING_AND_RETRIEVAL: if knowledge_base and ( - user.role == "admin" + user.role == 'admin' or knowledge_base.user_id == user.id or AccessGrants.has_access( user_id=user.id, - resource_type="knowledge", + resource_type='knowledge', resource_id=knowledge_base.id, - permission="read", + permission='read', ) ): files = Knowledges.get_files_by_id(knowledge_base.id) @@ -1171,45 +1081,45 @@ async def get_sources_from_items( documents = [] metadatas = [] for file in files: - documents.append(file.data.get("content", "")) + documents.append(file.data.get('content', '')) metadatas.append( { - "file_id": file.id, - "name": file.filename, - "source": file.filename, + 'file_id': file.id, + 'name': file.filename, + 'source': file.filename, } ) query_result = { - "documents": [documents], - "metadatas": [metadatas], + 'documents': [documents], + 'metadatas': [metadatas], } else: # Fallback to collection names - if item.get("legacy"): - collection_names = item.get("collection_names", []) + if item.get('legacy'): + collection_names = item.get('collection_names', []) else: - collection_names.append(item["id"]) + collection_names.append(item['id']) - elif item.get("docs"): + elif item.get('docs'): # BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL query_result = { - "documents": [[doc.get("content") for doc in item.get("docs")]], - "metadatas": [[doc.get("metadata") for doc in item.get("docs")]], + 'documents': [[doc.get('content') for doc in item.get('docs')]], + 'metadatas': [[doc.get('metadata') for doc in item.get('docs')]], } - elif item.get("collection_name"): + elif item.get('collection_name'): # Direct Collection Name - collection_names.append(item["collection_name"]) - elif item.get("collection_names"): + collection_names.append(item['collection_name']) + elif item.get('collection_names'): # Collection Names List - collection_names.extend(item["collection_names"]) + collection_names.extend(item['collection_names']) # If query_result is None # Fallback to collection names and vector search the collections if query_result is None and collection_names: collection_names = set(collection_names).difference(extracted_collections) if not collection_names: - log.debug(f"skipping {item} as it has already been extracted") + log.debug(f'skipping {item} as it has already been extracted') continue try: @@ -1231,9 +1141,7 @@ async def get_sources_from_items( enable_enriched_texts=request.app.state.config.ENABLE_RAG_HYBRID_SEARCH_ENRICHED_TEXTS, ) except Exception as e: - log.debug( - "Error when using hybrid search, using non hybrid search as fallback." - ) + log.debug('Error when using hybrid search, using non hybrid search as fallback.') # fallback to non-hybrid search if not hybrid_search and query_result is None: @@ -1249,22 +1157,22 @@ async def get_sources_from_items( extracted_collections.extend(collection_names) if query_result: - if "data" in item: - del item["data"] - query_results.append({**query_result, "file": item}) + if 'data' in item: + del item['data'] + query_results.append({**query_result, 'file': item}) sources = [] for query_result in query_results: try: - if "documents" in query_result: - if "metadatas" in query_result: + if 'documents' in query_result: + if 'metadatas' in query_result: source = { - "source": query_result["file"], - "document": query_result["documents"][0], - "metadata": query_result["metadatas"][0], + 'source': query_result['file'], + 'document': query_result['documents'][0], + 'metadata': query_result['metadatas'][0], } - if "distances" in query_result and query_result["distances"]: - source["distances"] = query_result["distances"][0] + if 'distances' in query_result and query_result['distances']: + source['distances'] = query_result['distances'][0] sources.append(source) except Exception as e: @@ -1274,7 +1182,7 @@ async def get_sources_from_items( def get_model_path(model: str, update_model: bool = False): # Construct huggingface_hub kwargs with local_files_only to return the snapshot path - cache_dir = os.getenv("SENTENCE_TRANSFORMERS_HOME") + cache_dir = os.getenv('SENTENCE_TRANSFORMERS_HOME') local_files_only = not update_model @@ -1282,34 +1190,30 @@ def get_model_path(model: str, update_model: bool = False): local_files_only = True snapshot_kwargs = { - "cache_dir": cache_dir, - "local_files_only": local_files_only, + 'cache_dir': cache_dir, + 'local_files_only': local_files_only, } - log.debug(f"model: {model}") - log.debug(f"snapshot_kwargs: {snapshot_kwargs}") + log.debug(f'model: {model}') + log.debug(f'snapshot_kwargs: {snapshot_kwargs}') # Inspiration from upstream sentence_transformers - if ( - os.path.exists(model) - or ("\\" in model or model.count("/") > 1) - and local_files_only - ): + if os.path.exists(model) or ('\\' in model or model.count('/') > 1) and local_files_only: # If fully qualified path exists, return input, else set repo_id return model - elif "/" not in model: + elif '/' not in model: # Set valid repo_id for model short-name - model = "sentence-transformers" + "/" + model + model = 'sentence-transformers' + '/' + model - snapshot_kwargs["repo_id"] = model + snapshot_kwargs['repo_id'] = model # Attempt to query the huggingface_hub library to determine the local path and/or to update try: model_repo_path = snapshot_download(**snapshot_kwargs) - log.debug(f"model_repo_path: {model_repo_path}") + log.debug(f'model_repo_path: {model_repo_path}') return model_repo_path except Exception as e: - log.exception(f"Cannot determine model snapshot path: {e}") + log.exception(f'Cannot determine model snapshot path: {e}') if OFFLINE_MODE: raise return model @@ -1329,7 +1233,7 @@ class RerankCompressor(BaseDocumentCompressor): r_score: float class Config: - extra = "forbid" + extra = 'forbid' arbitrary_types_allowed = True def compress_documents( @@ -1365,9 +1269,7 @@ class RerankCompressor(BaseDocumentCompressor): else: from sentence_transformers import util - query_embedding = await self.embedding_function( - query, RAG_EMBEDDING_QUERY_PREFIX - ) + query_embedding = await self.embedding_function(query, RAG_EMBEDDING_QUERY_PREFIX) document_embedding = await self.embedding_function( [doc.page_content for doc in documents], RAG_EMBEDDING_CONTENT_PREFIX ) @@ -1381,15 +1283,13 @@ class RerankCompressor(BaseDocumentCompressor): ) ) if self.r_score: - docs_with_scores = [ - (d, s) for d, s in docs_with_scores if s >= self.r_score - ] + docs_with_scores = [(d, s) for d, s in docs_with_scores if s >= self.r_score] result = sorted(docs_with_scores, key=operator.itemgetter(1), reverse=True) final_results = [] for doc, doc_score in result[: self.top_n]: metadata = doc.metadata - metadata["score"] = doc_score + metadata['score'] = doc_score doc = Document( page_content=doc.page_content, metadata=metadata, @@ -1397,7 +1297,5 @@ class RerankCompressor(BaseDocumentCompressor): final_results.append(doc) return final_results else: - log.warning( - "No valid scores found, check your reranking function. Returning original documents." - ) + log.warning('No valid scores found, check your reranking function. Returning original documents.') return documents diff --git a/backend/open_webui/retrieval/vector/dbs/chroma.py b/backend/open_webui/retrieval/vector/dbs/chroma.py index b8dc88945b..4ace732b2d 100755 --- a/backend/open_webui/retrieval/vector/dbs/chroma.py +++ b/backend/open_webui/retrieval/vector/dbs/chroma.py @@ -31,17 +31,15 @@ log = logging.getLogger(__name__) class ChromaClient(VectorDBBase): def __init__(self): settings_dict = { - "allow_reset": True, - "anonymized_telemetry": False, + 'allow_reset': True, + 'anonymized_telemetry': False, } if CHROMA_CLIENT_AUTH_PROVIDER is not None: - settings_dict["chroma_client_auth_provider"] = CHROMA_CLIENT_AUTH_PROVIDER + settings_dict['chroma_client_auth_provider'] = CHROMA_CLIENT_AUTH_PROVIDER if CHROMA_CLIENT_AUTH_CREDENTIALS is not None: - settings_dict["chroma_client_auth_credentials"] = ( - CHROMA_CLIENT_AUTH_CREDENTIALS - ) + settings_dict['chroma_client_auth_credentials'] = CHROMA_CLIENT_AUTH_CREDENTIALS - if CHROMA_HTTP_HOST != "": + if CHROMA_HTTP_HOST != '': self.client = chromadb.HttpClient( host=CHROMA_HTTP_HOST, port=CHROMA_HTTP_PORT, @@ -87,25 +85,23 @@ class ChromaClient(VectorDBBase): # chromadb has cosine distance, 2 (worst) -> 0 (best). Re-odering to 0 -> 1 # https://docs.trychroma.com/docs/collections/configure cosine equation - distances: list = result["distances"][0] + distances: list = result['distances'][0] distances = [2 - dist for dist in distances] distances = [[dist / 2 for dist in distances]] return SearchResult( **{ - "ids": result["ids"], - "distances": distances, - "documents": result["documents"], - "metadatas": result["metadatas"], + 'ids': result['ids'], + 'distances': distances, + 'documents': result['documents'], + 'metadatas': result['metadatas'], } ) return None except Exception as e: return None - def query( - self, collection_name: str, filter: dict, limit: Optional[int] = None - ) -> Optional[GetResult]: + def query(self, collection_name: str, filter: dict, limit: Optional[int] = None) -> Optional[GetResult]: # Query the items from the collection based on the filter. try: collection = self.client.get_collection(name=collection_name) @@ -117,9 +113,9 @@ class ChromaClient(VectorDBBase): return GetResult( **{ - "ids": [result["ids"]], - "documents": [result["documents"]], - "metadatas": [result["metadatas"]], + 'ids': [result['ids']], + 'documents': [result['documents']], + 'metadatas': [result['metadatas']], } ) return None @@ -133,23 +129,21 @@ class ChromaClient(VectorDBBase): result = collection.get() return GetResult( **{ - "ids": [result["ids"]], - "documents": [result["documents"]], - "metadatas": [result["metadatas"]], + 'ids': [result['ids']], + 'documents': [result['documents']], + 'metadatas': [result['metadatas']], } ) return None def insert(self, collection_name: str, items: list[VectorItem]): # Insert the items into the collection, if the collection does not exist, it will be created. - collection = self.client.get_or_create_collection( - name=collection_name, metadata={"hnsw:space": "cosine"} - ) + collection = self.client.get_or_create_collection(name=collection_name, metadata={'hnsw:space': 'cosine'}) - ids = [item["id"] for item in items] - documents = [item["text"] for item in items] - embeddings = [item["vector"] for item in items] - metadatas = [process_metadata(item["metadata"]) for item in items] + ids = [item['id'] for item in items] + documents = [item['text'] for item in items] + embeddings = [item['vector'] for item in items] + metadatas = [process_metadata(item['metadata']) for item in items] for batch in create_batches( api=self.client, @@ -162,18 +156,14 @@ class ChromaClient(VectorDBBase): def upsert(self, collection_name: str, items: list[VectorItem]): # Update the items in the collection, if the items are not present, insert them. If the collection does not exist, it will be created. - collection = self.client.get_or_create_collection( - name=collection_name, metadata={"hnsw:space": "cosine"} - ) + collection = self.client.get_or_create_collection(name=collection_name, metadata={'hnsw:space': 'cosine'}) - ids = [item["id"] for item in items] - documents = [item["text"] for item in items] - embeddings = [item["vector"] for item in items] - metadatas = [process_metadata(item["metadata"]) for item in items] + ids = [item['id'] for item in items] + documents = [item['text'] for item in items] + embeddings = [item['vector'] for item in items] + metadatas = [process_metadata(item['metadata']) for item in items] - collection.upsert( - ids=ids, documents=documents, embeddings=embeddings, metadatas=metadatas - ) + collection.upsert(ids=ids, documents=documents, embeddings=embeddings, metadatas=metadatas) def delete( self, @@ -191,9 +181,7 @@ class ChromaClient(VectorDBBase): collection.delete(where=filter) except Exception as e: # If collection doesn't exist, that's fine - nothing to delete - log.debug( - f"Attempted to delete from non-existent collection {collection_name}. Ignoring." - ) + log.debug(f'Attempted to delete from non-existent collection {collection_name}. Ignoring.') pass def reset(self): diff --git a/backend/open_webui/retrieval/vector/dbs/elasticsearch.py b/backend/open_webui/retrieval/vector/dbs/elasticsearch.py index dfb02ec029..201a5e1706 100644 --- a/backend/open_webui/retrieval/vector/dbs/elasticsearch.py +++ b/backend/open_webui/retrieval/vector/dbs/elasticsearch.py @@ -51,7 +51,7 @@ class ElasticsearchClient(VectorDBBase): # Status: works def _get_index_name(self, dimension: int) -> str: - return f"{self.index_prefix}_d{str(dimension)}" + return f'{self.index_prefix}_d{str(dimension)}' # Status: works def _scan_result_to_get_result(self, result) -> GetResult: @@ -62,24 +62,24 @@ class ElasticsearchClient(VectorDBBase): metadatas = [] for hit in result: - ids.append(hit["_id"]) - documents.append(hit["_source"].get("text")) - metadatas.append(hit["_source"].get("metadata")) + ids.append(hit['_id']) + documents.append(hit['_source'].get('text')) + metadatas.append(hit['_source'].get('metadata')) return GetResult(ids=[ids], documents=[documents], metadatas=[metadatas]) # Status: works def _result_to_get_result(self, result) -> GetResult: - if not result["hits"]["hits"]: + if not result['hits']['hits']: return None ids = [] documents = [] metadatas = [] - for hit in result["hits"]["hits"]: - ids.append(hit["_id"]) - documents.append(hit["_source"].get("text")) - metadatas.append(hit["_source"].get("metadata")) + for hit in result['hits']['hits']: + ids.append(hit['_id']) + documents.append(hit['_source'].get('text')) + metadatas.append(hit['_source'].get('metadata')) return GetResult(ids=[ids], documents=[documents], metadatas=[metadatas]) @@ -90,11 +90,11 @@ class ElasticsearchClient(VectorDBBase): documents = [] metadatas = [] - for hit in result["hits"]["hits"]: - ids.append(hit["_id"]) - distances.append(hit["_score"]) - documents.append(hit["_source"].get("text")) - metadatas.append(hit["_source"].get("metadata")) + for hit in result['hits']['hits']: + ids.append(hit['_id']) + distances.append(hit['_score']) + documents.append(hit['_source'].get('text')) + metadatas.append(hit['_source'].get('metadata')) return SearchResult( ids=[ids], @@ -106,26 +106,26 @@ class ElasticsearchClient(VectorDBBase): # Status: works def _create_index(self, dimension: int): body = { - "mappings": { - "dynamic_templates": [ + 'mappings': { + 'dynamic_templates': [ { - "strings": { - "match_mapping_type": "string", - "mapping": {"type": "keyword"}, + 'strings': { + 'match_mapping_type': 'string', + 'mapping': {'type': 'keyword'}, } } ], - "properties": { - "collection": {"type": "keyword"}, - "id": {"type": "keyword"}, - "vector": { - "type": "dense_vector", - "dims": dimension, # Adjust based on your vector dimensions - "index": True, - "similarity": "cosine", + 'properties': { + 'collection': {'type': 'keyword'}, + 'id': {'type': 'keyword'}, + 'vector': { + 'type': 'dense_vector', + 'dims': dimension, # Adjust based on your vector dimensions + 'index': True, + 'similarity': 'cosine', }, - "text": {"type": "text"}, - "metadata": {"type": "object"}, + 'text': {'type': 'text'}, + 'metadata': {'type': 'object'}, }, } } @@ -139,21 +139,19 @@ class ElasticsearchClient(VectorDBBase): # Status: works def has_collection(self, collection_name) -> bool: - query_body = {"query": {"bool": {"filter": []}}} - query_body["query"]["bool"]["filter"].append( - {"term": {"collection": collection_name}} - ) + query_body = {'query': {'bool': {'filter': []}}} + query_body['query']['bool']['filter'].append({'term': {'collection': collection_name}}) try: - result = self.client.count(index=f"{self.index_prefix}*", body=query_body) + result = self.client.count(index=f'{self.index_prefix}*', body=query_body) - return result.body["count"] > 0 + return result.body['count'] > 0 except Exception as e: return None def delete_collection(self, collection_name: str): - query = {"query": {"term": {"collection": collection_name}}} - self.client.delete_by_query(index=f"{self.index_prefix}*", body=query) + query = {'query': {'term': {'collection': collection_name}}} + self.client.delete_by_query(index=f'{self.index_prefix}*', body=query) # Status: works def search( @@ -164,51 +162,41 @@ class ElasticsearchClient(VectorDBBase): limit: int = 10, ) -> Optional[SearchResult]: query = { - "size": limit, - "_source": ["text", "metadata"], - "query": { - "script_score": { - "query": { - "bool": {"filter": [{"term": {"collection": collection_name}}]} - }, - "script": { - "source": "cosineSimilarity(params.vector, 'vector') + 1.0", - "params": { - "vector": vectors[0] - }, # Assuming single query vector + 'size': limit, + '_source': ['text', 'metadata'], + 'query': { + 'script_score': { + 'query': {'bool': {'filter': [{'term': {'collection': collection_name}}]}}, + 'script': { + 'source': "cosineSimilarity(params.vector, 'vector') + 1.0", + 'params': {'vector': vectors[0]}, # Assuming single query vector }, } }, } - result = self.client.search( - index=self._get_index_name(len(vectors[0])), body=query - ) + result = self.client.search(index=self._get_index_name(len(vectors[0])), body=query) return self._result_to_search_result(result) # Status: only tested halfwat - def query( - self, collection_name: str, filter: dict, limit: Optional[int] = None - ) -> Optional[GetResult]: + def query(self, collection_name: str, filter: dict, limit: Optional[int] = None) -> Optional[GetResult]: if not self.has_collection(collection_name): return None query_body = { - "query": {"bool": {"filter": []}}, - "_source": ["text", "metadata"], + 'query': {'bool': {'filter': []}}, + '_source': ['text', 'metadata'], } for field, value in filter.items(): - query_body["query"]["bool"]["filter"].append({"term": {field: value}}) - query_body["query"]["bool"]["filter"].append( - {"term": {"collection": collection_name}} - ) + query_body['query']['bool']['filter'].append({'term': {field: value}}) + query_body['query']['bool']['filter'].append({'term': {'collection': collection_name}}) size = limit if limit else 10 try: result = self.client.search( - index=f"{self.index_prefix}*", + index=f'{self.index_prefix}*', body=query_body, size=size, ) @@ -220,9 +208,7 @@ class ElasticsearchClient(VectorDBBase): # Status: works def _has_index(self, dimension: int): - return self.client.indices.exists( - index=self._get_index_name(dimension=dimension) - ) + return self.client.indices.exists(index=self._get_index_name(dimension=dimension)) def get_or_create_index(self, dimension: int): if not self._has_index(dimension=dimension): @@ -232,28 +218,28 @@ class ElasticsearchClient(VectorDBBase): def get(self, collection_name: str) -> Optional[GetResult]: # Get all the items in the collection. query = { - "query": {"bool": {"filter": [{"term": {"collection": collection_name}}]}}, - "_source": ["text", "metadata"], + 'query': {'bool': {'filter': [{'term': {'collection': collection_name}}]}}, + '_source': ['text', 'metadata'], } - results = list(scan(self.client, index=f"{self.index_prefix}*", query=query)) + results = list(scan(self.client, index=f'{self.index_prefix}*', query=query)) return self._scan_result_to_get_result(results) # Status: works def insert(self, collection_name: str, items: list[VectorItem]): - if not self._has_index(dimension=len(items[0]["vector"])): - self._create_index(dimension=len(items[0]["vector"])) + if not self._has_index(dimension=len(items[0]['vector'])): + self._create_index(dimension=len(items[0]['vector'])) for batch in self._create_batches(items): actions = [ { - "_index": self._get_index_name(dimension=len(items[0]["vector"])), - "_id": item["id"], - "_source": { - "collection": collection_name, - "vector": item["vector"], - "text": item["text"], - "metadata": process_metadata(item["metadata"]), + '_index': self._get_index_name(dimension=len(items[0]['vector'])), + '_id': item['id'], + '_source': { + 'collection': collection_name, + 'vector': item['vector'], + 'text': item['text'], + 'metadata': process_metadata(item['metadata']), }, } for item in batch @@ -262,21 +248,21 @@ class ElasticsearchClient(VectorDBBase): # Upsert documents using the update API with doc_as_upsert=True. def upsert(self, collection_name: str, items: list[VectorItem]): - if not self._has_index(dimension=len(items[0]["vector"])): - self._create_index(dimension=len(items[0]["vector"])) + if not self._has_index(dimension=len(items[0]['vector'])): + self._create_index(dimension=len(items[0]['vector'])) for batch in self._create_batches(items): actions = [ { - "_op_type": "update", - "_index": self._get_index_name(dimension=len(item["vector"])), - "_id": item["id"], - "doc": { - "collection": collection_name, - "vector": item["vector"], - "text": item["text"], - "metadata": process_metadata(item["metadata"]), + '_op_type': 'update', + '_index': self._get_index_name(dimension=len(item['vector'])), + '_id': item['id'], + 'doc': { + 'collection': collection_name, + 'vector': item['vector'], + 'text': item['text'], + 'metadata': process_metadata(item['metadata']), }, - "doc_as_upsert": True, + 'doc_as_upsert': True, } for item in batch ] @@ -289,22 +275,17 @@ class ElasticsearchClient(VectorDBBase): ids: Optional[list[str]] = None, filter: Optional[dict] = None, ): - - query = { - "query": {"bool": {"filter": [{"term": {"collection": collection_name}}]}} - } + query = {'query': {'bool': {'filter': [{'term': {'collection': collection_name}}]}}} # logic based on chromaDB if ids: - query["query"]["bool"]["filter"].append({"terms": {"_id": ids}}) + query['query']['bool']['filter'].append({'terms': {'_id': ids}}) elif filter: for field, value in filter.items(): - query["query"]["bool"]["filter"].append( - {"term": {f"metadata.{field}": value}} - ) + query['query']['bool']['filter'].append({'term': {f'metadata.{field}': value}}) - self.client.delete_by_query(index=f"{self.index_prefix}*", body=query) + self.client.delete_by_query(index=f'{self.index_prefix}*', body=query) def reset(self): - indices = self.client.indices.get(index=f"{self.index_prefix}*") + indices = self.client.indices.get(index=f'{self.index_prefix}*') for index in indices: self.client.indices.delete(index=index) diff --git a/backend/open_webui/retrieval/vector/dbs/mariadb_vector.py b/backend/open_webui/retrieval/vector/dbs/mariadb_vector.py index dfcfb3da59..1cb3563382 100644 --- a/backend/open_webui/retrieval/vector/dbs/mariadb_vector.py +++ b/backend/open_webui/retrieval/vector/dbs/mariadb_vector.py @@ -47,8 +47,8 @@ def _embedding_to_f32_bytes(vec: List[float]) -> bytes: byte sequence. We use array('f') to avoid a numpy dependency and byteswap on big-endian platforms for portability. """ - a = array.array("f", [float(x) for x in vec]) # float32 - if sys.byteorder != "little": + a = array.array('f', [float(x) for x in vec]) # float32 + if sys.byteorder != 'little': a.byteswap() return a.tobytes() @@ -68,7 +68,7 @@ def _safe_json(v: Any) -> Dict[str, Any]: return v if isinstance(v, (bytes, bytearray)): try: - v = v.decode("utf-8") + v = v.decode('utf-8') except Exception: return {} if isinstance(v, str): @@ -105,16 +105,16 @@ class MariaDBVectorClient(VectorDBBase): """ self.db_url = (db_url or MARIADB_VECTOR_DB_URL).strip() self.vector_length = int(vector_length) - self.distance_strategy = (distance_strategy or "cosine").strip().lower() + self.distance_strategy = (distance_strategy or 'cosine').strip().lower() self.index_m = int(index_m) - if self.distance_strategy not in {"cosine", "euclidean"}: + if self.distance_strategy not in {'cosine', 'euclidean'}: raise ValueError("distance_strategy must be 'cosine' or 'euclidean'") - if not self.db_url.lower().startswith("mariadb+mariadbconnector://"): + if not self.db_url.lower().startswith('mariadb+mariadbconnector://'): raise ValueError( - "MariaDBVectorClient requires mariadb+mariadbconnector:// (official MariaDB driver) " - "to ensure qmark paramstyle and correct VECTOR binding." + 'MariaDBVectorClient requires mariadb+mariadbconnector:// (official MariaDB driver) ' + 'to ensure qmark paramstyle and correct VECTOR binding.' ) if isinstance(MARIADB_VECTOR_POOL_SIZE, int): @@ -129,9 +129,7 @@ class MariaDBVectorClient(VectorDBBase): poolclass=QueuePool, ) else: - self.engine = create_engine( - self.db_url, pool_pre_ping=True, poolclass=NullPool - ) + self.engine = create_engine(self.db_url, pool_pre_ping=True, poolclass=NullPool) else: self.engine = create_engine(self.db_url, pool_pre_ping=True) self._init_schema() @@ -185,7 +183,7 @@ class MariaDBVectorClient(VectorDBBase): conn.commit() except Exception as e: conn.rollback() - log.exception(f"Error during database initialization: {e}") + log.exception(f'Error during database initialization: {e}') raise def _check_vector_length(self) -> None: @@ -197,19 +195,19 @@ class MariaDBVectorClient(VectorDBBase): """ with self._connect() as conn: with conn.cursor() as cur: - cur.execute("SHOW CREATE TABLE document_chunk") + cur.execute('SHOW CREATE TABLE document_chunk') row = cur.fetchone() if not row or len(row) < 2: return ddl = row[1] - m = re.search(r"vector\\((\\d+)\\)", ddl, flags=re.IGNORECASE) + m = re.search(r'vector\\((\\d+)\\)', ddl, flags=re.IGNORECASE) if not m: return existing = int(m.group(1)) if existing != int(self.vector_length): raise Exception( - f"VECTOR_LENGTH {self.vector_length} does not match existing vector column dimension {existing}. " - "Cannot change vector size after initialization without migrating the data." + f'VECTOR_LENGTH {self.vector_length} does not match existing vector column dimension {existing}. ' + 'Cannot change vector size after initialization without migrating the data.' ) def adjust_vector_length(self, vector: List[float]) -> List[float]: @@ -227,11 +225,7 @@ class MariaDBVectorClient(VectorDBBase): """ Return the MariaDB Vector distance function name for the configured strategy. """ - return ( - "vec_distance_cosine" - if self.distance_strategy == "cosine" - else "vec_distance_euclidean" - ) + return 'vec_distance_cosine' if self.distance_strategy == 'cosine' else 'vec_distance_euclidean' def _score_from_dist(self, dist: float) -> float: """ @@ -240,7 +234,7 @@ class MariaDBVectorClient(VectorDBBase): - cosine: score ~= 1 - cosine_distance, clamped to [0, 1] - euclidean: score = 1 / (1 + dist) """ - if self.distance_strategy == "cosine": + if self.distance_strategy == 'cosine': score = 1.0 - dist if score < 0.0: score = 0.0 @@ -260,48 +254,48 @@ class MariaDBVectorClient(VectorDBBase): - {"$or": [ ... ]} """ if not expr or not isinstance(expr, dict): - return "", [] + return '', [] - if "$and" in expr: + if '$and' in expr: parts: List[str] = [] params: List[Any] = [] - for e in expr.get("$and") or []: + for e in expr.get('$and') or []: s, p = self._build_filter_sql_qmark(e) if s: parts.append(s) params.extend(p) - return ("(" + " AND ".join(parts) + ")") if parts else "", params + return ('(' + ' AND '.join(parts) + ')') if parts else '', params - if "$or" in expr: + if '$or' in expr: parts: List[str] = [] params: List[Any] = [] - for e in expr.get("$or") or []: + for e in expr.get('$or') or []: s, p = self._build_filter_sql_qmark(e) if s: parts.append(s) params.extend(p) - return ("(" + " OR ".join(parts) + ")") if parts else "", params + return ('(' + ' OR '.join(parts) + ')') if parts else '', params clauses: List[str] = [] params: List[Any] = [] for key, value in expr.items(): - if key.startswith("$"): + if key.startswith('$'): continue json_expr = f"JSON_UNQUOTE(JSON_EXTRACT(vmetadata, '$.{key}'))" - if isinstance(value, dict) and "$in" in value: - vals = [str(v) for v in (value.get("$in") or [])] + if isinstance(value, dict) and '$in' in value: + vals = [str(v) for v in (value.get('$in') or [])] if not vals: - clauses.append("0=1") + clauses.append('0=1') continue ors = [] for v in vals: - ors.append(f"{json_expr} = ?") + ors.append(f'{json_expr} = ?') params.append(v) - clauses.append("(" + " OR ".join(ors) + ")") + clauses.append('(' + ' OR '.join(ors) + ')') else: - clauses.append(f"{json_expr} = ?") + clauses.append(f'{json_expr} = ?') params.append(str(value)) - return ("(" + " AND ".join(clauses) + ")") if clauses else "", params + return ('(' + ' AND '.join(clauses) + ')') if clauses else '', params def insert(self, collection_name: str, items: List[VectorItem]) -> None: """ @@ -322,15 +316,15 @@ class MariaDBVectorClient(VectorDBBase): """ params: List[Tuple[Any, ...]] = [] for item in items: - v = self.adjust_vector_length(item["vector"]) + v = self.adjust_vector_length(item['vector']) emb = _embedding_to_f32_bytes(v) - meta = process_metadata(item.get("metadata") or {}) + meta = process_metadata(item.get('metadata') or {}) params.append( ( - item["id"], + item['id'], emb, collection_name, - item.get("text"), + item.get('text'), json.dumps(meta), ) ) @@ -338,7 +332,7 @@ class MariaDBVectorClient(VectorDBBase): conn.commit() except Exception as e: conn.rollback() - log.exception(f"Error during insert: {e}") + log.exception(f'Error during insert: {e}') raise def upsert(self, collection_name: str, items: List[VectorItem]) -> None: @@ -365,15 +359,15 @@ class MariaDBVectorClient(VectorDBBase): """ params: List[Tuple[Any, ...]] = [] for item in items: - v = self.adjust_vector_length(item["vector"]) + v = self.adjust_vector_length(item['vector']) emb = _embedding_to_f32_bytes(v) - meta = process_metadata(item.get("metadata") or {}) + meta = process_metadata(item.get('metadata') or {}) params.append( ( - item["id"], + item['id'], emb, collection_name, - item.get("text"), + item.get('text'), json.dumps(meta), ) ) @@ -381,7 +375,7 @@ class MariaDBVectorClient(VectorDBBase): conn.commit() except Exception as e: conn.rollback() - log.exception(f"Error during upsert: {e}") + log.exception(f'Error during upsert: {e}') raise def search( @@ -415,10 +409,10 @@ class MariaDBVectorClient(VectorDBBase): with self._connect() as conn: with conn.cursor() as cur: fsql, fparams = self._build_filter_sql_qmark(filter or {}) - where = "collection_name = ?" + where = 'collection_name = ?' base_params: List[Any] = [collection_name] if fsql: - where = where + " AND " + fsql + where = where + ' AND ' + fsql base_params.extend(fparams) sql = f""" @@ -460,26 +454,24 @@ class MariaDBVectorClient(VectorDBBase): metadatas=metadatas, ) except Exception as e: - log.exception(f"[MARIADB_VECTOR] search() failed: {e}") + log.exception(f'[MARIADB_VECTOR] search() failed: {e}') return None - def query( - self, collection_name: str, filter: Dict[str, Any], limit: Optional[int] = None - ) -> Optional[GetResult]: + def query(self, collection_name: str, filter: Dict[str, Any], limit: Optional[int] = None) -> Optional[GetResult]: """ Retrieve documents by metadata filter (non-vector query). """ with self._connect() as conn: with conn.cursor() as cur: fsql, fparams = self._build_filter_sql_qmark(filter or {}) - where = "collection_name = ?" + where = 'collection_name = ?' params: List[Any] = [collection_name] if fsql: - where = where + " AND " + fsql + where = where + ' AND ' + fsql params.extend(fparams) - sql = f"SELECT id, text, vmetadata FROM document_chunk WHERE {where}" + sql = f'SELECT id, text, vmetadata FROM document_chunk WHERE {where}' if limit is not None: - sql += " LIMIT ?" + sql += ' LIMIT ?' params.append(int(limit)) cur.execute(sql, params) rows = cur.fetchall() @@ -490,18 +482,16 @@ class MariaDBVectorClient(VectorDBBase): metadatas = [[_safe_json(r[2]) for r in rows]] return GetResult(ids=ids, documents=documents, metadatas=metadatas) - def get( - self, collection_name: str, limit: Optional[int] = None - ) -> Optional[GetResult]: + def get(self, collection_name: str, limit: Optional[int] = None) -> Optional[GetResult]: """ Retrieve documents in a collection without filtering (optionally limited). """ with self._connect() as conn: with conn.cursor() as cur: - sql = "SELECT id, text, vmetadata FROM document_chunk WHERE collection_name = ?" + sql = 'SELECT id, text, vmetadata FROM document_chunk WHERE collection_name = ?' params: List[Any] = [collection_name] if limit is not None: - sql += " LIMIT ?" + sql += ' LIMIT ?' params.append(int(limit)) cur.execute(sql, params) rows = cur.fetchall() @@ -526,12 +516,12 @@ class MariaDBVectorClient(VectorDBBase): with self._connect() as conn: with conn.cursor() as cur: try: - where = ["collection_name = ?"] + where = ['collection_name = ?'] params: List[Any] = [collection_name] if ids: - ph = ", ".join(["?"] * len(ids)) - where.append(f"id IN ({ph})") + ph = ', '.join(['?'] * len(ids)) + where.append(f'id IN ({ph})') params.extend(ids) if filter: @@ -540,12 +530,12 @@ class MariaDBVectorClient(VectorDBBase): where.append(fsql) params.extend(fparams) - sql = "DELETE FROM document_chunk WHERE " + " AND ".join(where) + sql = 'DELETE FROM document_chunk WHERE ' + ' AND '.join(where) cur.execute(sql, params) conn.commit() except Exception as e: conn.rollback() - log.exception(f"Error during delete: {e}") + log.exception(f'Error during delete: {e}') raise def reset(self) -> None: @@ -555,11 +545,11 @@ class MariaDBVectorClient(VectorDBBase): with self._connect() as conn: with conn.cursor() as cur: try: - cur.execute("TRUNCATE TABLE document_chunk") + cur.execute('TRUNCATE TABLE document_chunk') conn.commit() except Exception as e: conn.rollback() - log.exception(f"Error during reset: {e}") + log.exception(f'Error during reset: {e}') raise def has_collection(self, collection_name: str) -> bool: @@ -570,7 +560,7 @@ class MariaDBVectorClient(VectorDBBase): with self._connect() as conn: with conn.cursor() as cur: cur.execute( - "SELECT 1 FROM document_chunk WHERE collection_name = ? LIMIT 1", + 'SELECT 1 FROM document_chunk WHERE collection_name = ? LIMIT 1', (collection_name,), ) return cur.fetchone() is not None @@ -590,4 +580,4 @@ class MariaDBVectorClient(VectorDBBase): try: self.engine.dispose() except Exception as e: - log.exception(f"Error during dispose the underlying SQLAlchemy engine: {e}") + log.exception(f'Error during dispose the underlying SQLAlchemy engine: {e}') diff --git a/backend/open_webui/retrieval/vector/dbs/milvus.py b/backend/open_webui/retrieval/vector/dbs/milvus.py index 4dcf76c64d..2f3d8f3890 100644 --- a/backend/open_webui/retrieval/vector/dbs/milvus.py +++ b/backend/open_webui/retrieval/vector/dbs/milvus.py @@ -35,7 +35,7 @@ log = logging.getLogger(__name__) class MilvusClient(VectorDBBase): def __init__(self): - self.collection_prefix = "open_webui" + self.collection_prefix = 'open_webui' if MILVUS_TOKEN is None: self.client = Client(uri=MILVUS_URI, db_name=MILVUS_DB) else: @@ -50,17 +50,17 @@ class MilvusClient(VectorDBBase): _documents = [] _metadatas = [] for item in match: - _ids.append(item.get("id")) - _documents.append(item.get("data", {}).get("text")) - _metadatas.append(item.get("metadata")) + _ids.append(item.get('id')) + _documents.append(item.get('data', {}).get('text')) + _metadatas.append(item.get('metadata')) ids.append(_ids) documents.append(_documents) metadatas.append(_metadatas) return GetResult( **{ - "ids": ids, - "documents": documents, - "metadatas": metadatas, + 'ids': ids, + 'documents': documents, + 'metadatas': metadatas, } ) @@ -75,23 +75,23 @@ class MilvusClient(VectorDBBase): _documents = [] _metadatas = [] for item in match: - _ids.append(item.get("id")) + _ids.append(item.get('id')) # normalize milvus score from [-1, 1] to [0, 1] range # https://milvus.io/docs/de/metric.md - _dist = (item.get("distance") + 1.0) / 2.0 + _dist = (item.get('distance') + 1.0) / 2.0 _distances.append(_dist) - _documents.append(item.get("entity", {}).get("data", {}).get("text")) - _metadatas.append(item.get("entity", {}).get("metadata")) + _documents.append(item.get('entity', {}).get('data', {}).get('text')) + _metadatas.append(item.get('entity', {}).get('metadata')) ids.append(_ids) distances.append(_distances) documents.append(_documents) metadatas.append(_metadatas) return SearchResult( **{ - "ids": ids, - "distances": distances, - "documents": documents, - "metadatas": metadatas, + 'ids': ids, + 'distances': distances, + 'documents': documents, + 'metadatas': metadatas, } ) @@ -101,21 +101,19 @@ class MilvusClient(VectorDBBase): enable_dynamic_field=True, ) schema.add_field( - field_name="id", + field_name='id', datatype=DataType.VARCHAR, is_primary=True, max_length=65535, ) schema.add_field( - field_name="vector", + field_name='vector', datatype=DataType.FLOAT_VECTOR, dim=dimension, - description="vector", - ) - schema.add_field(field_name="data", datatype=DataType.JSON, description="data") - schema.add_field( - field_name="metadata", datatype=DataType.JSON, description="metadata" + description='vector', ) + schema.add_field(field_name='data', datatype=DataType.JSON, description='data') + schema.add_field(field_name='metadata', datatype=DataType.JSON, description='metadata') index_params = self.client.prepare_index_params() @@ -123,44 +121,44 @@ class MilvusClient(VectorDBBase): index_type = MILVUS_INDEX_TYPE.upper() metric_type = MILVUS_METRIC_TYPE.upper() - log.info(f"Using Milvus index type: {index_type}, metric type: {metric_type}") + log.info(f'Using Milvus index type: {index_type}, metric type: {metric_type}') index_creation_params = {} - if index_type == "HNSW": + if index_type == 'HNSW': index_creation_params = { - "M": MILVUS_HNSW_M, - "efConstruction": MILVUS_HNSW_EFCONSTRUCTION, + 'M': MILVUS_HNSW_M, + 'efConstruction': MILVUS_HNSW_EFCONSTRUCTION, } - log.info(f"HNSW params: {index_creation_params}") - elif index_type == "IVF_FLAT": - index_creation_params = {"nlist": MILVUS_IVF_FLAT_NLIST} - log.info(f"IVF_FLAT params: {index_creation_params}") - elif index_type == "DISKANN": + log.info(f'HNSW params: {index_creation_params}') + elif index_type == 'IVF_FLAT': + index_creation_params = {'nlist': MILVUS_IVF_FLAT_NLIST} + log.info(f'IVF_FLAT params: {index_creation_params}') + elif index_type == 'DISKANN': index_creation_params = { - "max_degree": MILVUS_DISKANN_MAX_DEGREE, - "search_list_size": MILVUS_DISKANN_SEARCH_LIST_SIZE, + 'max_degree': MILVUS_DISKANN_MAX_DEGREE, + 'search_list_size': MILVUS_DISKANN_SEARCH_LIST_SIZE, } - log.info(f"DISKANN params: {index_creation_params}") - elif index_type in ["FLAT", "AUTOINDEX"]: - log.info(f"Using {index_type} index with no specific build-time params.") + log.info(f'DISKANN params: {index_creation_params}') + elif index_type in ['FLAT', 'AUTOINDEX']: + log.info(f'Using {index_type} index with no specific build-time params.') else: log.warning( f"Unsupported MILVUS_INDEX_TYPE: '{index_type}'. " - f"Supported types: HNSW, IVF_FLAT, DISKANN, FLAT, AUTOINDEX. " - f"Milvus will use its default for the collection if this type is not directly supported for index creation." + f'Supported types: HNSW, IVF_FLAT, DISKANN, FLAT, AUTOINDEX. ' + f'Milvus will use its default for the collection if this type is not directly supported for index creation.' ) # For unsupported types, pass the type directly to Milvus; it might handle it or use a default. # If Milvus errors out, the user needs to correct the MILVUS_INDEX_TYPE env var. index_params.add_index( - field_name="vector", + field_name='vector', index_type=index_type, metric_type=metric_type, params=index_creation_params, ) self.client.create_collection( - collection_name=f"{self.collection_prefix}_{collection_name}", + collection_name=f'{self.collection_prefix}_{collection_name}', schema=schema, index_params=index_params, ) @@ -170,17 +168,13 @@ class MilvusClient(VectorDBBase): def has_collection(self, collection_name: str) -> bool: # Check if the collection exists based on the collection name. - collection_name = collection_name.replace("-", "_") - return self.client.has_collection( - collection_name=f"{self.collection_prefix}_{collection_name}" - ) + collection_name = collection_name.replace('-', '_') + return self.client.has_collection(collection_name=f'{self.collection_prefix}_{collection_name}') def delete_collection(self, collection_name: str): # Delete the collection based on the collection name. - collection_name = collection_name.replace("-", "_") - return self.client.drop_collection( - collection_name=f"{self.collection_prefix}_{collection_name}" - ) + collection_name = collection_name.replace('-', '_') + return self.client.drop_collection(collection_name=f'{self.collection_prefix}_{collection_name}') def search( self, @@ -190,15 +184,15 @@ class MilvusClient(VectorDBBase): limit: int = 10, ) -> Optional[SearchResult]: # Search for the nearest neighbor items based on the vectors and return 'limit' number of results. - collection_name = collection_name.replace("-", "_") + collection_name = collection_name.replace('-', '_') # For some index types like IVF_FLAT, search params like nprobe can be set. # Example: search_params = {"nprobe": 10} if using IVF_FLAT # For simplicity, not adding configurable search_params here, but could be extended. result = self.client.search( - collection_name=f"{self.collection_prefix}_{collection_name}", + collection_name=f'{self.collection_prefix}_{collection_name}', data=vectors, limit=limit, - output_fields=["data", "metadata"], + output_fields=['data', 'metadata'], # search_params=search_params # Potentially add later if needed ) return self._result_to_search_result(result) @@ -206,11 +200,9 @@ class MilvusClient(VectorDBBase): def query(self, collection_name: str, filter: dict, limit: int = -1): connections.connect(uri=MILVUS_URI, token=MILVUS_TOKEN, db_name=MILVUS_DB) - collection_name = collection_name.replace("-", "_") + collection_name = collection_name.replace('-', '_') if not self.has_collection(collection_name): - log.warning( - f"Query attempted on non-existent collection: {self.collection_prefix}_{collection_name}" - ) + log.warning(f'Query attempted on non-existent collection: {self.collection_prefix}_{collection_name}') return None filter_expressions = [] @@ -220,9 +212,9 @@ class MilvusClient(VectorDBBase): else: filter_expressions.append(f'metadata["{key}"] == {value}') - filter_string = " && ".join(filter_expressions) + filter_string = ' && '.join(filter_expressions) - collection = Collection(f"{self.collection_prefix}_{collection_name}") + collection = Collection(f'{self.collection_prefix}_{collection_name}') collection.load() try: @@ -233,9 +225,9 @@ class MilvusClient(VectorDBBase): iterator = collection.query_iterator( expr=filter_string, output_fields=[ - "id", - "data", - "metadata", + 'id', + 'data', + 'metadata', ], limit=limit if limit > 0 else -1, ) @@ -248,7 +240,7 @@ class MilvusClient(VectorDBBase): break all_results.extend(batch) - log.debug(f"Total results from query: {len(all_results)}") + log.debug(f'Total results from query: {len(all_results)}') return self._result_to_get_result([all_results] if all_results else [[]]) except Exception as e: @@ -259,7 +251,7 @@ class MilvusClient(VectorDBBase): def get(self, collection_name: str) -> Optional[GetResult]: # Get all the items in the collection. This can be very resource-intensive for large collections. - collection_name = collection_name.replace("-", "_") + collection_name = collection_name.replace('-', '_') log.warning( f"Fetching ALL items from collection '{self.collection_prefix}_{collection_name}'. This might be slow for large collections." ) @@ -269,35 +261,25 @@ class MilvusClient(VectorDBBase): def insert(self, collection_name: str, items: list[VectorItem]): # Insert the items into the collection, if the collection does not exist, it will be created. - collection_name = collection_name.replace("-", "_") - if not self.client.has_collection( - collection_name=f"{self.collection_prefix}_{collection_name}" - ): - log.info( - f"Collection {self.collection_prefix}_{collection_name} does not exist. Creating now." - ) + collection_name = collection_name.replace('-', '_') + if not self.client.has_collection(collection_name=f'{self.collection_prefix}_{collection_name}'): + log.info(f'Collection {self.collection_prefix}_{collection_name} does not exist. Creating now.') if not items: log.error( - f"Cannot create collection {self.collection_prefix}_{collection_name} without items to determine dimension." + f'Cannot create collection {self.collection_prefix}_{collection_name} without items to determine dimension.' ) - raise ValueError( - "Cannot create Milvus collection without items to determine vector dimension." - ) - self._create_collection( - collection_name=collection_name, dimension=len(items[0]["vector"]) - ) + raise ValueError('Cannot create Milvus collection without items to determine vector dimension.') + self._create_collection(collection_name=collection_name, dimension=len(items[0]['vector'])) - log.info( - f"Inserting {len(items)} items into collection {self.collection_prefix}_{collection_name}." - ) + log.info(f'Inserting {len(items)} items into collection {self.collection_prefix}_{collection_name}.') return self.client.insert( - collection_name=f"{self.collection_prefix}_{collection_name}", + collection_name=f'{self.collection_prefix}_{collection_name}', data=[ { - "id": item["id"], - "vector": item["vector"], - "data": {"text": item["text"]}, - "metadata": process_metadata(item["metadata"]), + 'id': item['id'], + 'vector': item['vector'], + 'data': {'text': item['text']}, + 'metadata': process_metadata(item['metadata']), } for item in items ], @@ -305,35 +287,27 @@ class MilvusClient(VectorDBBase): def upsert(self, collection_name: str, items: list[VectorItem]): # Update the items in the collection, if the items are not present, insert them. If the collection does not exist, it will be created. - collection_name = collection_name.replace("-", "_") - if not self.client.has_collection( - collection_name=f"{self.collection_prefix}_{collection_name}" - ): - log.info( - f"Collection {self.collection_prefix}_{collection_name} does not exist for upsert. Creating now." - ) + collection_name = collection_name.replace('-', '_') + if not self.client.has_collection(collection_name=f'{self.collection_prefix}_{collection_name}'): + log.info(f'Collection {self.collection_prefix}_{collection_name} does not exist for upsert. Creating now.') if not items: log.error( - f"Cannot create collection {self.collection_prefix}_{collection_name} for upsert without items to determine dimension." + f'Cannot create collection {self.collection_prefix}_{collection_name} for upsert without items to determine dimension.' ) raise ValueError( - "Cannot create Milvus collection for upsert without items to determine vector dimension." + 'Cannot create Milvus collection for upsert without items to determine vector dimension.' ) - self._create_collection( - collection_name=collection_name, dimension=len(items[0]["vector"]) - ) + self._create_collection(collection_name=collection_name, dimension=len(items[0]['vector'])) - log.info( - f"Upserting {len(items)} items into collection {self.collection_prefix}_{collection_name}." - ) + log.info(f'Upserting {len(items)} items into collection {self.collection_prefix}_{collection_name}.') return self.client.upsert( - collection_name=f"{self.collection_prefix}_{collection_name}", + collection_name=f'{self.collection_prefix}_{collection_name}', data=[ { - "id": item["id"], - "vector": item["vector"], - "data": {"text": item["text"]}, - "metadata": process_metadata(item["metadata"]), + 'id': item['id'], + 'vector': item['vector'], + 'data': {'text': item['text']}, + 'metadata': process_metadata(item['metadata']), } for item in items ], @@ -346,46 +320,35 @@ class MilvusClient(VectorDBBase): filter: Optional[dict] = None, ): # Delete the items from the collection based on the ids or filter. - collection_name = collection_name.replace("-", "_") + collection_name = collection_name.replace('-', '_') if not self.has_collection(collection_name): - log.warning( - f"Delete attempted on non-existent collection: {self.collection_prefix}_{collection_name}" - ) + log.warning(f'Delete attempted on non-existent collection: {self.collection_prefix}_{collection_name}') return None if ids: - log.info( - f"Deleting items by IDs from {self.collection_prefix}_{collection_name}. IDs: {ids}" - ) + log.info(f'Deleting items by IDs from {self.collection_prefix}_{collection_name}. IDs: {ids}') return self.client.delete( - collection_name=f"{self.collection_prefix}_{collection_name}", + collection_name=f'{self.collection_prefix}_{collection_name}', ids=ids, ) elif filter: - filter_string = " && ".join( - [ - f'metadata["{key}"] == {json.dumps(value)}' - for key, value in filter.items() - ] - ) + filter_string = ' && '.join([f'metadata["{key}"] == {json.dumps(value)}' for key, value in filter.items()]) log.info( - f"Deleting items by filter from {self.collection_prefix}_{collection_name}. Filter: {filter_string}" + f'Deleting items by filter from {self.collection_prefix}_{collection_name}. Filter: {filter_string}' ) return self.client.delete( - collection_name=f"{self.collection_prefix}_{collection_name}", + collection_name=f'{self.collection_prefix}_{collection_name}', filter=filter_string, ) else: log.warning( - f"Delete operation on {self.collection_prefix}_{collection_name} called without IDs or filter. No action taken." + f'Delete operation on {self.collection_prefix}_{collection_name} called without IDs or filter. No action taken.' ) return None def reset(self): # Resets the database. This will delete all collections and item entries that match the prefix. - log.warning( - f"Resetting Milvus: Deleting all collections with prefix '{self.collection_prefix}'." - ) + log.warning(f"Resetting Milvus: Deleting all collections with prefix '{self.collection_prefix}'.") collection_names = self.client.list_collections() deleted_collections = [] for collection_name_full in collection_names: @@ -393,7 +356,7 @@ class MilvusClient(VectorDBBase): try: self.client.drop_collection(collection_name=collection_name_full) deleted_collections.append(collection_name_full) - log.info(f"Deleted collection: {collection_name_full}") + log.info(f'Deleted collection: {collection_name_full}') except Exception as e: - log.error(f"Error deleting collection {collection_name_full}: {e}") - log.info(f"Milvus reset complete. Deleted collections: {deleted_collections}") + log.error(f'Error deleting collection {collection_name_full}: {e}') + log.info(f'Milvus reset complete. Deleted collections: {deleted_collections}') diff --git a/backend/open_webui/retrieval/vector/dbs/milvus_multitenancy.py b/backend/open_webui/retrieval/vector/dbs/milvus_multitenancy.py index 0ecbac15d2..93b4a8cbc4 100644 --- a/backend/open_webui/retrieval/vector/dbs/milvus_multitenancy.py +++ b/backend/open_webui/retrieval/vector/dbs/milvus_multitenancy.py @@ -33,26 +33,26 @@ from pymilvus import ( log = logging.getLogger(__name__) -RESOURCE_ID_FIELD = "resource_id" +RESOURCE_ID_FIELD = 'resource_id' class MilvusClient(VectorDBBase): def __init__(self): # Milvus collection names can only contain numbers, letters, and underscores. - self.collection_prefix = MILVUS_COLLECTION_PREFIX.replace("-", "_") + self.collection_prefix = MILVUS_COLLECTION_PREFIX.replace('-', '_') connections.connect( - alias="default", + alias='default', uri=MILVUS_URI, token=MILVUS_TOKEN, db_name=MILVUS_DB, ) # Main collection types for multi-tenancy - self.MEMORY_COLLECTION = f"{self.collection_prefix}_memories" - self.KNOWLEDGE_COLLECTION = f"{self.collection_prefix}_knowledge" - self.FILE_COLLECTION = f"{self.collection_prefix}_files" - self.WEB_SEARCH_COLLECTION = f"{self.collection_prefix}_web_search" - self.HASH_BASED_COLLECTION = f"{self.collection_prefix}_hash_based" + self.MEMORY_COLLECTION = f'{self.collection_prefix}_memories' + self.KNOWLEDGE_COLLECTION = f'{self.collection_prefix}_knowledge' + self.FILE_COLLECTION = f'{self.collection_prefix}_files' + self.WEB_SEARCH_COLLECTION = f'{self.collection_prefix}_web_search' + self.HASH_BASED_COLLECTION = f'{self.collection_prefix}_hash_based' self.shared_collections = [ self.MEMORY_COLLECTION, self.KNOWLEDGE_COLLECTION, @@ -74,15 +74,13 @@ class MilvusClient(VectorDBBase): """ resource_id = collection_name - if collection_name.startswith("user-memory-"): + if collection_name.startswith('user-memory-'): return self.MEMORY_COLLECTION, resource_id - elif collection_name.startswith("file-"): + elif collection_name.startswith('file-'): return self.FILE_COLLECTION, resource_id - elif collection_name.startswith("web-search-"): + elif collection_name.startswith('web-search-'): return self.WEB_SEARCH_COLLECTION, resource_id - elif len(collection_name) == 63 and all( - c in "0123456789abcdef" for c in collection_name - ): + elif len(collection_name) == 63 and all(c in '0123456789abcdef' for c in collection_name): return self.HASH_BASED_COLLECTION, resource_id else: return self.KNOWLEDGE_COLLECTION, resource_id @@ -90,36 +88,36 @@ class MilvusClient(VectorDBBase): def _create_shared_collection(self, mt_collection_name: str, dimension: int): fields = [ FieldSchema( - name="id", + name='id', dtype=DataType.VARCHAR, is_primary=True, auto_id=False, max_length=36, ), - FieldSchema(name="vector", dtype=DataType.FLOAT_VECTOR, dim=dimension), - FieldSchema(name="text", dtype=DataType.VARCHAR, max_length=65535), - FieldSchema(name="metadata", dtype=DataType.JSON), + FieldSchema(name='vector', dtype=DataType.FLOAT_VECTOR, dim=dimension), + FieldSchema(name='text', dtype=DataType.VARCHAR, max_length=65535), + FieldSchema(name='metadata', dtype=DataType.JSON), FieldSchema(name=RESOURCE_ID_FIELD, dtype=DataType.VARCHAR, max_length=255), ] - schema = CollectionSchema(fields, "Shared collection for multi-tenancy") + schema = CollectionSchema(fields, 'Shared collection for multi-tenancy') collection = Collection(mt_collection_name, schema) index_params = { - "metric_type": MILVUS_METRIC_TYPE, - "index_type": MILVUS_INDEX_TYPE, - "params": {}, + 'metric_type': MILVUS_METRIC_TYPE, + 'index_type': MILVUS_INDEX_TYPE, + 'params': {}, } - if MILVUS_INDEX_TYPE == "HNSW": - index_params["params"] = { - "M": MILVUS_HNSW_M, - "efConstruction": MILVUS_HNSW_EFCONSTRUCTION, + if MILVUS_INDEX_TYPE == 'HNSW': + index_params['params'] = { + 'M': MILVUS_HNSW_M, + 'efConstruction': MILVUS_HNSW_EFCONSTRUCTION, } - elif MILVUS_INDEX_TYPE == "IVF_FLAT": - index_params["params"] = {"nlist": MILVUS_IVF_FLAT_NLIST} + elif MILVUS_INDEX_TYPE == 'IVF_FLAT': + index_params['params'] = {'nlist': MILVUS_IVF_FLAT_NLIST} - collection.create_index("vector", index_params) + collection.create_index('vector', index_params) collection.create_index(RESOURCE_ID_FIELD) - log.info(f"Created shared collection: {mt_collection_name}") + log.info(f'Created shared collection: {mt_collection_name}') return collection def _ensure_collection(self, mt_collection_name: str, dimension: int): @@ -127,9 +125,7 @@ class MilvusClient(VectorDBBase): self._create_shared_collection(mt_collection_name, dimension) def has_collection(self, collection_name: str) -> bool: - mt_collection, resource_id = self._get_collection_and_resource_id( - collection_name - ) + mt_collection, resource_id = self._get_collection_and_resource_id(collection_name) if not utility.has_collection(mt_collection): return False @@ -141,19 +137,17 @@ class MilvusClient(VectorDBBase): def upsert(self, collection_name: str, items: List[VectorItem]): if not items: return - mt_collection, resource_id = self._get_collection_and_resource_id( - collection_name - ) - dimension = len(items[0]["vector"]) + mt_collection, resource_id = self._get_collection_and_resource_id(collection_name) + dimension = len(items[0]['vector']) self._ensure_collection(mt_collection, dimension) collection = Collection(mt_collection) entities = [ { - "id": item["id"], - "vector": item["vector"], - "text": item["text"], - "metadata": item["metadata"], + 'id': item['id'], + 'vector': item['vector'], + 'text': item['text'], + 'metadata': item['metadata'], RESOURCE_ID_FIELD: resource_id, } for item in items @@ -170,41 +164,37 @@ class MilvusClient(VectorDBBase): if not vectors: return None - mt_collection, resource_id = self._get_collection_and_resource_id( - collection_name - ) + mt_collection, resource_id = self._get_collection_and_resource_id(collection_name) if not utility.has_collection(mt_collection): return None collection = Collection(mt_collection) collection.load() - search_params = {"metric_type": MILVUS_METRIC_TYPE, "params": {}} + search_params = {'metric_type': MILVUS_METRIC_TYPE, 'params': {}} results = collection.search( data=vectors, - anns_field="vector", + anns_field='vector', param=search_params, limit=limit, expr=f"{RESOURCE_ID_FIELD} == '{resource_id}'", - output_fields=["id", "text", "metadata"], + output_fields=['id', 'text', 'metadata'], ) ids, documents, metadatas, distances = [], [], [], [] for hits in results: batch_ids, batch_docs, batch_metadatas, batch_dists = [], [], [], [] for hit in hits: - batch_ids.append(hit.entity.get("id")) - batch_docs.append(hit.entity.get("text")) - batch_metadatas.append(hit.entity.get("metadata")) + batch_ids.append(hit.entity.get('id')) + batch_docs.append(hit.entity.get('text')) + batch_metadatas.append(hit.entity.get('metadata')) batch_dists.append(hit.distance) ids.append(batch_ids) documents.append(batch_docs) metadatas.append(batch_metadatas) distances.append(batch_dists) - return SearchResult( - ids=ids, documents=documents, metadatas=metadatas, distances=distances - ) + return SearchResult(ids=ids, documents=documents, metadatas=metadatas, distances=distances) def delete( self, @@ -212,9 +202,7 @@ class MilvusClient(VectorDBBase): ids: Optional[List[str]] = None, filter: Optional[Dict[str, Any]] = None, ): - mt_collection, resource_id = self._get_collection_and_resource_id( - collection_name - ) + mt_collection, resource_id = self._get_collection_and_resource_id(collection_name) if not utility.has_collection(mt_collection): return @@ -224,14 +212,14 @@ class MilvusClient(VectorDBBase): expr = [f"{RESOURCE_ID_FIELD} == '{resource_id}'"] if ids: # Milvus expects a string list for 'in' operator - id_list_str = ", ".join([f"'{id_val}'" for id_val in ids]) - expr.append(f"id in [{id_list_str}]") + id_list_str = ', '.join([f"'{id_val}'" for id_val in ids]) + expr.append(f'id in [{id_list_str}]') if filter: for key, value in filter.items(): expr.append(f"metadata['{key}'] == '{value}'") - collection.delete(" and ".join(expr)) + collection.delete(' and '.join(expr)) def reset(self): for collection_name in self.shared_collections: @@ -239,21 +227,15 @@ class MilvusClient(VectorDBBase): utility.drop_collection(collection_name) def delete_collection(self, collection_name: str): - mt_collection, resource_id = self._get_collection_and_resource_id( - collection_name - ) + mt_collection, resource_id = self._get_collection_and_resource_id(collection_name) if not utility.has_collection(mt_collection): return collection = Collection(mt_collection) collection.delete(f"{RESOURCE_ID_FIELD} == '{resource_id}'") - def query( - self, collection_name: str, filter: Dict[str, Any], limit: Optional[int] = None - ) -> Optional[GetResult]: - mt_collection, resource_id = self._get_collection_and_resource_id( - collection_name - ) + def query(self, collection_name: str, filter: Dict[str, Any], limit: Optional[int] = None) -> Optional[GetResult]: + mt_collection, resource_id = self._get_collection_and_resource_id(collection_name) if not utility.has_collection(mt_collection): return None @@ -269,8 +251,8 @@ class MilvusClient(VectorDBBase): expr.append(f"metadata['{key}'] == {value}") iterator = collection.query_iterator( - expr=" and ".join(expr), - output_fields=["id", "text", "metadata"], + expr=' and '.join(expr), + output_fields=['id', 'text', 'metadata'], limit=limit if limit else -1, ) @@ -282,9 +264,9 @@ class MilvusClient(VectorDBBase): break all_results.extend(batch) - ids = [res["id"] for res in all_results] - documents = [res["text"] for res in all_results] - metadatas = [res["metadata"] for res in all_results] + ids = [res['id'] for res in all_results] + documents = [res['text'] for res in all_results] + metadatas = [res['metadata'] for res in all_results] return GetResult(ids=[ids], documents=[documents], metadatas=[metadatas]) diff --git a/backend/open_webui/retrieval/vector/dbs/opengauss.py b/backend/open_webui/retrieval/vector/dbs/opengauss.py index 679847a1d4..ac97cf01fa 100644 --- a/backend/open_webui/retrieval/vector/dbs/opengauss.py +++ b/backend/open_webui/retrieval/vector/dbs/opengauss.py @@ -36,17 +36,15 @@ from sqlalchemy.dialects import registry class OpenGaussDialect(PGDialect_psycopg2): - name = "opengauss" + name = 'opengauss' def _get_server_version_info(self, connection): try: - version = connection.exec_driver_sql("SELECT version()").scalar() + version = connection.exec_driver_sql('SELECT version()').scalar() if not version: return (9, 0, 0) - match = re.search( - r"openGauss\s+(\d+)\.(\d+)\.(\d+)(?:-\w+)?", version, re.IGNORECASE - ) + match = re.search(r'openGauss\s+(\d+)\.(\d+)\.(\d+)(?:-\w+)?', version, re.IGNORECASE) if match: return (int(match.group(1)), int(match.group(2)), int(match.group(3))) @@ -56,7 +54,7 @@ class OpenGaussDialect(PGDialect_psycopg2): # Register dialect -registry.register("opengauss", __name__, "OpenGaussDialect") +registry.register('opengauss', __name__, 'OpenGaussDialect') from open_webui.retrieval.vector.utils import process_metadata from open_webui.retrieval.vector.main import ( @@ -80,11 +78,11 @@ VECTOR_LENGTH = OPENGAUSS_INITIALIZE_MAX_VECTOR_LENGTH Base = declarative_base() log = logging.getLogger(__name__) -log.setLevel(SRC_LOG_LEVELS["RAG"]) +log.setLevel(SRC_LOG_LEVELS['RAG']) class DocumentChunk(Base): - __tablename__ = "document_chunk" + __tablename__ = 'document_chunk' id = Column(Text, primary_key=True) vector = Column(Vector(dim=VECTOR_LENGTH), nullable=True) @@ -100,26 +98,24 @@ class OpenGaussClient(VectorDBBase): self.session = ScopedSession else: - engine_kwargs = {"pool_pre_ping": True, "dialect": OpenGaussDialect()} + engine_kwargs = {'pool_pre_ping': True, 'dialect': OpenGaussDialect()} if isinstance(OPENGAUSS_POOL_SIZE, int) and OPENGAUSS_POOL_SIZE > 0: engine_kwargs.update( { - "pool_size": OPENGAUSS_POOL_SIZE, - "max_overflow": OPENGAUSS_POOL_MAX_OVERFLOW, - "pool_timeout": OPENGAUSS_POOL_TIMEOUT, - "pool_recycle": OPENGAUSS_POOL_RECYCLE, - "poolclass": QueuePool, + 'pool_size': OPENGAUSS_POOL_SIZE, + 'max_overflow': OPENGAUSS_POOL_MAX_OVERFLOW, + 'pool_timeout': OPENGAUSS_POOL_TIMEOUT, + 'pool_recycle': OPENGAUSS_POOL_RECYCLE, + 'poolclass': QueuePool, } ) else: - engine_kwargs["poolclass"] = NullPool + engine_kwargs['poolclass'] = NullPool engine = create_engine(OPENGAUSS_DB_URL, **engine_kwargs) - SessionLocal = sessionmaker( - autocommit=False, autoflush=False, bind=engine, expire_on_commit=False - ) + SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine, expire_on_commit=False) self.session = scoped_session(SessionLocal) try: @@ -128,47 +124,42 @@ class OpenGaussClient(VectorDBBase): self.session.execute( text( - "CREATE INDEX IF NOT EXISTS idx_document_chunk_vector " - "ON document_chunk USING ivfflat (vector vector_cosine_ops) WITH (lists = 100);" + 'CREATE INDEX IF NOT EXISTS idx_document_chunk_vector ' + 'ON document_chunk USING ivfflat (vector vector_cosine_ops) WITH (lists = 100);' ) ) self.session.execute( text( - "CREATE INDEX IF NOT EXISTS idx_document_chunk_collection_name " - "ON document_chunk (collection_name);" + 'CREATE INDEX IF NOT EXISTS idx_document_chunk_collection_name ON document_chunk (collection_name);' ) ) self.session.commit() - log.info("OpenGauss vector database initialization completed.") + log.info('OpenGauss vector database initialization completed.') except Exception as e: self.session.rollback() - log.exception(f"OpenGauss Initialization failed.: {e}") + log.exception(f'OpenGauss Initialization failed.: {e}') raise def check_vector_length(self) -> None: metadata = MetaData() try: - document_chunk_table = Table( - "document_chunk", metadata, autoload_with=self.session.bind - ) + document_chunk_table = Table('document_chunk', metadata, autoload_with=self.session.bind) except NoSuchTableError: return - if "vector" in document_chunk_table.columns: - vector_column = document_chunk_table.columns["vector"] + if 'vector' in document_chunk_table.columns: + vector_column = document_chunk_table.columns['vector'] vector_type = vector_column.type if isinstance(vector_type, Vector): db_vector_length = vector_type.dim if db_vector_length != VECTOR_LENGTH: raise Exception( - f"Vector dimension mismatch: configured {VECTOR_LENGTH} vs. {db_vector_length} in the database." + f'Vector dimension mismatch: configured {VECTOR_LENGTH} vs. {db_vector_length} in the database.' ) else: raise Exception("The 'vector' column type is not Vector.") else: - raise Exception( - "The 'vector' column does not exist in the 'document_chunk' table." - ) + raise Exception("The 'vector' column does not exist in the 'document_chunk' table.") def adjust_vector_length(self, vector: List[float]) -> List[float]: current_length = len(vector) @@ -182,55 +173,47 @@ class OpenGaussClient(VectorDBBase): try: new_items = [] for item in items: - vector = self.adjust_vector_length(item["vector"]) + vector = self.adjust_vector_length(item['vector']) new_chunk = DocumentChunk( - id=item["id"], + id=item['id'], vector=vector, collection_name=collection_name, - text=item["text"], - vmetadata=process_metadata(item["metadata"]), + text=item['text'], + vmetadata=process_metadata(item['metadata']), ) new_items.append(new_chunk) self.session.bulk_save_objects(new_items) self.session.commit() - log.info( - f"Inserting {len(new_items)} items into collection '{collection_name}'." - ) + log.info(f"Inserting {len(new_items)} items into collection '{collection_name}'.") except Exception as e: self.session.rollback() - log.exception(f"Failed to insert data: {e}") + log.exception(f'Failed to insert data: {e}') raise def upsert(self, collection_name: str, items: List[VectorItem]) -> None: try: for item in items: - vector = self.adjust_vector_length(item["vector"]) - existing = ( - self.session.query(DocumentChunk) - .filter(DocumentChunk.id == item["id"]) - .first() - ) + vector = self.adjust_vector_length(item['vector']) + existing = self.session.query(DocumentChunk).filter(DocumentChunk.id == item['id']).first() if existing: existing.vector = vector - existing.text = item["text"] - existing.vmetadata = process_metadata(item["metadata"]) + existing.text = item['text'] + existing.vmetadata = process_metadata(item['metadata']) existing.collection_name = collection_name else: new_chunk = DocumentChunk( - id=item["id"], + id=item['id'], vector=vector, collection_name=collection_name, - text=item["text"], - vmetadata=process_metadata(item["metadata"]), + text=item['text'], + vmetadata=process_metadata(item['metadata']), ) self.session.add(new_chunk) self.session.commit() - log.info( - f"Inserting/updating {len(items)} items in collection '{collection_name}'." - ) + log.info(f"Inserting/updating {len(items)} items in collection '{collection_name}'.") except Exception as e: self.session.rollback() - log.exception(f"Failed to insert or update data.: {e}") + log.exception(f'Failed to insert or update data.: {e}') raise def search( @@ -250,35 +233,29 @@ class OpenGaussClient(VectorDBBase): def vector_expr(vector): return cast(array(vector), Vector(VECTOR_LENGTH)) - qid_col = column("qid", Integer) - q_vector_col = column("q_vector", Vector(VECTOR_LENGTH)) + qid_col = column('qid', Integer) + q_vector_col = column('q_vector', Vector(VECTOR_LENGTH)) query_vectors = ( values(qid_col, q_vector_col) - .data( - [(idx, vector_expr(vector)) for idx, vector in enumerate(vectors)] - ) - .alias("query_vectors") + .data([(idx, vector_expr(vector)) for idx, vector in enumerate(vectors)]) + .alias('query_vectors') ) result_fields = [ DocumentChunk.id, DocumentChunk.text, DocumentChunk.vmetadata, - (DocumentChunk.vector.cosine_distance(query_vectors.c.q_vector)).label( - "distance" - ), + (DocumentChunk.vector.cosine_distance(query_vectors.c.q_vector)).label('distance'), ] subq = ( select(*result_fields) .where(DocumentChunk.collection_name == collection_name) - .order_by( - DocumentChunk.vector.cosine_distance(query_vectors.c.q_vector) - ) + .order_by(DocumentChunk.vector.cosine_distance(query_vectors.c.q_vector)) ) if limit is not None: subq = subq.limit(limit) - subq = subq.lateral("result") + subq = subq.lateral('result') stmt = ( select( @@ -309,21 +286,15 @@ class OpenGaussClient(VectorDBBase): metadatas[qid].append(row.vmetadata) self.session.rollback() - return SearchResult( - ids=ids, distances=distances, documents=documents, metadatas=metadatas - ) + return SearchResult(ids=ids, distances=distances, documents=documents, metadatas=metadatas) except Exception as e: self.session.rollback() - log.exception(f"Vector search failed: {e}") + log.exception(f'Vector search failed: {e}') return None - def query( - self, collection_name: str, filter: Dict[str, Any], limit: Optional[int] = None - ) -> Optional[GetResult]: + def query(self, collection_name: str, filter: Dict[str, Any], limit: Optional[int] = None) -> Optional[GetResult]: try: - query = self.session.query(DocumentChunk).filter( - DocumentChunk.collection_name == collection_name - ) + query = self.session.query(DocumentChunk).filter(DocumentChunk.collection_name == collection_name) for key, value in filter.items(): query = query.filter(DocumentChunk.vmetadata[key].astext == str(value)) @@ -344,16 +315,12 @@ class OpenGaussClient(VectorDBBase): return GetResult(ids=ids, documents=documents, metadatas=metadatas) except Exception as e: self.session.rollback() - log.exception(f"Conditional query failed: {e}") + log.exception(f'Conditional query failed: {e}') return None - def get( - self, collection_name: str, limit: Optional[int] = None - ) -> Optional[GetResult]: + def get(self, collection_name: str, limit: Optional[int] = None) -> Optional[GetResult]: try: - query = self.session.query(DocumentChunk).filter( - DocumentChunk.collection_name == collection_name - ) + query = self.session.query(DocumentChunk).filter(DocumentChunk.collection_name == collection_name) if limit is not None: query = query.limit(limit) @@ -370,7 +337,7 @@ class OpenGaussClient(VectorDBBase): return GetResult(ids=ids, documents=documents, metadatas=metadatas) except Exception as e: self.session.rollback() - log.exception(f"Failed to retrieve data: {e}") + log.exception(f'Failed to retrieve data: {e}') return None def delete( @@ -380,32 +347,28 @@ class OpenGaussClient(VectorDBBase): filter: Optional[Dict[str, Any]] = None, ) -> None: try: - query = self.session.query(DocumentChunk).filter( - DocumentChunk.collection_name == collection_name - ) + query = self.session.query(DocumentChunk).filter(DocumentChunk.collection_name == collection_name) if ids: query = query.filter(DocumentChunk.id.in_(ids)) if filter: for key, value in filter.items(): - query = query.filter( - DocumentChunk.vmetadata[key].astext == str(value) - ) + query = query.filter(DocumentChunk.vmetadata[key].astext == str(value)) deleted = query.delete(synchronize_session=False) self.session.commit() log.info(f"Deleted {deleted} items from collection '{collection_name}'") except Exception as e: self.session.rollback() - log.exception(f"Failed to delete data: {e}") + log.exception(f'Failed to delete data: {e}') raise def reset(self) -> None: try: deleted = self.session.query(DocumentChunk).delete() self.session.commit() - log.info(f"Reset completed. Deleted {deleted} items") + log.info(f'Reset completed. Deleted {deleted} items') except Exception as e: self.session.rollback() - log.exception(f"Reset failed: {e}") + log.exception(f'Reset failed: {e}') raise def close(self) -> None: @@ -414,16 +377,14 @@ class OpenGaussClient(VectorDBBase): def has_collection(self, collection_name: str) -> bool: try: exists = ( - self.session.query(DocumentChunk) - .filter(DocumentChunk.collection_name == collection_name) - .first() + self.session.query(DocumentChunk).filter(DocumentChunk.collection_name == collection_name).first() is not None ) self.session.rollback() return exists except Exception as e: self.session.rollback() - log.exception(f"Failed to check collection existence: {e}") + log.exception(f'Failed to check collection existence: {e}') return False def delete_collection(self, collection_name: str) -> None: diff --git a/backend/open_webui/retrieval/vector/dbs/opensearch.py b/backend/open_webui/retrieval/vector/dbs/opensearch.py index 3ad82d7442..a08dca7865 100644 --- a/backend/open_webui/retrieval/vector/dbs/opensearch.py +++ b/backend/open_webui/retrieval/vector/dbs/opensearch.py @@ -24,7 +24,7 @@ from open_webui.config import ( class OpenSearchClient(VectorDBBase): def __init__(self): - self.index_prefix = "open_webui" + self.index_prefix = 'open_webui' self.client = OpenSearch( hosts=[OPENSEARCH_URI], use_ssl=OPENSEARCH_SSL, @@ -33,25 +33,25 @@ class OpenSearchClient(VectorDBBase): ) def _get_index_name(self, collection_name: str) -> str: - return f"{self.index_prefix}_{collection_name}" + return f'{self.index_prefix}_{collection_name}' def _result_to_get_result(self, result) -> GetResult: - if not result["hits"]["hits"]: + if not result['hits']['hits']: return None ids = [] documents = [] metadatas = [] - for hit in result["hits"]["hits"]: - ids.append(hit["_id"]) - documents.append(hit["_source"].get("text")) - metadatas.append(hit["_source"].get("metadata")) + for hit in result['hits']['hits']: + ids.append(hit['_id']) + documents.append(hit['_source'].get('text')) + metadatas.append(hit['_source'].get('metadata')) return GetResult(ids=[ids], documents=[documents], metadatas=[metadatas]) def _result_to_search_result(self, result) -> SearchResult: - if not result["hits"]["hits"]: + if not result['hits']['hits']: return None ids = [] @@ -59,11 +59,11 @@ class OpenSearchClient(VectorDBBase): documents = [] metadatas = [] - for hit in result["hits"]["hits"]: - ids.append(hit["_id"]) - distances.append(hit["_score"]) - documents.append(hit["_source"].get("text")) - metadatas.append(hit["_source"].get("metadata")) + for hit in result['hits']['hits']: + ids.append(hit['_id']) + distances.append(hit['_score']) + documents.append(hit['_source'].get('text')) + metadatas.append(hit['_source'].get('metadata')) return SearchResult( ids=[ids], @@ -74,33 +74,31 @@ class OpenSearchClient(VectorDBBase): def _create_index(self, collection_name: str, dimension: int): body = { - "settings": {"index": {"knn": True}}, - "mappings": { - "properties": { - "id": {"type": "keyword"}, - "vector": { - "type": "knn_vector", - "dimension": dimension, # Adjust based on your vector dimensions - "index": True, - "similarity": "faiss", - "method": { - "name": "hnsw", - "space_type": "innerproduct", # Use inner product to approximate cosine similarity - "engine": "faiss", - "parameters": { - "ef_construction": 128, - "m": 16, + 'settings': {'index': {'knn': True}}, + 'mappings': { + 'properties': { + 'id': {'type': 'keyword'}, + 'vector': { + 'type': 'knn_vector', + 'dimension': dimension, # Adjust based on your vector dimensions + 'index': True, + 'similarity': 'faiss', + 'method': { + 'name': 'hnsw', + 'space_type': 'innerproduct', # Use inner product to approximate cosine similarity + 'engine': 'faiss', + 'parameters': { + 'ef_construction': 128, + 'm': 16, }, }, }, - "text": {"type": "text"}, - "metadata": {"type": "object"}, + 'text': {'type': 'text'}, + 'metadata': {'type': 'object'}, } }, } - self.client.indices.create( - index=self._get_index_name(collection_name), body=body - ) + self.client.indices.create(index=self._get_index_name(collection_name), body=body) def _create_batches(self, items: list[VectorItem], batch_size=100): for i in range(0, len(items), batch_size): @@ -128,46 +126,40 @@ class OpenSearchClient(VectorDBBase): return None query = { - "size": limit, - "_source": ["text", "metadata"], - "query": { - "script_score": { - "query": {"match_all": {}}, - "script": { - "source": "(cosineSimilarity(params.query_value, doc[params.field]) + 1.0) / 2.0", - "params": { - "field": "vector", - "query_value": vectors[0], + 'size': limit, + '_source': ['text', 'metadata'], + 'query': { + 'script_score': { + 'query': {'match_all': {}}, + 'script': { + 'source': '(cosineSimilarity(params.query_value, doc[params.field]) + 1.0) / 2.0', + 'params': { + 'field': 'vector', + 'query_value': vectors[0], }, # Assuming single query vector }, } }, } - result = self.client.search( - index=self._get_index_name(collection_name), body=query - ) + result = self.client.search(index=self._get_index_name(collection_name), body=query) return self._result_to_search_result(result) except Exception as e: return None - def query( - self, collection_name: str, filter: dict, limit: Optional[int] = None - ) -> Optional[GetResult]: + def query(self, collection_name: str, filter: dict, limit: Optional[int] = None) -> Optional[GetResult]: if not self.has_collection(collection_name): return None query_body = { - "query": {"bool": {"filter": []}}, - "_source": ["text", "metadata"], + 'query': {'bool': {'filter': []}}, + '_source': ['text', 'metadata'], } for field, value in filter.items(): - query_body["query"]["bool"]["filter"].append( - {"term": {"metadata." + str(field) + ".keyword": value}} - ) + query_body['query']['bool']['filter'].append({'term': {'metadata.' + str(field) + '.keyword': value}}) size = limit if limit else 10000 @@ -188,28 +180,24 @@ class OpenSearchClient(VectorDBBase): self._create_index(collection_name, dimension) def get(self, collection_name: str) -> Optional[GetResult]: - query = {"query": {"match_all": {}}, "_source": ["text", "metadata"]} + query = {'query': {'match_all': {}}, '_source': ['text', 'metadata']} - result = self.client.search( - index=self._get_index_name(collection_name), body=query - ) + result = self.client.search(index=self._get_index_name(collection_name), body=query) return self._result_to_get_result(result) def insert(self, collection_name: str, items: list[VectorItem]): - self._create_index_if_not_exists( - collection_name=collection_name, dimension=len(items[0]["vector"]) - ) + self._create_index_if_not_exists(collection_name=collection_name, dimension=len(items[0]['vector'])) for batch in self._create_batches(items): actions = [ { - "_op_type": "index", - "_index": self._get_index_name(collection_name), - "_id": item["id"], - "_source": { - "vector": item["vector"], - "text": item["text"], - "metadata": process_metadata(item["metadata"]), + '_op_type': 'index', + '_index': self._get_index_name(collection_name), + '_id': item['id'], + '_source': { + 'vector': item['vector'], + 'text': item['text'], + 'metadata': process_metadata(item['metadata']), }, } for item in batch @@ -218,22 +206,20 @@ class OpenSearchClient(VectorDBBase): self.client.indices.refresh(index=self._get_index_name(collection_name)) def upsert(self, collection_name: str, items: list[VectorItem]): - self._create_index_if_not_exists( - collection_name=collection_name, dimension=len(items[0]["vector"]) - ) + self._create_index_if_not_exists(collection_name=collection_name, dimension=len(items[0]['vector'])) for batch in self._create_batches(items): actions = [ { - "_op_type": "update", - "_index": self._get_index_name(collection_name), - "_id": item["id"], - "doc": { - "vector": item["vector"], - "text": item["text"], - "metadata": process_metadata(item["metadata"]), + '_op_type': 'update', + '_index': self._get_index_name(collection_name), + '_id': item['id'], + 'doc': { + 'vector': item['vector'], + 'text': item['text'], + 'metadata': process_metadata(item['metadata']), }, - "doc_as_upsert": True, + 'doc_as_upsert': True, } for item in batch ] @@ -249,27 +235,23 @@ class OpenSearchClient(VectorDBBase): if ids: actions = [ { - "_op_type": "delete", - "_index": self._get_index_name(collection_name), - "_id": id, + '_op_type': 'delete', + '_index': self._get_index_name(collection_name), + '_id': id, } for id in ids ] bulk(self.client, actions) elif filter: query_body = { - "query": {"bool": {"filter": []}}, + 'query': {'bool': {'filter': []}}, } for field, value in filter.items(): - query_body["query"]["bool"]["filter"].append( - {"term": {"metadata." + str(field) + ".keyword": value}} - ) - self.client.delete_by_query( - index=self._get_index_name(collection_name), body=query_body - ) + query_body['query']['bool']['filter'].append({'term': {'metadata.' + str(field) + '.keyword': value}}) + self.client.delete_by_query(index=self._get_index_name(collection_name), body=query_body) self.client.indices.refresh(index=self._get_index_name(collection_name)) def reset(self): - indices = self.client.indices.get(index=f"{self.index_prefix}_*") + indices = self.client.indices.get(index=f'{self.index_prefix}_*') for index in indices: self.client.indices.delete(index=index) diff --git a/backend/open_webui/retrieval/vector/dbs/oracle23ai.py b/backend/open_webui/retrieval/vector/dbs/oracle23ai.py index 574de75303..9a5bd638d9 100644 --- a/backend/open_webui/retrieval/vector/dbs/oracle23ai.py +++ b/backend/open_webui/retrieval/vector/dbs/oracle23ai.py @@ -94,15 +94,15 @@ class Oracle23aiClient(VectorDBBase): self._create_dbcs_pool() dsn = ORACLE_DB_DSN - log.info(f"Creating Connection Pool [{ORACLE_DB_USER}:**@{dsn}]") + log.info(f'Creating Connection Pool [{ORACLE_DB_USER}:**@{dsn}]') with self.get_connection() as connection: - log.info(f"Connection version: {connection.version}") + log.info(f'Connection version: {connection.version}') self._initialize_database(connection) - log.info("Oracle Vector Search initialization complete.") + log.info('Oracle Vector Search initialization complete.') except Exception as e: - log.exception(f"Error during Oracle Vector Search initialization: {e}") + log.exception(f'Error during Oracle Vector Search initialization: {e}') raise def _create_adb_pool(self) -> None: @@ -122,7 +122,7 @@ class Oracle23aiClient(VectorDBBase): wallet_location=ORACLE_WALLET_DIR, wallet_password=ORACLE_WALLET_PASSWORD, ) - log.info("Created ADB connection pool with wallet authentication.") + log.info('Created ADB connection pool with wallet authentication.') def _create_dbcs_pool(self) -> None: """ @@ -138,7 +138,7 @@ class Oracle23aiClient(VectorDBBase): max=ORACLE_DB_POOL_MAX, increment=ORACLE_DB_POOL_INCREMENT, ) - log.info("Created DB connection pool with basic authentication.") + log.info('Created DB connection pool with basic authentication.') def get_connection(self): """ @@ -155,13 +155,11 @@ class Oracle23aiClient(VectorDBBase): return connection except oracledb.DatabaseError as e: (error_obj,) = e.args - log.exception( - f"Connection attempt {attempt + 1} failed: {error_obj.message}" - ) + log.exception(f'Connection attempt {attempt + 1} failed: {error_obj.message}') if attempt < max_retries - 1: wait_time = 2**attempt - log.info(f"Retrying in {wait_time} seconds...") + log.info(f'Retrying in {wait_time} seconds...') time.sleep(wait_time) else: raise @@ -177,30 +175,30 @@ class Oracle23aiClient(VectorDBBase): def _monitor(): while True: try: - log.info("[HealthCheck] Running periodic DB health check...") + log.info('[HealthCheck] Running periodic DB health check...') self.ensure_connection() - log.info("[HealthCheck] Connection is healthy.") + log.info('[HealthCheck] Connection is healthy.') except Exception as e: - log.exception(f"[HealthCheck] Connection health check failed: {e}") + log.exception(f'[HealthCheck] Connection health check failed: {e}') time.sleep(interval_seconds) thread = threading.Thread(target=_monitor, daemon=True) thread.start() - log.info(f"Started DB health monitor every {interval_seconds} seconds.") + log.info(f'Started DB health monitor every {interval_seconds} seconds.') def _reconnect_pool(self): """ Attempt to reinitialize the connection pool if it's been closed or broken. """ try: - log.info("Attempting to reinitialize the Oracle connection pool...") + log.info('Attempting to reinitialize the Oracle connection pool...') # Close existing pool if it exists if self.pool: try: self.pool.close() except Exception as close_error: - log.warning(f"Error closing existing pool: {close_error}") + log.warning(f'Error closing existing pool: {close_error}') # Re-create the appropriate connection pool based on DB type if ORACLE_DB_USE_WALLET: @@ -208,9 +206,9 @@ class Oracle23aiClient(VectorDBBase): else: # DBCS self._create_dbcs_pool() - log.info("Connection pool reinitialized.") + log.info('Connection pool reinitialized.') except Exception as e: - log.exception(f"Failed to reinitialize the connection pool: {e}") + log.exception(f'Failed to reinitialize the connection pool: {e}') raise def ensure_connection(self): @@ -220,11 +218,9 @@ class Oracle23aiClient(VectorDBBase): try: with self.get_connection() as connection: with connection.cursor() as cursor: - cursor.execute("SELECT 1 FROM dual") + cursor.execute('SELECT 1 FROM dual') except Exception as e: - log.exception( - f"Connection check failed: {e}, attempting to reconnect pool..." - ) + log.exception(f'Connection check failed: {e}, attempting to reconnect pool...') self._reconnect_pool() def _output_type_handler(self, cursor, metadata): @@ -239,9 +235,7 @@ class Oracle23aiClient(VectorDBBase): A variable with appropriate conversion for vector types """ if metadata.type_code is oracledb.DB_TYPE_VECTOR: - return cursor.var( - metadata.type_code, arraysize=cursor.arraysize, outconverter=list - ) + return cursor.var(metadata.type_code, arraysize=cursor.arraysize, outconverter=list) def _initialize_database(self, connection) -> None: """ @@ -257,7 +251,7 @@ class Oracle23aiClient(VectorDBBase): """ with connection.cursor() as cursor: try: - log.info("Creating Table document_chunk") + log.info('Creating Table document_chunk') cursor.execute( """ BEGIN @@ -279,7 +273,7 @@ class Oracle23aiClient(VectorDBBase): """ ) - log.info("Creating Index document_chunk_collection_name_idx") + log.info('Creating Index document_chunk_collection_name_idx') cursor.execute( """ BEGIN @@ -296,7 +290,7 @@ class Oracle23aiClient(VectorDBBase): """ ) - log.info("Creating VECTOR INDEX document_chunk_vector_ivf_idx") + log.info('Creating VECTOR INDEX document_chunk_vector_ivf_idx') cursor.execute( """ BEGIN @@ -318,11 +312,11 @@ class Oracle23aiClient(VectorDBBase): ) connection.commit() - log.info("Database initialization completed successfully.") + log.info('Database initialization completed successfully.') except Exception as e: connection.rollback() - log.exception(f"Error during database initialization: {e}") + log.exception(f'Error during database initialization: {e}') raise def check_vector_length(self) -> None: @@ -344,7 +338,7 @@ class Oracle23aiClient(VectorDBBase): Returns: bytes: The vector in Oracle BLOB format """ - return array.array("f", vector) + return array.array('f', vector) def adjust_vector_length(self, vector: List[float]) -> List[float]: """ @@ -373,7 +367,7 @@ class Oracle23aiClient(VectorDBBase): """ if isinstance(obj, Decimal): return float(obj) - raise TypeError(f"{obj} is not JSON serializable") + raise TypeError(f'{obj} is not JSON serializable') def _metadata_to_json(self, metadata: Dict) -> str: """ @@ -385,7 +379,7 @@ class Oracle23aiClient(VectorDBBase): Returns: str: JSON representation of metadata """ - return json.dumps(metadata, default=self._decimal_handler) if metadata else "{}" + return json.dumps(metadata, default=self._decimal_handler) if metadata else '{}' def _json_to_metadata(self, json_str: str) -> Dict: """ @@ -424,8 +418,8 @@ class Oracle23aiClient(VectorDBBase): try: with connection.cursor() as cursor: for item in items: - vector_blob = self._vector_to_blob(item["vector"]) - metadata_json = self._metadata_to_json(item["metadata"]) + vector_blob = self._vector_to_blob(item['vector']) + metadata_json = self._metadata_to_json(item['metadata']) cursor.execute( """ @@ -434,22 +428,20 @@ class Oracle23aiClient(VectorDBBase): VALUES (:id, :collection_name, :text, :metadata, :vector) """, { - "id": item["id"], - "collection_name": collection_name, - "text": item["text"], - "metadata": metadata_json, - "vector": vector_blob, + 'id': item['id'], + 'collection_name': collection_name, + 'text': item['text'], + 'metadata': metadata_json, + 'vector': vector_blob, }, ) connection.commit() - log.info( - f"Successfully inserted {len(items)} items into collection '{collection_name}'." - ) + log.info(f"Successfully inserted {len(items)} items into collection '{collection_name}'.") except Exception as e: connection.rollback() - log.exception(f"Error during insert: {e}") + log.exception(f'Error during insert: {e}') raise def upsert(self, collection_name: str, items: List[VectorItem]) -> None: @@ -480,8 +472,8 @@ class Oracle23aiClient(VectorDBBase): try: with connection.cursor() as cursor: for item in items: - vector_blob = self._vector_to_blob(item["vector"]) - metadata_json = self._metadata_to_json(item["metadata"]) + vector_blob = self._vector_to_blob(item['vector']) + metadata_json = self._metadata_to_json(item['metadata']) cursor.execute( """ @@ -499,27 +491,25 @@ class Oracle23aiClient(VectorDBBase): VALUES (:ins_id, :ins_collection_name, :ins_text, :ins_metadata, :ins_vector) """, { - "merge_id": item["id"], - "upd_collection_name": collection_name, - "upd_text": item["text"], - "upd_metadata": metadata_json, - "upd_vector": vector_blob, - "ins_id": item["id"], - "ins_collection_name": collection_name, - "ins_text": item["text"], - "ins_metadata": metadata_json, - "ins_vector": vector_blob, + 'merge_id': item['id'], + 'upd_collection_name': collection_name, + 'upd_text': item['text'], + 'upd_metadata': metadata_json, + 'upd_vector': vector_blob, + 'ins_id': item['id'], + 'ins_collection_name': collection_name, + 'ins_text': item['text'], + 'ins_metadata': metadata_json, + 'ins_vector': vector_blob, }, ) connection.commit() - log.info( - f"Successfully upserted {len(items)} items into collection '{collection_name}'." - ) + log.info(f"Successfully upserted {len(items)} items into collection '{collection_name}'.") except Exception as e: connection.rollback() - log.exception(f"Error during upsert: {e}") + log.exception(f'Error during upsert: {e}') raise def search( @@ -551,13 +541,11 @@ class Oracle23aiClient(VectorDBBase): ... for i, (id, dist) in enumerate(zip(results.ids[0], results.distances[0])): ... log.info(f"Match {i+1}: id={id}, distance={dist}") """ - log.info( - f"Searching items from collection '{collection_name}' with limit {limit}." - ) + log.info(f"Searching items from collection '{collection_name}' with limit {limit}.") try: if not vectors: - log.warning("No vectors provided for search.") + log.warning('No vectors provided for search.') return None num_queries = len(vectors) @@ -583,9 +571,9 @@ class Oracle23aiClient(VectorDBBase): FETCH APPROX FIRST :limit ROWS ONLY """, { - "query_vector": vector_blob, - "collection_name": collection_name, - "limit": limit, + 'query_vector': vector_blob, + 'collection_name': collection_name, + 'limit': limit, }, ) @@ -593,35 +581,21 @@ class Oracle23aiClient(VectorDBBase): for row in results: ids[qid].append(row[0]) - documents[qid].append( - row[1].read() - if isinstance(row[1], oracledb.LOB) - else str(row[1]) - ) + documents[qid].append(row[1].read() if isinstance(row[1], oracledb.LOB) else str(row[1])) # 🔧 FIXED: Parse JSON metadata properly - metadata_str = ( - row[2].read() - if isinstance(row[2], oracledb.LOB) - else row[2] - ) + metadata_str = row[2].read() if isinstance(row[2], oracledb.LOB) else row[2] metadatas[qid].append(self._json_to_metadata(metadata_str)) distances[qid].append(float(row[3])) - log.info( - f"Search completed. Found {sum(len(ids[i]) for i in range(num_queries))} total results." - ) + log.info(f'Search completed. Found {sum(len(ids[i]) for i in range(num_queries))} total results.') - return SearchResult( - ids=ids, distances=distances, documents=documents, metadatas=metadatas - ) + return SearchResult(ids=ids, distances=distances, documents=documents, metadatas=metadatas) except Exception as e: - log.exception(f"Error during search: {e}") + log.exception(f'Error during search: {e}') return None - def query( - self, collection_name: str, filter: Dict, limit: Optional[int] = None - ) -> Optional[GetResult]: + def query(self, collection_name: str, filter: Dict, limit: Optional[int] = None) -> Optional[GetResult]: """ Query items based on metadata filters. @@ -653,15 +627,15 @@ class Oracle23aiClient(VectorDBBase): WHERE collection_name = :collection_name """ - params = {"collection_name": collection_name} + params = {'collection_name': collection_name} for i, (key, value) in enumerate(filter.items()): - param_name = f"value_{i}" + param_name = f'value_{i}' query += f" AND JSON_VALUE(vmetadata, '$.{key}' RETURNING VARCHAR2(4096)) = :{param_name}" params[param_name] = str(value) - query += " FETCH FIRST :limit ROWS ONLY" - params["limit"] = limit + query += ' FETCH FIRST :limit ROWS ONLY' + params['limit'] = limit with self.get_connection() as connection: with connection.cursor() as cursor: @@ -669,32 +643,25 @@ class Oracle23aiClient(VectorDBBase): results = cursor.fetchall() if not results: - log.info("No results found for query.") + log.info('No results found for query.') return None ids = [[row[0] for row in results]] - documents = [ - [ - row[1].read() if isinstance(row[1], oracledb.LOB) else str(row[1]) - for row in results - ] - ] + documents = [[row[1].read() if isinstance(row[1], oracledb.LOB) else str(row[1]) for row in results]] # 🔧 FIXED: Parse JSON metadata properly metadatas = [ [ - self._json_to_metadata( - row[2].read() if isinstance(row[2], oracledb.LOB) else row[2] - ) + self._json_to_metadata(row[2].read() if isinstance(row[2], oracledb.LOB) else row[2]) for row in results ] ] - log.info(f"Query completed. Found {len(results)} results.") + log.info(f'Query completed. Found {len(results)} results.') return GetResult(ids=ids, documents=documents, metadatas=metadatas) except Exception as e: - log.exception(f"Error during query: {e}") + log.exception(f'Error during query: {e}') return None def get(self, collection_name: str) -> Optional[GetResult]: @@ -729,28 +696,21 @@ class Oracle23aiClient(VectorDBBase): WHERE collection_name = :collection_name FETCH FIRST :limit ROWS ONLY """, - {"collection_name": collection_name, "limit": limit}, + {'collection_name': collection_name, 'limit': limit}, ) results = cursor.fetchall() if not results: - log.info("No results found.") + log.info('No results found.') return None ids = [[row[0] for row in results]] - documents = [ - [ - row[1].read() if isinstance(row[1], oracledb.LOB) else str(row[1]) - for row in results - ] - ] + documents = [[row[1].read() if isinstance(row[1], oracledb.LOB) else str(row[1]) for row in results]] # 🔧 FIXED: Parse JSON metadata properly metadatas = [ [ - self._json_to_metadata( - row[2].read() if isinstance(row[2], oracledb.LOB) else row[2] - ) + self._json_to_metadata(row[2].read() if isinstance(row[2], oracledb.LOB) else row[2]) for row in results ] ] @@ -758,7 +718,7 @@ class Oracle23aiClient(VectorDBBase): return GetResult(ids=ids, documents=documents, metadatas=metadatas) except Exception as e: - log.exception(f"Error during get: {e}") + log.exception(f'Error during get: {e}') return None def delete( @@ -790,21 +750,19 @@ class Oracle23aiClient(VectorDBBase): log.info(f"Deleting items from collection '{collection_name}'.") try: - query = ( - "DELETE FROM document_chunk WHERE collection_name = :collection_name" - ) - params = {"collection_name": collection_name} + query = 'DELETE FROM document_chunk WHERE collection_name = :collection_name' + params = {'collection_name': collection_name} if ids: # 🔧 FIXED: Use proper parameterized query to prevent SQL injection - placeholders = ",".join([f":id_{i}" for i in range(len(ids))]) - query += f" AND id IN ({placeholders})" + placeholders = ','.join([f':id_{i}' for i in range(len(ids))]) + query += f' AND id IN ({placeholders})' for i, id_val in enumerate(ids): - params[f"id_{i}"] = id_val + params[f'id_{i}'] = id_val if filter: for i, (key, value) in enumerate(filter.items()): - param_name = f"value_{i}" + param_name = f'value_{i}' query += f" AND JSON_VALUE(vmetadata, '$.{key}' RETURNING VARCHAR2(4096)) = :{param_name}" params[param_name] = str(value) @@ -817,7 +775,7 @@ class Oracle23aiClient(VectorDBBase): log.info(f"Deleted {deleted} items from collection '{collection_name}'.") except Exception as e: - log.exception(f"Error during delete: {e}") + log.exception(f'Error during delete: {e}') raise def reset(self) -> None: @@ -833,21 +791,19 @@ class Oracle23aiClient(VectorDBBase): >>> client = Oracle23aiClient() >>> client.reset() # Warning: Removes all data! """ - log.info("Resetting database - deleting all items.") + log.info('Resetting database - deleting all items.') try: with self.get_connection() as connection: with connection.cursor() as cursor: - cursor.execute("DELETE FROM document_chunk") + cursor.execute('DELETE FROM document_chunk') deleted = cursor.rowcount connection.commit() - log.info( - f"Reset complete. Deleted {deleted} items from 'document_chunk' table." - ) + log.info(f"Reset complete. Deleted {deleted} items from 'document_chunk' table.") except Exception as e: - log.exception(f"Error during reset: {e}") + log.exception(f'Error during reset: {e}') raise def close(self) -> None: @@ -862,11 +818,11 @@ class Oracle23aiClient(VectorDBBase): >>> client.close() """ try: - if hasattr(self, "pool") and self.pool: + if hasattr(self, 'pool') and self.pool: self.pool.close() - log.info("Oracle Vector Search connection pool closed.") + log.info('Oracle Vector Search connection pool closed.') except Exception as e: - log.exception(f"Error closing connection pool: {e}") + log.exception(f'Error closing connection pool: {e}') def has_collection(self, collection_name: str) -> bool: """ @@ -895,7 +851,7 @@ class Oracle23aiClient(VectorDBBase): WHERE collection_name = :collection_name FETCH FIRST 1 ROWS ONLY """, - {"collection_name": collection_name}, + {'collection_name': collection_name}, ) count = cursor.fetchone()[0] @@ -903,7 +859,7 @@ class Oracle23aiClient(VectorDBBase): return count > 0 except Exception as e: - log.exception(f"Error checking collection existence: {e}") + log.exception(f'Error checking collection existence: {e}') return False def delete_collection(self, collection_name: str) -> None: @@ -929,15 +885,13 @@ class Oracle23aiClient(VectorDBBase): DELETE FROM document_chunk WHERE collection_name = :collection_name """, - {"collection_name": collection_name}, + {'collection_name': collection_name}, ) deleted = cursor.rowcount connection.commit() - log.info( - f"Collection '{collection_name}' deleted. Removed {deleted} items." - ) + log.info(f"Collection '{collection_name}' deleted. Removed {deleted} items.") except Exception as e: log.exception(f"Error deleting collection '{collection_name}': {e}") diff --git a/backend/open_webui/retrieval/vector/dbs/pgvector.py b/backend/open_webui/retrieval/vector/dbs/pgvector.py index 481f9d92fc..4775ff21f4 100644 --- a/backend/open_webui/retrieval/vector/dbs/pgvector.py +++ b/backend/open_webui/retrieval/vector/dbs/pgvector.py @@ -55,7 +55,7 @@ VECTOR_LENGTH = PGVECTOR_INITIALIZE_MAX_VECTOR_LENGTH USE_HALFVEC = PGVECTOR_USE_HALFVEC VECTOR_TYPE_FACTORY = HALFVEC if USE_HALFVEC else Vector -VECTOR_OPCLASS = "halfvec_cosine_ops" if USE_HALFVEC else "vector_cosine_ops" +VECTOR_OPCLASS = 'halfvec_cosine_ops' if USE_HALFVEC else 'vector_cosine_ops' Base = declarative_base() log = logging.getLogger(__name__) @@ -65,12 +65,12 @@ def pgcrypto_encrypt(val, key): return func.pgp_sym_encrypt(val, literal(key)) -def pgcrypto_decrypt(col, key, outtype="text"): +def pgcrypto_decrypt(col, key, outtype='text'): return func.cast(func.pgp_sym_decrypt(col, literal(key)), outtype) class DocumentChunk(Base): - __tablename__ = "document_chunk" + __tablename__ = 'document_chunk' id = Column(Text, primary_key=True) vector = Column(VECTOR_TYPE_FACTORY(dim=VECTOR_LENGTH), nullable=True) @@ -86,7 +86,6 @@ class DocumentChunk(Base): class PgvectorClient(VectorDBBase): def __init__(self) -> None: - # if no pgvector uri, use the existing database connection if not PGVECTOR_DB_URL: from open_webui.internal.db import ScopedSession @@ -105,46 +104,44 @@ class PgvectorClient(VectorDBBase): poolclass=QueuePool, ) else: - engine = create_engine( - PGVECTOR_DB_URL, pool_pre_ping=True, poolclass=NullPool - ) + engine = create_engine(PGVECTOR_DB_URL, pool_pre_ping=True, poolclass=NullPool) else: engine = create_engine(PGVECTOR_DB_URL, pool_pre_ping=True) - SessionLocal = sessionmaker( - autocommit=False, autoflush=False, bind=engine, expire_on_commit=False - ) + SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine, expire_on_commit=False) self.session = scoped_session(SessionLocal) try: # Ensure the pgvector extension is available # Use a conditional check to avoid permission issues on Azure PostgreSQL if PGVECTOR_CREATE_EXTENSION: - self.session.execute(text(""" + self.session.execute( + text(""" DO $$ BEGIN IF NOT EXISTS (SELECT 1 FROM pg_extension WHERE extname = 'vector') THEN CREATE EXTENSION IF NOT EXISTS vector; END IF; END $$; - """)) + """) + ) if PGVECTOR_PGCRYPTO: # Ensure the pgcrypto extension is available for encryption # Use a conditional check to avoid permission issues on Azure PostgreSQL - self.session.execute(text(""" + self.session.execute( + text(""" DO $$ BEGIN IF NOT EXISTS (SELECT 1 FROM pg_extension WHERE extname = 'pgcrypto') THEN CREATE EXTENSION IF NOT EXISTS pgcrypto; END IF; END $$; - """)) + """) + ) if not PGVECTOR_PGCRYPTO_KEY: - raise ValueError( - "PGVECTOR_PGCRYPTO_KEY must be set when PGVECTOR_PGCRYPTO is enabled." - ) + raise ValueError('PGVECTOR_PGCRYPTO_KEY must be set when PGVECTOR_PGCRYPTO is enabled.') # Check vector length consistency self.check_vector_length() @@ -160,15 +157,14 @@ class PgvectorClient(VectorDBBase): self.session.execute( text( - "CREATE INDEX IF NOT EXISTS idx_document_chunk_collection_name " - "ON document_chunk (collection_name);" + 'CREATE INDEX IF NOT EXISTS idx_document_chunk_collection_name ON document_chunk (collection_name);' ) ) self.session.commit() - log.info("Initialization complete.") + log.info('Initialization complete.') except Exception as e: self.session.rollback() - log.exception(f"Error during initialization: {e}") + log.exception(f'Error during initialization: {e}') raise @staticmethod @@ -176,7 +172,7 @@ class PgvectorClient(VectorDBBase): if not index_def: return None try: - after_using = index_def.lower().split("using ", 1)[1] + after_using = index_def.lower().split('using ', 1)[1] return after_using.split()[0] except (IndexError, AttributeError): return None @@ -189,23 +185,23 @@ class PgvectorClient(VectorDBBase): index_method, ) elif USE_HALFVEC: - index_method = "hnsw" + index_method = 'hnsw' log.info( - "VECTOR_LENGTH=%s exceeds 2000; using halfvec column type with hnsw index.", + 'VECTOR_LENGTH=%s exceeds 2000; using halfvec column type with hnsw index.', VECTOR_LENGTH, ) else: - index_method = "ivfflat" + index_method = 'ivfflat' - if index_method == "hnsw": - index_options = f"WITH (m = {PGVECTOR_HNSW_M}, ef_construction = {PGVECTOR_HNSW_EF_CONSTRUCTION})" + if index_method == 'hnsw': + index_options = f'WITH (m = {PGVECTOR_HNSW_M}, ef_construction = {PGVECTOR_HNSW_EF_CONSTRUCTION})' else: - index_options = f"WITH (lists = {PGVECTOR_IVFFLAT_LISTS})" + index_options = f'WITH (lists = {PGVECTOR_IVFFLAT_LISTS})' return index_method, index_options def _ensure_vector_index(self, index_method: str, index_options: str) -> None: - index_name = "idx_document_chunk_vector" + index_name = 'idx_document_chunk_vector' existing_index_def = self.session.execute( text(""" SELECT indexdef @@ -214,7 +210,7 @@ class PgvectorClient(VectorDBBase): AND tablename = 'document_chunk' AND indexname = :index_name """), - {"index_name": index_name}, + {'index_name': index_name}, ).scalar() existing_method = self._extract_index_method(existing_index_def) @@ -222,23 +218,23 @@ class PgvectorClient(VectorDBBase): raise RuntimeError( f"Existing pgvector index '{index_name}' uses method '{existing_method}' but configuration now " f"requires '{index_method}'. Automatic rebuild is disabled to prevent long-running maintenance. " - "Drop the index manually (optionally after tuning maintenance_work_mem/max_parallel_maintenance_workers) " - "and recreate it with the new method before restarting Open WebUI." + 'Drop the index manually (optionally after tuning maintenance_work_mem/max_parallel_maintenance_workers) ' + 'and recreate it with the new method before restarting Open WebUI.' ) if not existing_index_def: index_sql = ( - f"CREATE INDEX IF NOT EXISTS {index_name} " - f"ON document_chunk USING {index_method} (vector {VECTOR_OPCLASS})" + f'CREATE INDEX IF NOT EXISTS {index_name} ' + f'ON document_chunk USING {index_method} (vector {VECTOR_OPCLASS})' ) if index_options: - index_sql = f"{index_sql} {index_options}" + index_sql = f'{index_sql} {index_options}' self.session.execute(text(index_sql)) log.info( "Ensured vector index '%s' using %s%s.", index_name, index_method, - f" {index_options}" if index_options else "", + f' {index_options}' if index_options else '', ) def check_vector_length(self) -> None: @@ -249,16 +245,14 @@ class PgvectorClient(VectorDBBase): metadata = MetaData() try: # Attempt to reflect the 'document_chunk' table - document_chunk_table = Table( - "document_chunk", metadata, autoload_with=self.session.bind - ) + document_chunk_table = Table('document_chunk', metadata, autoload_with=self.session.bind) except NoSuchTableError: # Table does not exist; no action needed return # Proceed to check the vector column - if "vector" in document_chunk_table.columns: - vector_column = document_chunk_table.columns["vector"] + if 'vector' in document_chunk_table.columns: + vector_column = document_chunk_table.columns['vector'] vector_type = vector_column.type expected_type = HALFVEC if USE_HALFVEC else Vector @@ -268,16 +262,14 @@ class PgvectorClient(VectorDBBase): f"('{expected_type.__name__}') for VECTOR_LENGTH {VECTOR_LENGTH}." ) - db_vector_length = getattr(vector_type, "dim", None) + db_vector_length = getattr(vector_type, 'dim', None) if db_vector_length is not None and db_vector_length != VECTOR_LENGTH: raise Exception( - f"VECTOR_LENGTH {VECTOR_LENGTH} does not match existing vector column dimension {db_vector_length}. " - "Cannot change vector size after initialization without migrating the data." + f'VECTOR_LENGTH {VECTOR_LENGTH} does not match existing vector column dimension {db_vector_length}. ' + 'Cannot change vector size after initialization without migrating the data.' ) else: - raise Exception( - "The 'vector' column does not exist in the 'document_chunk' table." - ) + raise Exception("The 'vector' column does not exist in the 'document_chunk' table.") def adjust_vector_length(self, vector: List[float]) -> List[float]: # Adjust vector to have length VECTOR_LENGTH @@ -294,10 +286,10 @@ class PgvectorClient(VectorDBBase): try: if PGVECTOR_PGCRYPTO: for item in items: - vector = self.adjust_vector_length(item["vector"]) + vector = self.adjust_vector_length(item['vector']) # Use raw SQL for BYTEA/pgcrypto # Ensure metadata is converted to its JSON text representation - json_metadata = json.dumps(item["metadata"]) + json_metadata = json.dumps(item['metadata']) self.session.execute( text(""" INSERT INTO document_chunk @@ -310,12 +302,12 @@ class PgvectorClient(VectorDBBase): ON CONFLICT (id) DO NOTHING """), { - "id": item["id"], - "vector": vector, - "collection_name": collection_name, - "text": item["text"], - "metadata_text": json_metadata, - "key": PGVECTOR_PGCRYPTO_KEY, + 'id': item['id'], + 'vector': vector, + 'collection_name': collection_name, + 'text': item['text'], + 'metadata_text': json_metadata, + 'key': PGVECTOR_PGCRYPTO_KEY, }, ) self.session.commit() @@ -324,31 +316,29 @@ class PgvectorClient(VectorDBBase): else: new_items = [] for item in items: - vector = self.adjust_vector_length(item["vector"]) + vector = self.adjust_vector_length(item['vector']) new_chunk = DocumentChunk( - id=item["id"], + id=item['id'], vector=vector, collection_name=collection_name, - text=item["text"], - vmetadata=process_metadata(item["metadata"]), + text=item['text'], + vmetadata=process_metadata(item['metadata']), ) new_items.append(new_chunk) self.session.bulk_save_objects(new_items) self.session.commit() - log.info( - f"Inserted {len(new_items)} items into collection '{collection_name}'." - ) + log.info(f"Inserted {len(new_items)} items into collection '{collection_name}'.") except Exception as e: self.session.rollback() - log.exception(f"Error during insert: {e}") + log.exception(f'Error during insert: {e}') raise def upsert(self, collection_name: str, items: List[VectorItem]) -> None: try: if PGVECTOR_PGCRYPTO: for item in items: - vector = self.adjust_vector_length(item["vector"]) - json_metadata = json.dumps(item["metadata"]) + vector = self.adjust_vector_length(item['vector']) + json_metadata = json.dumps(item['metadata']) self.session.execute( text(""" INSERT INTO document_chunk @@ -365,47 +355,39 @@ class PgvectorClient(VectorDBBase): vmetadata = EXCLUDED.vmetadata """), { - "id": item["id"], - "vector": vector, - "collection_name": collection_name, - "text": item["text"], - "metadata_text": json_metadata, - "key": PGVECTOR_PGCRYPTO_KEY, + 'id': item['id'], + 'vector': vector, + 'collection_name': collection_name, + 'text': item['text'], + 'metadata_text': json_metadata, + 'key': PGVECTOR_PGCRYPTO_KEY, }, ) self.session.commit() log.info(f"Encrypted & upserted {len(items)} into '{collection_name}'") else: for item in items: - vector = self.adjust_vector_length(item["vector"]) - existing = ( - self.session.query(DocumentChunk) - .filter(DocumentChunk.id == item["id"]) - .first() - ) + vector = self.adjust_vector_length(item['vector']) + existing = self.session.query(DocumentChunk).filter(DocumentChunk.id == item['id']).first() if existing: existing.vector = vector - existing.text = item["text"] - existing.vmetadata = process_metadata(item["metadata"]) - existing.collection_name = ( - collection_name # Update collection_name if necessary - ) + existing.text = item['text'] + existing.vmetadata = process_metadata(item['metadata']) + existing.collection_name = collection_name # Update collection_name if necessary else: new_chunk = DocumentChunk( - id=item["id"], + id=item['id'], vector=vector, collection_name=collection_name, - text=item["text"], - vmetadata=process_metadata(item["metadata"]), + text=item['text'], + vmetadata=process_metadata(item['metadata']), ) self.session.add(new_chunk) self.session.commit() - log.info( - f"Upserted {len(items)} items into collection '{collection_name}'." - ) + log.info(f"Upserted {len(items)} items into collection '{collection_name}'.") except Exception as e: self.session.rollback() - log.exception(f"Error during upsert: {e}") + log.exception(f'Error during upsert: {e}') raise def search( @@ -427,38 +409,26 @@ class PgvectorClient(VectorDBBase): return cast(array(vector), VECTOR_TYPE_FACTORY(VECTOR_LENGTH)) # Create the values for query vectors - qid_col = column("qid", Integer) - q_vector_col = column("q_vector", VECTOR_TYPE_FACTORY(VECTOR_LENGTH)) + qid_col = column('qid', Integer) + q_vector_col = column('q_vector', VECTOR_TYPE_FACTORY(VECTOR_LENGTH)) query_vectors = ( values(qid_col, q_vector_col) - .data( - [(idx, vector_expr(vector)) for idx, vector in enumerate(vectors)] - ) - .alias("query_vectors") + .data([(idx, vector_expr(vector)) for idx, vector in enumerate(vectors)]) + .alias('query_vectors') ) result_fields = [ DocumentChunk.id, ] if PGVECTOR_PGCRYPTO: + result_fields.append(pgcrypto_decrypt(DocumentChunk.text, PGVECTOR_PGCRYPTO_KEY, Text).label('text')) result_fields.append( - pgcrypto_decrypt( - DocumentChunk.text, PGVECTOR_PGCRYPTO_KEY, Text - ).label("text") - ) - result_fields.append( - pgcrypto_decrypt( - DocumentChunk.vmetadata, PGVECTOR_PGCRYPTO_KEY, JSONB - ).label("vmetadata") + pgcrypto_decrypt(DocumentChunk.vmetadata, PGVECTOR_PGCRYPTO_KEY, JSONB).label('vmetadata') ) else: result_fields.append(DocumentChunk.text) result_fields.append(DocumentChunk.vmetadata) - result_fields.append( - (DocumentChunk.vector.cosine_distance(query_vectors.c.q_vector)).label( - "distance" - ) - ) + result_fields.append((DocumentChunk.vector.cosine_distance(query_vectors.c.q_vector)).label('distance')) # Build the lateral subquery for each query vector where_clauses = [DocumentChunk.collection_name == collection_name] @@ -466,9 +436,9 @@ class PgvectorClient(VectorDBBase): # Apply metadata filter if provided if filter: for key, value in filter.items(): - if isinstance(value, dict) and "$in" in value: + if isinstance(value, dict) and '$in' in value: # Handle $in operator: {"field": {"$in": [values]}} - in_values = value["$in"] + in_values = value['$in'] if PGVECTOR_PGCRYPTO: where_clauses.append( pgcrypto_decrypt( @@ -478,11 +448,7 @@ class PgvectorClient(VectorDBBase): )[key].astext.in_([str(v) for v in in_values]) ) else: - where_clauses.append( - DocumentChunk.vmetadata[key].astext.in_( - [str(v) for v in in_values] - ) - ) + where_clauses.append(DocumentChunk.vmetadata[key].astext.in_([str(v) for v in in_values])) else: # Handle simple equality: {"field": "value"} if PGVECTOR_PGCRYPTO: @@ -495,20 +461,16 @@ class PgvectorClient(VectorDBBase): == str(value) ) else: - where_clauses.append( - DocumentChunk.vmetadata[key].astext == str(value) - ) + where_clauses.append(DocumentChunk.vmetadata[key].astext == str(value)) subq = ( select(*result_fields) .where(*where_clauses) - .order_by( - (DocumentChunk.vector.cosine_distance(query_vectors.c.q_vector)) - ) + .order_by((DocumentChunk.vector.cosine_distance(query_vectors.c.q_vector))) ) if limit is not None: subq = subq.limit(limit) - subq = subq.lateral("result") + subq = subq.lateral('result') # Build the main query by joining query_vectors and the lateral subquery stmt = ( @@ -550,17 +512,13 @@ class PgvectorClient(VectorDBBase): metadatas[qid].append(row.vmetadata) self.session.rollback() # read-only transaction - return SearchResult( - ids=ids, distances=distances, documents=documents, metadatas=metadatas - ) + return SearchResult(ids=ids, distances=distances, documents=documents, metadatas=metadatas) except Exception as e: self.session.rollback() - log.exception(f"Error during search: {e}") + log.exception(f'Error during search: {e}') return None - def query( - self, collection_name: str, filter: Dict[str, Any], limit: Optional[int] = None - ) -> Optional[GetResult]: + def query(self, collection_name: str, filter: Dict[str, Any], limit: Optional[int] = None) -> Optional[GetResult]: try: if PGVECTOR_PGCRYPTO: # Build where clause for vmetadata filter @@ -568,32 +526,22 @@ class PgvectorClient(VectorDBBase): for key, value in filter.items(): # decrypt then check key: JSON filter after decryption where_clauses.append( - pgcrypto_decrypt( - DocumentChunk.vmetadata, PGVECTOR_PGCRYPTO_KEY, JSONB - )[key].astext + pgcrypto_decrypt(DocumentChunk.vmetadata, PGVECTOR_PGCRYPTO_KEY, JSONB)[key].astext == str(value) ) stmt = select( DocumentChunk.id, - pgcrypto_decrypt( - DocumentChunk.text, PGVECTOR_PGCRYPTO_KEY, Text - ).label("text"), - pgcrypto_decrypt( - DocumentChunk.vmetadata, PGVECTOR_PGCRYPTO_KEY, JSONB - ).label("vmetadata"), + pgcrypto_decrypt(DocumentChunk.text, PGVECTOR_PGCRYPTO_KEY, Text).label('text'), + pgcrypto_decrypt(DocumentChunk.vmetadata, PGVECTOR_PGCRYPTO_KEY, JSONB).label('vmetadata'), ).where(*where_clauses) if limit is not None: stmt = stmt.limit(limit) results = self.session.execute(stmt).all() else: - query = self.session.query(DocumentChunk).filter( - DocumentChunk.collection_name == collection_name - ) + query = self.session.query(DocumentChunk).filter(DocumentChunk.collection_name == collection_name) for key, value in filter.items(): - query = query.filter( - DocumentChunk.vmetadata[key].astext == str(value) - ) + query = query.filter(DocumentChunk.vmetadata[key].astext == str(value)) if limit is not None: query = query.limit(limit) @@ -615,22 +563,16 @@ class PgvectorClient(VectorDBBase): ) except Exception as e: self.session.rollback() - log.exception(f"Error during query: {e}") + log.exception(f'Error during query: {e}') return None - def get( - self, collection_name: str, limit: Optional[int] = None - ) -> Optional[GetResult]: + def get(self, collection_name: str, limit: Optional[int] = None) -> Optional[GetResult]: try: if PGVECTOR_PGCRYPTO: stmt = select( DocumentChunk.id, - pgcrypto_decrypt( - DocumentChunk.text, PGVECTOR_PGCRYPTO_KEY, Text - ).label("text"), - pgcrypto_decrypt( - DocumentChunk.vmetadata, PGVECTOR_PGCRYPTO_KEY, JSONB - ).label("vmetadata"), + pgcrypto_decrypt(DocumentChunk.text, PGVECTOR_PGCRYPTO_KEY, Text).label('text'), + pgcrypto_decrypt(DocumentChunk.vmetadata, PGVECTOR_PGCRYPTO_KEY, JSONB).label('vmetadata'), ).where(DocumentChunk.collection_name == collection_name) if limit is not None: stmt = stmt.limit(limit) @@ -639,10 +581,7 @@ class PgvectorClient(VectorDBBase): documents = [[row.text for row in results]] metadatas = [[row.vmetadata for row in results]] else: - - query = self.session.query(DocumentChunk).filter( - DocumentChunk.collection_name == collection_name - ) + query = self.session.query(DocumentChunk).filter(DocumentChunk.collection_name == collection_name) if limit is not None: query = query.limit(limit) @@ -659,7 +598,7 @@ class PgvectorClient(VectorDBBase): return GetResult(ids=ids, documents=documents, metadatas=metadatas) except Exception as e: self.session.rollback() - log.exception(f"Error during get: {e}") + log.exception(f'Error during get: {e}') return None def delete( @@ -676,43 +615,35 @@ class PgvectorClient(VectorDBBase): if filter: for key, value in filter.items(): wheres.append( - pgcrypto_decrypt( - DocumentChunk.vmetadata, PGVECTOR_PGCRYPTO_KEY, JSONB - )[key].astext + pgcrypto_decrypt(DocumentChunk.vmetadata, PGVECTOR_PGCRYPTO_KEY, JSONB)[key].astext == str(value) ) stmt = DocumentChunk.__table__.delete().where(*wheres) result = self.session.execute(stmt) deleted = result.rowcount else: - query = self.session.query(DocumentChunk).filter( - DocumentChunk.collection_name == collection_name - ) + query = self.session.query(DocumentChunk).filter(DocumentChunk.collection_name == collection_name) if ids: query = query.filter(DocumentChunk.id.in_(ids)) if filter: for key, value in filter.items(): - query = query.filter( - DocumentChunk.vmetadata[key].astext == str(value) - ) + query = query.filter(DocumentChunk.vmetadata[key].astext == str(value)) deleted = query.delete(synchronize_session=False) self.session.commit() log.info(f"Deleted {deleted} items from collection '{collection_name}'.") except Exception as e: self.session.rollback() - log.exception(f"Error during delete: {e}") + log.exception(f'Error during delete: {e}') raise def reset(self) -> None: try: deleted = self.session.query(DocumentChunk).delete() self.session.commit() - log.info( - f"Reset complete. Deleted {deleted} items from 'document_chunk' table." - ) + log.info(f"Reset complete. Deleted {deleted} items from 'document_chunk' table.") except Exception as e: self.session.rollback() - log.exception(f"Error during reset: {e}") + log.exception(f'Error during reset: {e}') raise def close(self) -> None: @@ -721,16 +652,14 @@ class PgvectorClient(VectorDBBase): def has_collection(self, collection_name: str) -> bool: try: exists = ( - self.session.query(DocumentChunk) - .filter(DocumentChunk.collection_name == collection_name) - .first() + self.session.query(DocumentChunk).filter(DocumentChunk.collection_name == collection_name).first() is not None ) self.session.rollback() # read-only transaction return exists except Exception as e: self.session.rollback() - log.exception(f"Error checking collection existence: {e}") + log.exception(f'Error checking collection existence: {e}') return False def delete_collection(self, collection_name: str) -> None: diff --git a/backend/open_webui/retrieval/vector/dbs/pinecone.py b/backend/open_webui/retrieval/vector/dbs/pinecone.py index 27bc50b70e..6469ac9172 100644 --- a/backend/open_webui/retrieval/vector/dbs/pinecone.py +++ b/backend/open_webui/retrieval/vector/dbs/pinecone.py @@ -45,7 +45,7 @@ log = logging.getLogger(__name__) class PineconeClient(VectorDBBase): def __init__(self): - self.collection_prefix = "open-webui" + self.collection_prefix = 'open-webui' # Validate required configuration self._validate_config() @@ -67,7 +67,7 @@ class PineconeClient(VectorDBBase): timeout=30, # Reasonable timeout for operations ) self.using_grpc = True - log.info("Using Pinecone gRPC client for optimal performance") + log.info('Using Pinecone gRPC client for optimal performance') else: # Fallback to HTTP client with enhanced connection pooling self.client = Pinecone( @@ -76,7 +76,7 @@ class PineconeClient(VectorDBBase): timeout=30, # Reasonable timeout for operations ) self.using_grpc = False - log.info("Using Pinecone HTTP client (gRPC not available)") + log.info('Using Pinecone HTTP client (gRPC not available)') # Persistent executor for batch operations self._executor = concurrent.futures.ThreadPoolExecutor(max_workers=5) @@ -88,20 +88,18 @@ class PineconeClient(VectorDBBase): """Validate that all required configuration variables are set.""" missing_vars = [] if not PINECONE_API_KEY: - missing_vars.append("PINECONE_API_KEY") + missing_vars.append('PINECONE_API_KEY') if not PINECONE_ENVIRONMENT: - missing_vars.append("PINECONE_ENVIRONMENT") + missing_vars.append('PINECONE_ENVIRONMENT') if not PINECONE_INDEX_NAME: - missing_vars.append("PINECONE_INDEX_NAME") + missing_vars.append('PINECONE_INDEX_NAME') if not PINECONE_DIMENSION: - missing_vars.append("PINECONE_DIMENSION") + missing_vars.append('PINECONE_DIMENSION') if not PINECONE_CLOUD: - missing_vars.append("PINECONE_CLOUD") + missing_vars.append('PINECONE_CLOUD') if missing_vars: - raise ValueError( - f"Required configuration missing: {', '.join(missing_vars)}" - ) + raise ValueError(f'Required configuration missing: {", ".join(missing_vars)}') def _initialize_index(self) -> None: """Initialize the Pinecone index.""" @@ -126,8 +124,8 @@ class PineconeClient(VectorDBBase): ) except Exception as e: - log.error(f"Failed to initialize Pinecone index: {e}") - raise RuntimeError(f"Failed to initialize Pinecone index: {e}") + log.error(f'Failed to initialize Pinecone index: {e}') + raise RuntimeError(f'Failed to initialize Pinecone index: {e}') def _retry_pinecone_operation(self, operation_func, max_retries=3): """Retry Pinecone operations with exponential backoff for rate limits and network issues.""" @@ -140,18 +138,18 @@ class PineconeClient(VectorDBBase): is_retryable = any( keyword in error_str for keyword in [ - "rate limit", - "quota", - "timeout", - "network", - "connection", - "unavailable", - "internal error", - "429", - "500", - "502", - "503", - "504", + 'rate limit', + 'quota', + 'timeout', + 'network', + 'connection', + 'unavailable', + 'internal error', + '429', + '500', + '502', + '503', + '504', ] ) @@ -162,45 +160,42 @@ class PineconeClient(VectorDBBase): # Exponential backoff with jitter delay = (2**attempt) + random.uniform(0, 1) log.warning( - f"Pinecone operation failed (attempt {attempt + 1}/{max_retries}), " - f"retrying in {delay:.2f}s: {e}" + f'Pinecone operation failed (attempt {attempt + 1}/{max_retries}), retrying in {delay:.2f}s: {e}' ) time.sleep(delay) - def _create_points( - self, items: List[VectorItem], collection_name_with_prefix: str - ) -> List[Dict[str, Any]]: + def _create_points(self, items: List[VectorItem], collection_name_with_prefix: str) -> List[Dict[str, Any]]: """Convert VectorItem objects to Pinecone point format.""" points = [] for item in items: # Start with any existing metadata or an empty dict - metadata = item.get("metadata", {}).copy() if item.get("metadata") else {} + metadata = item.get('metadata', {}).copy() if item.get('metadata') else {} # Add text to metadata if available - if "text" in item: - metadata["text"] = item["text"] + if 'text' in item: + metadata['text'] = item['text'] # Always add collection_name to metadata for filtering - metadata["collection_name"] = collection_name_with_prefix + metadata['collection_name'] = collection_name_with_prefix point = { - "id": item["id"], - "values": item["vector"], - "metadata": process_metadata(metadata), + 'id': item['id'], + 'values': item['vector'], + 'metadata': process_metadata(metadata), } points.append(point) return points def _get_collection_name_with_prefix(self, collection_name: str) -> str: """Get the collection name with prefix.""" - return f"{self.collection_prefix}_{collection_name}" + return f'{self.collection_prefix}_{collection_name}' def _normalize_distance(self, score: float) -> float: """Normalize distance score based on the metric used.""" - if self.metric.lower() == "cosine": + if self.metric.lower() == 'cosine': # Cosine similarity ranges from -1 to 1, normalize to 0 to 1 return (score + 1.0) / 2.0 - elif self.metric.lower() in ["euclidean", "dotproduct"]: + elif self.metric.lower() in ['euclidean', 'dotproduct']: # These are already suitable for ranking (smaller is better for Euclidean) return score else: @@ -214,68 +209,56 @@ class PineconeClient(VectorDBBase): metadatas = [] for match in matches: - metadata = getattr(match, "metadata", {}) or {} - ids.append(match.id if hasattr(match, "id") else match["id"]) - documents.append(metadata.get("text", "")) + metadata = getattr(match, 'metadata', {}) or {} + ids.append(match.id if hasattr(match, 'id') else match['id']) + documents.append(metadata.get('text', '')) metadatas.append(metadata) return GetResult( **{ - "ids": [ids], - "documents": [documents], - "metadatas": [metadatas], + 'ids': [ids], + 'documents': [documents], + 'metadatas': [metadatas], } ) def has_collection(self, collection_name: str) -> bool: """Check if a collection exists by searching for at least one item.""" - collection_name_with_prefix = self._get_collection_name_with_prefix( - collection_name - ) + collection_name_with_prefix = self._get_collection_name_with_prefix(collection_name) try: # Search for at least 1 item with this collection name in metadata response = self.index.query( vector=[0.0] * self.dimension, # dummy vector top_k=1, - filter={"collection_name": collection_name_with_prefix}, + filter={'collection_name': collection_name_with_prefix}, include_metadata=False, ) - matches = getattr(response, "matches", []) or [] + matches = getattr(response, 'matches', []) or [] return len(matches) > 0 except Exception as e: - log.exception( - f"Error checking collection '{collection_name_with_prefix}': {e}" - ) + log.exception(f"Error checking collection '{collection_name_with_prefix}': {e}") return False def delete_collection(self, collection_name: str) -> None: """Delete a collection by removing all vectors with the collection name in metadata.""" - collection_name_with_prefix = self._get_collection_name_with_prefix( - collection_name - ) + collection_name_with_prefix = self._get_collection_name_with_prefix(collection_name) try: - self.index.delete(filter={"collection_name": collection_name_with_prefix}) - log.info( - f"Collection '{collection_name_with_prefix}' deleted (all vectors removed)." - ) + self.index.delete(filter={'collection_name': collection_name_with_prefix}) + log.info(f"Collection '{collection_name_with_prefix}' deleted (all vectors removed).") except Exception as e: - log.warning( - f"Failed to delete collection '{collection_name_with_prefix}': {e}" - ) + log.warning(f"Failed to delete collection '{collection_name_with_prefix}': {e}") raise def insert(self, collection_name: str, items: List[VectorItem]) -> None: """Insert vectors into a collection.""" if not items: - log.warning("No items to insert") + log.warning('No items to insert') return start_time = time.time() - collection_name_with_prefix = self._get_collection_name_with_prefix( - collection_name - ) + collection_name_with_prefix = self._get_collection_name_with_prefix(collection_name) points = self._create_points(items, collection_name_with_prefix) # Parallelize batch inserts for performance @@ -288,26 +271,23 @@ class PineconeClient(VectorDBBase): try: future.result() except Exception as e: - log.error(f"Error inserting batch: {e}") + log.error(f'Error inserting batch: {e}') raise elapsed = time.time() - start_time - log.debug(f"Insert of {len(points)} vectors took {elapsed:.2f} seconds") + log.debug(f'Insert of {len(points)} vectors took {elapsed:.2f} seconds') log.info( - f"Successfully inserted {len(points)} vectors in parallel batches " - f"into '{collection_name_with_prefix}'" + f"Successfully inserted {len(points)} vectors in parallel batches into '{collection_name_with_prefix}'" ) def upsert(self, collection_name: str, items: List[VectorItem]) -> None: """Upsert (insert or update) vectors into a collection.""" if not items: - log.warning("No items to upsert") + log.warning('No items to upsert') return start_time = time.time() - collection_name_with_prefix = self._get_collection_name_with_prefix( - collection_name - ) + collection_name_with_prefix = self._get_collection_name_with_prefix(collection_name) points = self._create_points(items, collection_name_with_prefix) # Parallelize batch upserts for performance @@ -320,78 +300,53 @@ class PineconeClient(VectorDBBase): try: future.result() except Exception as e: - log.error(f"Error upserting batch: {e}") + log.error(f'Error upserting batch: {e}') raise elapsed = time.time() - start_time - log.debug(f"Upsert of {len(points)} vectors took {elapsed:.2f} seconds") + log.debug(f'Upsert of {len(points)} vectors took {elapsed:.2f} seconds') log.info( - f"Successfully upserted {len(points)} vectors in parallel batches " - f"into '{collection_name_with_prefix}'" + f"Successfully upserted {len(points)} vectors in parallel batches into '{collection_name_with_prefix}'" ) async def insert_async(self, collection_name: str, items: List[VectorItem]) -> None: """Async version of insert using asyncio and run_in_executor for improved performance.""" if not items: - log.warning("No items to insert") + log.warning('No items to insert') return - collection_name_with_prefix = self._get_collection_name_with_prefix( - collection_name - ) + collection_name_with_prefix = self._get_collection_name_with_prefix(collection_name) points = self._create_points(items, collection_name_with_prefix) # Create batches - batches = [ - points[i : i + BATCH_SIZE] for i in range(0, len(points), BATCH_SIZE) - ] + batches = [points[i : i + BATCH_SIZE] for i in range(0, len(points), BATCH_SIZE)] loop = asyncio.get_event_loop() - tasks = [ - loop.run_in_executor( - None, functools.partial(self.index.upsert, vectors=batch) - ) - for batch in batches - ] + tasks = [loop.run_in_executor(None, functools.partial(self.index.upsert, vectors=batch)) for batch in batches] results = await asyncio.gather(*tasks, return_exceptions=True) for result in results: if isinstance(result, Exception): - log.error(f"Error in async insert batch: {result}") + log.error(f'Error in async insert batch: {result}') raise result - log.info( - f"Successfully async inserted {len(points)} vectors in batches " - f"into '{collection_name_with_prefix}'" - ) + log.info(f"Successfully async inserted {len(points)} vectors in batches into '{collection_name_with_prefix}'") async def upsert_async(self, collection_name: str, items: List[VectorItem]) -> None: """Async version of upsert using asyncio and run_in_executor for improved performance.""" if not items: - log.warning("No items to upsert") + log.warning('No items to upsert') return - collection_name_with_prefix = self._get_collection_name_with_prefix( - collection_name - ) + collection_name_with_prefix = self._get_collection_name_with_prefix(collection_name) points = self._create_points(items, collection_name_with_prefix) # Create batches - batches = [ - points[i : i + BATCH_SIZE] for i in range(0, len(points), BATCH_SIZE) - ] + batches = [points[i : i + BATCH_SIZE] for i in range(0, len(points), BATCH_SIZE)] loop = asyncio.get_event_loop() - tasks = [ - loop.run_in_executor( - None, functools.partial(self.index.upsert, vectors=batch) - ) - for batch in batches - ] + tasks = [loop.run_in_executor(None, functools.partial(self.index.upsert, vectors=batch)) for batch in batches] results = await asyncio.gather(*tasks, return_exceptions=True) for result in results: if isinstance(result, Exception): - log.error(f"Error in async upsert batch: {result}") + log.error(f'Error in async upsert batch: {result}') raise result - log.info( - f"Successfully async upserted {len(points)} vectors in batches " - f"into '{collection_name_with_prefix}'" - ) + log.info(f"Successfully async upserted {len(points)} vectors in batches into '{collection_name_with_prefix}'") def search( self, @@ -402,12 +357,10 @@ class PineconeClient(VectorDBBase): ) -> Optional[SearchResult]: """Search for similar vectors in a collection.""" if not vectors or not vectors[0]: - log.warning("No vectors provided for search") + log.warning('No vectors provided for search') return None - collection_name_with_prefix = self._get_collection_name_with_prefix( - collection_name - ) + collection_name_with_prefix = self._get_collection_name_with_prefix(collection_name) if limit is None or limit <= 0: limit = NO_LIMIT @@ -421,10 +374,10 @@ class PineconeClient(VectorDBBase): vector=query_vector, top_k=limit, include_metadata=True, - filter={"collection_name": collection_name_with_prefix}, + filter={'collection_name': collection_name_with_prefix}, ) - matches = getattr(query_response, "matches", []) or [] + matches = getattr(query_response, 'matches', []) or [] if not matches: # Return empty result if no matches return SearchResult( @@ -438,12 +391,7 @@ class PineconeClient(VectorDBBase): get_result = self._result_to_get_result(matches) # Calculate normalized distances based on metric - distances = [ - [ - self._normalize_distance(getattr(match, "score", 0.0)) - for match in matches - ] - ] + distances = [[self._normalize_distance(getattr(match, 'score', 0.0)) for match in matches]] return SearchResult( ids=get_result.ids, @@ -455,13 +403,9 @@ class PineconeClient(VectorDBBase): log.error(f"Error searching in '{collection_name_with_prefix}': {e}") return None - def query( - self, collection_name: str, filter: Dict, limit: Optional[int] = None - ) -> Optional[GetResult]: + def query(self, collection_name: str, filter: Dict, limit: Optional[int] = None) -> Optional[GetResult]: """Query vectors by metadata filter.""" - collection_name_with_prefix = self._get_collection_name_with_prefix( - collection_name - ) + collection_name_with_prefix = self._get_collection_name_with_prefix(collection_name) if limit is None or limit <= 0: limit = NO_LIMIT @@ -471,7 +415,7 @@ class PineconeClient(VectorDBBase): zero_vector = [0.0] * self.dimension # Combine user filter with collection_name - pinecone_filter = {"collection_name": collection_name_with_prefix} + pinecone_filter = {'collection_name': collection_name_with_prefix} if filter: pinecone_filter.update(filter) @@ -483,7 +427,7 @@ class PineconeClient(VectorDBBase): include_metadata=True, ) - matches = getattr(query_response, "matches", []) or [] + matches = getattr(query_response, 'matches', []) or [] return self._result_to_get_result(matches) except Exception as e: @@ -492,9 +436,7 @@ class PineconeClient(VectorDBBase): def get(self, collection_name: str) -> Optional[GetResult]: """Get all vectors in a collection.""" - collection_name_with_prefix = self._get_collection_name_with_prefix( - collection_name - ) + collection_name_with_prefix = self._get_collection_name_with_prefix(collection_name) try: # Use a zero vector for fetching all entries @@ -505,10 +447,10 @@ class PineconeClient(VectorDBBase): vector=zero_vector, top_k=NO_LIMIT, include_metadata=True, - filter={"collection_name": collection_name_with_prefix}, + filter={'collection_name': collection_name_with_prefix}, ) - matches = getattr(query_response, "matches", []) or [] + matches = getattr(query_response, 'matches', []) or [] return self._result_to_get_result(matches) except Exception as e: @@ -522,9 +464,7 @@ class PineconeClient(VectorDBBase): filter: Optional[Dict] = None, ) -> None: """Delete vectors by IDs or filter.""" - collection_name_with_prefix = self._get_collection_name_with_prefix( - collection_name - ) + collection_name_with_prefix = self._get_collection_name_with_prefix(collection_name) try: if ids: @@ -534,28 +474,20 @@ class PineconeClient(VectorDBBase): # Note: When deleting by ID, we can't filter by collection_name # This is a limitation of Pinecone - be careful with ID uniqueness self.index.delete(ids=batch_ids) - log.debug( - f"Deleted batch of {len(batch_ids)} vectors by ID " - f"from '{collection_name_with_prefix}'" - ) - log.info( - f"Successfully deleted {len(ids)} vectors by ID " - f"from '{collection_name_with_prefix}'" - ) + log.debug(f"Deleted batch of {len(batch_ids)} vectors by ID from '{collection_name_with_prefix}'") + log.info(f"Successfully deleted {len(ids)} vectors by ID from '{collection_name_with_prefix}'") elif filter: # Combine user filter with collection_name - pinecone_filter = {"collection_name": collection_name_with_prefix} + pinecone_filter = {'collection_name': collection_name_with_prefix} if filter: pinecone_filter.update(filter) # Delete by metadata filter self.index.delete(filter=pinecone_filter) - log.info( - f"Successfully deleted vectors by filter from '{collection_name_with_prefix}'" - ) + log.info(f"Successfully deleted vectors by filter from '{collection_name_with_prefix}'") else: - log.warning("No ids or filter provided for delete operation") + log.warning('No ids or filter provided for delete operation') except Exception as e: log.error(f"Error deleting from collection '{collection_name}': {e}") @@ -565,9 +497,9 @@ class PineconeClient(VectorDBBase): """Reset the database by deleting all collections.""" try: self.index.delete(delete_all=True) - log.info("All vectors successfully deleted from the index.") + log.info('All vectors successfully deleted from the index.') except Exception as e: - log.error(f"Failed to reset Pinecone index: {e}") + log.error(f'Failed to reset Pinecone index: {e}') raise def close(self): @@ -576,7 +508,7 @@ class PineconeClient(VectorDBBase): # The new Pinecone client doesn't need explicit closing pass except Exception as e: - log.warning(f"Failed to clean up Pinecone resources: {e}") + log.warning(f'Failed to clean up Pinecone resources: {e}') self._executor.shutdown(wait=True) def __enter__(self): diff --git a/backend/open_webui/retrieval/vector/dbs/qdrant.py b/backend/open_webui/retrieval/vector/dbs/qdrant.py index e774bb875f..f050bebeb5 100644 --- a/backend/open_webui/retrieval/vector/dbs/qdrant.py +++ b/backend/open_webui/retrieval/vector/dbs/qdrant.py @@ -76,19 +76,19 @@ class QdrantClient(VectorDBBase): for point in points: payload = point.payload ids.append(point.id) - documents.append(payload["text"]) - metadatas.append(payload["metadata"]) + documents.append(payload['text']) + metadatas.append(payload['metadata']) return GetResult( **{ - "ids": [ids], - "documents": [documents], - "metadatas": [metadatas], + 'ids': [ids], + 'documents': [documents], + 'metadatas': [metadatas], } ) def _create_collection(self, collection_name: str, dimension: int): - collection_name_with_prefix = f"{self.collection_prefix}_{collection_name}" + collection_name_with_prefix = f'{self.collection_prefix}_{collection_name}' self.client.create_collection( collection_name=collection_name_with_prefix, vectors_config=models.VectorParams( @@ -104,7 +104,7 @@ class QdrantClient(VectorDBBase): # Create payload indexes for efficient filtering self.client.create_payload_index( collection_name=collection_name_with_prefix, - field_name="metadata.hash", + field_name='metadata.hash', field_schema=models.KeywordIndexParams( type=models.KeywordIndexType.KEYWORD, is_tenant=False, @@ -113,40 +113,34 @@ class QdrantClient(VectorDBBase): ) self.client.create_payload_index( collection_name=collection_name_with_prefix, - field_name="metadata.file_id", + field_name='metadata.file_id', field_schema=models.KeywordIndexParams( type=models.KeywordIndexType.KEYWORD, is_tenant=False, on_disk=self.QDRANT_ON_DISK, ), ) - log.info(f"collection {collection_name_with_prefix} successfully created!") + log.info(f'collection {collection_name_with_prefix} successfully created!') def _create_collection_if_not_exists(self, collection_name, dimension): if not self.has_collection(collection_name=collection_name): - self._create_collection( - collection_name=collection_name, dimension=dimension - ) + self._create_collection(collection_name=collection_name, dimension=dimension) def _create_points(self, items: list[VectorItem]): return [ PointStruct( - id=item["id"], - vector=item["vector"], - payload={"text": item["text"], "metadata": item["metadata"]}, + id=item['id'], + vector=item['vector'], + payload={'text': item['text'], 'metadata': item['metadata']}, ) for item in items ] def has_collection(self, collection_name: str) -> bool: - return self.client.collection_exists( - f"{self.collection_prefix}_{collection_name}" - ) + return self.client.collection_exists(f'{self.collection_prefix}_{collection_name}') def delete_collection(self, collection_name: str): - return self.client.delete_collection( - collection_name=f"{self.collection_prefix}_{collection_name}" - ) + return self.client.delete_collection(collection_name=f'{self.collection_prefix}_{collection_name}') def search( self, @@ -160,7 +154,7 @@ class QdrantClient(VectorDBBase): limit = NO_LIMIT # otherwise qdrant would set limit to 10! query_response = self.client.query_points( - collection_name=f"{self.collection_prefix}_{collection_name}", + collection_name=f'{self.collection_prefix}_{collection_name}', query=vectors[0], limit=limit, ) @@ -184,13 +178,11 @@ class QdrantClient(VectorDBBase): field_conditions = [] for key, value in filter.items(): field_conditions.append( - models.FieldCondition( - key=f"metadata.{key}", match=models.MatchValue(value=value) - ) + models.FieldCondition(key=f'metadata.{key}', match=models.MatchValue(value=value)) ) points = self.client.scroll( - collection_name=f"{self.collection_prefix}_{collection_name}", + collection_name=f'{self.collection_prefix}_{collection_name}', scroll_filter=models.Filter(should=field_conditions), limit=limit, ) @@ -202,22 +194,22 @@ class QdrantClient(VectorDBBase): def get(self, collection_name: str) -> Optional[GetResult]: # Get all the items in the collection. points = self.client.scroll( - collection_name=f"{self.collection_prefix}_{collection_name}", + collection_name=f'{self.collection_prefix}_{collection_name}', limit=NO_LIMIT, # otherwise qdrant would set limit to 10! ) return self._result_to_get_result(points[0]) def insert(self, collection_name: str, items: list[VectorItem]): # Insert the items into the collection, if the collection does not exist, it will be created. - self._create_collection_if_not_exists(collection_name, len(items[0]["vector"])) + self._create_collection_if_not_exists(collection_name, len(items[0]['vector'])) points = self._create_points(items) - self.client.upload_points(f"{self.collection_prefix}_{collection_name}", points) + self.client.upload_points(f'{self.collection_prefix}_{collection_name}', points) def upsert(self, collection_name: str, items: list[VectorItem]): # Update the items in the collection, if the items are not present, insert them. If the collection does not exist, it will be created. - self._create_collection_if_not_exists(collection_name, len(items[0]["vector"])) + self._create_collection_if_not_exists(collection_name, len(items[0]['vector'])) points = self._create_points(items) - return self.client.upsert(f"{self.collection_prefix}_{collection_name}", points) + return self.client.upsert(f'{self.collection_prefix}_{collection_name}', points) def delete( self, @@ -230,26 +222,28 @@ class QdrantClient(VectorDBBase): if ids: for id_value in ids: - field_conditions.append( - models.FieldCondition( - key="metadata.id", - match=models.MatchValue(value=id_value), + ( + field_conditions.append( + models.FieldCondition( + key='metadata.id', + match=models.MatchValue(value=id_value), + ), ), - ), + ) elif filter: for key, value in filter.items(): - field_conditions.append( - models.FieldCondition( - key=f"metadata.{key}", - match=models.MatchValue(value=value), + ( + field_conditions.append( + models.FieldCondition( + key=f'metadata.{key}', + match=models.MatchValue(value=value), + ), ), - ), + ) return self.client.delete( - collection_name=f"{self.collection_prefix}_{collection_name}", - points_selector=models.FilterSelector( - filter=models.Filter(must=field_conditions) - ), + collection_name=f'{self.collection_prefix}_{collection_name}', + points_selector=models.FilterSelector(filter=models.Filter(must=field_conditions)), ) def reset(self): diff --git a/backend/open_webui/retrieval/vector/dbs/qdrant_multitenancy.py b/backend/open_webui/retrieval/vector/dbs/qdrant_multitenancy.py index 5ad2ac6929..c3c2ba41d0 100644 --- a/backend/open_webui/retrieval/vector/dbs/qdrant_multitenancy.py +++ b/backend/open_webui/retrieval/vector/dbs/qdrant_multitenancy.py @@ -29,22 +29,18 @@ from qdrant_client.http.models import PointStruct from qdrant_client.models import models NO_LIMIT = 999999999 -TENANT_ID_FIELD = "tenant_id" +TENANT_ID_FIELD = 'tenant_id' DEFAULT_DIMENSION = 384 log = logging.getLogger(__name__) def _tenant_filter(tenant_id: str) -> models.FieldCondition: - return models.FieldCondition( - key=TENANT_ID_FIELD, match=models.MatchValue(value=tenant_id) - ) + return models.FieldCondition(key=TENANT_ID_FIELD, match=models.MatchValue(value=tenant_id)) def _metadata_filter(key: str, value: Any) -> models.FieldCondition: - return models.FieldCondition( - key=f"metadata.{key}", match=models.MatchValue(value=value) - ) + return models.FieldCondition(key=f'metadata.{key}', match=models.MatchValue(value=value)) class QdrantClient(VectorDBBase): @@ -59,9 +55,7 @@ class QdrantClient(VectorDBBase): self.QDRANT_HNSW_M = QDRANT_HNSW_M if not self.QDRANT_URI: - raise ValueError( - "QDRANT_URI is not set. Please configure it in the environment variables." - ) + raise ValueError('QDRANT_URI is not set. Please configure it in the environment variables.') # Unified handling for either scheme parsed = urlparse(self.QDRANT_URI) @@ -86,19 +80,19 @@ class QdrantClient(VectorDBBase): ) # Main collection types for multi-tenancy - self.MEMORY_COLLECTION = f"{self.collection_prefix}_memories" - self.KNOWLEDGE_COLLECTION = f"{self.collection_prefix}_knowledge" - self.FILE_COLLECTION = f"{self.collection_prefix}_files" - self.WEB_SEARCH_COLLECTION = f"{self.collection_prefix}_web-search" - self.HASH_BASED_COLLECTION = f"{self.collection_prefix}_hash-based" + self.MEMORY_COLLECTION = f'{self.collection_prefix}_memories' + self.KNOWLEDGE_COLLECTION = f'{self.collection_prefix}_knowledge' + self.FILE_COLLECTION = f'{self.collection_prefix}_files' + self.WEB_SEARCH_COLLECTION = f'{self.collection_prefix}_web-search' + self.HASH_BASED_COLLECTION = f'{self.collection_prefix}_hash-based' def _result_to_get_result(self, points) -> GetResult: ids, documents, metadatas = [], [], [] for point in points: payload = point.payload ids.append(point.id) - documents.append(payload["text"]) - metadatas.append(payload["metadata"]) + documents.append(payload['text']) + metadatas.append(payload['metadata']) return GetResult(ids=[ids], documents=[documents], metadatas=[metadatas]) def _get_collection_and_tenant_id(self, collection_name: str) -> Tuple[str, str]: @@ -118,29 +112,25 @@ class QdrantClient(VectorDBBase): # Check for user memory collections tenant_id = collection_name - if collection_name.startswith("user-memory-"): + if collection_name.startswith('user-memory-'): return self.MEMORY_COLLECTION, tenant_id # Check for file collections - elif collection_name.startswith("file-"): + elif collection_name.startswith('file-'): return self.FILE_COLLECTION, tenant_id # Check for web search collections - elif collection_name.startswith("web-search-"): + elif collection_name.startswith('web-search-'): return self.WEB_SEARCH_COLLECTION, tenant_id # Handle hash-based collections (YouTube and web URLs) - elif len(collection_name) == 63 and all( - c in "0123456789abcdef" for c in collection_name - ): + elif len(collection_name) == 63 and all(c in '0123456789abcdef' for c in collection_name): return self.HASH_BASED_COLLECTION, tenant_id else: return self.KNOWLEDGE_COLLECTION, tenant_id - def _create_multi_tenant_collection( - self, mt_collection_name: str, dimension: int = DEFAULT_DIMENSION - ): + def _create_multi_tenant_collection(self, mt_collection_name: str, dimension: int = DEFAULT_DIMENSION): """ Creates a collection with multi-tenancy configuration and payload indexes for tenant_id and metadata fields. """ @@ -158,9 +148,7 @@ class QdrantClient(VectorDBBase): m=0, ), ) - log.info( - f"Multi-tenant collection {mt_collection_name} created with dimension {dimension}!" - ) + log.info(f'Multi-tenant collection {mt_collection_name} created with dimension {dimension}!') self.client.create_payload_index( collection_name=mt_collection_name, @@ -172,7 +160,7 @@ class QdrantClient(VectorDBBase): ), ) - for field in ("metadata.hash", "metadata.file_id"): + for field in ('metadata.hash', 'metadata.file_id'): self.client.create_payload_index( collection_name=mt_collection_name, field_name=field, @@ -182,28 +170,24 @@ class QdrantClient(VectorDBBase): ), ) - def _create_points( - self, items: List[VectorItem], tenant_id: str - ) -> List[PointStruct]: + def _create_points(self, items: List[VectorItem], tenant_id: str) -> List[PointStruct]: """ Create point structs from vector items with tenant ID. """ return [ PointStruct( - id=item["id"], - vector=item["vector"], + id=item['id'], + vector=item['vector'], payload={ - "text": item["text"], - "metadata": item["metadata"], + 'text': item['text'], + 'metadata': item['metadata'], TENANT_ID_FIELD: tenant_id, }, ) for item in items ] - def _ensure_collection( - self, mt_collection_name: str, dimension: int = DEFAULT_DIMENSION - ): + def _ensure_collection(self, mt_collection_name: str, dimension: int = DEFAULT_DIMENSION): """ Ensure the collection exists and payload indexes are created for tenant_id and metadata fields. """ @@ -246,15 +230,13 @@ class QdrantClient(VectorDBBase): must_conditions = [_tenant_filter(tenant_id)] should_conditions = [] if ids: - should_conditions = [_metadata_filter("id", id_value) for id_value in ids] + should_conditions = [_metadata_filter('id', id_value) for id_value in ids] elif filter: must_conditions += [_metadata_filter(k, v) for k, v in filter.items()] return self.client.delete( collection_name=mt_collection, - points_selector=models.FilterSelector( - filter=models.Filter(must=must_conditions, should=should_conditions) - ), + points_selector=models.FilterSelector(filter=models.Filter(must=must_conditions, should=should_conditions)), ) def search( @@ -289,9 +271,7 @@ class QdrantClient(VectorDBBase): distances=[[(point.score + 1.0) / 2.0 for point in query_response.points]], ) - def query( - self, collection_name: str, filter: Dict[str, Any], limit: Optional[int] = None - ): + def query(self, collection_name: str, filter: Dict[str, Any], limit: Optional[int] = None): """ Query points with filters and tenant isolation. """ @@ -338,7 +318,7 @@ class QdrantClient(VectorDBBase): if not self.client or not items: return None mt_collection, tenant_id = self._get_collection_and_tenant_id(collection_name) - dimension = len(items[0]["vector"]) + dimension = len(items[0]['vector']) self._ensure_collection(mt_collection, dimension) points = self._create_points(items, tenant_id) self.client.upload_points(mt_collection, points) @@ -372,7 +352,5 @@ class QdrantClient(VectorDBBase): return None self.client.delete( collection_name=mt_collection, - points_selector=models.FilterSelector( - filter=models.Filter(must=[_tenant_filter(tenant_id)]) - ), + points_selector=models.FilterSelector(filter=models.Filter(must=[_tenant_filter(tenant_id)])), ) diff --git a/backend/open_webui/retrieval/vector/dbs/s3vector.py b/backend/open_webui/retrieval/vector/dbs/s3vector.py index 1a30e04e55..8877d206e6 100644 --- a/backend/open_webui/retrieval/vector/dbs/s3vector.py +++ b/backend/open_webui/retrieval/vector/dbs/s3vector.py @@ -28,18 +28,16 @@ class S3VectorClient(VectorDBBase): # Simple validation - log warnings instead of raising exceptions if not self.bucket_name: - log.warning("S3_VECTOR_BUCKET_NAME not set - S3Vector will not work") + log.warning('S3_VECTOR_BUCKET_NAME not set - S3Vector will not work') if not self.region: - log.warning("S3_VECTOR_REGION not set - S3Vector will not work") + log.warning('S3_VECTOR_REGION not set - S3Vector will not work') if self.bucket_name and self.region: try: - self.client = boto3.client("s3vectors", region_name=self.region) - log.info( - f"S3Vector client initialized for bucket '{self.bucket_name}' in region '{self.region}'" - ) + self.client = boto3.client('s3vectors', region_name=self.region) + log.info(f"S3Vector client initialized for bucket '{self.bucket_name}' in region '{self.region}'") except Exception as e: - log.error(f"Failed to initialize S3Vector client: {e}") + log.error(f'Failed to initialize S3Vector client: {e}') self.client = None else: self.client = None @@ -48,8 +46,8 @@ class S3VectorClient(VectorDBBase): self, index_name: str, dimension: int, - data_type: str = "float32", - distance_metric: str = "cosine", + data_type: str = 'float32', + distance_metric: str = 'cosine', ) -> None: """ Create a new index in the S3 vector bucket for the given collection if it does not exist. @@ -66,21 +64,17 @@ class S3VectorClient(VectorDBBase): dimension=dimension, distanceMetric=distance_metric, metadataConfiguration={ - "nonFilterableMetadataKeys": [ - "text", + 'nonFilterableMetadataKeys': [ + 'text', ] }, ) - log.info( - f"Created S3 index: {index_name} (dim={dimension}, type={data_type}, metric={distance_metric})" - ) + log.info(f'Created S3 index: {index_name} (dim={dimension}, type={data_type}, metric={distance_metric})') except Exception as e: log.error(f"Error creating S3 index '{index_name}': {e}") raise - def _filter_metadata( - self, metadata: Dict[str, Any], item_id: str - ) -> Dict[str, Any]: + def _filter_metadata(self, metadata: Dict[str, Any], item_id: str) -> Dict[str, Any]: """ Filter vector metadata keys to comply with S3 Vector API limit of 10 keys maximum. """ @@ -89,16 +83,16 @@ class S3VectorClient(VectorDBBase): # Keep only the first 10 keys, prioritizing important ones based on actual Open WebUI metadata important_keys = [ - "text", # The actual document content - "file_id", # File ID - "source", # Document source file - "title", # Document title - "page", # Page number - "total_pages", # Total pages in document - "embedding_config", # Embedding configuration - "created_by", # User who created it - "name", # Document name - "hash", # Content hash + 'text', # The actual document content + 'file_id', # File ID + 'source', # Document source file + 'title', # Document title + 'page', # Page number + 'total_pages', # Total pages in document + 'embedding_config', # Embedding configuration + 'created_by', # User who created it + 'name', # Document name + 'hash', # Content hash ] filtered_metadata = {} @@ -117,9 +111,7 @@ class S3VectorClient(VectorDBBase): if len(filtered_metadata) >= 10: break - log.warning( - f"Metadata for key '{item_id}' had {len(metadata)} keys, limited to 10 keys" - ) + log.warning(f"Metadata for key '{item_id}' had {len(metadata)} keys, limited to 10 keys") return filtered_metadata def has_collection(self, collection_name: str) -> bool: @@ -128,9 +120,7 @@ class S3VectorClient(VectorDBBase): This avoids pagination issues with list_indexes() and is significantly faster. """ try: - self.client.get_index( - vectorBucketName=self.bucket_name, indexName=collection_name - ) + self.client.get_index(vectorBucketName=self.bucket_name, indexName=collection_name) return True except Exception as e: log.error(f"Error checking if index '{collection_name}' exists: {e}") @@ -142,16 +132,12 @@ class S3VectorClient(VectorDBBase): """ if not self.has_collection(collection_name): - log.warning( - f"Collection '{collection_name}' does not exist, nothing to delete" - ) + log.warning(f"Collection '{collection_name}' does not exist, nothing to delete") return try: log.info(f"Deleting collection '{collection_name}'") - self.client.delete_index( - vectorBucketName=self.bucket_name, indexName=collection_name - ) + self.client.delete_index(vectorBucketName=self.bucket_name, indexName=collection_name) log.info(f"Successfully deleted collection '{collection_name}'") except Exception as e: log.error(f"Error deleting collection '{collection_name}': {e}") @@ -162,10 +148,10 @@ class S3VectorClient(VectorDBBase): Insert vector items into the S3 Vector index. Create index if it does not exist. """ if not items: - log.warning("No items to insert") + log.warning('No items to insert') return - dimension = len(items[0]["vector"]) + dimension = len(items[0]['vector']) try: if not self.has_collection(collection_name): @@ -173,36 +159,36 @@ class S3VectorClient(VectorDBBase): self._create_index( index_name=collection_name, dimension=dimension, - data_type="float32", - distance_metric="cosine", + data_type='float32', + distance_metric='cosine', ) # Prepare vectors for insertion vectors = [] for item in items: # Ensure vector data is in the correct format for S3 Vector API - vector_data = item["vector"] + vector_data = item['vector'] if isinstance(vector_data, list): # Convert list to float32 values as required by S3 Vector API vector_data = [float(x) for x in vector_data] # Prepare metadata, ensuring the text field is preserved - metadata = item.get("metadata", {}).copy() + metadata = item.get('metadata', {}).copy() # Add the text field to metadata so it's available for retrieval - metadata["text"] = item["text"] + metadata['text'] = item['text'] # Convert metadata to string format for consistency metadata = process_metadata(metadata) # Filter metadata to comply with S3 Vector API limit of 10 keys - metadata = self._filter_metadata(metadata, item["id"]) + metadata = self._filter_metadata(metadata, item['id']) vectors.append( { - "key": item["id"], - "data": {"float32": vector_data}, - "metadata": metadata, + 'key': item['id'], + 'data': {'float32': vector_data}, + 'metadata': metadata, } ) @@ -215,15 +201,11 @@ class S3VectorClient(VectorDBBase): indexName=collection_name, vectors=batch, ) - log.info( - f"Inserted batch {i//batch_size + 1}: {len(batch)} vectors into index '{collection_name}'." - ) + log.info(f"Inserted batch {i // batch_size + 1}: {len(batch)} vectors into index '{collection_name}'.") - log.info( - f"Completed insertion of {len(vectors)} vectors into index '{collection_name}'." - ) + log.info(f"Completed insertion of {len(vectors)} vectors into index '{collection_name}'.") except Exception as e: - log.error(f"Error inserting vectors: {e}") + log.error(f'Error inserting vectors: {e}') raise def upsert(self, collection_name: str, items: List[VectorItem]) -> None: @@ -231,49 +213,47 @@ class S3VectorClient(VectorDBBase): Insert or update vector items in the S3 Vector index. Create index if it does not exist. """ if not items: - log.warning("No items to upsert") + log.warning('No items to upsert') return - dimension = len(items[0]["vector"]) - log.info(f"Upsert dimension: {dimension}") + dimension = len(items[0]['vector']) + log.info(f'Upsert dimension: {dimension}') try: if not self.has_collection(collection_name): - log.info( - f"Index '{collection_name}' does not exist. Creating index for upsert." - ) + log.info(f"Index '{collection_name}' does not exist. Creating index for upsert.") self._create_index( index_name=collection_name, dimension=dimension, - data_type="float32", - distance_metric="cosine", + data_type='float32', + distance_metric='cosine', ) # Prepare vectors for upsert vectors = [] for item in items: # Ensure vector data is in the correct format for S3 Vector API - vector_data = item["vector"] + vector_data = item['vector'] if isinstance(vector_data, list): # Convert list to float32 values as required by S3 Vector API vector_data = [float(x) for x in vector_data] # Prepare metadata, ensuring the text field is preserved - metadata = item.get("metadata", {}).copy() + metadata = item.get('metadata', {}).copy() # Add the text field to metadata so it's available for retrieval - metadata["text"] = item["text"] + metadata['text'] = item['text'] # Convert metadata to string format for consistency metadata = process_metadata(metadata) # Filter metadata to comply with S3 Vector API limit of 10 keys - metadata = self._filter_metadata(metadata, item["id"]) + metadata = self._filter_metadata(metadata, item['id']) vectors.append( { - "key": item["id"], - "data": {"float32": vector_data}, - "metadata": metadata, + 'key': item['id'], + 'data': {'float32': vector_data}, + 'metadata': metadata, } ) @@ -283,12 +263,10 @@ class S3VectorClient(VectorDBBase): batch = vectors[i : i + batch_size] if i == 0: # Log sample info for first batch only log.info( - f"Upserting batch 1: {len(batch)} vectors. First vector sample: key={batch[0]['key']}, data_type={type(batch[0]['data']['float32'])}, data_len={len(batch[0]['data']['float32'])}" + f'Upserting batch 1: {len(batch)} vectors. First vector sample: key={batch[0]["key"]}, data_type={type(batch[0]["data"]["float32"])}, data_len={len(batch[0]["data"]["float32"])}' ) else: - log.info( - f"Upserting batch {i//batch_size + 1}: {len(batch)} vectors." - ) + log.info(f'Upserting batch {i // batch_size + 1}: {len(batch)} vectors.') self.client.put_vectors( vectorBucketName=self.bucket_name, @@ -296,11 +274,9 @@ class S3VectorClient(VectorDBBase): vectors=batch, ) - log.info( - f"Completed upsert of {len(vectors)} vectors into index '{collection_name}'." - ) + log.info(f"Completed upsert of {len(vectors)} vectors into index '{collection_name}'.") except Exception as e: - log.error(f"Error upserting vectors: {e}") + log.error(f'Error upserting vectors: {e}') raise def search( @@ -319,13 +295,11 @@ class S3VectorClient(VectorDBBase): return None if not vectors: - log.warning("No query vectors provided") + log.warning('No query vectors provided') return None try: - log.info( - f"Searching collection '{collection_name}' with {len(vectors)} query vectors, limit={limit}" - ) + log.info(f"Searching collection '{collection_name}' with {len(vectors)} query vectors, limit={limit}") # Initialize result lists all_ids = [] @@ -335,10 +309,10 @@ class S3VectorClient(VectorDBBase): # Process each query vector for i, query_vector in enumerate(vectors): - log.debug(f"Processing query vector {i+1}/{len(vectors)}") + log.debug(f'Processing query vector {i + 1}/{len(vectors)}') # Prepare the query vector in S3 Vector format - query_vector_dict = {"float32": [float(x) for x in query_vector]} + query_vector_dict = {'float32': [float(x) for x in query_vector]} # Call S3 Vector query API response = self.client.query_vectors( @@ -356,24 +330,22 @@ class S3VectorClient(VectorDBBase): query_metadatas = [] query_distances = [] - result_vectors = response.get("vectors", []) + result_vectors = response.get('vectors', []) for vector in result_vectors: - vector_id = vector.get("key") - vector_metadata = vector.get("metadata", {}) - vector_distance = vector.get("distance", 0.0) + vector_id = vector.get('key') + vector_metadata = vector.get('metadata', {}) + vector_distance = vector.get('distance', 0.0) # Extract document text from metadata - document_text = "" + document_text = '' if isinstance(vector_metadata, dict): # Get the text field first (highest priority) - document_text = vector_metadata.get("text") + document_text = vector_metadata.get('text') if not document_text: # Fallback to other possible text fields document_text = ( - vector_metadata.get("content") - or vector_metadata.get("document") - or vector_id + vector_metadata.get('content') or vector_metadata.get('document') or vector_id ) else: document_text = vector_id @@ -389,7 +361,7 @@ class S3VectorClient(VectorDBBase): all_metadatas.append(query_metadatas) all_distances.append(query_distances) - log.info(f"Search completed. Found results for {len(all_ids)} queries") + log.info(f'Search completed. Found results for {len(all_ids)} queries') # Return SearchResult format return SearchResult( @@ -402,24 +374,20 @@ class S3VectorClient(VectorDBBase): except Exception as e: log.error(f"Error searching collection '{collection_name}': {str(e)}") # Handle specific AWS exceptions - if hasattr(e, "response") and "Error" in e.response: - error_code = e.response["Error"]["Code"] - if error_code == "NotFoundException": + if hasattr(e, 'response') and 'Error' in e.response: + error_code = e.response['Error']['Code'] + if error_code == 'NotFoundException': log.warning(f"Collection '{collection_name}' not found") return None - elif error_code == "ValidationException": - log.error(f"Invalid query vector dimensions or parameters") + elif error_code == 'ValidationException': + log.error(f'Invalid query vector dimensions or parameters') return None - elif error_code == "AccessDeniedException": - log.error( - f"Access denied for collection '{collection_name}'. Check permissions." - ) + elif error_code == 'AccessDeniedException': + log.error(f"Access denied for collection '{collection_name}'. Check permissions.") return None raise - def query( - self, collection_name: str, filter: Dict, limit: Optional[int] = None - ) -> Optional[GetResult]: + def query(self, collection_name: str, filter: Dict, limit: Optional[int] = None) -> Optional[GetResult]: """ Query vectors from a collection using metadata filter. """ @@ -429,7 +397,7 @@ class S3VectorClient(VectorDBBase): return GetResult(ids=[[]], documents=[[]], metadatas=[[]]) if not filter: - log.warning("No filter provided, returning all vectors") + log.warning('No filter provided, returning all vectors') return self.get(collection_name) try: @@ -443,17 +411,13 @@ class S3VectorClient(VectorDBBase): all_vectors_result = self.get(collection_name) if not all_vectors_result or not all_vectors_result.ids: - log.warning("No vectors found in collection") + log.warning('No vectors found in collection') return GetResult(ids=[[]], documents=[[]], metadatas=[[]]) # Extract the lists from the result all_ids = all_vectors_result.ids[0] if all_vectors_result.ids else [] - all_documents = ( - all_vectors_result.documents[0] if all_vectors_result.documents else [] - ) - all_metadatas = ( - all_vectors_result.metadatas[0] if all_vectors_result.metadatas else [] - ) + all_documents = all_vectors_result.documents[0] if all_vectors_result.documents else [] + all_metadatas = all_vectors_result.metadatas[0] if all_vectors_result.metadatas else [] # Apply client-side filtering filtered_ids = [] @@ -472,9 +436,7 @@ class S3VectorClient(VectorDBBase): if limit and len(filtered_ids) >= limit: break - log.info( - f"Filter applied: {len(filtered_ids)} vectors match out of {len(all_ids)} total" - ) + log.info(f'Filter applied: {len(filtered_ids)} vectors match out of {len(all_ids)} total') # Return GetResult format if filtered_ids: @@ -489,15 +451,13 @@ class S3VectorClient(VectorDBBase): except Exception as e: log.error(f"Error querying collection '{collection_name}': {str(e)}") # Handle specific AWS exceptions - if hasattr(e, "response") and "Error" in e.response: - error_code = e.response["Error"]["Code"] - if error_code == "NotFoundException": + if hasattr(e, 'response') and 'Error' in e.response: + error_code = e.response['Error']['Code'] + if error_code == 'NotFoundException': log.warning(f"Collection '{collection_name}' not found") return GetResult(ids=[[]], documents=[[]], metadatas=[[]]) - elif error_code == "AccessDeniedException": - log.error( - f"Access denied for collection '{collection_name}'. Check permissions." - ) + elif error_code == 'AccessDeniedException': + log.error(f"Access denied for collection '{collection_name}'. Check permissions.") return GetResult(ids=[[]], documents=[[]], metadatas=[[]]) raise @@ -524,47 +484,43 @@ class S3VectorClient(VectorDBBase): while True: # Prepare request parameters request_params = { - "vectorBucketName": self.bucket_name, - "indexName": collection_name, - "returnData": False, # Don't include vector data (not needed for get) - "returnMetadata": True, # Include metadata - "maxResults": 500, # Use reasonable page size + 'vectorBucketName': self.bucket_name, + 'indexName': collection_name, + 'returnData': False, # Don't include vector data (not needed for get) + 'returnMetadata': True, # Include metadata + 'maxResults': 500, # Use reasonable page size } if next_token: - request_params["nextToken"] = next_token + request_params['nextToken'] = next_token # Call S3 Vector API response = self.client.list_vectors(**request_params) # Process vectors in this page - vectors = response.get("vectors", []) + vectors = response.get('vectors', []) for vector in vectors: - vector_id = vector.get("key") - vector_data = vector.get("data", {}) - vector_metadata = vector.get("metadata", {}) + vector_id = vector.get('key') + vector_data = vector.get('data', {}) + vector_metadata = vector.get('metadata', {}) # Extract the actual vector array - vector_array = vector_data.get("float32", []) + vector_array = vector_data.get('float32', []) # For documents, we try to extract text from metadata or use the vector ID - document_text = "" + document_text = '' if isinstance(vector_metadata, dict): # Get the text field first (highest priority) - document_text = vector_metadata.get("text") + document_text = vector_metadata.get('text') if not document_text: # Fallback to other possible text fields document_text = ( - vector_metadata.get("content") - or vector_metadata.get("document") - or vector_id + vector_metadata.get('content') or vector_metadata.get('document') or vector_id ) # Log the actual content for debugging - log.debug( - f"Document text preview (first 200 chars): {str(document_text)[:200]}" - ) + log.debug(f'Document text preview (first 200 chars): {str(document_text)[:200]}') else: document_text = vector_id @@ -573,37 +529,29 @@ class S3VectorClient(VectorDBBase): all_metadatas.append(vector_metadata) # Check if there are more pages - next_token = response.get("nextToken") + next_token = response.get('nextToken') if not next_token: break - log.info( - f"Retrieved {len(all_ids)} vectors from collection '{collection_name}'" - ) + log.info(f"Retrieved {len(all_ids)} vectors from collection '{collection_name}'") # Return in GetResult format # The Open WebUI GetResult expects lists of lists, so we wrap each list if all_ids: - return GetResult( - ids=[all_ids], documents=[all_documents], metadatas=[all_metadatas] - ) + return GetResult(ids=[all_ids], documents=[all_documents], metadatas=[all_metadatas]) else: return GetResult(ids=[[]], documents=[[]], metadatas=[[]]) except Exception as e: - log.error( - f"Error retrieving vectors from collection '{collection_name}': {str(e)}" - ) + log.error(f"Error retrieving vectors from collection '{collection_name}': {str(e)}") # Handle specific AWS exceptions - if hasattr(e, "response") and "Error" in e.response: - error_code = e.response["Error"]["Code"] - if error_code == "NotFoundException": + if hasattr(e, 'response') and 'Error' in e.response: + error_code = e.response['Error']['Code'] + if error_code == 'NotFoundException': log.warning(f"Collection '{collection_name}' not found") return GetResult(ids=[[]], documents=[[]], metadatas=[[]]) - elif error_code == "AccessDeniedException": - log.error( - f"Access denied for collection '{collection_name}'. Check permissions." - ) + elif error_code == 'AccessDeniedException': + log.error(f"Access denied for collection '{collection_name}'. Check permissions.") return GetResult(ids=[[]], documents=[[]], metadatas=[[]]) raise @@ -618,20 +566,16 @@ class S3VectorClient(VectorDBBase): """ if not self.has_collection(collection_name): - log.warning( - f"Collection '{collection_name}' does not exist, nothing to delete" - ) + log.warning(f"Collection '{collection_name}' does not exist, nothing to delete") return # Check if this is a knowledge collection (not file-specific) - is_knowledge_collection = not collection_name.startswith("file-") + is_knowledge_collection = not collection_name.startswith('file-') try: if ids: # Delete by specific vector IDs/keys - log.info( - f"Deleting {len(ids)} vectors by IDs from collection '{collection_name}'" - ) + log.info(f"Deleting {len(ids)} vectors by IDs from collection '{collection_name}'") self.client.delete_vectors( vectorBucketName=self.bucket_name, indexName=collection_name, @@ -641,15 +585,13 @@ class S3VectorClient(VectorDBBase): elif filter: # Handle filter-based deletion - log.info( - f"Deleting vectors by filter from collection '{collection_name}': {filter}" - ) + log.info(f"Deleting vectors by filter from collection '{collection_name}': {filter}") # If this is a knowledge collection and we have a file_id filter, # also clean up the corresponding file-specific collection - if is_knowledge_collection and "file_id" in filter: - file_id = filter["file_id"] - file_collection_name = f"file-{file_id}" + if is_knowledge_collection and 'file_id' in filter: + file_id = filter['file_id'] + file_collection_name = f'file-{file_id}' if self.has_collection(file_collection_name): log.info( f"Found related file-specific collection '{file_collection_name}', deleting it to prevent duplicates" @@ -661,9 +603,7 @@ class S3VectorClient(VectorDBBase): query_result = self.query(collection_name, filter) if query_result and query_result.ids and query_result.ids[0]: matching_ids = query_result.ids[0] - log.info( - f"Found {len(matching_ids)} vectors matching filter, deleting them" - ) + log.info(f'Found {len(matching_ids)} vectors matching filter, deleting them') # Delete the matching vectors by ID self.client.delete_vectors( @@ -671,17 +611,13 @@ class S3VectorClient(VectorDBBase): indexName=collection_name, keys=matching_ids, ) - log.info( - f"Deleted {len(matching_ids)} vectors from index '{collection_name}' using filter" - ) + log.info(f"Deleted {len(matching_ids)} vectors from index '{collection_name}' using filter") else: - log.warning("No vectors found matching the filter criteria") + log.warning('No vectors found matching the filter criteria') else: - log.warning("No IDs or filter provided for deletion") + log.warning('No IDs or filter provided for deletion') except Exception as e: - log.error( - f"Error deleting vectors from collection '{collection_name}': {e}" - ) + log.error(f"Error deleting vectors from collection '{collection_name}': {e}") raise def reset(self) -> None: @@ -690,36 +626,32 @@ class S3VectorClient(VectorDBBase): """ try: - log.warning( - "Reset called - this will delete all vector indexes in the S3 bucket" - ) + log.warning('Reset called - this will delete all vector indexes in the S3 bucket') # List all indexes response = self.client.list_indexes(vectorBucketName=self.bucket_name) - indexes = response.get("indexes", []) + indexes = response.get('indexes', []) if not indexes: - log.warning("No indexes found to delete") + log.warning('No indexes found to delete') return # Delete all indexes deleted_count = 0 for index in indexes: - index_name = index.get("indexName") + index_name = index.get('indexName') if index_name: try: - self.client.delete_index( - vectorBucketName=self.bucket_name, indexName=index_name - ) + self.client.delete_index(vectorBucketName=self.bucket_name, indexName=index_name) deleted_count += 1 - log.info(f"Deleted index: {index_name}") + log.info(f'Deleted index: {index_name}') except Exception as e: log.error(f"Error deleting index '{index_name}': {e}") - log.info(f"Reset completed: deleted {deleted_count} indexes") + log.info(f'Reset completed: deleted {deleted_count} indexes') except Exception as e: - log.error(f"Error during reset: {e}") + log.error(f'Error during reset: {e}') raise def _matches_filter(self, metadata: Dict[str, Any], filter: Dict[str, Any]) -> bool: @@ -732,15 +664,15 @@ class S3VectorClient(VectorDBBase): # Check each filter condition for key, expected_value in filter.items(): # Handle special operators - if key.startswith("$"): - if key == "$and": + if key.startswith('$'): + if key == '$and': # All conditions must match if not isinstance(expected_value, list): continue for condition in expected_value: if not self._matches_filter(metadata, condition): return False - elif key == "$or": + elif key == '$or': # At least one condition must match if not isinstance(expected_value, list): continue @@ -760,22 +692,19 @@ class S3VectorClient(VectorDBBase): if isinstance(expected_value, dict): # Handle comparison operators for op, op_value in expected_value.items(): - if op == "$eq": + if op == '$eq': if actual_value != op_value: return False - elif op == "$ne": + elif op == '$ne': if actual_value == op_value: return False - elif op == "$in": - if ( - not isinstance(op_value, list) - or actual_value not in op_value - ): + elif op == '$in': + if not isinstance(op_value, list) or actual_value not in op_value: return False - elif op == "$nin": + elif op == '$nin': if isinstance(op_value, list) and actual_value in op_value: return False - elif op == "$exists": + elif op == '$exists': if bool(op_value) != (key in metadata): return False # Add more operators as needed diff --git a/backend/open_webui/retrieval/vector/dbs/weaviate.py b/backend/open_webui/retrieval/vector/dbs/weaviate.py index c9b09ad638..2cf4c135c5 100644 --- a/backend/open_webui/retrieval/vector/dbs/weaviate.py +++ b/backend/open_webui/retrieval/vector/dbs/weaviate.py @@ -60,47 +60,43 @@ class WeaviateClient(VectorDBBase): try: # Build connection parameters connection_params = { - "http_host": WEAVIATE_HTTP_HOST, - "http_port": WEAVIATE_HTTP_PORT, - "http_secure": WEAVIATE_HTTP_SECURE, - "grpc_host": WEAVIATE_GRPC_HOST, - "grpc_port": WEAVIATE_GRPC_PORT, - "grpc_secure": WEAVIATE_GRPC_SECURE, - "skip_init_checks": WEAVIATE_SKIP_INIT_CHECKS, + 'http_host': WEAVIATE_HTTP_HOST, + 'http_port': WEAVIATE_HTTP_PORT, + 'http_secure': WEAVIATE_HTTP_SECURE, + 'grpc_host': WEAVIATE_GRPC_HOST, + 'grpc_port': WEAVIATE_GRPC_PORT, + 'grpc_secure': WEAVIATE_GRPC_SECURE, + 'skip_init_checks': WEAVIATE_SKIP_INIT_CHECKS, } # Only add auth_credentials if WEAVIATE_API_KEY exists and is not empty if WEAVIATE_API_KEY: - connection_params["auth_credentials"] = ( - weaviate.classes.init.Auth.api_key(WEAVIATE_API_KEY) - ) + connection_params['auth_credentials'] = weaviate.classes.init.Auth.api_key(WEAVIATE_API_KEY) self.client = weaviate.connect_to_custom(**connection_params) self.client.connect() except Exception as e: - raise ConnectionError(f"Failed to connect to Weaviate: {e}") from e + raise ConnectionError(f'Failed to connect to Weaviate: {e}') from e def _sanitize_collection_name(self, collection_name: str) -> str: """Sanitize collection name to be a valid Weaviate class name.""" if not isinstance(collection_name, str) or not collection_name.strip(): - raise ValueError("Collection name must be a non-empty string") + raise ValueError('Collection name must be a non-empty string') # Requirements for a valid Weaviate class name: # The collection name must begin with a capital letter. # The name can only contain letters, numbers, and the underscore (_) character. Spaces are not allowed. # Replace hyphens with underscores and keep only alphanumeric characters - name = re.sub(r"[^a-zA-Z0-9_]", "", collection_name.replace("-", "_")) - name = name.strip("_") + name = re.sub(r'[^a-zA-Z0-9_]', '', collection_name.replace('-', '_')) + name = name.strip('_') if not name: - raise ValueError( - "Could not sanitize collection name to be a valid Weaviate class name" - ) + raise ValueError('Could not sanitize collection name to be a valid Weaviate class name') # Ensure it starts with a letter and is capitalized if not name[0].isalpha(): - name = "C" + name + name = 'C' + name return name[0].upper() + name[1:] @@ -118,9 +114,7 @@ class WeaviateClient(VectorDBBase): name=collection_name, vector_config=weaviate.classes.config.Configure.Vectors.self_provided(), properties=[ - weaviate.classes.config.Property( - name="text", data_type=weaviate.classes.config.DataType.TEXT - ), + weaviate.classes.config.Property(name='text', data_type=weaviate.classes.config.DataType.TEXT), ], ) @@ -133,19 +127,15 @@ class WeaviateClient(VectorDBBase): with collection.batch.fixed_size(batch_size=100) as batch: for item in items: - item_uuid = str(uuid.uuid4()) if not item["id"] else str(item["id"]) + item_uuid = str(uuid.uuid4()) if not item['id'] else str(item['id']) - properties = {"text": item["text"]} - if item["metadata"]: - clean_metadata = _convert_uuids_to_strings( - process_metadata(item["metadata"]) - ) - clean_metadata.pop("text", None) + properties = {'text': item['text']} + if item['metadata']: + clean_metadata = _convert_uuids_to_strings(process_metadata(item['metadata'])) + clean_metadata.pop('text', None) properties.update(clean_metadata) - batch.add_object( - properties=properties, uuid=item_uuid, vector=item["vector"] - ) + batch.add_object(properties=properties, uuid=item_uuid, vector=item['vector']) def upsert(self, collection_name: str, items: List[VectorItem]) -> None: sane_collection_name = self._sanitize_collection_name(collection_name) @@ -156,19 +146,15 @@ class WeaviateClient(VectorDBBase): with collection.batch.fixed_size(batch_size=100) as batch: for item in items: - item_uuid = str(item["id"]) if item["id"] else None + item_uuid = str(item['id']) if item['id'] else None - properties = {"text": item["text"]} - if item["metadata"]: - clean_metadata = _convert_uuids_to_strings( - process_metadata(item["metadata"]) - ) - clean_metadata.pop("text", None) + properties = {'text': item['text']} + if item['metadata']: + clean_metadata = _convert_uuids_to_strings(process_metadata(item['metadata'])) + clean_metadata.pop('text', None) properties.update(clean_metadata) - batch.add_object( - properties=properties, uuid=item_uuid, vector=item["vector"] - ) + batch.add_object(properties=properties, uuid=item_uuid, vector=item['vector']) def search( self, @@ -205,16 +191,12 @@ class WeaviateClient(VectorDBBase): for obj in response.objects: properties = dict(obj.properties) if obj.properties else {} - documents.append(properties.pop("text", "")) + documents.append(properties.pop('text', '')) metadatas.append(_convert_uuids_to_strings(properties)) # Weaviate has cosine distance, 2 (worst) -> 0 (best). Re-ordering to 0 -> 1 raw_distances = [ - ( - obj.metadata.distance - if obj.metadata and obj.metadata.distance - else 2.0 - ) + (obj.metadata.distance if obj.metadata and obj.metadata.distance else 2.0) for obj in response.objects ] distances = [(2 - dist) / 2 for dist in raw_distances] @@ -231,16 +213,14 @@ class WeaviateClient(VectorDBBase): return SearchResult( **{ - "ids": result_ids, - "documents": result_documents, - "metadatas": result_metadatas, - "distances": result_distances, + 'ids': result_ids, + 'documents': result_documents, + 'metadatas': result_metadatas, + 'distances': result_distances, } ) - def query( - self, collection_name: str, filter: Dict, limit: Optional[int] = None - ) -> Optional[GetResult]: + def query(self, collection_name: str, filter: Dict, limit: Optional[int] = None) -> Optional[GetResult]: sane_collection_name = self._sanitize_collection_name(collection_name) if not self.client.collections.exists(sane_collection_name): return None @@ -250,21 +230,15 @@ class WeaviateClient(VectorDBBase): weaviate_filter = None if filter: for key, value in filter.items(): - prop_filter = weaviate.classes.query.Filter.by_property(name=key).equal( - value - ) + prop_filter = weaviate.classes.query.Filter.by_property(name=key).equal(value) weaviate_filter = ( prop_filter if weaviate_filter is None - else weaviate.classes.query.Filter.all_of( - [weaviate_filter, prop_filter] - ) + else weaviate.classes.query.Filter.all_of([weaviate_filter, prop_filter]) ) try: - response = collection.query.fetch_objects( - filters=weaviate_filter, limit=limit - ) + response = collection.query.fetch_objects(filters=weaviate_filter, limit=limit) ids = [str(obj.uuid) for obj in response.objects] documents = [] @@ -272,14 +246,14 @@ class WeaviateClient(VectorDBBase): for obj in response.objects: properties = dict(obj.properties) if obj.properties else {} - documents.append(properties.pop("text", "")) + documents.append(properties.pop('text', '')) metadatas.append(_convert_uuids_to_strings(properties)) return GetResult( **{ - "ids": [ids], - "documents": [documents], - "metadatas": [metadatas], + 'ids': [ids], + 'documents': [documents], + 'metadatas': [metadatas], } ) except Exception: @@ -297,7 +271,7 @@ class WeaviateClient(VectorDBBase): for item in collection.iterator(): ids.append(str(item.uuid)) properties = dict(item.properties) if item.properties else {} - documents.append(properties.pop("text", "")) + documents.append(properties.pop('text', '')) metadatas.append(_convert_uuids_to_strings(properties)) if not ids: @@ -305,9 +279,9 @@ class WeaviateClient(VectorDBBase): return GetResult( **{ - "ids": [ids], - "documents": [documents], - "metadatas": [metadatas], + 'ids': [ids], + 'documents': [documents], + 'metadatas': [metadatas], } ) except Exception: @@ -332,15 +306,11 @@ class WeaviateClient(VectorDBBase): elif filter: weaviate_filter = None for key, value in filter.items(): - prop_filter = weaviate.classes.query.Filter.by_property( - name=key - ).equal(value) + prop_filter = weaviate.classes.query.Filter.by_property(name=key).equal(value) weaviate_filter = ( prop_filter if weaviate_filter is None - else weaviate.classes.query.Filter.all_of( - [weaviate_filter, prop_filter] - ) + else weaviate.classes.query.Filter.all_of([weaviate_filter, prop_filter]) ) if weaviate_filter: diff --git a/backend/open_webui/retrieval/vector/factory.py b/backend/open_webui/retrieval/vector/factory.py index d92b335864..8c0208fd4f 100644 --- a/backend/open_webui/retrieval/vector/factory.py +++ b/backend/open_webui/retrieval/vector/factory.py @@ -8,7 +8,6 @@ from open_webui.config import ( class Vector: - @staticmethod def get_vector(vector_type: str) -> VectorDBBase: """ @@ -82,7 +81,7 @@ class Vector: return WeaviateClient() case _: - raise ValueError(f"Unsupported vector type: {vector_type}") + raise ValueError(f'Unsupported vector type: {vector_type}') VECTOR_DB_CLIENT = Vector.get_vector(VECTOR_DB) diff --git a/backend/open_webui/retrieval/vector/main.py b/backend/open_webui/retrieval/vector/main.py index a76fec9956..f7904baa20 100644 --- a/backend/open_webui/retrieval/vector/main.py +++ b/backend/open_webui/retrieval/vector/main.py @@ -63,9 +63,7 @@ class VectorDBBase(ABC): pass @abstractmethod - def query( - self, collection_name: str, filter: Dict, limit: Optional[int] = None - ) -> Optional[GetResult]: + def query(self, collection_name: str, filter: Dict, limit: Optional[int] = None) -> Optional[GetResult]: """Query vectors from a collection using metadata filter.""" pass diff --git a/backend/open_webui/retrieval/vector/type.py b/backend/open_webui/retrieval/vector/type.py index df9453aa3e..999aee9c54 100644 --- a/backend/open_webui/retrieval/vector/type.py +++ b/backend/open_webui/retrieval/vector/type.py @@ -2,15 +2,15 @@ from enum import StrEnum class VectorType(StrEnum): - MILVUS = "milvus" - MARIADB_VECTOR = "mariadb-vector" - QDRANT = "qdrant" - CHROMA = "chroma" - PINECONE = "pinecone" - ELASTICSEARCH = "elasticsearch" - OPENSEARCH = "opensearch" - PGVECTOR = "pgvector" - ORACLE23AI = "oracle23ai" - S3VECTOR = "s3vector" - WEAVIATE = "weaviate" - OPENGAUSS = "opengauss" + MILVUS = 'milvus' + MARIADB_VECTOR = 'mariadb-vector' + QDRANT = 'qdrant' + CHROMA = 'chroma' + PINECONE = 'pinecone' + ELASTICSEARCH = 'elasticsearch' + OPENSEARCH = 'opensearch' + PGVECTOR = 'pgvector' + ORACLE23AI = 'oracle23ai' + S3VECTOR = 's3vector' + WEAVIATE = 'weaviate' + OPENGAUSS = 'opengauss' diff --git a/backend/open_webui/retrieval/vector/utils.py b/backend/open_webui/retrieval/vector/utils.py index a39d364419..b2e2fed762 100644 --- a/backend/open_webui/retrieval/vector/utils.py +++ b/backend/open_webui/retrieval/vector/utils.py @@ -1,13 +1,11 @@ from datetime import datetime -KEYS_TO_EXCLUDE = ["content", "pages", "tables", "paragraphs", "sections", "figures"] +KEYS_TO_EXCLUDE = ['content', 'pages', 'tables', 'paragraphs', 'sections', 'figures'] def filter_metadata(metadata: dict[str, any]) -> dict[str, any]: # Removes large/redundant fields from metadata dict. - metadata = { - key: value for key, value in metadata.items() if key not in KEYS_TO_EXCLUDE - } + metadata = {key: value for key, value in metadata.items() if key not in KEYS_TO_EXCLUDE} return metadata diff --git a/backend/open_webui/retrieval/web/azure.py b/backend/open_webui/retrieval/web/azure.py index 3859ccc9b7..4f74ecc982 100644 --- a/backend/open_webui/retrieval/web/azure.py +++ b/backend/open_webui/retrieval/web/azure.py @@ -40,20 +40,17 @@ def search_azure( from azure.search.documents import SearchClient except ImportError: log.error( - "azure-search-documents package is not installed. " - "Install it with: pip install azure-search-documents" + 'azure-search-documents package is not installed. Install it with: pip install azure-search-documents' ) raise ImportError( - "azure-search-documents is required for Azure AI Search. " - "Install it with: pip install azure-search-documents" + 'azure-search-documents is required for Azure AI Search. ' + 'Install it with: pip install azure-search-documents' ) try: # Create search client with API key authentication credential = AzureKeyCredential(api_key) - search_client = SearchClient( - endpoint=endpoint, index_name=index_name, credential=credential - ) + search_client = SearchClient(endpoint=endpoint, index_name=index_name, credential=credential) # Perform the search results = search_client.search(search_text=query, top=count) @@ -68,42 +65,42 @@ def search_azure( # Try to find URL field (common names) link = ( - result_dict.get("url") - or result_dict.get("link") - or result_dict.get("uri") - or result_dict.get("metadata_storage_path") - or "" + result_dict.get('url') + or result_dict.get('link') + or result_dict.get('uri') + or result_dict.get('metadata_storage_path') + or '' ) # Try to find title field (common names) title = ( - result_dict.get("title") - or result_dict.get("name") - or result_dict.get("metadata_title") - or result_dict.get("metadata_storage_name") + result_dict.get('title') + or result_dict.get('name') + or result_dict.get('metadata_title') + or result_dict.get('metadata_storage_name') or None ) # Try to find content/snippet field (common names) snippet = ( - result_dict.get("content") - or result_dict.get("snippet") - or result_dict.get("description") - or result_dict.get("summary") - or result_dict.get("text") + result_dict.get('content') + or result_dict.get('snippet') + or result_dict.get('description') + or result_dict.get('summary') + or result_dict.get('text') or None ) # Truncate snippet if too long if snippet and len(snippet) > 500: - snippet = snippet[:497] + "..." + snippet = snippet[:497] + '...' if link: # Only add if we found a valid link search_results.append( { - "link": link, - "title": title, - "snippet": snippet, + 'link': link, + 'title': title, + 'snippet': snippet, } ) @@ -114,13 +111,13 @@ def search_azure( # Convert to SearchResult objects return [ SearchResult( - link=result["link"], - title=result.get("title"), - snippet=result.get("snippet"), + link=result['link'], + title=result.get('title'), + snippet=result.get('snippet'), ) for result in search_results ] except Exception as ex: - log.error(f"Azure AI Search error: {ex}") + log.error(f'Azure AI Search error: {ex}') raise ex diff --git a/backend/open_webui/retrieval/web/bing.py b/backend/open_webui/retrieval/web/bing.py index 4c9822b900..b7cfea89de 100644 --- a/backend/open_webui/retrieval/web/bing.py +++ b/backend/open_webui/retrieval/web/bing.py @@ -21,48 +21,44 @@ def search_bing( filter_list: Optional[list[str]] = None, ) -> list[SearchResult]: mkt = locale - params = {"q": query, "mkt": mkt, "count": count} - headers = {"Ocp-Apim-Subscription-Key": subscription_key} + params = {'q': query, 'mkt': mkt, 'count': count} + headers = {'Ocp-Apim-Subscription-Key': subscription_key} try: response = requests.get(endpoint, headers=headers, params=params) response.raise_for_status() json_response = response.json() - results = json_response.get("webPages", {}).get("value", []) + results = json_response.get('webPages', {}).get('value', []) if filter_list: results = get_filtered_results(results, filter_list) return [ SearchResult( - link=result["url"], - title=result.get("name"), - snippet=result.get("snippet"), + link=result['url'], + title=result.get('name'), + snippet=result.get('snippet'), ) for result in results ] except Exception as ex: - log.error(f"Error: {ex}") + log.error(f'Error: {ex}') raise ex def main(): - parser = argparse.ArgumentParser(description="Search Bing from the command line.") + parser = argparse.ArgumentParser(description='Search Bing from the command line.') parser.add_argument( - "query", + 'query', type=str, - default="Top 10 international news today", - help="The search query.", + default='Top 10 international news today', + help='The search query.', ) + parser.add_argument('--count', type=int, default=10, help='Number of search results to return.') + parser.add_argument('--filter', nargs='*', help='List of filters to apply to the search results.') parser.add_argument( - "--count", type=int, default=10, help="Number of search results to return." - ) - parser.add_argument( - "--filter", nargs="*", help="List of filters to apply to the search results." - ) - parser.add_argument( - "--locale", + '--locale', type=str, - default="en-US", - help="The locale to use for the search, maps to market in api", + default='en-US', + help='The locale to use for the search, maps to market in api', ) args = parser.parse_args() diff --git a/backend/open_webui/retrieval/web/bocha.py b/backend/open_webui/retrieval/web/bocha.py index 7e3c9b0a40..3557dcffb9 100644 --- a/backend/open_webui/retrieval/web/bocha.py +++ b/backend/open_webui/retrieval/web/bocha.py @@ -10,43 +10,38 @@ log = logging.getLogger(__name__) def _parse_response(response): results = [] - if "data" in response: - data = response["data"] - if "webPages" in data: - webPages = data["webPages"] - if "value" in webPages: + if 'data' in response: + data = response['data'] + if 'webPages' in data: + webPages = data['webPages'] + if 'value' in webPages: results = [ { - "id": item.get("id", ""), - "name": item.get("name", ""), - "url": item.get("url", ""), - "snippet": item.get("snippet", ""), - "summary": item.get("summary", ""), - "siteName": item.get("siteName", ""), - "siteIcon": item.get("siteIcon", ""), - "datePublished": item.get("datePublished", "") - or item.get("dateLastCrawled", ""), + 'id': item.get('id', ''), + 'name': item.get('name', ''), + 'url': item.get('url', ''), + 'snippet': item.get('snippet', ''), + 'summary': item.get('summary', ''), + 'siteName': item.get('siteName', ''), + 'siteIcon': item.get('siteIcon', ''), + 'datePublished': item.get('datePublished', '') or item.get('dateLastCrawled', ''), } - for item in webPages["value"] + for item in webPages['value'] ] return results -def search_bocha( - api_key: str, query: str, count: int, filter_list: Optional[list[str]] = None -) -> list[SearchResult]: +def search_bocha(api_key: str, query: str, count: int, filter_list: Optional[list[str]] = None) -> list[SearchResult]: """Search using Bocha's Search API and return the results as a list of SearchResult objects. Args: api_key (str): A Bocha Search API key query (str): The query to search for """ - url = "https://api.bochaai.com/v1/web-search?utm_source=ollama" - headers = {"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"} + url = 'https://api.bochaai.com/v1/web-search?utm_source=ollama' + headers = {'Authorization': f'Bearer {api_key}', 'Content-Type': 'application/json'} - payload = json.dumps( - {"query": query, "summary": True, "freshness": "noLimit", "count": count} - ) + payload = json.dumps({'query': query, 'summary': True, 'freshness': 'noLimit', 'count': count}) response = requests.post(url, headers=headers, data=payload, timeout=5) response.raise_for_status() @@ -56,8 +51,6 @@ def search_bocha( results = get_filtered_results(results, filter_list) return [ - SearchResult( - link=result["url"], title=result.get("name"), snippet=result.get("summary") - ) + SearchResult(link=result['url'], title=result.get('name'), snippet=result.get('summary')) for result in results[:count] ] diff --git a/backend/open_webui/retrieval/web/brave.py b/backend/open_webui/retrieval/web/brave.py index 49c8a88e81..9e663c2684 100644 --- a/backend/open_webui/retrieval/web/brave.py +++ b/backend/open_webui/retrieval/web/brave.py @@ -8,44 +8,42 @@ from open_webui.retrieval.web.main import SearchResult, get_filtered_results log = logging.getLogger(__name__) -def search_brave( - api_key: str, query: str, count: int, filter_list: Optional[list[str]] = None -) -> list[SearchResult]: +def search_brave(api_key: str, query: str, count: int, filter_list: Optional[list[str]] = None) -> list[SearchResult]: """Search using Brave's Search API and return the results as a list of SearchResult objects. Args: api_key (str): A Brave Search API key query (str): The query to search for """ - url = "https://api.search.brave.com/res/v1/web/search" + url = 'https://api.search.brave.com/res/v1/web/search' headers = { - "Accept": "application/json", - "Accept-Encoding": "gzip", - "X-Subscription-Token": api_key, + 'Accept': 'application/json', + 'Accept-Encoding': 'gzip', + 'X-Subscription-Token': api_key, } - params = {"q": query, "count": count} + params = {'q': query, 'count': count} response = requests.get(url, headers=headers, params=params) # Handle 429 rate limiting - Brave free tier allows 1 request/second # If rate limited, wait 1 second and retry once before failing if response.status_code == 429: - log.info("Brave Search API rate limited (429), retrying after 1 second...") + log.info('Brave Search API rate limited (429), retrying after 1 second...') time.sleep(1) response = requests.get(url, headers=headers, params=params) response.raise_for_status() json_response = response.json() - results = json_response.get("web", {}).get("results", []) + results = json_response.get('web', {}).get('results', []) if filter_list: results = get_filtered_results(results, filter_list) return [ SearchResult( - link=result["url"], - title=result.get("title"), - snippet=result.get("description"), + link=result['url'], + title=result.get('title'), + snippet=result.get('description'), ) for result in results[:count] ] diff --git a/backend/open_webui/retrieval/web/duckduckgo.py b/backend/open_webui/retrieval/web/duckduckgo.py index 7528418cdb..da1c3f77ec 100644 --- a/backend/open_webui/retrieval/web/duckduckgo.py +++ b/backend/open_webui/retrieval/web/duckduckgo.py @@ -13,7 +13,7 @@ def search_duckduckgo( count: int, filter_list: Optional[list[str]] = None, concurrent_requests: Optional[int] = None, - backend: Optional[str] = "auto", + backend: Optional[str] = 'auto', ) -> list[SearchResult]: """ Search using DuckDuckGo's Search API and return the results as a list of SearchResult objects. @@ -33,20 +33,18 @@ def search_duckduckgo( # Use the ddgs.text() method to perform the search try: - search_results = ddgs.text( - query, safesearch="moderate", max_results=count, backend=backend - ) + search_results = ddgs.text(query, safesearch='moderate', max_results=count, backend=backend) except RatelimitException as e: - log.error(f"RatelimitException: {e}") + log.error(f'RatelimitException: {e}') if filter_list: search_results = get_filtered_results(search_results, filter_list) # Return the list of search results return [ SearchResult( - link=result["href"], - title=result.get("title"), - snippet=result.get("body"), + link=result['href'], + title=result.get('title'), + snippet=result.get('body'), ) for result in search_results ] diff --git a/backend/open_webui/retrieval/web/exa.py b/backend/open_webui/retrieval/web/exa.py index df9554fab2..860917854e 100644 --- a/backend/open_webui/retrieval/web/exa.py +++ b/backend/open_webui/retrieval/web/exa.py @@ -7,7 +7,7 @@ from open_webui.retrieval.web.main import SearchResult log = logging.getLogger(__name__) -EXA_API_BASE = "https://api.exa.ai" +EXA_API_BASE = 'https://api.exa.ai' @dataclass @@ -31,36 +31,34 @@ def search_exa( count (int): Number of results to return filter_list (Optional[list[str]]): List of domains to filter results by """ - log.info(f"Searching with Exa for query: {query}") + log.info(f'Searching with Exa for query: {query}') - headers = {"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"} + headers = {'Authorization': f'Bearer {api_key}', 'Content-Type': 'application/json'} payload = { - "query": query, - "numResults": count or 5, - "includeDomains": filter_list, - "contents": {"text": True, "highlights": True}, - "type": "auto", # Use the auto search type (keyword or neural) + 'query': query, + 'numResults': count or 5, + 'includeDomains': filter_list, + 'contents': {'text': True, 'highlights': True}, + 'type': 'auto', # Use the auto search type (keyword or neural) } try: - response = requests.post( - f"{EXA_API_BASE}/search", headers=headers, json=payload - ) + response = requests.post(f'{EXA_API_BASE}/search', headers=headers, json=payload) response.raise_for_status() data = response.json() results = [] - for result in data["results"]: + for result in data['results']: results.append( ExaResult( - url=result["url"], - title=result["title"], - text=result["text"], + url=result['url'], + title=result['title'], + text=result['text'], ) ) - log.info(f"Found {len(results)} results") + log.info(f'Found {len(results)} results') return [ SearchResult( link=result.url, @@ -70,5 +68,5 @@ def search_exa( for result in results ] except Exception as e: - log.error(f"Error searching Exa: {e}") + log.error(f'Error searching Exa: {e}') return [] diff --git a/backend/open_webui/retrieval/web/external.py b/backend/open_webui/retrieval/web/external.py index e8cf72e9f0..7f5a2bf2af 100644 --- a/backend/open_webui/retrieval/web/external.py +++ b/backend/open_webui/retrieval/web/external.py @@ -24,12 +24,12 @@ def search_external( ) -> List[SearchResult]: try: headers = { - "User-Agent": "Open WebUI (https://github.com/open-webui/open-webui) RAG Bot", - "Authorization": f"Bearer {external_api_key}", + 'User-Agent': 'Open WebUI (https://github.com/open-webui/open-webui) RAG Bot', + 'Authorization': f'Bearer {external_api_key}', } headers = include_user_info_headers(headers, user) - chat_id = getattr(request.state, "chat_id", None) + chat_id = getattr(request.state, 'chat_id', None) if chat_id: headers[FORWARD_SESSION_INFO_HEADER_CHAT_ID] = str(chat_id) @@ -37,8 +37,8 @@ def search_external( external_url, headers=headers, json={ - "query": query, - "count": count, + 'query': query, + 'count': count, }, ) response.raise_for_status() @@ -47,14 +47,14 @@ def search_external( results = get_filtered_results(results, filter_list) results = [ SearchResult( - link=result.get("link"), - title=result.get("title"), - snippet=result.get("snippet"), + link=result.get('link'), + title=result.get('title'), + snippet=result.get('snippet'), ) for result in results[:count] ] - log.info(f"External search results: {results}") + log.info(f'External search results: {results}') return results except Exception as e: - log.error(f"Error in External search: {e}") + log.error(f'Error in External search: {e}') return [] diff --git a/backend/open_webui/retrieval/web/firecrawl.py b/backend/open_webui/retrieval/web/firecrawl.py index e6e96992a1..4bb23e3797 100644 --- a/backend/open_webui/retrieval/web/firecrawl.py +++ b/backend/open_webui/retrieval/web/firecrawl.py @@ -17,9 +17,7 @@ def search_firecrawl( from firecrawl import FirecrawlApp firecrawl = FirecrawlApp(api_key=firecrawl_api_key, api_url=firecrawl_url) - response = firecrawl.search( - query=query, limit=count, ignore_invalid_urls=True, timeout=count * 3 - ) + response = firecrawl.search(query=query, limit=count, ignore_invalid_urls=True, timeout=count * 3) results = response.web if filter_list: results = get_filtered_results(results, filter_list) @@ -31,8 +29,8 @@ def search_firecrawl( ) for result in results[:count] ] - log.info(f"External search results: {results}") + log.info(f'External search results: {results}') return results except Exception as e: - log.error(f"Error in External search: {e}") + log.error(f'Error in External search: {e}') return [] diff --git a/backend/open_webui/retrieval/web/google_pse.py b/backend/open_webui/retrieval/web/google_pse.py index 96fa8c98cd..bb0a852658 100644 --- a/backend/open_webui/retrieval/web/google_pse.py +++ b/backend/open_webui/retrieval/web/google_pse.py @@ -28,11 +28,11 @@ def search_google_pse( Returns: list[SearchResult]: A list of SearchResult objects. """ - url = "https://www.googleapis.com/customsearch/v1" + url = 'https://www.googleapis.com/customsearch/v1' - headers = {"Content-Type": "application/json"} + headers = {'Content-Type': 'application/json'} if referer: - headers["Referer"] = referer + headers['Referer'] = referer all_results = [] start_index = 1 # Google PSE start parameter is 1-based @@ -40,21 +40,19 @@ def search_google_pse( while count > 0: num_results_this_page = min(count, 10) # Google PSE max results per page is 10 params = { - "cx": search_engine_id, - "q": query, - "key": api_key, - "num": num_results_this_page, - "start": start_index, + 'cx': search_engine_id, + 'q': query, + 'key': api_key, + 'num': num_results_this_page, + 'start': start_index, } - response = requests.request("GET", url, headers=headers, params=params) + response = requests.request('GET', url, headers=headers, params=params) response.raise_for_status() json_response = response.json() - results = json_response.get("items", []) + results = json_response.get('items', []) if results: # check if results are returned. If not, no more pages to fetch. all_results.extend(results) - count -= len( - results - ) # Decrement count by the number of results fetched in this page. + count -= len(results) # Decrement count by the number of results fetched in this page. start_index += 10 # Increment start index for the next page else: break # No more results from Google PSE, break the loop @@ -64,9 +62,9 @@ def search_google_pse( return [ SearchResult( - link=result["link"], - title=result.get("title"), - snippet=result.get("snippet"), + link=result['link'], + title=result.get('title'), + snippet=result.get('snippet'), ) for result in all_results ] diff --git a/backend/open_webui/retrieval/web/jina_search.py b/backend/open_webui/retrieval/web/jina_search.py index d1168bb36f..b3266c47d0 100644 --- a/backend/open_webui/retrieval/web/jina_search.py +++ b/backend/open_webui/retrieval/web/jina_search.py @@ -7,9 +7,7 @@ from yarl import URL log = logging.getLogger(__name__) -def search_jina( - api_key: str, query: str, count: int, base_url: str = "" -) -> list[SearchResult]: +def search_jina(api_key: str, query: str, count: int, base_url: str = '') -> list[SearchResult]: """ Search using Jina's Search API and return the results as a list of SearchResult objects. Args: @@ -21,16 +19,16 @@ def search_jina( Returns: list[SearchResult]: A list of search results """ - jina_search_endpoint = base_url if base_url else "https://s.jina.ai/" + jina_search_endpoint = base_url if base_url else 'https://s.jina.ai/' headers = { - "Accept": "application/json", - "Content-Type": "application/json", - "Authorization": api_key, - "X-Retain-Images": "none", + 'Accept': 'application/json', + 'Content-Type': 'application/json', + 'Authorization': api_key, + 'X-Retain-Images': 'none', } - payload = {"q": query, "count": count if count <= 10 else 10} + payload = {'q': query, 'count': count if count <= 10 else 10} url = str(URL(jina_search_endpoint)) response = requests.post(url, headers=headers, json=payload) @@ -38,12 +36,12 @@ def search_jina( data = response.json() results = [] - for result in data["data"]: + for result in data['data']: results.append( SearchResult( - link=result["url"], - title=result.get("title"), - snippet=result.get("content"), + link=result['url'], + title=result.get('title'), + snippet=result.get('content'), ) ) diff --git a/backend/open_webui/retrieval/web/kagi.py b/backend/open_webui/retrieval/web/kagi.py index f0303acf69..e6ed570011 100644 --- a/backend/open_webui/retrieval/web/kagi.py +++ b/backend/open_webui/retrieval/web/kagi.py @@ -7,9 +7,7 @@ from open_webui.retrieval.web.main import SearchResult, get_filtered_results log = logging.getLogger(__name__) -def search_kagi( - api_key: str, query: str, count: int, filter_list: Optional[list[str]] = None -) -> list[SearchResult]: +def search_kagi(api_key: str, query: str, count: int, filter_list: Optional[list[str]] = None) -> list[SearchResult]: """Search using Kagi's Search API and return the results as a list of SearchResult objects. The Search API will inherit the settings in your account, including results personalization and snippet length. @@ -19,23 +17,21 @@ def search_kagi( query (str): The query to search for count (int): The number of results to return """ - url = "https://kagi.com/api/v0/search" + url = 'https://kagi.com/api/v0/search' headers = { - "Authorization": f"Bot {api_key}", + 'Authorization': f'Bot {api_key}', } - params = {"q": query, "limit": count} + params = {'q': query, 'limit': count} response = requests.get(url, headers=headers, params=params) response.raise_for_status() json_response = response.json() - search_results = json_response.get("data", []) + search_results = json_response.get('data', []) results = [ - SearchResult( - link=result["url"], title=result["title"], snippet=result.get("snippet") - ) + SearchResult(link=result['url'], title=result['title'], snippet=result.get('snippet')) for result in search_results - if result["t"] == 0 + if result['t'] == 0 ] print(results) diff --git a/backend/open_webui/retrieval/web/main.py b/backend/open_webui/retrieval/web/main.py index 1b8df9f8ee..3a8fed52dd 100644 --- a/backend/open_webui/retrieval/web/main.py +++ b/backend/open_webui/retrieval/web/main.py @@ -16,7 +16,7 @@ def get_filtered_results(results, filter_list): filtered_results = [] for result in results: - url = result.get("url") or result.get("link", "") or result.get("href", "") + url = result.get('url') or result.get('link', '') or result.get('href', '') if not validators.url(url): continue diff --git a/backend/open_webui/retrieval/web/mojeek.py b/backend/open_webui/retrieval/web/mojeek.py index d48f7aeef8..a094ef6fc8 100644 --- a/backend/open_webui/retrieval/web/mojeek.py +++ b/backend/open_webui/retrieval/web/mojeek.py @@ -7,32 +7,27 @@ from open_webui.retrieval.web.main import SearchResult, get_filtered_results log = logging.getLogger(__name__) -def search_mojeek( - api_key: str, query: str, count: int, filter_list: Optional[list[str]] = None -) -> list[SearchResult]: +def search_mojeek(api_key: str, query: str, count: int, filter_list: Optional[list[str]] = None) -> list[SearchResult]: """Search using Mojeek's Search API and return the results as a list of SearchResult objects. Args: api_key (str): A Mojeek Search API key query (str): The query to search for """ - url = "https://api.mojeek.com/search" + url = 'https://api.mojeek.com/search' headers = { - "Accept": "application/json", + 'Accept': 'application/json', } - params = {"q": query, "api_key": api_key, "fmt": "json", "t": count} + params = {'q': query, 'api_key': api_key, 'fmt': 'json', 't': count} response = requests.get(url, headers=headers, params=params) response.raise_for_status() json_response = response.json() - results = json_response.get("response", {}).get("results", []) + results = json_response.get('response', {}).get('results', []) print(results) if filter_list: results = get_filtered_results(results, filter_list) return [ - SearchResult( - link=result["url"], title=result.get("title"), snippet=result.get("desc") - ) - for result in results + SearchResult(link=result['url'], title=result.get('title'), snippet=result.get('desc')) for result in results ] diff --git a/backend/open_webui/retrieval/web/ollama.py b/backend/open_webui/retrieval/web/ollama.py index 71bd9d5124..7ed19b91b1 100644 --- a/backend/open_webui/retrieval/web/ollama.py +++ b/backend/open_webui/retrieval/web/ollama.py @@ -23,30 +23,30 @@ def search_ollama_cloud( count (int): Number of results to return filter_list (Optional[list[str]]): List of domains to filter results by """ - log.info(f"Searching with Ollama for query: {query}") + log.info(f'Searching with Ollama for query: {query}') - headers = {"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"} - payload = {"query": query, "max_results": count} + headers = {'Authorization': f'Bearer {api_key}', 'Content-Type': 'application/json'} + payload = {'query': query, 'max_results': count} try: - response = requests.post(f"{url}/api/web_search", headers=headers, json=payload) + response = requests.post(f'{url}/api/web_search', headers=headers, json=payload) response.raise_for_status() data = response.json() - results = data.get("results", []) - log.info(f"Found {len(results)} results") + results = data.get('results', []) + log.info(f'Found {len(results)} results') if filter_list: results = get_filtered_results(results, filter_list) return [ SearchResult( - link=result.get("url", ""), - title=result.get("title", ""), - snippet=result.get("content", ""), + link=result.get('url', ''), + title=result.get('title', ''), + snippet=result.get('content', ''), ) for result in results ] except Exception as e: - log.error(f"Error searching Ollama: {e}") + log.error(f'Error searching Ollama: {e}') return [] diff --git a/backend/open_webui/retrieval/web/perplexity.py b/backend/open_webui/retrieval/web/perplexity.py index aae802b432..8b3a9d3b08 100644 --- a/backend/open_webui/retrieval/web/perplexity.py +++ b/backend/open_webui/retrieval/web/perplexity.py @@ -5,13 +5,13 @@ import requests from open_webui.retrieval.web.main import SearchResult, get_filtered_results MODELS = Literal[ - "sonar", - "sonar-pro", - "sonar-reasoning", - "sonar-reasoning-pro", - "sonar-deep-research", + 'sonar', + 'sonar-pro', + 'sonar-reasoning', + 'sonar-reasoning-pro', + 'sonar-deep-research', ] -SEARCH_CONTEXT_USAGE_LEVELS = Literal["low", "medium", "high"] +SEARCH_CONTEXT_USAGE_LEVELS = Literal['low', 'medium', 'high'] log = logging.getLogger(__name__) @@ -22,8 +22,8 @@ def search_perplexity( query: str, count: int, filter_list: Optional[list[str]] = None, - model: MODELS = "sonar", - search_context_usage: SEARCH_CONTEXT_USAGE_LEVELS = "medium", + model: MODELS = 'sonar', + search_context_usage: SEARCH_CONTEXT_USAGE_LEVELS = 'medium', ) -> list[SearchResult]: """Search using Perplexity API and return the results as a list of SearchResult objects. @@ -38,66 +38,63 @@ def search_perplexity( """ # Handle PersistentConfig object - if hasattr(api_key, "__str__"): + if hasattr(api_key, '__str__'): api_key = str(api_key) try: - url = "https://api.perplexity.ai/chat/completions" + url = 'https://api.perplexity.ai/chat/completions' # Create payload for the API call payload = { - "model": model, - "messages": [ + 'model': model, + 'messages': [ { - "role": "system", - "content": "You are a search assistant. Provide factual information with citations.", + 'role': 'system', + 'content': 'You are a search assistant. Provide factual information with citations.', }, - {"role": "user", "content": query}, + {'role': 'user', 'content': query}, ], - "temperature": 0.2, # Lower temperature for more factual responses - "stream": False, - "web_search_options": { - "search_context_usage": search_context_usage, + 'temperature': 0.2, # Lower temperature for more factual responses + 'stream': False, + 'web_search_options': { + 'search_context_usage': search_context_usage, }, } headers = { - "Authorization": f"Bearer {api_key}", - "Content-Type": "application/json", + 'Authorization': f'Bearer {api_key}', + 'Content-Type': 'application/json', } # Make the API request - response = requests.request("POST", url, json=payload, headers=headers) + response = requests.request('POST', url, json=payload, headers=headers) # Parse the JSON response json_response = response.json() # Extract citations from the response - citations = json_response.get("citations", []) + citations = json_response.get('citations', []) # Create search results from citations results = [] for i, citation in enumerate(citations[:count]): # Extract content from the response to use as snippet - content = "" - if "choices" in json_response and json_response["choices"]: + content = '' + if 'choices' in json_response and json_response['choices']: if i == 0: - content = json_response["choices"][0]["message"]["content"] + content = json_response['choices'][0]['message']['content'] - result = {"link": citation, "title": f"Source {i+1}", "snippet": content} + result = {'link': citation, 'title': f'Source {i + 1}', 'snippet': content} results.append(result) if filter_list: - results = get_filtered_results(results, filter_list) return [ - SearchResult( - link=result["link"], title=result["title"], snippet=result["snippet"] - ) + SearchResult(link=result['link'], title=result['title'], snippet=result['snippet']) for result in results[:count] ] except Exception as e: - log.error(f"Error searching with Perplexity API: {e}") + log.error(f'Error searching with Perplexity API: {e}') return [] diff --git a/backend/open_webui/retrieval/web/perplexity_search.py b/backend/open_webui/retrieval/web/perplexity_search.py index 744a505c05..9cbec049d9 100644 --- a/backend/open_webui/retrieval/web/perplexity_search.py +++ b/backend/open_webui/retrieval/web/perplexity_search.py @@ -13,7 +13,7 @@ def search_perplexity_search( query: str, count: int, filter_list: Optional[list[str]] = None, - api_url: str = "https://api.perplexity.ai/search", + api_url: str = 'https://api.perplexity.ai/search', user=None, ) -> list[SearchResult]: """Search using Perplexity API and return the results as a list of SearchResult objects. @@ -29,10 +29,10 @@ def search_perplexity_search( """ # Handle PersistentConfig object - if hasattr(api_key, "__str__"): + if hasattr(api_key, '__str__'): api_key = str(api_key) - if hasattr(api_url, "__str__"): + if hasattr(api_url, '__str__'): api_url = str(api_url) try: @@ -40,13 +40,13 @@ def search_perplexity_search( # Create payload for the API call payload = { - "query": query, - "max_results": count, + 'query': query, + 'max_results': count, } headers = { - "Authorization": f"Bearer {api_key}", - "Content-Type": "application/json", + 'Authorization': f'Bearer {api_key}', + 'Content-Type': 'application/json', } # Forward user info headers if user is provided @@ -54,20 +54,17 @@ def search_perplexity_search( headers = include_user_info_headers(headers, user) # Make the API request - response = requests.request("POST", url, json=payload, headers=headers) + response = requests.request('POST', url, json=payload, headers=headers) # Parse the JSON response json_response = response.json() # Extract citations from the response - results = json_response.get("results", []) + results = json_response.get('results', []) return [ - SearchResult( - link=result["url"], title=result["title"], snippet=result["snippet"] - ) - for result in results + SearchResult(link=result['url'], title=result['title'], snippet=result['snippet']) for result in results ] except Exception as e: - log.error(f"Error searching with Perplexity Search API: {e}") + log.error(f'Error searching with Perplexity Search API: {e}') return [] diff --git a/backend/open_webui/retrieval/web/searchapi.py b/backend/open_webui/retrieval/web/searchapi.py index caf781c5df..855269ef02 100644 --- a/backend/open_webui/retrieval/web/searchapi.py +++ b/backend/open_webui/retrieval/web/searchapi.py @@ -21,28 +21,26 @@ def search_searchapi( api_key (str): A searchapi.io API key query (str): The query to search for """ - url = "https://www.searchapi.io/api/v1/search" + url = 'https://www.searchapi.io/api/v1/search' - engine = engine or "google" + engine = engine or 'google' - payload = {"engine": engine, "q": query, "api_key": api_key} + payload = {'engine': engine, 'q': query, 'api_key': api_key} - url = f"{url}?{urlencode(payload)}" - response = requests.request("GET", url) + url = f'{url}?{urlencode(payload)}' + response = requests.request('GET', url) json_response = response.json() - log.info(f"results from searchapi search: {json_response}") + log.info(f'results from searchapi search: {json_response}') - results = sorted( - json_response.get("organic_results", []), key=lambda x: x.get("position", 0) - ) + results = sorted(json_response.get('organic_results', []), key=lambda x: x.get('position', 0)) if filter_list: results = get_filtered_results(results, filter_list) return [ SearchResult( - link=result["link"], - title=result.get("title"), - snippet=result.get("snippet"), + link=result['link'], + title=result.get('title'), + snippet=result.get('snippet'), ) for result in results[:count] ] diff --git a/backend/open_webui/retrieval/web/searxng.py b/backend/open_webui/retrieval/web/searxng.py index b3d4eb8795..0335bea9a3 100644 --- a/backend/open_webui/retrieval/web/searxng.py +++ b/backend/open_webui/retrieval/web/searxng.py @@ -38,38 +38,38 @@ def search_searxng( """ # Default values for optional parameters are provided as empty strings or None when not specified. - language = kwargs.get("language", "all") - safesearch = kwargs.get("safesearch", "1") - time_range = kwargs.get("time_range", "") - categories = "".join(kwargs.get("categories", [])) + language = kwargs.get('language', 'all') + safesearch = kwargs.get('safesearch', '1') + time_range = kwargs.get('time_range', '') + categories = ''.join(kwargs.get('categories', [])) params = { - "q": query, - "format": "json", - "pageno": 1, - "safesearch": safesearch, - "language": language, - "time_range": time_range, - "categories": categories, - "theme": "simple", - "image_proxy": 0, + 'q': query, + 'format': 'json', + 'pageno': 1, + 'safesearch': safesearch, + 'language': language, + 'time_range': time_range, + 'categories': categories, + 'theme': 'simple', + 'image_proxy': 0, } # Legacy query format - if "" in query_url: + if '' in query_url: # Strip all query parameters from the URL - query_url = query_url.split("?")[0] + query_url = query_url.split('?')[0] - log.debug(f"searching {query_url}") + log.debug(f'searching {query_url}') response = requests.get( query_url, headers={ - "User-Agent": "Open WebUI (https://github.com/open-webui/open-webui) RAG Bot", - "Accept": "text/html", - "Accept-Encoding": "gzip, deflate", - "Accept-Language": "en-US,en;q=0.5", - "Connection": "keep-alive", + 'User-Agent': 'Open WebUI (https://github.com/open-webui/open-webui) RAG Bot', + 'Accept': 'text/html', + 'Accept-Encoding': 'gzip, deflate', + 'Accept-Language': 'en-US,en;q=0.5', + 'Connection': 'keep-alive', }, params=params, ) @@ -77,13 +77,11 @@ def search_searxng( response.raise_for_status() # Raise an exception for HTTP errors. json_response = response.json() - results = json_response.get("results", []) - sorted_results = sorted(results, key=lambda x: x.get("score", 0), reverse=True) + results = json_response.get('results', []) + sorted_results = sorted(results, key=lambda x: x.get('score', 0), reverse=True) if filter_list: sorted_results = get_filtered_results(sorted_results, filter_list) return [ - SearchResult( - link=result["url"], title=result.get("title"), snippet=result.get("content") - ) + SearchResult(link=result['url'], title=result.get('title'), snippet=result.get('content')) for result in sorted_results[:count] ] diff --git a/backend/open_webui/retrieval/web/serpapi.py b/backend/open_webui/retrieval/web/serpapi.py index bb421b500f..602f60d7a7 100644 --- a/backend/open_webui/retrieval/web/serpapi.py +++ b/backend/open_webui/retrieval/web/serpapi.py @@ -21,28 +21,26 @@ def search_serpapi( api_key (str): A serpapi.com API key query (str): The query to search for """ - url = "https://serpapi.com/search" + url = 'https://serpapi.com/search' - engine = engine or "google" + engine = engine or 'google' - payload = {"engine": engine, "q": query, "api_key": api_key} + payload = {'engine': engine, 'q': query, 'api_key': api_key} - url = f"{url}?{urlencode(payload)}" - response = requests.request("GET", url) + url = f'{url}?{urlencode(payload)}' + response = requests.request('GET', url) json_response = response.json() - log.info(f"results from serpapi search: {json_response}") + log.info(f'results from serpapi search: {json_response}') - results = sorted( - json_response.get("organic_results", []), key=lambda x: x.get("position", 0) - ) + results = sorted(json_response.get('organic_results', []), key=lambda x: x.get('position', 0)) if filter_list: results = get_filtered_results(results, filter_list) return [ SearchResult( - link=result["link"], - title=result.get("title"), - snippet=result.get("snippet"), + link=result['link'], + title=result.get('title'), + snippet=result.get('snippet'), ) for result in results[:count] ] diff --git a/backend/open_webui/retrieval/web/serper.py b/backend/open_webui/retrieval/web/serper.py index 5a745e304e..98ab40cb19 100644 --- a/backend/open_webui/retrieval/web/serper.py +++ b/backend/open_webui/retrieval/web/serper.py @@ -8,34 +8,30 @@ from open_webui.retrieval.web.main import SearchResult, get_filtered_results log = logging.getLogger(__name__) -def search_serper( - api_key: str, query: str, count: int, filter_list: Optional[list[str]] = None -) -> list[SearchResult]: +def search_serper(api_key: str, query: str, count: int, filter_list: Optional[list[str]] = None) -> list[SearchResult]: """Search using serper.dev's API and return the results as a list of SearchResult objects. Args: api_key (str): A serper.dev API key query (str): The query to search for """ - url = "https://google.serper.dev/search" + url = 'https://google.serper.dev/search' - payload = json.dumps({"q": query}) - headers = {"X-API-KEY": api_key, "Content-Type": "application/json"} + payload = json.dumps({'q': query}) + headers = {'X-API-KEY': api_key, 'Content-Type': 'application/json'} - response = requests.request("POST", url, headers=headers, data=payload) + response = requests.request('POST', url, headers=headers, data=payload) response.raise_for_status() json_response = response.json() - results = sorted( - json_response.get("organic", []), key=lambda x: x.get("position", 0) - ) + results = sorted(json_response.get('organic', []), key=lambda x: x.get('position', 0)) if filter_list: results = get_filtered_results(results, filter_list) return [ SearchResult( - link=result["link"], - title=result.get("title"), - snippet=result.get("description"), + link=result['link'], + title=result.get('title'), + snippet=result.get('description'), ) for result in results[:count] ] diff --git a/backend/open_webui/retrieval/web/serply.py b/backend/open_webui/retrieval/web/serply.py index 68843eba85..f245392b75 100644 --- a/backend/open_webui/retrieval/web/serply.py +++ b/backend/open_webui/retrieval/web/serply.py @@ -12,10 +12,10 @@ def search_serply( api_key: str, query: str, count: int, - hl: str = "us", + hl: str = 'us', limit: int = 10, - device_type: str = "desktop", - proxy_location: str = "US", + device_type: str = 'desktop', + proxy_location: str = 'US', filter_list: Optional[list[str]] = None, ) -> list[SearchResult]: """Search using serper.dev's API and return the results as a list of SearchResult objects. @@ -26,42 +26,40 @@ def search_serply( hl (str): Host Language code to display results in (reference https://developers.google.com/custom-search/docs/xml_results?hl=en#wsInterfaceLanguages) limit (int): The maximum number of results to return [10-100, defaults to 10] """ - log.info("Searching with Serply") + log.info('Searching with Serply') - url = "https://api.serply.io/v1/search/" + url = 'https://api.serply.io/v1/search/' query_payload = { - "q": query, - "language": "en", - "num": limit, - "gl": proxy_location.upper(), - "hl": hl.lower(), + 'q': query, + 'language': 'en', + 'num': limit, + 'gl': proxy_location.upper(), + 'hl': hl.lower(), } - url = f"{url}{urlencode(query_payload)}" + url = f'{url}{urlencode(query_payload)}' headers = { - "X-API-KEY": api_key, - "X-User-Agent": device_type, - "User-Agent": "open-webui", - "X-Proxy-Location": proxy_location, + 'X-API-KEY': api_key, + 'X-User-Agent': device_type, + 'User-Agent': 'open-webui', + 'X-Proxy-Location': proxy_location, } - response = requests.request("GET", url, headers=headers) + response = requests.request('GET', url, headers=headers) response.raise_for_status() json_response = response.json() - log.info(f"results from serply search: {json_response}") + log.info(f'results from serply search: {json_response}') - results = sorted( - json_response.get("results", []), key=lambda x: x.get("realPosition", 0) - ) + results = sorted(json_response.get('results', []), key=lambda x: x.get('realPosition', 0)) if filter_list: results = get_filtered_results(results, filter_list) return [ SearchResult( - link=result["link"], - title=result.get("title"), - snippet=result.get("description"), + link=result['link'], + title=result.get('title'), + snippet=result.get('description'), ) for result in results[:count] ] diff --git a/backend/open_webui/retrieval/web/serpstack.py b/backend/open_webui/retrieval/web/serpstack.py index 97db858724..28a4956645 100644 --- a/backend/open_webui/retrieval/web/serpstack.py +++ b/backend/open_webui/retrieval/web/serpstack.py @@ -21,26 +21,22 @@ def search_serpstack( query (str): The query to search for https_enabled (bool): Whether to use HTTPS or HTTP for the API request """ - url = f"{'https' if https_enabled else 'http'}://api.serpstack.com/search" + url = f'{"https" if https_enabled else "http"}://api.serpstack.com/search' - headers = {"Content-Type": "application/json"} + headers = {'Content-Type': 'application/json'} params = { - "access_key": api_key, - "query": query, + 'access_key': api_key, + 'query': query, } - response = requests.request("POST", url, headers=headers, params=params) + response = requests.request('POST', url, headers=headers, params=params) response.raise_for_status() json_response = response.json() - results = sorted( - json_response.get("organic_results", []), key=lambda x: x.get("position", 0) - ) + results = sorted(json_response.get('organic_results', []), key=lambda x: x.get('position', 0)) if filter_list: results = get_filtered_results(results, filter_list) return [ - SearchResult( - link=result["url"], title=result.get("title"), snippet=result.get("snippet") - ) + SearchResult(link=result['url'], title=result.get('title'), snippet=result.get('snippet')) for result in results[:count] ] diff --git a/backend/open_webui/retrieval/web/sougou.py b/backend/open_webui/retrieval/web/sougou.py index d8747c3ade..b267374d79 100644 --- a/backend/open_webui/retrieval/web/sougou.py +++ b/backend/open_webui/retrieval/web/sougou.py @@ -26,33 +26,26 @@ def search_sougou( try: cred = credential.Credential(sougou_api_sid, sougou_api_sk) http_profile = HttpProfile() - http_profile.endpoint = "tms.tencentcloudapi.com" + http_profile.endpoint = 'tms.tencentcloudapi.com' client_profile = ClientProfile() client_profile.http_profile = http_profile - params = json.dumps({"Query": query, "Cnt": 20}) - common_client = CommonClient( - "tms", "2020-12-29", cred, "", profile=client_profile - ) + params = json.dumps({'Query': query, 'Cnt': 20}) + common_client = CommonClient('tms', '2020-12-29', cred, '', profile=client_profile) results = [ - json.loads(page) - for page in common_client.call_json("SearchPro", json.loads(params))[ - "Response" - ]["Pages"] + json.loads(page) for page in common_client.call_json('SearchPro', json.loads(params))['Response']['Pages'] ] - sorted_results = sorted( - results, key=lambda x: x.get("scour", 0.0), reverse=True - ) + sorted_results = sorted(results, key=lambda x: x.get('scour', 0.0), reverse=True) if filter_list: sorted_results = get_filtered_results(sorted_results, filter_list) return [ SearchResult( - link=result.get("url"), - title=result.get("title"), - snippet=result.get("passage"), + link=result.get('url'), + title=result.get('title'), + snippet=result.get('passage'), ) for result in sorted_results[:count] ] except TencentCloudSDKException as err: - log.error(f"Error in Sougou search: {err}") + log.error(f'Error in Sougou search: {err}') return [] diff --git a/backend/open_webui/retrieval/web/tavily.py b/backend/open_webui/retrieval/web/tavily.py index 6d9ff89a87..6b52bbb45b 100644 --- a/backend/open_webui/retrieval/web/tavily.py +++ b/backend/open_webui/retrieval/web/tavily.py @@ -24,26 +24,26 @@ def search_tavily( Returns: list[SearchResult]: A list of search results """ - url = "https://api.tavily.com/search" + url = 'https://api.tavily.com/search' headers = { - "Content-Type": "application/json", - "Authorization": f"Bearer {api_key}", + 'Content-Type': 'application/json', + 'Authorization': f'Bearer {api_key}', } - data = {"query": query, "max_results": count} + data = {'query': query, 'max_results': count} response = requests.post(url, headers=headers, json=data) response.raise_for_status() json_response = response.json() - results = json_response.get("results", []) + results = json_response.get('results', []) if filter_list: results = get_filtered_results(results, filter_list) return [ SearchResult( - link=result["url"], - title=result.get("title", ""), - snippet=result.get("content"), + link=result['url'], + title=result.get('title', ''), + snippet=result.get('content'), ) for result in results ] diff --git a/backend/open_webui/retrieval/web/utils.py b/backend/open_webui/retrieval/web/utils.py index 45787cb4bd..c9442f208b 100644 --- a/backend/open_webui/retrieval/web/utils.py +++ b/backend/open_webui/retrieval/web/utils.py @@ -67,16 +67,14 @@ def validate_url(url: Union[str, Sequence[str]]): parsed_url = urllib.parse.urlparse(url) # Protocol validation - only allow http/https - if parsed_url.scheme not in ["http", "https"]: - log.warning( - f"Blocked non-HTTP(S) protocol: {parsed_url.scheme} in URL: {url}" - ) + if parsed_url.scheme not in ['http', 'https']: + log.warning(f'Blocked non-HTTP(S) protocol: {parsed_url.scheme} in URL: {url}') raise ValueError(ERROR_MESSAGES.INVALID_URL) # Blocklist check using unified filtering logic if WEB_FETCH_FILTER_LIST: if not is_string_allowed(url, WEB_FETCH_FILTER_LIST): - log.warning(f"URL blocked by filter list: {url}") + log.warning(f'URL blocked by filter list: {url}') raise ValueError(ERROR_MESSAGES.INVALID_URL) if not ENABLE_RAG_LOCAL_WEB_FETCH: @@ -106,29 +104,29 @@ def safe_validate_urls(url: Sequence[str]) -> Sequence[str]: if validate_url(u): valid_urls.append(u) except Exception as e: - log.debug(f"Invalid URL {u}: {str(e)}") + log.debug(f'Invalid URL {u}: {str(e)}') continue return valid_urls def extract_metadata(soup, url): - metadata = {"source": url} - if title := soup.find("title"): - metadata["title"] = title.get_text() - if description := soup.find("meta", attrs={"name": "description"}): - metadata["description"] = description.get("content", "No description found.") - if html := soup.find("html"): - metadata["language"] = html.get("lang", "No language found.") + metadata = {'source': url} + if title := soup.find('title'): + metadata['title'] = title.get_text() + if description := soup.find('meta', attrs={'name': 'description'}): + metadata['description'] = description.get('content', 'No description found.') + if html := soup.find('html'): + metadata['language'] = html.get('lang', 'No language found.') return metadata def verify_ssl_cert(url: str) -> bool: """Verify SSL certificate for the given URL.""" - if not url.startswith("https://"): + if not url.startswith('https://'): return True try: - hostname = url.split("://")[-1].split("/")[0] + hostname = url.split('://')[-1].split('/')[0] context = ssl.create_default_context(cafile=certifi.where()) with context.wrap_socket(ssl.socket(), server_hostname=hostname) as s: s.connect((hostname, 443)) @@ -136,7 +134,7 @@ def verify_ssl_cert(url: str) -> bool: except ssl.SSLError: return False except Exception as e: - log.warning(f"SSL verification failed for {url}: {str(e)}") + log.warning(f'SSL verification failed for {url}: {str(e)}') return False @@ -168,14 +166,14 @@ class URLProcessingMixin: async def _safe_process_url(self, url: str) -> bool: """Perform safety checks before processing a URL.""" if self.verify_ssl and not await self._verify_ssl_cert(url): - raise ValueError(f"SSL certificate verification failed for {url}") + raise ValueError(f'SSL certificate verification failed for {url}') await self._wait_for_rate_limit() return True def _safe_process_url_sync(self, url: str) -> bool: """Synchronous version of safety checks.""" if self.verify_ssl and not verify_ssl_cert(url): - raise ValueError(f"SSL certificate verification failed for {url}") + raise ValueError(f'SSL certificate verification failed for {url}') self._sync_wait_for_rate_limit() return True @@ -191,7 +189,7 @@ class SafeFireCrawlLoader(BaseLoader, RateLimitMixin, URLProcessingMixin): api_key: Optional[str] = None, api_url: Optional[str] = None, timeout: Optional[int] = None, - mode: Literal["crawl", "scrape", "map"] = "scrape", + mode: Literal['crawl', 'scrape', 'map'] = 'scrape', proxy: Optional[Dict[str, str]] = None, params: Optional[Dict] = None, ): @@ -216,15 +214,15 @@ class SafeFireCrawlLoader(BaseLoader, RateLimitMixin, URLProcessingMixin): params: The parameters to pass to the Firecrawl API. For more details, visit: https://docs.firecrawl.dev/sdks/python#batch-scrape """ - proxy_server = proxy.get("server") if proxy else None + proxy_server = proxy.get('server') if proxy else None if trust_env and not proxy_server: env_proxies = urllib.request.getproxies() - env_proxy_server = env_proxies.get("https") or env_proxies.get("http") + env_proxy_server = env_proxies.get('https') or env_proxies.get('http') if env_proxy_server: if proxy: - proxy["server"] = env_proxy_server + proxy['server'] = env_proxy_server else: - proxy = {"server": env_proxy_server} + proxy = {'server': env_proxy_server} self.web_paths = web_paths self.verify_ssl = verify_ssl self.requests_per_second = requests_per_second @@ -240,7 +238,7 @@ class SafeFireCrawlLoader(BaseLoader, RateLimitMixin, URLProcessingMixin): def lazy_load(self) -> Iterator[Document]: """Load documents using FireCrawl batch_scrape.""" log.debug( - "Starting FireCrawl batch scrape for %d URLs, mode: %s, params: %s", + 'Starting FireCrawl batch scrape for %d URLs, mode: %s, params: %s', len(self.web_paths), self.mode, self.params, @@ -251,7 +249,7 @@ class SafeFireCrawlLoader(BaseLoader, RateLimitMixin, URLProcessingMixin): firecrawl = FirecrawlApp(api_key=self.api_key, api_url=self.api_url) result = firecrawl.batch_scrape( self.web_paths, - formats=["markdown"], + formats=['markdown'], skip_tls_verification=not self.verify_ssl, ignore_invalid_urls=True, remove_base64_images=True, @@ -260,28 +258,26 @@ class SafeFireCrawlLoader(BaseLoader, RateLimitMixin, URLProcessingMixin): **self.params, ) - if result.status != "completed": - raise RuntimeError( - f"FireCrawl batch scrape did not complete successfully. result: {result}" - ) + if result.status != 'completed': + raise RuntimeError(f'FireCrawl batch scrape did not complete successfully. result: {result}') for data in result.data: metadata = data.metadata or {} yield Document( - page_content=data.markdown or "", - metadata={"source": metadata.url or metadata.source_url or ""}, + page_content=data.markdown or '', + metadata={'source': metadata.url or metadata.source_url or ''}, ) except Exception as e: if self.continue_on_failure: - log.exception(f"Error extracting content from URLs: {e}") + log.exception(f'Error extracting content from URLs: {e}') else: raise e async def alazy_load(self): """Async version of lazy_load.""" log.debug( - "Starting FireCrawl batch scrape for %d URLs, mode: %s, params: %s", + 'Starting FireCrawl batch scrape for %d URLs, mode: %s, params: %s', len(self.web_paths), self.mode, self.params, @@ -292,7 +288,7 @@ class SafeFireCrawlLoader(BaseLoader, RateLimitMixin, URLProcessingMixin): firecrawl = FirecrawlApp(api_key=self.api_key, api_url=self.api_url) result = firecrawl.batch_scrape( self.web_paths, - formats=["markdown"], + formats=['markdown'], skip_tls_verification=not self.verify_ssl, ignore_invalid_urls=True, remove_base64_images=True, @@ -301,21 +297,19 @@ class SafeFireCrawlLoader(BaseLoader, RateLimitMixin, URLProcessingMixin): **self.params, ) - if result.status != "completed": - raise RuntimeError( - f"FireCrawl batch scrape did not complete successfully. result: {result}" - ) + if result.status != 'completed': + raise RuntimeError(f'FireCrawl batch scrape did not complete successfully. result: {result}') for data in result.data: metadata = data.metadata or {} yield Document( - page_content=data.markdown or "", - metadata={"source": metadata.url or metadata.source_url or ""}, + page_content=data.markdown or '', + metadata={'source': metadata.url or metadata.source_url or ''}, ) except Exception as e: if self.continue_on_failure: - log.exception(f"Error extracting content from URLs: {e}") + log.exception(f'Error extracting content from URLs: {e}') else: raise e @@ -325,7 +319,7 @@ class SafeTavilyLoader(BaseLoader, RateLimitMixin, URLProcessingMixin): self, web_paths: Union[str, List[str]], api_key: str, - extract_depth: Literal["basic", "advanced"] = "basic", + extract_depth: Literal['basic', 'advanced'] = 'basic', continue_on_failure: bool = True, requests_per_second: Optional[float] = None, verify_ssl: bool = True, @@ -345,15 +339,15 @@ class SafeTavilyLoader(BaseLoader, RateLimitMixin, URLProcessingMixin): proxy: Optional proxy configuration. """ # Initialize proxy configuration if using environment variables - proxy_server = proxy.get("server") if proxy else None + proxy_server = proxy.get('server') if proxy else None if trust_env and not proxy_server: env_proxies = urllib.request.getproxies() - env_proxy_server = env_proxies.get("https") or env_proxies.get("http") + env_proxy_server = env_proxies.get('https') or env_proxies.get('http') if env_proxy_server: if proxy: - proxy["server"] = env_proxy_server + proxy['server'] = env_proxy_server else: - proxy = {"server": env_proxy_server} + proxy = {'server': env_proxy_server} # Store parameters for creating TavilyLoader instances self.web_paths = web_paths if isinstance(web_paths, list) else [web_paths] @@ -376,14 +370,14 @@ class SafeTavilyLoader(BaseLoader, RateLimitMixin, URLProcessingMixin): self._safe_process_url_sync(url) valid_urls.append(url) except Exception as e: - log.warning(f"SSL verification failed for {url}: {str(e)}") + log.warning(f'SSL verification failed for {url}: {str(e)}') if not self.continue_on_failure: raise e if not valid_urls: if self.continue_on_failure: - log.warning("No valid URLs to process after SSL verification") + log.warning('No valid URLs to process after SSL verification') return - raise ValueError("No valid URLs to process after SSL verification") + raise ValueError('No valid URLs to process after SSL verification') try: loader = TavilyLoader( urls=valid_urls, @@ -394,7 +388,7 @@ class SafeTavilyLoader(BaseLoader, RateLimitMixin, URLProcessingMixin): yield from loader.lazy_load() except Exception as e: if self.continue_on_failure: - log.exception(f"Error extracting content from URLs: {e}") + log.exception(f'Error extracting content from URLs: {e}') else: raise e @@ -406,15 +400,15 @@ class SafeTavilyLoader(BaseLoader, RateLimitMixin, URLProcessingMixin): await self._safe_process_url(url) valid_urls.append(url) except Exception as e: - log.warning(f"SSL verification failed for {url}: {str(e)}") + log.warning(f'SSL verification failed for {url}: {str(e)}') if not self.continue_on_failure: raise e if not valid_urls: if self.continue_on_failure: - log.warning("No valid URLs to process after SSL verification") + log.warning('No valid URLs to process after SSL verification') return - raise ValueError("No valid URLs to process after SSL verification") + raise ValueError('No valid URLs to process after SSL verification') try: loader = TavilyLoader( @@ -427,7 +421,7 @@ class SafeTavilyLoader(BaseLoader, RateLimitMixin, URLProcessingMixin): yield document except Exception as e: if self.continue_on_failure: - log.exception(f"Error loading URLs: {e}") + log.exception(f'Error loading URLs: {e}') else: raise e @@ -462,15 +456,15 @@ class SafePlaywrightURLLoader(PlaywrightURLLoader, RateLimitMixin, URLProcessing ): """Initialize with additional safety parameters and remote browser support.""" - proxy_server = proxy.get("server") if proxy else None + proxy_server = proxy.get('server') if proxy else None if trust_env and not proxy_server: env_proxies = urllib.request.getproxies() - env_proxy_server = env_proxies.get("https") or env_proxies.get("http") + env_proxy_server = env_proxies.get('https') or env_proxies.get('http') if env_proxy_server: if proxy: - proxy["server"] = env_proxy_server + proxy['server'] = env_proxy_server else: - proxy = {"server": env_proxy_server} + proxy = {'server': env_proxy_server} # We'll set headless to False if using playwright_ws_url since it's handled by the remote browser super().__init__( @@ -504,14 +498,14 @@ class SafePlaywrightURLLoader(PlaywrightURLLoader, RateLimitMixin, URLProcessing page = browser.new_page() response = page.goto(url, timeout=self.playwright_timeout) if response is None: - raise ValueError(f"page.goto() returned None for url {url}") + raise ValueError(f'page.goto() returned None for url {url}') text = self.evaluator.evaluate(page, browser, response) - metadata = {"source": url} + metadata = {'source': url} yield Document(page_content=text, metadata=metadata) except Exception as e: if self.continue_on_failure: - log.exception(f"Error loading {url}: {e}") + log.exception(f'Error loading {url}: {e}') continue raise e browser.close() @@ -525,9 +519,7 @@ class SafePlaywrightURLLoader(PlaywrightURLLoader, RateLimitMixin, URLProcessing if self.playwright_ws_url: browser = await p.chromium.connect(self.playwright_ws_url) else: - browser = await p.chromium.launch( - headless=self.headless, proxy=self.proxy - ) + browser = await p.chromium.launch(headless=self.headless, proxy=self.proxy) for url in self.urls: try: @@ -535,14 +527,14 @@ class SafePlaywrightURLLoader(PlaywrightURLLoader, RateLimitMixin, URLProcessing page = await browser.new_page() response = await page.goto(url, timeout=self.playwright_timeout) if response is None: - raise ValueError(f"page.goto() returned None for url {url}") + raise ValueError(f'page.goto() returned None for url {url}') text = await self.evaluator.evaluate_async(page, browser, response) - metadata = {"source": url} + metadata = {'source': url} yield Document(page_content=text, metadata=metadata) except Exception as e: if self.continue_on_failure: - log.exception(f"Error loading {url}: {e}") + log.exception(f'Error loading {url}: {e}') continue raise e await browser.close() @@ -560,9 +552,7 @@ class SafeWebBaseLoader(WebBaseLoader): super().__init__(*args, **kwargs) self.trust_env = trust_env - async def _fetch( - self, url: str, retries: int = 3, cooldown: int = 2, backoff: float = 1.5 - ) -> str: + async def _fetch(self, url: str, retries: int = 3, cooldown: int = 2, backoff: float = 1.5) -> str: async with aiohttp.ClientSession(trust_env=self.trust_env) as session: for i in range(retries): try: @@ -571,7 +561,7 @@ class SafeWebBaseLoader(WebBaseLoader): cookies=self.session.cookies.get_dict(), ) if not self.session.verify: - kwargs["ssl"] = False + kwargs['ssl'] = False async with session.get( url, @@ -585,16 +575,11 @@ class SafeWebBaseLoader(WebBaseLoader): if i == retries - 1: raise else: - log.warning( - f"Error fetching {url} with attempt " - f"{i + 1}/{retries}: {e}. Retrying..." - ) + log.warning(f'Error fetching {url} with attempt {i + 1}/{retries}: {e}. Retrying...') await asyncio.sleep(cooldown * backoff**i) - raise ValueError("retry count exceeded") + raise ValueError('retry count exceeded') - def _unpack_fetch_results( - self, results: Any, urls: List[str], parser: Union[str, None] = None - ) -> List[Any]: + def _unpack_fetch_results(self, results: Any, urls: List[str], parser: Union[str, None] = None) -> List[Any]: """Unpack fetch results into BeautifulSoup objects.""" from bs4 import BeautifulSoup @@ -602,17 +587,15 @@ class SafeWebBaseLoader(WebBaseLoader): for i, result in enumerate(results): url = urls[i] if parser is None: - if url.endswith(".xml"): - parser = "xml" + if url.endswith('.xml'): + parser = 'xml' else: parser = self.default_parser self._check_parser(parser) final_results.append(BeautifulSoup(result, parser, **self.bs_kwargs)) return final_results - async def ascrape_all( - self, urls: List[str], parser: Union[str, None] = None - ) -> List[Any]: + async def ascrape_all(self, urls: List[str], parser: Union[str, None] = None) -> List[Any]: """Async fetch all urls, then return soups for all results.""" results = await self.fetch_all(urls) return self._unpack_fetch_results(results, urls, parser=parser) @@ -630,22 +613,20 @@ class SafeWebBaseLoader(WebBaseLoader): yield Document(page_content=text, metadata=metadata) except Exception as e: # Log the error and continue with the next URL - log.exception(f"Error loading {path}: {e}") + log.exception(f'Error loading {path}: {e}') async def alazy_load(self) -> AsyncIterator[Document]: """Async lazy load text from the url(s) in web_path.""" results = await self.ascrape_all(self.web_paths) for path, soup in zip(self.web_paths, results): text = soup.get_text(**self.bs_get_text_kwargs) - metadata = {"source": path} - if title := soup.find("title"): - metadata["title"] = title.get_text() - if description := soup.find("meta", attrs={"name": "description"}): - metadata["description"] = description.get( - "content", "No description found." - ) - if html := soup.find("html"): - metadata["language"] = html.get("lang", "No language found.") + metadata = {'source': path} + if title := soup.find('title'): + metadata['title'] = title.get_text() + if description := soup.find('meta', attrs={'name': 'description'}): + metadata['description'] = description.get('content', 'No description found.') + if html := soup.find('html'): + metadata['language'] = html.get('lang', 'No language found.') yield Document(page_content=text, metadata=metadata) async def aload(self) -> list[Document]: @@ -663,18 +644,18 @@ def get_web_loader( safe_urls = safe_validate_urls([urls] if isinstance(urls, str) else urls) if not safe_urls: - log.warning(f"All provided URLs were blocked or invalid: {urls}") + log.warning(f'All provided URLs were blocked or invalid: {urls}') raise ValueError(ERROR_MESSAGES.INVALID_URL) web_loader_args = { - "web_paths": safe_urls, - "verify_ssl": verify_ssl, - "requests_per_second": requests_per_second, - "continue_on_failure": True, - "trust_env": trust_env, + 'web_paths': safe_urls, + 'verify_ssl': verify_ssl, + 'requests_per_second': requests_per_second, + 'continue_on_failure': True, + 'trust_env': trust_env, } - if WEB_LOADER_ENGINE.value == "" or WEB_LOADER_ENGINE.value == "safe_web": + if WEB_LOADER_ENGINE.value == '' or WEB_LOADER_ENGINE.value == 'safe_web': WebLoaderClass = SafeWebBaseLoader request_kwargs = {} @@ -685,42 +666,42 @@ def get_web_loader( timeout_value = None if timeout_value: - request_kwargs["timeout"] = timeout_value + request_kwargs['timeout'] = timeout_value if request_kwargs: - web_loader_args["requests_kwargs"] = request_kwargs + web_loader_args['requests_kwargs'] = request_kwargs - if WEB_LOADER_ENGINE.value == "playwright": + if WEB_LOADER_ENGINE.value == 'playwright': WebLoaderClass = SafePlaywrightURLLoader - web_loader_args["playwright_timeout"] = PLAYWRIGHT_TIMEOUT.value + web_loader_args['playwright_timeout'] = PLAYWRIGHT_TIMEOUT.value if PLAYWRIGHT_WS_URL.value: - web_loader_args["playwright_ws_url"] = PLAYWRIGHT_WS_URL.value + web_loader_args['playwright_ws_url'] = PLAYWRIGHT_WS_URL.value - if WEB_LOADER_ENGINE.value == "firecrawl": + if WEB_LOADER_ENGINE.value == 'firecrawl': WebLoaderClass = SafeFireCrawlLoader - web_loader_args["api_key"] = FIRECRAWL_API_KEY.value - web_loader_args["api_url"] = FIRECRAWL_API_BASE_URL.value + web_loader_args['api_key'] = FIRECRAWL_API_KEY.value + web_loader_args['api_url'] = FIRECRAWL_API_BASE_URL.value if FIRECRAWL_TIMEOUT.value: try: - web_loader_args["timeout"] = int(FIRECRAWL_TIMEOUT.value) + web_loader_args['timeout'] = int(FIRECRAWL_TIMEOUT.value) except ValueError: pass - if WEB_LOADER_ENGINE.value == "tavily": + if WEB_LOADER_ENGINE.value == 'tavily': WebLoaderClass = SafeTavilyLoader - web_loader_args["api_key"] = TAVILY_API_KEY.value - web_loader_args["extract_depth"] = TAVILY_EXTRACT_DEPTH.value + web_loader_args['api_key'] = TAVILY_API_KEY.value + web_loader_args['extract_depth'] = TAVILY_EXTRACT_DEPTH.value - if WEB_LOADER_ENGINE.value == "external": + if WEB_LOADER_ENGINE.value == 'external': WebLoaderClass = ExternalWebLoader - web_loader_args["external_url"] = EXTERNAL_WEB_LOADER_URL.value - web_loader_args["external_api_key"] = EXTERNAL_WEB_LOADER_API_KEY.value + web_loader_args['external_url'] = EXTERNAL_WEB_LOADER_URL.value + web_loader_args['external_api_key'] = EXTERNAL_WEB_LOADER_API_KEY.value if WebLoaderClass: web_loader = WebLoaderClass(**web_loader_args) log.debug( - "Using WEB_LOADER_ENGINE %s for %s URLs", + 'Using WEB_LOADER_ENGINE %s for %s URLs', web_loader.__class__.__name__, len(safe_urls), ) @@ -728,6 +709,6 @@ def get_web_loader( return web_loader else: raise ValueError( - f"Invalid WEB_LOADER_ENGINE: {WEB_LOADER_ENGINE.value}. " + f'Invalid WEB_LOADER_ENGINE: {WEB_LOADER_ENGINE.value}. ' "Please set it to 'safe_web', 'playwright', 'firecrawl', or 'tavily'." ) diff --git a/backend/open_webui/retrieval/web/yacy.py b/backend/open_webui/retrieval/web/yacy.py index 2419717b24..32ca04f531 100644 --- a/backend/open_webui/retrieval/web/yacy.py +++ b/backend/open_webui/retrieval/web/yacy.py @@ -41,29 +41,29 @@ def search_yacy( yacy_auth = HTTPDigestAuth(username, password) params = { - "query": query, - "contentdom": "text", - "resource": "global", - "maximumRecords": count, - "nav": "none", + 'query': query, + 'contentdom': 'text', + 'resource': 'global', + 'maximumRecords': count, + 'nav': 'none', } # Check if provided a json API URL - if not query_url.endswith("yacysearch.json"): + if not query_url.endswith('yacysearch.json'): # Strip all query parameters from the URL - query_url = query_url.rstrip("/") + "/yacysearch.json" + query_url = query_url.rstrip('/') + '/yacysearch.json' - log.debug(f"searching {query_url}") + log.debug(f'searching {query_url}') response = requests.get( query_url, auth=yacy_auth, headers={ - "User-Agent": "Open WebUI (https://github.com/open-webui/open-webui) RAG Bot", - "Accept": "text/html", - "Accept-Encoding": "gzip, deflate", - "Accept-Language": "en-US,en;q=0.5", - "Connection": "keep-alive", + 'User-Agent': 'Open WebUI (https://github.com/open-webui/open-webui) RAG Bot', + 'Accept': 'text/html', + 'Accept-Encoding': 'gzip, deflate', + 'Accept-Language': 'en-US,en;q=0.5', + 'Connection': 'keep-alive', }, params=params, ) @@ -71,15 +71,15 @@ def search_yacy( response.raise_for_status() # Raise an exception for HTTP errors. json_response = response.json() - results = json_response.get("channels", [{}])[0].get("items", []) - sorted_results = sorted(results, key=lambda x: x.get("ranking", 0), reverse=True) + results = json_response.get('channels', [{}])[0].get('items', []) + sorted_results = sorted(results, key=lambda x: x.get('ranking', 0), reverse=True) if filter_list: sorted_results = get_filtered_results(sorted_results, filter_list) return [ SearchResult( - link=result["link"], - title=result.get("title"), - snippet=result.get("description"), + link=result['link'], + title=result.get('title'), + snippet=result.get('description'), ) for result in sorted_results[:count] ] diff --git a/backend/open_webui/retrieval/web/yandex.py b/backend/open_webui/retrieval/web/yandex.py index fba4ee482e..352d2a3afb 100644 --- a/backend/open_webui/retrieval/web/yandex.py +++ b/backend/open_webui/retrieval/web/yandex.py @@ -20,14 +20,14 @@ log = logging.getLogger(__name__) def xml_element_contents_to_string(element: Element) -> str: - buffer = [element.text if element.text else ""] + buffer = [element.text if element.text else ''] for child in element: buffer.append(xml_element_contents_to_string(child)) - buffer.append(element.tail if element.tail else "") + buffer.append(element.tail if element.tail else '') - return "".join(buffer) + return ''.join(buffer) def search_yandex( @@ -42,42 +42,38 @@ def search_yandex( ) -> List[SearchResult]: try: headers = { - "User-Agent": "Open WebUI (https://github.com/open-webui/open-webui) RAG Bot", - "Authorization": f"Api-Key {yandex_search_api_key}", + 'User-Agent': 'Open WebUI (https://github.com/open-webui/open-webui) RAG Bot', + 'Authorization': f'Api-Key {yandex_search_api_key}', } if user is not None: headers = include_user_info_headers(headers, user) - chat_id = getattr(request.state, "chat_id", None) + chat_id = getattr(request.state, 'chat_id', None) if chat_id: headers[FORWARD_SESSION_INFO_HEADER_CHAT_ID] = str(chat_id) - payload = {} if yandex_search_config == "" else json.loads(yandex_search_config) + payload = {} if yandex_search_config == '' else json.loads(yandex_search_config) - if type(payload.get("query", None)) != dict: - payload["query"] = {} + if type(payload.get('query', None)) != dict: + payload['query'] = {} - if "searchType" not in payload["query"]: - payload["query"]["searchType"] = "SEARCH_TYPE_RU" + if 'searchType' not in payload['query']: + payload['query']['searchType'] = 'SEARCH_TYPE_RU' - payload["query"]["queryText"] = query + payload['query']['queryText'] = query - if type(payload.get("groupSpec", None)) != dict: - payload["groupSpec"] = {} + if type(payload.get('groupSpec', None)) != dict: + payload['groupSpec'] = {} - if "groupMode" not in payload["groupSpec"]: - payload["groupSpec"]["groupMode"] = "GROUP_MODE_DEEP" + if 'groupMode' not in payload['groupSpec']: + payload['groupSpec']['groupMode'] = 'GROUP_MODE_DEEP' - payload["groupSpec"]["groupsOnPage"] = count - payload["groupSpec"]["docsInGroup"] = 1 + payload['groupSpec']['groupsOnPage'] = count + payload['groupSpec']['docsInGroup'] = 1 response = requests.post( - ( - "https://searchapi.api.cloud.yandex.net/v2/web/search" - if yandex_search_url == "" - else yandex_search_url - ), + ('https://searchapi.api.cloud.yandex.net/v2/web/search' if yandex_search_url == '' else yandex_search_url), headers=headers, json=payload, ) @@ -85,29 +81,21 @@ def search_yandex( response.raise_for_status() response_body = response.json() - if "rawData" not in response_body: - raise Exception(f"No `rawData` in response body: {response_body}") + if 'rawData' not in response_body: + raise Exception(f'No `rawData` in response body: {response_body}') - search_result_body_bytes = base64.decodebytes( - bytes(response_body["rawData"], "utf-8") - ) + search_result_body_bytes = base64.decodebytes(bytes(response_body['rawData'], 'utf-8')) doc_root = ET.parse(io.BytesIO(search_result_body_bytes)) results = [] - for group in doc_root.findall("response/results/grouping/group"): + for group in doc_root.findall('response/results/grouping/group'): results.append( { - "url": xml_element_contents_to_string(group.find("doc/url")).strip( - "\n" - ), - "title": xml_element_contents_to_string( - group.find("doc/title") - ).strip("\n"), - "snippet": xml_element_contents_to_string( - group.find("doc/passages/passage") - ), + 'url': xml_element_contents_to_string(group.find('doc/url')).strip('\n'), + 'title': xml_element_contents_to_string(group.find('doc/title')).strip('\n'), + 'snippet': xml_element_contents_to_string(group.find('doc/passages/passage')), } ) @@ -115,49 +103,47 @@ def search_yandex( results = [ SearchResult( - link=result.get("url"), - title=result.get("title"), - snippet=result.get("snippet"), + link=result.get('url'), + title=result.get('title'), + snippet=result.get('snippet'), ) for result in results[:count] ] - log.info(f"Yandex search results: {results}") + log.info(f'Yandex search results: {results}') return results except Exception as e: - log.error(f"Error in search: {e}") + log.error(f'Error in search: {e}') return [] -if __name__ == "__main__": +if __name__ == '__main__': from starlette.datastructures import Headers from fastapi import FastAPI result = search_yandex( Request( { - "type": "http", - "asgi.version": "3.0", - "asgi.spec_version": "2.0", - "method": "GET", - "path": "/internal", - "query_string": b"", - "headers": Headers({}).raw, - "client": ("127.0.0.1", 12345), - "server": ("127.0.0.1", 80), - "scheme": "http", - "app": FastAPI(), + 'type': 'http', + 'asgi.version': '3.0', + 'asgi.spec_version': '2.0', + 'method': 'GET', + 'path': '/internal', + 'query_string': b'', + 'headers': Headers({}).raw, + 'client': ('127.0.0.1', 12345), + 'server': ('127.0.0.1', 80), + 'scheme': 'http', + 'app': FastAPI(), }, None, ), - os.environ.get("YANDEX_WEB_SEARCH_URL", ""), - os.environ.get("YANDEX_WEB_SEARCH_API_KEY", ""), - os.environ.get( - "YANDEX_WEB_SEARCH_CONFIG", '{"query": {"searchType": "SEARCH_TYPE_COM"}}' - ), - "TOP movies of the past year", + os.environ.get('YANDEX_WEB_SEARCH_URL', ''), + os.environ.get('YANDEX_WEB_SEARCH_API_KEY', ''), + os.environ.get('YANDEX_WEB_SEARCH_CONFIG', '{"query": {"searchType": "SEARCH_TYPE_COM"}}'), + 'TOP movies of the past year', 3, ) diff --git a/backend/open_webui/retrieval/web/ydc.py b/backend/open_webui/retrieval/web/ydc.py index 21d725a895..21059d8b03 100644 --- a/backend/open_webui/retrieval/web/ydc.py +++ b/backend/open_webui/retrieval/web/ydc.py @@ -12,7 +12,7 @@ def search_youcom( query: str, count: int, filter_list: Optional[List[str]] = None, - language: str = "EN", + language: str = 'EN', ) -> List[SearchResult]: """Search using You.com's YDC Index API and return the results as a list of SearchResult objects. @@ -23,30 +23,30 @@ def search_youcom( filter_list (list[str], optional): Domain filter list language (str): Language code for search results (default: "EN") """ - url = "https://ydc-index.io/v1/search" + url = 'https://ydc-index.io/v1/search' headers = { - "Accept": "application/json", - "X-API-KEY": api_key, + 'Accept': 'application/json', + 'X-API-KEY': api_key, } params = { - "query": query, - "count": count, - "language": language, + 'query': query, + 'count': count, + 'language': language, } response = requests.get(url, headers=headers, params=params) response.raise_for_status() json_response = response.json() - results = json_response.get("results", {}).get("web", []) + results = json_response.get('results', {}).get('web', []) if filter_list: results = get_filtered_results(results, filter_list) return [ SearchResult( - link=result["url"], - title=result.get("title"), + link=result['url'], + title=result.get('title'), snippet=_build_snippet(result), ) for result in results[:count] @@ -62,12 +62,12 @@ def _build_snippet(result: dict) -> str: """ parts: list[str] = [] - description = result.get("description") + description = result.get('description') if description: parts.append(description) - snippets = result.get("snippets") + snippets = result.get('snippets') if snippets and isinstance(snippets, list): parts.extend(snippets) - return "\n\n".join(parts) + return '\n\n'.join(parts) diff --git a/backend/open_webui/routers/analytics.py b/backend/open_webui/routers/analytics.py index 9579845a49..790c134295 100644 --- a/backend/open_webui/routers/analytics.py +++ b/backend/open_webui/routers/analytics.py @@ -53,18 +53,16 @@ class UserAnalyticsResponse(BaseModel): #################### -@router.get("/models", response_model=ModelAnalyticsResponse) +@router.get('/models', response_model=ModelAnalyticsResponse) async def get_model_analytics( - start_date: Optional[int] = Query(None, description="Start timestamp (epoch)"), - end_date: Optional[int] = Query(None, description="End timestamp (epoch)"), - group_id: Optional[str] = Query(None, description="Filter by user group ID"), + start_date: Optional[int] = Query(None, description='Start timestamp (epoch)'), + end_date: Optional[int] = Query(None, description='End timestamp (epoch)'), + group_id: Optional[str] = Query(None, description='Filter by user group ID'), user=Depends(get_admin_user), db: Session = Depends(get_session), ): """Get message counts per model.""" - counts = ChatMessages.get_message_count_by_model( - start_date=start_date, end_date=end_date, group_id=group_id, db=db - ) + counts = ChatMessages.get_message_count_by_model(start_date=start_date, end_date=end_date, group_id=group_id, db=db) models = [ ModelAnalyticsEntry(model_id=model_id, count=count) for model_id, count in sorted(counts.items(), key=lambda x: -x[1]) @@ -72,27 +70,23 @@ async def get_model_analytics( return ModelAnalyticsResponse(models=models) -@router.get("/users", response_model=UserAnalyticsResponse) +@router.get('/users', response_model=UserAnalyticsResponse) async def get_user_analytics( - start_date: Optional[int] = Query(None, description="Start timestamp (epoch)"), - end_date: Optional[int] = Query(None, description="End timestamp (epoch)"), - group_id: Optional[str] = Query(None, description="Filter by user group ID"), - limit: int = Query(50, description="Max users to return"), + start_date: Optional[int] = Query(None, description='Start timestamp (epoch)'), + end_date: Optional[int] = Query(None, description='End timestamp (epoch)'), + group_id: Optional[str] = Query(None, description='Filter by user group ID'), + limit: int = Query(50, description='Max users to return'), user=Depends(get_admin_user), db: Session = Depends(get_session), ): """Get message counts and token usage per user with user info.""" - counts = ChatMessages.get_message_count_by_user( - start_date=start_date, end_date=end_date, group_id=group_id, db=db - ) + counts = ChatMessages.get_message_count_by_user(start_date=start_date, end_date=end_date, group_id=group_id, db=db) token_usage = ChatMessages.get_token_usage_by_user( start_date=start_date, end_date=end_date, group_id=group_id, db=db ) # Get user info for top users - top_user_ids = [ - uid for uid, _ in sorted(counts.items(), key=lambda x: -x[1])[:limit] - ] + top_user_ids = [uid for uid, _ in sorted(counts.items(), key=lambda x: -x[1])[:limit]] user_info = {u.id: u for u in Users.get_users_by_user_ids(top_user_ids, db=db)} users = [] @@ -105,22 +99,22 @@ async def get_user_analytics( name=u.name if u else None, email=u.email if u else None, count=counts[user_id], - input_tokens=tokens.get("input_tokens", 0), - output_tokens=tokens.get("output_tokens", 0), - total_tokens=tokens.get("total_tokens", 0), + input_tokens=tokens.get('input_tokens', 0), + output_tokens=tokens.get('output_tokens', 0), + total_tokens=tokens.get('total_tokens', 0), ) ) return UserAnalyticsResponse(users=users) -@router.get("/messages", response_model=list[ChatMessageModel]) +@router.get('/messages', response_model=list[ChatMessageModel]) async def get_messages( - model_id: Optional[str] = Query(None, description="Filter by model ID"), - user_id: Optional[str] = Query(None, description="Filter by user ID"), - chat_id: Optional[str] = Query(None, description="Filter by chat ID"), - start_date: Optional[int] = Query(None, description="Start timestamp (epoch)"), - end_date: Optional[int] = Query(None, description="End timestamp (epoch)"), + model_id: Optional[str] = Query(None, description='Filter by model ID'), + user_id: Optional[str] = Query(None, description='Filter by user ID'), + chat_id: Optional[str] = Query(None, description='Filter by chat ID'), + start_date: Optional[int] = Query(None, description='Start timestamp (epoch)'), + end_date: Optional[int] = Query(None, description='End timestamp (epoch)'), skip: int = Query(0), limit: int = Query(50, le=100), user=Depends(get_admin_user), @@ -139,9 +133,7 @@ async def get_messages( db=db, ) elif user_id: - return ChatMessages.get_messages_by_user_id( - user_id=user_id, skip=skip, limit=limit, db=db - ) + return ChatMessages.get_messages_by_user_id(user_id=user_id, skip=skip, limit=limit, db=db) else: # Return empty if no filter specified return [] @@ -154,11 +146,11 @@ class SummaryResponse(BaseModel): total_users: int -@router.get("/summary", response_model=SummaryResponse) +@router.get('/summary', response_model=SummaryResponse) async def get_summary( - start_date: Optional[int] = Query(None, description="Start timestamp (epoch)"), - end_date: Optional[int] = Query(None, description="End timestamp (epoch)"), - group_id: Optional[str] = Query(None, description="Filter by user group ID"), + start_date: Optional[int] = Query(None, description='Start timestamp (epoch)'), + end_date: Optional[int] = Query(None, description='End timestamp (epoch)'), + group_id: Optional[str] = Query(None, description='Filter by user group ID'), user=Depends(get_admin_user), db: Session = Depends(get_session), ): @@ -190,29 +182,24 @@ class DailyStatsResponse(BaseModel): data: list[DailyStatsEntry] -@router.get("/daily", response_model=DailyStatsResponse) +@router.get('/daily', response_model=DailyStatsResponse) async def get_daily_stats( - start_date: Optional[int] = Query(None, description="Start timestamp (epoch)"), - end_date: Optional[int] = Query(None, description="End timestamp (epoch)"), - group_id: Optional[str] = Query(None, description="Filter by user group ID"), - granularity: str = Query("daily", description="Granularity: 'hourly' or 'daily'"), + start_date: Optional[int] = Query(None, description='Start timestamp (epoch)'), + end_date: Optional[int] = Query(None, description='End timestamp (epoch)'), + group_id: Optional[str] = Query(None, description='Filter by user group ID'), + granularity: str = Query('daily', description="Granularity: 'hourly' or 'daily'"), user=Depends(get_admin_user), db: Session = Depends(get_session), ): """Get message counts grouped by model for time-series chart.""" - if granularity == "hourly": - counts = ChatMessages.get_hourly_message_counts_by_model( - start_date=start_date, end_date=end_date, db=db - ) + if granularity == 'hourly': + counts = ChatMessages.get_hourly_message_counts_by_model(start_date=start_date, end_date=end_date, db=db) else: counts = ChatMessages.get_daily_message_counts_by_model( start_date=start_date, end_date=end_date, group_id=group_id, db=db ) return DailyStatsResponse( - data=[ - DailyStatsEntry(date=date, models=models) - for date, models in sorted(counts.items()) - ] + data=[DailyStatsEntry(date=date, models=models) for date, models in sorted(counts.items())] ) @@ -231,22 +218,20 @@ class TokenUsageResponse(BaseModel): total_tokens: int -@router.get("/tokens", response_model=TokenUsageResponse) +@router.get('/tokens', response_model=TokenUsageResponse) async def get_token_usage( start_date: Optional[int] = Query(None), end_date: Optional[int] = Query(None), - group_id: Optional[str] = Query(None, description="Filter by user group ID"), + group_id: Optional[str] = Query(None, description='Filter by user group ID'), user=Depends(get_admin_user), db: Session = Depends(get_session), ): """Get token usage aggregated by model.""" - usage = ChatMessages.get_token_usage_by_model( - start_date=start_date, end_date=end_date, group_id=group_id, db=db - ) + usage = ChatMessages.get_token_usage_by_model(start_date=start_date, end_date=end_date, group_id=group_id, db=db) models = [ TokenUsageEntry(model_id=model_id, **data) - for model_id, data in sorted(usage.items(), key=lambda x: -x[1]["total_tokens"]) + for model_id, data in sorted(usage.items(), key=lambda x: -x[1]['total_tokens']) ] total_input = sum(m.input_tokens for m in models) @@ -278,7 +263,7 @@ class ModelChatsResponse(BaseModel): total: int -@router.get("/models/{model_id:path}/chats", response_model=ModelChatsResponse) +@router.get('/models/{model_id:path}/chats', response_model=ModelChatsResponse) async def get_model_chats( model_id: str, start_date: Optional[int] = Query(None), @@ -311,7 +296,7 @@ async def get_model_chats( continue # Get user_id from first user message - first_user_msg = next((m for m in messages if m.role == "user"), None) + first_user_msg = next((m for m in messages if m.role == 'user'), None) user_id = first_user_msg.user_id if first_user_msg else None # Extract first message content as preview @@ -321,8 +306,8 @@ async def get_model_chats( if isinstance(content, str): first_message = content[:200] elif isinstance(content, list): - text_parts = [b.get("text", "") for b in content if isinstance(b, dict)] - first_message = " ".join(text_parts)[:200] + text_parts = [b.get('text', '') for b in content if isinstance(b, dict)] + first_message = ' '.join(text_parts)[:200] # Get user info user_name = None @@ -367,10 +352,10 @@ class ModelOverviewResponse(BaseModel): tags: list[TagEntry] -@router.get("/models/{model_id:path}/overview", response_model=ModelOverviewResponse) +@router.get('/models/{model_id:path}/overview', response_model=ModelOverviewResponse) async def get_model_overview( model_id: str, - days: int = Query(30, description="Number of days of history (0 for all)"), + days: int = Query(30, description='Number of days of history (0 for all)'), user=Depends(get_admin_user), db: Session = Depends(get_session), ): @@ -387,7 +372,7 @@ async def get_model_overview( ) # Get feedback history per day - history_counts: dict[str, dict] = defaultdict(lambda: {"won": 0, "lost": 0}) + history_counts: dict[str, dict] = defaultdict(lambda: {'won': 0, 'lost': 0}) # Calculate start date for history now = datetime.now() @@ -398,19 +383,19 @@ async def get_model_overview( for chat_id in chat_ids: feedbacks = Feedbacks.get_feedbacks_by_chat_id(chat_id, db=db) for fb in feedbacks: - if fb.data and "rating" in fb.data: - rating = fb.data["rating"] + if fb.data and 'rating' in fb.data: + rating = fb.data['rating'] fb_date = datetime.fromtimestamp(fb.created_at) # Filter by date range if start_dt and fb_date < start_dt: continue - date_str = fb_date.strftime("%Y-%m-%d") + date_str = fb_date.strftime('%Y-%m-%d') if rating == 1: - history_counts[date_str]["won"] += 1 + history_counts[date_str]['won'] += 1 elif rating == -1: - history_counts[date_str]["lost"] += 1 + history_counts[date_str]['lost'] += 1 # Fill in missing days history = [] @@ -421,18 +406,18 @@ async def get_model_overview( elif history_counts: # Find earliest date min_date = min(history_counts.keys()) - current = datetime.strptime(min_date, "%Y-%m-%d") + current = datetime.strptime(min_date, '%Y-%m-%d') else: current = now while current <= end_dt: - date_str = current.strftime("%Y-%m-%d") - counts = history_counts.get(date_str, {"won": 0, "lost": 0}) + date_str = current.strftime('%Y-%m-%d') + counts = history_counts.get(date_str, {'won': 0, 'lost': 0}) history.append( HistoryEntry( date=date_str, - won=counts["won"], - lost=counts["lost"], + won=counts['won'], + lost=counts['lost'], ) ) current += timedelta(days=1) @@ -442,13 +427,10 @@ async def get_model_overview( for chat_id in chat_ids: chat = Chats.get_chat_by_id(chat_id, db=db) if chat and chat.meta: - for tag in chat.meta.get("tags", []): + for tag in chat.meta.get('tags', []): tag_counts[tag] += 1 # Sort by count and take top 10 - tags = [ - TagEntry(tag=tag, count=count) - for tag, count in sorted(tag_counts.items(), key=lambda x: -x[1])[:10] - ] + tags = [TagEntry(tag=tag, count=count) for tag, count in sorted(tag_counts.items(), key=lambda x: -x[1])[:10]] return ModelOverviewResponse(history=history, tags=tags) diff --git a/backend/open_webui/routers/audio.py b/backend/open_webui/routers/audio.py index a1f3ac523f..7e1fd9e3ee 100644 --- a/backend/open_webui/routers/audio.py +++ b/backend/open_webui/routers/audio.py @@ -68,7 +68,7 @@ AZURE_MAX_FILE_SIZE = AZURE_MAX_FILE_SIZE_MB * 1024 * 1024 # Convert MB to byte log = logging.getLogger(__name__) -SPEECH_CACHE_DIR = CACHE_DIR / "audio" / "speech" +SPEECH_CACHE_DIR = CACHE_DIR / 'audio' / 'speech' SPEECH_CACHE_DIR.mkdir(parents=True, exist_ok=True) @@ -86,19 +86,19 @@ def is_audio_conversion_required(file_path): """ Check if the given audio file needs conversion to mp3. """ - SUPPORTED_FORMATS = {"flac", "m4a", "mp3", "mp4", "mpeg", "wav", "webm"} + SUPPORTED_FORMATS = {'flac', 'm4a', 'mp3', 'mp4', 'mpeg', 'wav', 'webm'} if not os.path.isfile(file_path): - log.error(f"File not found: {file_path}") + log.error(f'File not found: {file_path}') return False try: info = mediainfo(file_path) - codec_name = info.get("codec_name", "").lower() - codec_type = info.get("codec_type", "").lower() - codec_tag_string = info.get("codec_tag_string", "").lower() + codec_name = info.get('codec_name', '').lower() + codec_type = info.get('codec_type', '').lower() + codec_tag_string = info.get('codec_tag_string', '').lower() - if codec_name == "aac" and codec_type == "audio" and codec_tag_string == "mp4a": + if codec_name == 'aac' and codec_type == 'audio' and codec_tag_string == 'mp4a': # File is AAC/mp4a audio, recommend mp3 conversion return True @@ -108,20 +108,20 @@ def is_audio_conversion_required(file_path): return True except Exception as e: - log.error(f"Error getting audio format: {e}") + log.error(f'Error getting audio format: {e}') return False def convert_audio_to_mp3(file_path): """Convert audio file to mp3 format.""" try: - output_path = os.path.splitext(file_path)[0] + ".mp3" + output_path = os.path.splitext(file_path)[0] + '.mp3' audio = AudioSegment.from_file(file_path) - audio.export(output_path, format="mp3") - log.info(f"Converted {file_path} to {output_path}") + audio.export(output_path, format='mp3') + log.info(f'Converted {file_path} to {output_path}') return output_path except Exception as e: - log.error(f"Error converting audio file: {e}") + log.error(f'Error converting audio file: {e}') return None @@ -131,20 +131,18 @@ def set_faster_whisper_model(model: str, auto_update: bool = False): from faster_whisper import WhisperModel faster_whisper_kwargs = { - "model_size_or_path": model, - "device": DEVICE_TYPE if DEVICE_TYPE and DEVICE_TYPE == "cuda" else "cpu", - "compute_type": WHISPER_COMPUTE_TYPE, - "download_root": WHISPER_MODEL_DIR, - "local_files_only": not auto_update, + 'model_size_or_path': model, + 'device': DEVICE_TYPE if DEVICE_TYPE and DEVICE_TYPE == 'cuda' else 'cpu', + 'compute_type': WHISPER_COMPUTE_TYPE, + 'download_root': WHISPER_MODEL_DIR, + 'local_files_only': not auto_update, } try: whisper_model = WhisperModel(**faster_whisper_kwargs) except Exception: - log.warning( - "WhisperModel initialization failed, attempting download with local_files_only=False" - ) - faster_whisper_kwargs["local_files_only"] = False + log.warning('WhisperModel initialization failed, attempting download with local_files_only=False') + faster_whisper_kwargs['local_files_only'] = False whisper_model = WhisperModel(**faster_whisper_kwargs) return whisper_model @@ -193,46 +191,44 @@ class AudioConfigUpdateForm(BaseModel): stt: STTConfigForm -@router.get("/config") +@router.get('/config') async def get_audio_config(request: Request, user=Depends(get_admin_user)): return { - "tts": { - "OPENAI_API_BASE_URL": request.app.state.config.TTS_OPENAI_API_BASE_URL, - "OPENAI_API_KEY": request.app.state.config.TTS_OPENAI_API_KEY, - "OPENAI_PARAMS": request.app.state.config.TTS_OPENAI_PARAMS, - "API_KEY": request.app.state.config.TTS_API_KEY, - "ENGINE": request.app.state.config.TTS_ENGINE, - "MODEL": request.app.state.config.TTS_MODEL, - "VOICE": request.app.state.config.TTS_VOICE, - "SPLIT_ON": request.app.state.config.TTS_SPLIT_ON, - "AZURE_SPEECH_REGION": request.app.state.config.TTS_AZURE_SPEECH_REGION, - "AZURE_SPEECH_BASE_URL": request.app.state.config.TTS_AZURE_SPEECH_BASE_URL, - "AZURE_SPEECH_OUTPUT_FORMAT": request.app.state.config.TTS_AZURE_SPEECH_OUTPUT_FORMAT, + 'tts': { + 'OPENAI_API_BASE_URL': request.app.state.config.TTS_OPENAI_API_BASE_URL, + 'OPENAI_API_KEY': request.app.state.config.TTS_OPENAI_API_KEY, + 'OPENAI_PARAMS': request.app.state.config.TTS_OPENAI_PARAMS, + 'API_KEY': request.app.state.config.TTS_API_KEY, + 'ENGINE': request.app.state.config.TTS_ENGINE, + 'MODEL': request.app.state.config.TTS_MODEL, + 'VOICE': request.app.state.config.TTS_VOICE, + 'SPLIT_ON': request.app.state.config.TTS_SPLIT_ON, + 'AZURE_SPEECH_REGION': request.app.state.config.TTS_AZURE_SPEECH_REGION, + 'AZURE_SPEECH_BASE_URL': request.app.state.config.TTS_AZURE_SPEECH_BASE_URL, + 'AZURE_SPEECH_OUTPUT_FORMAT': request.app.state.config.TTS_AZURE_SPEECH_OUTPUT_FORMAT, }, - "stt": { - "OPENAI_API_BASE_URL": request.app.state.config.STT_OPENAI_API_BASE_URL, - "OPENAI_API_KEY": request.app.state.config.STT_OPENAI_API_KEY, - "ENGINE": request.app.state.config.STT_ENGINE, - "MODEL": request.app.state.config.STT_MODEL, - "SUPPORTED_CONTENT_TYPES": request.app.state.config.STT_SUPPORTED_CONTENT_TYPES, - "WHISPER_MODEL": request.app.state.config.WHISPER_MODEL, - "DEEPGRAM_API_KEY": request.app.state.config.DEEPGRAM_API_KEY, - "AZURE_API_KEY": request.app.state.config.AUDIO_STT_AZURE_API_KEY, - "AZURE_REGION": request.app.state.config.AUDIO_STT_AZURE_REGION, - "AZURE_LOCALES": request.app.state.config.AUDIO_STT_AZURE_LOCALES, - "AZURE_BASE_URL": request.app.state.config.AUDIO_STT_AZURE_BASE_URL, - "AZURE_MAX_SPEAKERS": request.app.state.config.AUDIO_STT_AZURE_MAX_SPEAKERS, - "MISTRAL_API_KEY": request.app.state.config.AUDIO_STT_MISTRAL_API_KEY, - "MISTRAL_API_BASE_URL": request.app.state.config.AUDIO_STT_MISTRAL_API_BASE_URL, - "MISTRAL_USE_CHAT_COMPLETIONS": request.app.state.config.AUDIO_STT_MISTRAL_USE_CHAT_COMPLETIONS, + 'stt': { + 'OPENAI_API_BASE_URL': request.app.state.config.STT_OPENAI_API_BASE_URL, + 'OPENAI_API_KEY': request.app.state.config.STT_OPENAI_API_KEY, + 'ENGINE': request.app.state.config.STT_ENGINE, + 'MODEL': request.app.state.config.STT_MODEL, + 'SUPPORTED_CONTENT_TYPES': request.app.state.config.STT_SUPPORTED_CONTENT_TYPES, + 'WHISPER_MODEL': request.app.state.config.WHISPER_MODEL, + 'DEEPGRAM_API_KEY': request.app.state.config.DEEPGRAM_API_KEY, + 'AZURE_API_KEY': request.app.state.config.AUDIO_STT_AZURE_API_KEY, + 'AZURE_REGION': request.app.state.config.AUDIO_STT_AZURE_REGION, + 'AZURE_LOCALES': request.app.state.config.AUDIO_STT_AZURE_LOCALES, + 'AZURE_BASE_URL': request.app.state.config.AUDIO_STT_AZURE_BASE_URL, + 'AZURE_MAX_SPEAKERS': request.app.state.config.AUDIO_STT_AZURE_MAX_SPEAKERS, + 'MISTRAL_API_KEY': request.app.state.config.AUDIO_STT_MISTRAL_API_KEY, + 'MISTRAL_API_BASE_URL': request.app.state.config.AUDIO_STT_MISTRAL_API_BASE_URL, + 'MISTRAL_USE_CHAT_COMPLETIONS': request.app.state.config.AUDIO_STT_MISTRAL_USE_CHAT_COMPLETIONS, }, } -@router.post("/config/update") -async def update_audio_config( - request: Request, form_data: AudioConfigUpdateForm, user=Depends(get_admin_user) -): +@router.post('/config/update') +async def update_audio_config(request: Request, form_data: AudioConfigUpdateForm, user=Depends(get_admin_user)): request.app.state.config.TTS_OPENAI_API_BASE_URL = form_data.tts.OPENAI_API_BASE_URL request.app.state.config.TTS_OPENAI_API_KEY = form_data.tts.OPENAI_API_KEY request.app.state.config.TTS_OPENAI_PARAMS = form_data.tts.OPENAI_PARAMS @@ -242,20 +238,14 @@ async def update_audio_config( request.app.state.config.TTS_VOICE = form_data.tts.VOICE request.app.state.config.TTS_SPLIT_ON = form_data.tts.SPLIT_ON request.app.state.config.TTS_AZURE_SPEECH_REGION = form_data.tts.AZURE_SPEECH_REGION - request.app.state.config.TTS_AZURE_SPEECH_BASE_URL = ( - form_data.tts.AZURE_SPEECH_BASE_URL - ) - request.app.state.config.TTS_AZURE_SPEECH_OUTPUT_FORMAT = ( - form_data.tts.AZURE_SPEECH_OUTPUT_FORMAT - ) + request.app.state.config.TTS_AZURE_SPEECH_BASE_URL = form_data.tts.AZURE_SPEECH_BASE_URL + request.app.state.config.TTS_AZURE_SPEECH_OUTPUT_FORMAT = form_data.tts.AZURE_SPEECH_OUTPUT_FORMAT request.app.state.config.STT_OPENAI_API_BASE_URL = form_data.stt.OPENAI_API_BASE_URL request.app.state.config.STT_OPENAI_API_KEY = form_data.stt.OPENAI_API_KEY request.app.state.config.STT_ENGINE = form_data.stt.ENGINE request.app.state.config.STT_MODEL = form_data.stt.MODEL - request.app.state.config.STT_SUPPORTED_CONTENT_TYPES = ( - form_data.stt.SUPPORTED_CONTENT_TYPES - ) + request.app.state.config.STT_SUPPORTED_CONTENT_TYPES = form_data.stt.SUPPORTED_CONTENT_TYPES request.app.state.config.WHISPER_MODEL = form_data.stt.WHISPER_MODEL request.app.state.config.DEEPGRAM_API_KEY = form_data.stt.DEEPGRAM_API_KEY @@ -263,18 +253,12 @@ async def update_audio_config( request.app.state.config.AUDIO_STT_AZURE_REGION = form_data.stt.AZURE_REGION request.app.state.config.AUDIO_STT_AZURE_LOCALES = form_data.stt.AZURE_LOCALES request.app.state.config.AUDIO_STT_AZURE_BASE_URL = form_data.stt.AZURE_BASE_URL - request.app.state.config.AUDIO_STT_AZURE_MAX_SPEAKERS = ( - form_data.stt.AZURE_MAX_SPEAKERS - ) + request.app.state.config.AUDIO_STT_AZURE_MAX_SPEAKERS = form_data.stt.AZURE_MAX_SPEAKERS request.app.state.config.AUDIO_STT_MISTRAL_API_KEY = form_data.stt.MISTRAL_API_KEY - request.app.state.config.AUDIO_STT_MISTRAL_API_BASE_URL = ( - form_data.stt.MISTRAL_API_BASE_URL - ) - request.app.state.config.AUDIO_STT_MISTRAL_USE_CHAT_COMPLETIONS = ( - form_data.stt.MISTRAL_USE_CHAT_COMPLETIONS - ) + request.app.state.config.AUDIO_STT_MISTRAL_API_BASE_URL = form_data.stt.MISTRAL_API_BASE_URL + request.app.state.config.AUDIO_STT_MISTRAL_USE_CHAT_COMPLETIONS = form_data.stt.MISTRAL_USE_CHAT_COMPLETIONS - if request.app.state.config.STT_ENGINE == "": + if request.app.state.config.STT_ENGINE == '': request.app.state.faster_whisper_model = set_faster_whisper_model( form_data.stt.WHISPER_MODEL, WHISPER_MODEL_AUTO_UPDATE ) @@ -282,35 +266,35 @@ async def update_audio_config( request.app.state.faster_whisper_model = None return { - "tts": { - "ENGINE": request.app.state.config.TTS_ENGINE, - "MODEL": request.app.state.config.TTS_MODEL, - "VOICE": request.app.state.config.TTS_VOICE, - "OPENAI_API_BASE_URL": request.app.state.config.TTS_OPENAI_API_BASE_URL, - "OPENAI_API_KEY": request.app.state.config.TTS_OPENAI_API_KEY, - "OPENAI_PARAMS": request.app.state.config.TTS_OPENAI_PARAMS, - "API_KEY": request.app.state.config.TTS_API_KEY, - "SPLIT_ON": request.app.state.config.TTS_SPLIT_ON, - "AZURE_SPEECH_REGION": request.app.state.config.TTS_AZURE_SPEECH_REGION, - "AZURE_SPEECH_BASE_URL": request.app.state.config.TTS_AZURE_SPEECH_BASE_URL, - "AZURE_SPEECH_OUTPUT_FORMAT": request.app.state.config.TTS_AZURE_SPEECH_OUTPUT_FORMAT, + 'tts': { + 'ENGINE': request.app.state.config.TTS_ENGINE, + 'MODEL': request.app.state.config.TTS_MODEL, + 'VOICE': request.app.state.config.TTS_VOICE, + 'OPENAI_API_BASE_URL': request.app.state.config.TTS_OPENAI_API_BASE_URL, + 'OPENAI_API_KEY': request.app.state.config.TTS_OPENAI_API_KEY, + 'OPENAI_PARAMS': request.app.state.config.TTS_OPENAI_PARAMS, + 'API_KEY': request.app.state.config.TTS_API_KEY, + 'SPLIT_ON': request.app.state.config.TTS_SPLIT_ON, + 'AZURE_SPEECH_REGION': request.app.state.config.TTS_AZURE_SPEECH_REGION, + 'AZURE_SPEECH_BASE_URL': request.app.state.config.TTS_AZURE_SPEECH_BASE_URL, + 'AZURE_SPEECH_OUTPUT_FORMAT': request.app.state.config.TTS_AZURE_SPEECH_OUTPUT_FORMAT, }, - "stt": { - "OPENAI_API_BASE_URL": request.app.state.config.STT_OPENAI_API_BASE_URL, - "OPENAI_API_KEY": request.app.state.config.STT_OPENAI_API_KEY, - "ENGINE": request.app.state.config.STT_ENGINE, - "MODEL": request.app.state.config.STT_MODEL, - "SUPPORTED_CONTENT_TYPES": request.app.state.config.STT_SUPPORTED_CONTENT_TYPES, - "WHISPER_MODEL": request.app.state.config.WHISPER_MODEL, - "DEEPGRAM_API_KEY": request.app.state.config.DEEPGRAM_API_KEY, - "AZURE_API_KEY": request.app.state.config.AUDIO_STT_AZURE_API_KEY, - "AZURE_REGION": request.app.state.config.AUDIO_STT_AZURE_REGION, - "AZURE_LOCALES": request.app.state.config.AUDIO_STT_AZURE_LOCALES, - "AZURE_BASE_URL": request.app.state.config.AUDIO_STT_AZURE_BASE_URL, - "AZURE_MAX_SPEAKERS": request.app.state.config.AUDIO_STT_AZURE_MAX_SPEAKERS, - "MISTRAL_API_KEY": request.app.state.config.AUDIO_STT_MISTRAL_API_KEY, - "MISTRAL_API_BASE_URL": request.app.state.config.AUDIO_STT_MISTRAL_API_BASE_URL, - "MISTRAL_USE_CHAT_COMPLETIONS": request.app.state.config.AUDIO_STT_MISTRAL_USE_CHAT_COMPLETIONS, + 'stt': { + 'OPENAI_API_BASE_URL': request.app.state.config.STT_OPENAI_API_BASE_URL, + 'OPENAI_API_KEY': request.app.state.config.STT_OPENAI_API_KEY, + 'ENGINE': request.app.state.config.STT_ENGINE, + 'MODEL': request.app.state.config.STT_MODEL, + 'SUPPORTED_CONTENT_TYPES': request.app.state.config.STT_SUPPORTED_CONTENT_TYPES, + 'WHISPER_MODEL': request.app.state.config.WHISPER_MODEL, + 'DEEPGRAM_API_KEY': request.app.state.config.DEEPGRAM_API_KEY, + 'AZURE_API_KEY': request.app.state.config.AUDIO_STT_AZURE_API_KEY, + 'AZURE_REGION': request.app.state.config.AUDIO_STT_AZURE_REGION, + 'AZURE_LOCALES': request.app.state.config.AUDIO_STT_AZURE_LOCALES, + 'AZURE_BASE_URL': request.app.state.config.AUDIO_STT_AZURE_BASE_URL, + 'AZURE_MAX_SPEAKERS': request.app.state.config.AUDIO_STT_AZURE_MAX_SPEAKERS, + 'MISTRAL_API_KEY': request.app.state.config.AUDIO_STT_MISTRAL_API_KEY, + 'MISTRAL_API_BASE_URL': request.app.state.config.AUDIO_STT_MISTRAL_API_BASE_URL, + 'MISTRAL_USE_CHAT_COMPLETIONS': request.app.state.config.AUDIO_STT_MISTRAL_USE_CHAT_COMPLETIONS, }, } @@ -320,27 +304,23 @@ def load_speech_pipeline(request): from datasets import load_dataset if request.app.state.speech_synthesiser is None: - request.app.state.speech_synthesiser = pipeline( - "text-to-speech", "microsoft/speecht5_tts" - ) + request.app.state.speech_synthesiser = pipeline('text-to-speech', 'microsoft/speecht5_tts') if request.app.state.speech_speaker_embeddings_dataset is None: request.app.state.speech_speaker_embeddings_dataset = load_dataset( - "Matthijs/cmu-arctic-xvectors", split="validation" + 'Matthijs/cmu-arctic-xvectors', split='validation' ) -@router.post("/speech") +@router.post('/speech') async def speech(request: Request, user=Depends(get_verified_user)): - if request.app.state.config.TTS_ENGINE == "": + if request.app.state.config.TTS_ENGINE == '': raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND, ) - if user.role != "admin" and not has_permission( - user.id, "chat.tts", request.app.state.config.USER_PERMISSIONS - ): + if user.role != 'admin' and not has_permission(user.id, 'chat.tts', request.app.state.config.USER_PERMISSIONS): raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.ACCESS_PROHIBITED, @@ -349,12 +329,12 @@ async def speech(request: Request, user=Depends(get_verified_user)): body = await request.body() name = hashlib.sha256( body - + str(request.app.state.config.TTS_ENGINE).encode("utf-8") - + str(request.app.state.config.TTS_MODEL).encode("utf-8") + + str(request.app.state.config.TTS_ENGINE).encode('utf-8') + + str(request.app.state.config.TTS_MODEL).encode('utf-8') ).hexdigest() - file_path = SPEECH_CACHE_DIR.joinpath(f"{name}.mp3") - file_body_path = SPEECH_CACHE_DIR.joinpath(f"{name}.json") + file_path = SPEECH_CACHE_DIR.joinpath(f'{name}.mp3') + file_body_path = SPEECH_CACHE_DIR.joinpath(f'{name}.json') # Check if the file already exists in the cache if file_path.is_file(): @@ -362,34 +342,32 @@ async def speech(request: Request, user=Depends(get_verified_user)): payload = None try: - payload = json.loads(body.decode("utf-8")) + payload = json.loads(body.decode('utf-8')) except Exception as e: log.exception(e) - raise HTTPException(status_code=400, detail="Invalid JSON payload") + raise HTTPException(status_code=400, detail='Invalid JSON payload') r = None - if request.app.state.config.TTS_ENGINE == "openai": - payload["model"] = request.app.state.config.TTS_MODEL + if request.app.state.config.TTS_ENGINE == 'openai': + payload['model'] = request.app.state.config.TTS_MODEL try: timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT) - async with aiohttp.ClientSession( - timeout=timeout, trust_env=True - ) as session: + async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session: payload = { **payload, **(request.app.state.config.TTS_OPENAI_PARAMS or {}), } headers = { - "Content-Type": "application/json", - "Authorization": f"Bearer {request.app.state.config.TTS_OPENAI_API_KEY}", + 'Content-Type': 'application/json', + 'Authorization': f'Bearer {request.app.state.config.TTS_OPENAI_API_KEY}', } if ENABLE_FORWARD_USER_INFO_HEADERS: headers = include_user_info_headers(headers, user) r = await session.post( - url=f"{request.app.state.config.TTS_OPENAI_API_BASE_URL}/audio/speech", + url=f'{request.app.state.config.TTS_OPENAI_API_BASE_URL}/audio/speech', json=payload, headers=headers, ssl=AIOHTTP_CLIENT_SESSION_SSL, @@ -397,10 +375,10 @@ async def speech(request: Request, user=Depends(get_verified_user)): r.raise_for_status() - async with aiofiles.open(file_path, "wb") as f: + async with aiofiles.open(file_path, 'wb') as f: await f.write(await r.read()) - async with aiofiles.open(file_body_path, "w") as f: + async with aiofiles.open(file_body_path, 'w') as f: await f.write(json.dumps(payload)) return FileResponse(file_path) @@ -410,57 +388,55 @@ async def speech(request: Request, user=Depends(get_verified_user)): detail = None status_code = 500 - detail = f"Open WebUI: Server Connection Error" + detail = f'Open WebUI: Server Connection Error' if r is not None: status_code = r.status try: res = await r.json() - if "error" in res: - detail = f"External: {res['error']}" + if 'error' in res: + detail = f'External: {res["error"]}' except Exception: - detail = f"External: {e}" + detail = f'External: {e}' raise HTTPException( status_code=status_code, detail=detail, ) - elif request.app.state.config.TTS_ENGINE == "elevenlabs": - voice_id = payload.get("voice", "") + elif request.app.state.config.TTS_ENGINE == 'elevenlabs': + voice_id = payload.get('voice', '') if voice_id not in get_available_voices(request): raise HTTPException( status_code=400, - detail="Invalid voice id", + detail='Invalid voice id', ) try: timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT) - async with aiohttp.ClientSession( - timeout=timeout, trust_env=True - ) as session: + async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session: async with session.post( - f"{ELEVENLABS_API_BASE_URL}/v1/text-to-speech/{voice_id}", + f'{ELEVENLABS_API_BASE_URL}/v1/text-to-speech/{voice_id}', json={ - "text": payload["input"], - "model_id": request.app.state.config.TTS_MODEL, - "voice_settings": {"stability": 0.5, "similarity_boost": 0.5}, + 'text': payload['input'], + 'model_id': request.app.state.config.TTS_MODEL, + 'voice_settings': {'stability': 0.5, 'similarity_boost': 0.5}, }, headers={ - "Accept": "audio/mpeg", - "Content-Type": "application/json", - "xi-api-key": request.app.state.config.TTS_API_KEY, + 'Accept': 'audio/mpeg', + 'Content-Type': 'application/json', + 'xi-api-key': request.app.state.config.TTS_API_KEY, }, ssl=AIOHTTP_CLIENT_SESSION_SSL, ) as r: r.raise_for_status() - async with aiofiles.open(file_path, "wb") as f: + async with aiofiles.open(file_path, 'wb') as f: await f.write(await r.read()) - async with aiofiles.open(file_body_path, "w") as f: + async with aiofiles.open(file_body_path, 'w') as f: await f.write(json.dumps(payload)) return FileResponse(file_path) @@ -472,54 +448,51 @@ async def speech(request: Request, user=Depends(get_verified_user)): try: if r.status != 200: res = await r.json() - if "error" in res: - detail = f"External: {res['error'].get('message', '')}" + if 'error' in res: + detail = f'External: {res["error"].get("message", "")}' except Exception: - detail = f"External: {e}" + detail = f'External: {e}' raise HTTPException( - status_code=getattr(r, "status", 500) if r else 500, - detail=detail if detail else "Open WebUI: Server Connection Error", + status_code=getattr(r, 'status', 500) if r else 500, + detail=detail if detail else 'Open WebUI: Server Connection Error', ) - elif request.app.state.config.TTS_ENGINE == "azure": + elif request.app.state.config.TTS_ENGINE == 'azure': try: - payload = json.loads(body.decode("utf-8")) + payload = json.loads(body.decode('utf-8')) except Exception as e: log.exception(e) - raise HTTPException(status_code=400, detail="Invalid JSON payload") + raise HTTPException(status_code=400, detail='Invalid JSON payload') - region = request.app.state.config.TTS_AZURE_SPEECH_REGION or "eastus" + region = request.app.state.config.TTS_AZURE_SPEECH_REGION or 'eastus' base_url = request.app.state.config.TTS_AZURE_SPEECH_BASE_URL language = request.app.state.config.TTS_VOICE - locale = "-".join(request.app.state.config.TTS_VOICE.split("-")[:2]) + locale = '-'.join(request.app.state.config.TTS_VOICE.split('-')[:2]) output_format = request.app.state.config.TTS_AZURE_SPEECH_OUTPUT_FORMAT try: data = f""" - {html.escape(payload["input"])} + {html.escape(payload['input'])} """ timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT) - async with aiohttp.ClientSession( - timeout=timeout, trust_env=True - ) as session: + async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session: async with session.post( - (base_url or f"https://{region}.tts.speech.microsoft.com") - + "/cognitiveservices/v1", + (base_url or f'https://{region}.tts.speech.microsoft.com') + '/cognitiveservices/v1', headers={ - "Ocp-Apim-Subscription-Key": request.app.state.config.TTS_API_KEY, - "Content-Type": "application/ssml+xml", - "X-Microsoft-OutputFormat": output_format, + 'Ocp-Apim-Subscription-Key': request.app.state.config.TTS_API_KEY, + 'Content-Type': 'application/ssml+xml', + 'X-Microsoft-OutputFormat': output_format, }, data=data, ssl=AIOHTTP_CLIENT_SESSION_SSL, ) as r: r.raise_for_status() - async with aiofiles.open(file_path, "wb") as f: + async with aiofiles.open(file_path, 'wb') as f: await f.write(await r.read()) - async with aiofiles.open(file_body_path, "w") as f: + async with aiofiles.open(file_body_path, 'w') as f: await f.write(json.dumps(payload)) return FileResponse(file_path) @@ -531,23 +504,23 @@ async def speech(request: Request, user=Depends(get_verified_user)): try: if r.status != 200: res = await r.json() - if "error" in res: - detail = f"External: {res['error'].get('message', '')}" + if 'error' in res: + detail = f'External: {res["error"].get("message", "")}' except Exception: - detail = f"External: {e}" + detail = f'External: {e}' raise HTTPException( - status_code=getattr(r, "status", 500) if r else 500, - detail=detail if detail else "Open WebUI: Server Connection Error", + status_code=getattr(r, 'status', 500) if r else 500, + detail=detail if detail else 'Open WebUI: Server Connection Error', ) - elif request.app.state.config.TTS_ENGINE == "transformers": + elif request.app.state.config.TTS_ENGINE == 'transformers': payload = None try: - payload = json.loads(body.decode("utf-8")) + payload = json.loads(body.decode('utf-8')) except Exception as e: log.exception(e) - raise HTTPException(status_code=400, detail="Invalid JSON payload") + raise HTTPException(status_code=400, detail='Invalid JSON payload') import torch import soundfile as sf @@ -558,24 +531,20 @@ async def speech(request: Request, user=Depends(get_verified_user)): speaker_index = 6799 try: - speaker_index = embeddings_dataset["filename"].index( - request.app.state.config.TTS_MODEL - ) + speaker_index = embeddings_dataset['filename'].index(request.app.state.config.TTS_MODEL) except Exception: pass - speaker_embedding = torch.tensor( - embeddings_dataset[speaker_index]["xvector"] - ).unsqueeze(0) + speaker_embedding = torch.tensor(embeddings_dataset[speaker_index]['xvector']).unsqueeze(0) speech = request.app.state.speech_synthesiser( - payload["input"], - forward_params={"speaker_embeddings": speaker_embedding}, + payload['input'], + forward_params={'speaker_embeddings': speaker_embedding}, ) - sf.write(file_path, speech["audio"], samplerate=speech["sampling_rate"]) + sf.write(file_path, speech['audio'], samplerate=speech['sampling_rate']) - async with aiofiles.open(file_body_path, "w") as f: + async with aiofiles.open(file_body_path, 'w') as f: await f.write(json.dumps(payload)) return FileResponse(file_path) @@ -584,20 +553,18 @@ async def speech(request: Request, user=Depends(get_verified_user)): def transcription_handler(request, file_path, metadata, user=None): filename = os.path.basename(file_path) file_dir = os.path.dirname(file_path) - id = filename.split(".")[0] + id = filename.split('.')[0] metadata = metadata or {} languages = [ - metadata.get("language", None) if not WHISPER_LANGUAGE else WHISPER_LANGUAGE, + metadata.get('language', None) if not WHISPER_LANGUAGE else WHISPER_LANGUAGE, None, # Always fallback to None in case transcription fails ] - if request.app.state.config.STT_ENGINE == "": + if request.app.state.config.STT_ENGINE == '': if request.app.state.faster_whisper_model is None: - request.app.state.faster_whisper_model = set_faster_whisper_model( - request.app.state.config.WHISPER_MODEL - ) + request.app.state.faster_whisper_model = set_faster_whisper_model(request.app.state.config.WHISPER_MODEL) model = request.app.state.faster_whisper_model segments, info = model.transcribe( @@ -607,43 +574,38 @@ def transcription_handler(request, file_path, metadata, user=None): language=languages[0], multilingual=WHISPER_MULTILINGUAL, ) - log.info( - "Detected language '%s' with probability %f" - % (info.language, info.language_probability) - ) + log.info("Detected language '%s' with probability %f" % (info.language, info.language_probability)) - transcript = "".join([segment.text for segment in list(segments)]) - data = {"text": transcript.strip()} + transcript = ''.join([segment.text for segment in list(segments)]) + data = {'text': transcript.strip()} # save the transcript to a json file - transcript_file = f"{file_dir}/{id}.json" - with open(transcript_file, "w") as f: + transcript_file = f'{file_dir}/{id}.json' + with open(transcript_file, 'w') as f: json.dump(data, f) log.debug(data) return data - elif request.app.state.config.STT_ENGINE == "openai": + elif request.app.state.config.STT_ENGINE == 'openai': r = None try: for language in languages: payload = { - "model": request.app.state.config.STT_MODEL, + 'model': request.app.state.config.STT_MODEL, } if language: - payload["language"] = language + payload['language'] = language - headers = { - "Authorization": f"Bearer {request.app.state.config.STT_OPENAI_API_KEY}" - } + headers = {'Authorization': f'Bearer {request.app.state.config.STT_OPENAI_API_KEY}'} if user and ENABLE_FORWARD_USER_INFO_HEADERS: headers = include_user_info_headers(headers, user) - with open(file_path, "rb") as audio_file: + with open(file_path, 'rb') as audio_file: r = requests.post( - url=f"{request.app.state.config.STT_OPENAI_API_BASE_URL}/audio/transcriptions", + url=f'{request.app.state.config.STT_OPENAI_API_BASE_URL}/audio/transcriptions', headers=headers, - files={"file": (filename, audio_file)}, + files={'file': (filename, audio_file)}, data=payload, timeout=AIOHTTP_CLIENT_TIMEOUT, ) @@ -656,8 +618,8 @@ def transcription_handler(request, file_path, metadata, user=None): data = r.json() # save the transcript to a json file - transcript_file = f"{file_dir}/{id}.json" - with open(transcript_file, "w") as f: + transcript_file = f'{file_dir}/{id}.json' + with open(transcript_file, 'w') as f: json.dump(data, f) return data @@ -668,41 +630,41 @@ def transcription_handler(request, file_path, metadata, user=None): if r is not None: try: res = r.json() - if "error" in res: - detail = f"External: {res['error'].get('message', '')}" + if 'error' in res: + detail = f'External: {res["error"].get("message", "")}' except Exception: - detail = f"External: {e}" + detail = f'External: {e}' - raise Exception(detail if detail else "Open WebUI: Server Connection Error") + raise Exception(detail if detail else 'Open WebUI: Server Connection Error') - elif request.app.state.config.STT_ENGINE == "deepgram": + elif request.app.state.config.STT_ENGINE == 'deepgram': try: # Determine the MIME type of the file mime, _ = mimetypes.guess_type(file_path) if not mime: - mime = "audio/wav" # fallback to wav if undetectable + mime = 'audio/wav' # fallback to wav if undetectable # Read the audio file - with open(file_path, "rb") as f: + with open(file_path, 'rb') as f: file_data = f.read() # Build headers and parameters headers = { - "Authorization": f"Token {request.app.state.config.DEEPGRAM_API_KEY}", - "Content-Type": mime, + 'Authorization': f'Token {request.app.state.config.DEEPGRAM_API_KEY}', + 'Content-Type': mime, } for language in languages: params = {} if request.app.state.config.STT_MODEL: - params["model"] = request.app.state.config.STT_MODEL + params['model'] = request.app.state.config.STT_MODEL if language: - params["language"] = language + params['language'] = language # Make request to Deepgram API r = requests.post( - "https://api.deepgram.com/v1/listen?smart_format=true", + 'https://api.deepgram.com/v1/listen?smart_format=true', headers=headers, params=params, data=file_data, @@ -718,19 +680,15 @@ def transcription_handler(request, file_path, metadata, user=None): # Extract transcript from Deepgram response try: - transcript = response_data["results"]["channels"][0]["alternatives"][ - 0 - ].get("transcript", "") + transcript = response_data['results']['channels'][0]['alternatives'][0].get('transcript', '') except (KeyError, IndexError) as e: - log.error(f"Malformed response from Deepgram: {str(e)}") - raise Exception( - "Failed to parse Deepgram response - unexpected response format" - ) - data = {"text": transcript.strip()} + log.error(f'Malformed response from Deepgram: {str(e)}') + raise Exception('Failed to parse Deepgram response - unexpected response format') + data = {'text': transcript.strip()} # Save transcript - transcript_file = f"{file_dir}/{id}.json" - with open(transcript_file, "w") as f: + transcript_file = f'{file_dir}/{id}.json' + with open(transcript_file, 'w') as f: json.dump(data, f) return data @@ -741,16 +699,16 @@ def transcription_handler(request, file_path, metadata, user=None): if r is not None: try: res = r.json() - if "error" in res: - detail = f"External: {res['error'].get('message', '')}" + if 'error' in res: + detail = f'External: {res["error"].get("message", "")}' except Exception: - detail = f"External: {e}" - raise Exception(detail if detail else "Open WebUI: Server Connection Error") + detail = f'External: {e}' + raise Exception(detail if detail else 'Open WebUI: Server Connection Error') - elif request.app.state.config.STT_ENGINE == "azure": + elif request.app.state.config.STT_ENGINE == 'azure': # Check file exists and size if not os.path.exists(file_path): - raise HTTPException(status_code=400, detail="Audio file not found") + raise HTTPException(status_code=400, detail='Audio file not found') # Check file size (Azure has a larger limit of 200MB) file_size = os.path.getsize(file_path) @@ -761,7 +719,7 @@ def transcription_handler(request, file_path, metadata, user=None): ) api_key = request.app.state.config.AUDIO_STT_AZURE_API_KEY - region = request.app.state.config.AUDIO_STT_AZURE_REGION or "eastus" + region = request.app.state.config.AUDIO_STT_AZURE_REGION or 'eastus' locales = request.app.state.config.AUDIO_STT_AZURE_LOCALES base_url = request.app.state.config.AUDIO_STT_AZURE_BASE_URL max_speakers = request.app.state.config.AUDIO_STT_AZURE_MAX_SPEAKERS or 3 @@ -769,36 +727,36 @@ def transcription_handler(request, file_path, metadata, user=None): # IF NO LOCALES, USE DEFAULTS if len(locales) < 2: locales = [ - "en-US", - "es-ES", - "es-MX", - "fr-FR", - "hi-IN", - "it-IT", - "de-DE", - "en-GB", - "en-IN", - "ja-JP", - "ko-KR", - "pt-BR", - "zh-CN", + 'en-US', + 'es-ES', + 'es-MX', + 'fr-FR', + 'hi-IN', + 'it-IT', + 'de-DE', + 'en-GB', + 'en-IN', + 'ja-JP', + 'ko-KR', + 'pt-BR', + 'zh-CN', ] - locales = ",".join(locales) + locales = ','.join(locales) if not api_key or not region: raise HTTPException( status_code=400, - detail="Azure API key is required for Azure STT", + detail='Azure API key is required for Azure STT', ) r = None try: # Prepare the request data = { - "definition": json.dumps( + 'definition': json.dumps( { - "locales": locales.split(","), - "diarization": {"maxSpeakers": max_speakers, "enabled": True}, + 'locales': locales.split(','), + 'diarization': {'maxSpeakers': max_speakers, 'enabled': True}, } if locales else {} @@ -806,17 +764,17 @@ def transcription_handler(request, file_path, metadata, user=None): } url = ( - base_url or f"https://{region}.api.cognitive.microsoft.com" - ) + "/speechtotext/transcriptions:transcribe?api-version=2024-11-15" + base_url or f'https://{region}.api.cognitive.microsoft.com' + ) + '/speechtotext/transcriptions:transcribe?api-version=2024-11-15' # Use context manager to ensure file is properly closed - with open(file_path, "rb") as audio_file: + with open(file_path, 'rb') as audio_file: r = requests.post( url=url, - files={"audio": audio_file}, + files={'audio': audio_file}, data=data, headers={ - "Ocp-Apim-Subscription-Key": api_key, + 'Ocp-Apim-Subscription-Key': api_key, }, timeout=AIOHTTP_CLIENT_TIMEOUT, ) @@ -825,100 +783,93 @@ def transcription_handler(request, file_path, metadata, user=None): response = r.json() # Extract transcript from response - if not response.get("combinedPhrases"): - raise ValueError("No transcription found in response") + if not response.get('combinedPhrases'): + raise ValueError('No transcription found in response') # Get the full transcript from combinedPhrases - transcript = response["combinedPhrases"][0].get("text", "").strip() + transcript = response['combinedPhrases'][0].get('text', '').strip() if not transcript: - raise ValueError("Empty transcript in response") + raise ValueError('Empty transcript in response') - data = {"text": transcript} + data = {'text': transcript} # Save transcript to json file (consistent with other providers) - transcript_file = f"{file_dir}/{id}.json" - with open(transcript_file, "w") as f: + transcript_file = f'{file_dir}/{id}.json' + with open(transcript_file, 'w') as f: json.dump(data, f) log.debug(data) return data except (KeyError, IndexError, ValueError) as e: - log.exception("Error parsing Azure response") + log.exception('Error parsing Azure response') raise HTTPException( status_code=500, - detail=f"Failed to parse Azure response: {str(e)}", + detail=f'Failed to parse Azure response: {str(e)}', ) except requests.exceptions.RequestException as e: log.exception(e) detail = None - status_code = getattr(r, "status_code", 500) if r else 500 + status_code = getattr(r, 'status_code', 500) if r else 500 try: if r is not None and r.status_code != 200: res = r.json() # Azure returns {"code": "...", "message": "...", "innerError": {...}} - if "code" in res and "message" in res: - azure_code = res.get("innerError", {}).get("code", res["code"]) + if 'code' in res and 'message' in res: + azure_code = res.get('innerError', {}).get('code', res['code']) user_facing_codes = { - "EmptyAudioFile", - "AudioLengthLimitExceeded", - "NoLanguageIdentified", - "MultipleLanguagesIdentified", + 'EmptyAudioFile', + 'AudioLengthLimitExceeded', + 'NoLanguageIdentified', + 'MultipleLanguagesIdentified', } if azure_code in user_facing_codes: - detail = res["message"] + detail = res['message'] else: - log.error( - f"Azure STT error [{azure_code}]: {res['message']}" - ) - detail = "An error occurred during transcription." - elif "error" in res: - detail = f"External: {res['error'].get('message', '')}" + log.error(f'Azure STT error [{azure_code}]: {res["message"]}') + detail = 'An error occurred during transcription.' + elif 'error' in res: + detail = f'External: {res["error"].get("message", "")}' except Exception: - detail = f"External: {e}" + detail = f'External: {e}' raise HTTPException( status_code=status_code, - detail=detail if detail else "Open WebUI: Server Connection Error", + detail=detail if detail else 'Open WebUI: Server Connection Error', ) - elif request.app.state.config.STT_ENGINE == "mistral": + elif request.app.state.config.STT_ENGINE == 'mistral': # Check file exists if not os.path.exists(file_path): - raise HTTPException(status_code=400, detail="Audio file not found") + raise HTTPException(status_code=400, detail='Audio file not found') # Check file size file_size = os.path.getsize(file_path) if file_size > MAX_FILE_SIZE: raise HTTPException( status_code=400, - detail=f"File size exceeds limit of {MAX_FILE_SIZE_MB}MB", + detail=f'File size exceeds limit of {MAX_FILE_SIZE_MB}MB', ) api_key = request.app.state.config.AUDIO_STT_MISTRAL_API_KEY - api_base_url = ( - request.app.state.config.AUDIO_STT_MISTRAL_API_BASE_URL - or "https://api.mistral.ai/v1" - ) - use_chat_completions = ( - request.app.state.config.AUDIO_STT_MISTRAL_USE_CHAT_COMPLETIONS - ) + api_base_url = request.app.state.config.AUDIO_STT_MISTRAL_API_BASE_URL or 'https://api.mistral.ai/v1' + use_chat_completions = request.app.state.config.AUDIO_STT_MISTRAL_USE_CHAT_COMPLETIONS if not api_key: raise HTTPException( status_code=400, - detail="Mistral API key is required for Mistral STT", + detail='Mistral API key is required for Mistral STT', ) r = None try: # Use voxtral-mini-latest as the default model for transcription - model = request.app.state.config.STT_MODEL or "voxtral-mini-latest" + model = request.app.state.config.STT_MODEL or 'voxtral-mini-latest' log.info( - f"Mistral STT - model: {model}, " - f"method: {'chat_completions' if use_chat_completions else 'transcriptions'}" + f'Mistral STT - model: {model}, ' + f'method: {"chat_completions" if use_chat_completions else "transcriptions"}' ) if use_chat_completions: @@ -927,42 +878,42 @@ def transcription_handler(request, file_path, metadata, user=None): audio_file_to_use = file_path if is_audio_conversion_required(file_path): - log.debug("Converting audio to mp3 for chat completions API") + log.debug('Converting audio to mp3 for chat completions API') converted_path = convert_audio_to_mp3(file_path) if converted_path: audio_file_to_use = converted_path else: - log.error("Audio conversion failed") + log.error('Audio conversion failed') raise HTTPException( status_code=500, - detail="Audio conversion failed. Chat completions API requires mp3 or wav format.", + detail='Audio conversion failed. Chat completions API requires mp3 or wav format.', ) # Read and encode audio file as base64 - with open(audio_file_to_use, "rb") as audio_file: - audio_base64 = base64.b64encode(audio_file.read()).decode("utf-8") + with open(audio_file_to_use, 'rb') as audio_file: + audio_base64 = base64.b64encode(audio_file.read()).decode('utf-8') # Prepare chat completions request - url = f"{api_base_url}/chat/completions" + url = f'{api_base_url}/chat/completions' # Add language instruction if specified - language = metadata.get("language", None) if metadata else None + language = metadata.get('language', None) if metadata else None if language: - text_instruction = f"Transcribe this audio exactly as spoken in {language}. Do not translate it." + text_instruction = f'Transcribe this audio exactly as spoken in {language}. Do not translate it.' else: - text_instruction = "Transcribe this audio exactly as spoken in its original language. Do not translate it to another language." + text_instruction = 'Transcribe this audio exactly as spoken in its original language. Do not translate it to another language.' payload = { - "model": model, - "messages": [ + 'model': model, + 'messages': [ { - "role": "user", - "content": [ + 'role': 'user', + 'content': [ { - "type": "input_audio", - "input_audio": audio_base64, + 'type': 'input_audio', + 'input_audio': audio_base64, }, - {"type": "text", "text": text_instruction}, + {'type': 'text', 'text': text_instruction}, ], } ], @@ -972,8 +923,8 @@ def transcription_handler(request, file_path, metadata, user=None): url=url, json=payload, headers={ - "Authorization": f"Bearer {api_key}", - "Content-Type": "application/json", + 'Authorization': f'Bearer {api_key}', + 'Content-Type': 'application/json', }, timeout=AIOHTTP_CLIENT_TIMEOUT, ) @@ -982,42 +933,37 @@ def transcription_handler(request, file_path, metadata, user=None): response = r.json() # Extract transcript from chat completion response - transcript = ( - response.get("choices", [{}])[0] - .get("message", {}) - .get("content", "") - .strip() - ) + transcript = response.get('choices', [{}])[0].get('message', {}).get('content', '').strip() if not transcript: - raise ValueError("Empty transcript in response") + raise ValueError('Empty transcript in response') - data = {"text": transcript} + data = {'text': transcript} else: # Use dedicated transcriptions API - url = f"{api_base_url}/audio/transcriptions" + url = f'{api_base_url}/audio/transcriptions' # Determine the MIME type mime_type, _ = mimetypes.guess_type(file_path) if not mime_type: - mime_type = "audio/webm" + mime_type = 'audio/webm' # Use context manager to ensure file is properly closed - with open(file_path, "rb") as audio_file: - files = {"file": (filename, audio_file, mime_type)} - data_form = {"model": model} + with open(file_path, 'rb') as audio_file: + files = {'file': (filename, audio_file, mime_type)} + data_form = {'model': model} # Add language if specified in metadata - language = metadata.get("language", None) if metadata else None + language = metadata.get('language', None) if metadata else None if language: - data_form["language"] = language + data_form['language'] = language r = requests.post( url=url, files=files, data=data_form, headers={ - "Authorization": f"Bearer {api_key}", + 'Authorization': f'Bearer {api_key}', }, timeout=AIOHTTP_CLIENT_TIMEOUT, ) @@ -1026,25 +972,25 @@ def transcription_handler(request, file_path, metadata, user=None): response = r.json() # Extract transcript from response - transcript = response.get("text", "").strip() + transcript = response.get('text', '').strip() if not transcript: - raise ValueError("Empty transcript in response") + raise ValueError('Empty transcript in response') - data = {"text": transcript} + data = {'text': transcript} # Save transcript to json file (consistent with other providers) - transcript_file = f"{file_dir}/{id}.json" - with open(transcript_file, "w") as f: + transcript_file = f'{file_dir}/{id}.json' + with open(transcript_file, 'w') as f: json.dump(data, f) log.debug(data) return data except ValueError as e: - log.exception("Error parsing Mistral response") + log.exception('Error parsing Mistral response') raise HTTPException( status_code=500, - detail=f"Failed to parse Mistral response: {str(e)}", + detail=f'Failed to parse Mistral response: {str(e)}', ) except requests.exceptions.RequestException as e: log.exception(e) @@ -1053,23 +999,21 @@ def transcription_handler(request, file_path, metadata, user=None): try: if r is not None and r.status_code != 200: res = r.json() - if "error" in res: - detail = f"External: {res['error'].get('message', '')}" + if 'error' in res: + detail = f'External: {res["error"].get("message", "")}' else: - detail = f"External: {r.text}" + detail = f'External: {r.text}' except Exception: - detail = f"External: {e}" + detail = f'External: {e}' raise HTTPException( - status_code=getattr(r, "status_code", 500) if r else 500, - detail=detail if detail else "Open WebUI: Server Connection Error", + status_code=getattr(r, 'status_code', 500) if r else 500, + detail=detail if detail else 'Open WebUI: Server Connection Error', ) -def transcribe( - request: Request, file_path: str, metadata: Optional[dict] = None, user=None -): - log.info(f"transcribe: {file_path} {metadata}") +def transcribe(request: Request, file_path: str, metadata: Optional[dict] = None, user=None): + log.info(f'transcribe: {file_path} {metadata}') if is_audio_conversion_required(file_path): file_path = convert_audio_to_mp3(file_path) @@ -1082,7 +1026,7 @@ def transcribe( # Always produce a list of chunk paths (could be one entry if small) try: chunk_paths = split_audio(file_path, MAX_FILE_SIZE) - print(f"Chunk paths: {chunk_paths}") + print(f'Chunk paths: {chunk_paths}') except Exception as e: log.exception(e) raise HTTPException( @@ -1095,9 +1039,7 @@ def transcribe( with ThreadPoolExecutor() as executor: # Submit tasks for each chunk_path futures = [ - executor.submit( - transcription_handler, request, chunk_path, metadata, user - ) + executor.submit(transcription_handler, request, chunk_path, metadata, user) for chunk_path in chunk_paths ] # Gather results as they complete @@ -1109,7 +1051,7 @@ def transcribe( except Exception as transcribe_exc: raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=f"Error transcribing chunk: {transcribe_exc}", + detail=f'Error transcribing chunk: {transcribe_exc}', ) finally: # Clean up only the temporary chunks, never the original file @@ -1121,22 +1063,20 @@ def transcribe( pass return { - "text": " ".join([result["text"] for result in results]), + 'text': ' '.join([result['text'] for result in results]), } def compress_audio(file_path): if os.path.getsize(file_path) > MAX_FILE_SIZE: - id = os.path.splitext(os.path.basename(file_path))[ - 0 - ] # Handles names with multiple dots + id = os.path.splitext(os.path.basename(file_path))[0] # Handles names with multiple dots file_dir = os.path.dirname(file_path) audio = AudioSegment.from_file(file_path) audio = audio.set_frame_rate(16000).set_channels(1) # Compress audio - compressed_path = os.path.join(file_dir, f"{id}_compressed.mp3") - audio.export(compressed_path, format="mp3", bitrate="32k") + compressed_path = os.path.join(file_dir, f'{id}_compressed.mp3') + audio.export(compressed_path, format='mp3', bitrate='32k') # log.debug(f"Compressed audio to {compressed_path}") # Uncomment if log is defined return compressed_path @@ -1144,7 +1084,7 @@ def compress_audio(file_path): return file_path -def split_audio(file_path, max_bytes, format="mp3", bitrate="32k"): +def split_audio(file_path, max_bytes, format='mp3', bitrate='32k'): """ Splits audio into chunks not exceeding max_bytes. Returns a list of chunk file paths. If audio fits, returns list with original path. @@ -1167,7 +1107,7 @@ def split_audio(file_path, max_bytes, format="mp3", bitrate="32k"): while start < duration_ms: end = min(start + approx_chunk_ms, duration_ms) chunk = audio[start:end] - chunk_path = f"{base}_chunk_{i}.{format}" + chunk_path = f'{base}_chunk_{i}.{format}' chunk.export(chunk_path, format=format, bitrate=bitrate) # Reduce chunk duration if still too large @@ -1178,7 +1118,7 @@ def split_audio(file_path, max_bytes, format="mp3", bitrate="32k"): if os.path.getsize(chunk_path) > max_bytes: os.remove(chunk_path) - raise Exception("Audio chunk cannot be reduced below max file size.") + raise Exception('Audio chunk cannot be reduced below max file size.') chunks.append(chunk_path) start = end @@ -1187,24 +1127,20 @@ def split_audio(file_path, max_bytes, format="mp3", bitrate="32k"): return chunks -@router.post("/transcriptions") +@router.post('/transcriptions') def transcription( request: Request, file: UploadFile = File(...), language: Optional[str] = Form(None), user=Depends(get_verified_user), ): - if user.role != "admin" and not has_permission( - user.id, "chat.stt", request.app.state.config.USER_PERMISSIONS - ): + if user.role != 'admin' and not has_permission(user.id, 'chat.stt', request.app.state.config.USER_PERMISSIONS): raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.ACCESS_PROHIBITED, ) - log.info(f"file.content_type: {file.content_type}") - stt_supported_content_types = getattr( - request.app.state.config, "STT_SUPPORTED_CONTENT_TYPES", [] - ) + log.info(f'file.content_type: {file.content_type}') + stt_supported_content_types = getattr(request.app.state.config, 'STT_SUPPORTED_CONTENT_TYPES', []) if not strict_match_mime_type(stt_supported_content_types, file.content_type): raise HTTPException( @@ -1213,36 +1149,36 @@ def transcription( ) try: - safe_name = os.path.basename(file.filename) if file.filename else "" - ext = safe_name.rsplit(".", 1)[-1] if "." in safe_name else "" + safe_name = os.path.basename(file.filename) if file.filename else '' + ext = safe_name.rsplit('.', 1)[-1] if '.' in safe_name else '' id = uuid.uuid4() - filename = f"{id}.{ext}" + filename = f'{id}.{ext}' contents = file.file.read() - file_dir = f"{CACHE_DIR}/audio/transcriptions" + file_dir = f'{CACHE_DIR}/audio/transcriptions' os.makedirs(file_dir, exist_ok=True) - file_path = f"{file_dir}/{filename}" + file_path = f'{file_dir}/{filename}' # Defense-in-depth: ensure resolved path stays within intended directory if not os.path.realpath(file_path).startswith(os.path.realpath(file_dir)): - raise ValueError("Invalid file path detected") + raise ValueError('Invalid file path detected') - with open(file_path, "wb") as f: + with open(file_path, 'wb') as f: f.write(contents) try: metadata = None if language: - metadata = {"language": language} + metadata = {'language': language} result = transcribe(request, file_path, metadata, user) return { **result, - "filename": os.path.basename(file_path), + 'filename': os.path.basename(file_path), } except HTTPException: @@ -1252,7 +1188,7 @@ def transcription( raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, - detail="Transcription failed.", + detail='Transcription failed.', ) except HTTPException: @@ -1262,123 +1198,107 @@ def transcription( raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, - detail="Transcription failed.", + detail='Transcription failed.', ) def get_available_models(request: Request) -> list[dict]: available_models = [] - if request.app.state.config.TTS_ENGINE == "openai": + if request.app.state.config.TTS_ENGINE == 'openai': # Use custom endpoint if not using the official OpenAI API URL - if not request.app.state.config.TTS_OPENAI_API_BASE_URL.startswith( - "https://api.openai.com" - ): + if not request.app.state.config.TTS_OPENAI_API_BASE_URL.startswith('https://api.openai.com'): try: response = requests.get( - f"{request.app.state.config.TTS_OPENAI_API_BASE_URL}/audio/models", + f'{request.app.state.config.TTS_OPENAI_API_BASE_URL}/audio/models', timeout=AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST, ) response.raise_for_status() data = response.json() - available_models = data.get("models", []) + available_models = data.get('models', []) except Exception as e: - log.error(f"Error fetching models from custom endpoint: {str(e)}") - available_models = [{"id": "tts-1"}, {"id": "tts-1-hd"}] + log.error(f'Error fetching models from custom endpoint: {str(e)}') + available_models = [{'id': 'tts-1'}, {'id': 'tts-1-hd'}] else: - available_models = [{"id": "tts-1"}, {"id": "tts-1-hd"}] - elif request.app.state.config.TTS_ENGINE == "elevenlabs": + available_models = [{'id': 'tts-1'}, {'id': 'tts-1-hd'}] + elif request.app.state.config.TTS_ENGINE == 'elevenlabs': try: response = requests.get( - f"{ELEVENLABS_API_BASE_URL}/v1/models", + f'{ELEVENLABS_API_BASE_URL}/v1/models', headers={ - "xi-api-key": request.app.state.config.TTS_API_KEY, - "Content-Type": "application/json", + 'xi-api-key': request.app.state.config.TTS_API_KEY, + 'Content-Type': 'application/json', }, timeout=5, ) response.raise_for_status() models = response.json() - available_models = [ - {"name": model["name"], "id": model["model_id"]} for model in models - ] + available_models = [{'name': model['name'], 'id': model['model_id']} for model in models] except requests.RequestException as e: - log.error(f"Error fetching voices: {str(e)}") + log.error(f'Error fetching voices: {str(e)}') return available_models -@router.get("/models") +@router.get('/models') async def get_models(request: Request, user=Depends(get_verified_user)): - return {"models": get_available_models(request)} + return {'models': get_available_models(request)} def get_available_voices(request) -> dict: """Returns {voice_id: voice_name} dict""" available_voices = {} - if request.app.state.config.TTS_ENGINE == "openai": + if request.app.state.config.TTS_ENGINE == 'openai': # Use custom endpoint if not using the official OpenAI API URL - if not request.app.state.config.TTS_OPENAI_API_BASE_URL.startswith( - "https://api.openai.com" - ): + if not request.app.state.config.TTS_OPENAI_API_BASE_URL.startswith('https://api.openai.com'): try: response = requests.get( - f"{request.app.state.config.TTS_OPENAI_API_BASE_URL}/audio/voices", + f'{request.app.state.config.TTS_OPENAI_API_BASE_URL}/audio/voices', timeout=AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST, ) response.raise_for_status() data = response.json() - voices_list = data.get("voices", []) - available_voices = {voice["id"]: voice["name"] for voice in voices_list} + voices_list = data.get('voices', []) + available_voices = {voice['id']: voice['name'] for voice in voices_list} except Exception as e: - log.error(f"Error fetching voices from custom endpoint: {str(e)}") + log.error(f'Error fetching voices from custom endpoint: {str(e)}') available_voices = { - "alloy": "alloy", - "echo": "echo", - "fable": "fable", - "onyx": "onyx", - "nova": "nova", - "shimmer": "shimmer", + 'alloy': 'alloy', + 'echo': 'echo', + 'fable': 'fable', + 'onyx': 'onyx', + 'nova': 'nova', + 'shimmer': 'shimmer', } else: available_voices = { - "alloy": "alloy", - "echo": "echo", - "fable": "fable", - "onyx": "onyx", - "nova": "nova", - "shimmer": "shimmer", + 'alloy': 'alloy', + 'echo': 'echo', + 'fable': 'fable', + 'onyx': 'onyx', + 'nova': 'nova', + 'shimmer': 'shimmer', } - elif request.app.state.config.TTS_ENGINE == "elevenlabs": + elif request.app.state.config.TTS_ENGINE == 'elevenlabs': try: - available_voices = get_elevenlabs_voices( - api_key=request.app.state.config.TTS_API_KEY - ) + available_voices = get_elevenlabs_voices(api_key=request.app.state.config.TTS_API_KEY) except Exception: # Avoided @lru_cache with exception pass - elif request.app.state.config.TTS_ENGINE == "azure": + elif request.app.state.config.TTS_ENGINE == 'azure': try: region = request.app.state.config.TTS_AZURE_SPEECH_REGION base_url = request.app.state.config.TTS_AZURE_SPEECH_BASE_URL - url = ( - base_url or f"https://{region}.tts.speech.microsoft.com" - ) + "/cognitiveservices/voices/list" - headers = { - "Ocp-Apim-Subscription-Key": request.app.state.config.TTS_API_KEY - } + url = (base_url or f'https://{region}.tts.speech.microsoft.com') + '/cognitiveservices/voices/list' + headers = {'Ocp-Apim-Subscription-Key': request.app.state.config.TTS_API_KEY} - response = requests.get( - url, headers=headers, timeout=AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST - ) + response = requests.get(url, headers=headers, timeout=AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST) response.raise_for_status() voices = response.json() for voice in voices: - available_voices[voice["ShortName"]] = ( - f"{voice['DisplayName']} ({voice['ShortName']})" - ) + available_voices[voice['ShortName']] = f'{voice["DisplayName"]} ({voice["ShortName"]})' except requests.RequestException as e: - log.error(f"Error fetching voices: {str(e)}") + log.error(f'Error fetching voices: {str(e)}') return available_voices @@ -1396,10 +1316,10 @@ def get_elevenlabs_voices(api_key: str) -> dict: try: # TODO: Add retries response = requests.get( - f"{ELEVENLABS_API_BASE_URL}/v1/voices", + f'{ELEVENLABS_API_BASE_URL}/v1/voices', headers={ - "xi-api-key": api_key, - "Content-Type": "application/json", + 'xi-api-key': api_key, + 'Content-Type': 'application/json', }, timeout=AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST, ) @@ -1407,20 +1327,16 @@ def get_elevenlabs_voices(api_key: str) -> dict: voices_data = response.json() voices = {} - for voice in voices_data.get("voices", []): - voices[voice["voice_id"]] = voice["name"] + for voice in voices_data.get('voices', []): + voices[voice['voice_id']] = voice['name'] except requests.RequestException as e: # Avoid @lru_cache with exception - log.error(f"Error fetching voices: {str(e)}") - raise RuntimeError(f"Error fetching voices: {str(e)}") + log.error(f'Error fetching voices: {str(e)}') + raise RuntimeError(f'Error fetching voices: {str(e)}') return voices -@router.get("/voices") +@router.get('/voices') async def get_voices(request: Request, user=Depends(get_verified_user)): - return { - "voices": [ - {"id": k, "name": v} for k, v in get_available_voices(request).items() - ] - } + return {'voices': [{'id': k, 'name': v} for k, v in get_available_voices(request).items()]} diff --git a/backend/open_webui/routers/auths.py b/backend/open_webui/routers/auths.py index f5a60c59e6..721a4069d4 100644 --- a/backend/open_webui/routers/auths.py +++ b/backend/open_webui/routers/auths.py @@ -90,14 +90,10 @@ router = APIRouter() log = logging.getLogger(__name__) -signin_rate_limiter = RateLimiter( - redis_client=get_redis_client(), limit=5 * 3, window=60 * 3 -) +signin_rate_limiter = RateLimiter(redis_client=get_redis_client(), limit=5 * 3, window=60 * 3) -def create_session_response( - request: Request, user, db, response: Response = None, set_cookie: bool = False -) -> dict: +def create_session_response(request: Request, user, db, response: Response = None, set_cookie: bool = False) -> dict: """ Create JWT token and build session response for a user. Shared helper for signin, signup, ldap_auth, add_user, and token_exchange endpoints. @@ -115,18 +111,14 @@ def create_session_response( expires_at = int(time.time()) + int(expires_delta.total_seconds()) token = create_token( - data={"id": user.id}, + data={'id': user.id}, expires_delta=expires_delta, ) if set_cookie and response: - datetime_expires_at = ( - datetime.datetime.fromtimestamp(expires_at, datetime.timezone.utc) - if expires_at - else None - ) + datetime_expires_at = datetime.datetime.fromtimestamp(expires_at, datetime.timezone.utc) if expires_at else None response.set_cookie( - key="token", + key='token', value=token, expires=datetime_expires_at, httponly=True, @@ -134,20 +126,18 @@ def create_session_response( secure=WEBUI_AUTH_COOKIE_SECURE, ) - user_permissions = get_permissions( - user.id, request.app.state.config.USER_PERMISSIONS, db=db - ) + user_permissions = get_permissions(user.id, request.app.state.config.USER_PERMISSIONS, db=db) return { - "token": token, - "token_type": "Bearer", - "expires_at": expires_at, - "id": user.id, - "email": user.email, - "name": user.name, - "role": user.role, - "profile_image_url": f"/api/v1/users/{user.id}/profile/image", - "permissions": user_permissions, + 'token': token, + 'token_type': 'Bearer', + 'expires_at': expires_at, + 'id': user.id, + 'email': user.email, + 'name': user.name, + 'role': user.role, + 'profile_image_url': f'/api/v1/users/{user.id}/profile/image', + 'permissions': user_permissions, } @@ -167,15 +157,14 @@ class SessionUserInfoResponse(SessionUserResponse, UserStatus): date_of_birth: Optional[datetime.date] = None -@router.get("/", response_model=SessionUserInfoResponse) +@router.get('/', response_model=SessionUserInfoResponse) async def get_session_user( request: Request, response: Response, user=Depends(get_current_user), db: Session = Depends(get_session), ): - - auth_header = request.headers.get("Authorization") + auth_header = request.headers.get('Authorization') auth_token = get_http_authorization_cred(auth_header) token = auth_token.credentials data = decode_token(token) @@ -183,7 +172,7 @@ async def get_session_user( expires_at = None if data: - expires_at = data.get("exp") + expires_at = data.get('exp') if (expires_at is not None) and int(time.time()) > expires_at: raise HTTPException( @@ -193,38 +182,32 @@ async def get_session_user( # Set the cookie token response.set_cookie( - key="token", + key='token', value=token, - expires=( - datetime.datetime.fromtimestamp(expires_at, datetime.timezone.utc) - if expires_at - else None - ), + expires=(datetime.datetime.fromtimestamp(expires_at, datetime.timezone.utc) if expires_at else None), httponly=True, # Ensures the cookie is not accessible via JavaScript samesite=WEBUI_AUTH_COOKIE_SAME_SITE, secure=WEBUI_AUTH_COOKIE_SECURE, ) - user_permissions = get_permissions( - user.id, request.app.state.config.USER_PERMISSIONS, db=db - ) + user_permissions = get_permissions(user.id, request.app.state.config.USER_PERMISSIONS, db=db) return { - "token": token, - "token_type": "Bearer", - "expires_at": expires_at, - "id": user.id, - "email": user.email, - "name": user.name, - "role": user.role, - "profile_image_url": user.profile_image_url, - "bio": user.bio, - "gender": user.gender, - "date_of_birth": user.date_of_birth, - "status_emoji": user.status_emoji, - "status_message": user.status_message, - "status_expires_at": user.status_expires_at, - "permissions": user_permissions, + 'token': token, + 'token_type': 'Bearer', + 'expires_at': expires_at, + 'id': user.id, + 'email': user.email, + 'name': user.name, + 'role': user.role, + 'profile_image_url': user.profile_image_url, + 'bio': user.bio, + 'gender': user.gender, + 'date_of_birth': user.date_of_birth, + 'status_emoji': user.status_emoji, + 'status_message': user.status_message, + 'status_expires_at': user.status_expires_at, + 'permissions': user_permissions, } @@ -233,7 +216,7 @@ async def get_session_user( ############################ -@router.post("/update/profile", response_model=UserProfileImageResponse) +@router.post('/update/profile', response_model=UserProfileImageResponse) async def update_profile( form_data: UpdateProfileForm, session_user=Depends(get_verified_user), @@ -262,7 +245,7 @@ class UpdateTimezoneForm(BaseModel): timezone: str -@router.post("/update/timezone") +@router.post('/update/timezone') async def update_timezone( form_data: UpdateTimezoneForm, session_user=Depends(get_current_user), @@ -271,10 +254,10 @@ async def update_timezone( if session_user: Users.update_user_by_id( session_user.id, - {"timezone": form_data.timezone}, + {'timezone': form_data.timezone}, db=db, ) - return {"status": True} + return {'status': True} else: raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED) @@ -284,7 +267,7 @@ async def update_timezone( ############################ -@router.post("/update/password", response_model=bool) +@router.post('/update/password', response_model=bool) async def update_password( form_data: UpdatePasswordForm, session_user=Depends(get_current_user), @@ -315,7 +298,7 @@ async def update_password( ############################ # LDAP Authentication ############################ -@router.post("/ldap", response_model=SessionUserResponse) +@router.post('/ldap', response_model=SessionUserResponse) async def ldap_auth( request: Request, response: Response, @@ -324,7 +307,7 @@ async def ldap_auth( ): # Security checks FIRST - before loading any config if not request.app.state.config.ENABLE_LDAP: - raise HTTPException(400, detail="LDAP authentication is not enabled") + raise HTTPException(400, detail='LDAP authentication is not enabled') if not ENABLE_PASSWORD_AUTH: raise HTTPException( @@ -344,14 +327,8 @@ async def ldap_auth( LDAP_APP_PASSWORD = request.app.state.config.LDAP_APP_PASSWORD LDAP_USE_TLS = request.app.state.config.LDAP_USE_TLS LDAP_CA_CERT_FILE = request.app.state.config.LDAP_CA_CERT_FILE - LDAP_VALIDATE_CERT = ( - CERT_REQUIRED if request.app.state.config.LDAP_VALIDATE_CERT else CERT_NONE - ) - LDAP_CIPHERS = ( - request.app.state.config.LDAP_CIPHERS - if request.app.state.config.LDAP_CIPHERS - else "ALL" - ) + LDAP_VALIDATE_CERT = CERT_REQUIRED if request.app.state.config.LDAP_VALIDATE_CERT else CERT_NONE + LDAP_CIPHERS = request.app.state.config.LDAP_CIPHERS if request.app.state.config.LDAP_CIPHERS else 'ALL' try: tls = Tls( @@ -361,8 +338,8 @@ async def ldap_auth( ciphers=LDAP_CIPHERS, ) except Exception as e: - log.error(f"TLS configuration error: {str(e)}") - raise HTTPException(400, detail="Failed to configure TLS for LDAP connection.") + log.error(f'TLS configuration error: {str(e)}') + raise HTTPException(400, detail='Failed to configure TLS for LDAP connection.') try: server = Server( @@ -376,44 +353,38 @@ async def ldap_auth( server, LDAP_APP_DN, LDAP_APP_PASSWORD, - auto_bind="NONE", - authentication="SIMPLE" if LDAP_APP_DN else "ANONYMOUS", + auto_bind='NONE', + authentication='SIMPLE' if LDAP_APP_DN else 'ANONYMOUS', ) if not await asyncio.to_thread(connection_app.bind): - raise HTTPException(400, detail="Application account bind failed") + raise HTTPException(400, detail='Application account bind failed') - ENABLE_LDAP_GROUP_MANAGEMENT = ( - request.app.state.config.ENABLE_LDAP_GROUP_MANAGEMENT - ) + ENABLE_LDAP_GROUP_MANAGEMENT = request.app.state.config.ENABLE_LDAP_GROUP_MANAGEMENT ENABLE_LDAP_GROUP_CREATION = request.app.state.config.ENABLE_LDAP_GROUP_CREATION LDAP_ATTRIBUTE_FOR_GROUPS = request.app.state.config.LDAP_ATTRIBUTE_FOR_GROUPS search_attributes = [ - f"{LDAP_ATTRIBUTE_FOR_USERNAME}", - f"{LDAP_ATTRIBUTE_FOR_MAIL}", - "cn", + f'{LDAP_ATTRIBUTE_FOR_USERNAME}', + f'{LDAP_ATTRIBUTE_FOR_MAIL}', + 'cn', ] if ENABLE_LDAP_GROUP_MANAGEMENT: - search_attributes.append(f"{LDAP_ATTRIBUTE_FOR_GROUPS}") - log.info( - f"LDAP Group Management enabled. Adding {LDAP_ATTRIBUTE_FOR_GROUPS} to search attributes" - ) - log.info(f"LDAP search attributes: {search_attributes}") + search_attributes.append(f'{LDAP_ATTRIBUTE_FOR_GROUPS}') + log.info(f'LDAP Group Management enabled. Adding {LDAP_ATTRIBUTE_FOR_GROUPS} to search attributes') + log.info(f'LDAP search attributes: {search_attributes}') search_success = await asyncio.to_thread( connection_app.search, search_base=LDAP_SEARCH_BASE, - search_filter=f"(&({LDAP_ATTRIBUTE_FOR_USERNAME}={escape_filter_chars(form_data.user.lower())}){LDAP_SEARCH_FILTERS})", + search_filter=f'(&({LDAP_ATTRIBUTE_FOR_USERNAME}={escape_filter_chars(form_data.user.lower())}){LDAP_SEARCH_FILTERS})', attributes=search_attributes, ) if not search_success or not connection_app.entries: - raise HTTPException(400, detail="User not found in the LDAP server") + raise HTTPException(400, detail='User not found in the LDAP server') entry = connection_app.entries[0] - entry_username = entry[f"{LDAP_ATTRIBUTE_FOR_USERNAME}"].value - email = entry[ - f"{LDAP_ATTRIBUTE_FOR_MAIL}" - ].value # retrieve the Attribute value + entry_username = entry[f'{LDAP_ATTRIBUTE_FOR_USERNAME}'].value + email = entry[f'{LDAP_ATTRIBUTE_FOR_MAIL}'].value # retrieve the Attribute value username_list = [] # list of usernames from LDAP attribute if isinstance(entry_username, list): @@ -423,7 +394,7 @@ async def ldap_auth( # TODO: support multiple emails if LDAP returns a list if not email: - raise HTTPException(400, "User does not have a valid email address.") + raise HTTPException(400, 'User does not have a valid email address.') elif isinstance(email, str): email = email.lower() elif isinstance(email, list): @@ -431,47 +402,43 @@ async def ldap_auth( else: email = str(email).lower() - cn = str(entry["cn"]) # common name + cn = str(entry['cn']) # common name user_dn = entry.entry_dn # user distinguished name user_groups = [] if ENABLE_LDAP_GROUP_MANAGEMENT and LDAP_ATTRIBUTE_FOR_GROUPS in entry: group_dns = entry[LDAP_ATTRIBUTE_FOR_GROUPS] - log.info(f"LDAP raw group DNs for user {username_list}: {group_dns}") + log.info(f'LDAP raw group DNs for user {username_list}: {group_dns}') if group_dns: - log.info(f"LDAP group_dns original: {group_dns}") - log.info(f"LDAP group_dns type: {type(group_dns)}") - log.info(f"LDAP group_dns length: {len(group_dns)}") + log.info(f'LDAP group_dns original: {group_dns}') + log.info(f'LDAP group_dns type: {type(group_dns)}') + log.info(f'LDAP group_dns length: {len(group_dns)}') - if hasattr(group_dns, "value"): + if hasattr(group_dns, 'value'): group_dns = group_dns.value - log.info(f"Extracted .value property: {group_dns}") - elif hasattr(group_dns, "__iter__") and not isinstance( - group_dns, (str, bytes) - ): + log.info(f'Extracted .value property: {group_dns}') + elif hasattr(group_dns, '__iter__') and not isinstance(group_dns, (str, bytes)): group_dns = list(group_dns) - log.info(f"Converted to list: {group_dns}") + log.info(f'Converted to list: {group_dns}') if isinstance(group_dns, list): group_dns = [str(item) for item in group_dns] else: group_dns = [str(group_dns)] - log.info( - f"LDAP group_dns after processing - type: {type(group_dns)}, length: {len(group_dns)}" - ) + log.info(f'LDAP group_dns after processing - type: {type(group_dns)}, length: {len(group_dns)}') for group_idx, group_dn in enumerate(group_dns): group_dn = str(group_dn) - log.info(f"Processing group DN #{group_idx + 1}: {group_dn}") + log.info(f'Processing group DN #{group_idx + 1}: {group_dn}') try: group_cn = None - for item in group_dn.split(","): + for item in group_dn.split(','): item = item.strip() - if item.upper().startswith("CN="): + if item.upper().startswith('CN='): group_cn = item[3:] break @@ -479,22 +446,16 @@ async def ldap_auth( user_groups.append(group_cn) else: - log.warning( - f"Could not extract CN from group DN: {group_dn}" - ) + log.warning(f'Could not extract CN from group DN: {group_dn}') except Exception as e: - log.warning( - f"Failed to extract group name from DN {group_dn}: {e}" - ) + log.warning(f'Failed to extract group name from DN {group_dn}: {e}') - log.info( - f"LDAP groups for user {username_list}: {user_groups} (total: {len(user_groups)})" - ) + log.info(f'LDAP groups for user {username_list}: {user_groups} (total: {len(user_groups)})') else: - log.info(f"No groups found for user {username_list}") + log.info(f'No groups found for user {username_list}') elif ENABLE_LDAP_GROUP_MANAGEMENT: log.warning( - f"LDAP Group Management enabled but {LDAP_ATTRIBUTE_FOR_GROUPS} attribute not found in user entry" + f'LDAP Group Management enabled but {LDAP_ATTRIBUTE_FOR_GROUPS} attribute not found in user entry' ) if username_list and form_data.user.lower() in username_list: @@ -502,20 +463,16 @@ async def ldap_auth( server, user_dn, form_data.password, - auto_bind="NONE", - authentication="SIMPLE", + auto_bind='NONE', + authentication='SIMPLE', ) if not await asyncio.to_thread(connection_user.bind): - raise HTTPException(400, "Authentication failed.") + raise HTTPException(400, 'Authentication failed.') user = Users.get_user_by_email(email, db=db) if not user: try: - role = ( - "admin" - if not Users.has_users(db=db) - else request.app.state.config.DEFAULT_USER_ROLE - ) + role = 'admin' if not Users.has_users(db=db) else request.app.state.config.DEFAULT_USER_ROLE user = Auths.insert_new_auth( email=email, @@ -526,9 +483,7 @@ async def ldap_auth( ) if not user: - raise HTTPException( - 500, detail=ERROR_MESSAGES.CREATE_USER_ERROR - ) + raise HTTPException(500, detail=ERROR_MESSAGES.CREATE_USER_ERROR) apply_default_group_assignment( request.app.state.config.DEFAULT_GROUP_ID, @@ -539,39 +494,29 @@ async def ldap_auth( except HTTPException: raise except Exception as err: - log.error(f"LDAP user creation error: {str(err)}") - raise HTTPException( - 500, detail="Internal error occurred during LDAP user creation." - ) + log.error(f'LDAP user creation error: {str(err)}') + raise HTTPException(500, detail='Internal error occurred during LDAP user creation.') user = Auths.authenticate_user_by_email(email, db=db) if user: - if ( - user.role != "admin" - and ENABLE_LDAP_GROUP_MANAGEMENT - and user_groups - ): + if user.role != 'admin' and ENABLE_LDAP_GROUP_MANAGEMENT and user_groups: if ENABLE_LDAP_GROUP_CREATION: Groups.create_groups_by_group_names(user.id, user_groups, db=db) try: Groups.sync_groups_by_group_names(user.id, user_groups, db=db) - log.info( - f"Successfully synced groups for user {user.id}: {user_groups}" - ) + log.info(f'Successfully synced groups for user {user.id}: {user_groups}') except Exception as e: - log.error(f"Failed to sync groups for user {user.id}: {e}") + log.error(f'Failed to sync groups for user {user.id}: {e}') - return create_session_response( - request, user, db, response, set_cookie=True - ) + return create_session_response(request, user, db, response, set_cookie=True) else: raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED) else: - raise HTTPException(400, "User record mismatch.") + raise HTTPException(400, 'User record mismatch.') except Exception as e: - log.error(f"LDAP authentication error: {str(e)}") - raise HTTPException(400, detail="LDAP authentication failed.") + log.error(f'LDAP authentication error: {str(e)}') + raise HTTPException(400, detail='LDAP authentication failed.') ############################ @@ -579,7 +524,7 @@ async def ldap_auth( ############################ -@router.post("/signin", response_model=SessionUserResponse) +@router.post('/signin', response_model=SessionUserResponse) async def signin( request: Request, response: Response, @@ -602,7 +547,7 @@ async def signin( if WEBUI_AUTH_TRUSTED_NAME_HEADER: name = request.headers.get(WEBUI_AUTH_TRUSTED_NAME_HEADER, email) try: - name = urllib.parse.unquote(name, encoding="utf-8") + name = urllib.parse.unquote(name, encoding='utf-8') except Exception as e: pass @@ -616,18 +561,16 @@ async def signin( ) user = Auths.authenticate_user_by_email(email, db=db) - if WEBUI_AUTH_TRUSTED_GROUPS_HEADER and user and user.role != "admin": - group_names = request.headers.get( - WEBUI_AUTH_TRUSTED_GROUPS_HEADER, "" - ).split(",") + if WEBUI_AUTH_TRUSTED_GROUPS_HEADER and user and user.role != 'admin': + group_names = request.headers.get(WEBUI_AUTH_TRUSTED_GROUPS_HEADER, '').split(',') group_names = [name.strip() for name in group_names if name.strip()] if group_names: Groups.sync_groups_by_group_names(user.id, group_names, db=db) elif WEBUI_AUTH == False: - admin_email = "admin@localhost" - admin_password = "admin" + admin_email = 'admin@localhost' + admin_password = 'admin' if Users.get_user_by_email(admin_email.lower(), db=db): user = Auths.authenticate_user( @@ -643,7 +586,7 @@ async def signin( request, admin_email, admin_password, - "User", + 'User', db=db, ) @@ -659,14 +602,14 @@ async def signin( detail=ERROR_MESSAGES.RATE_LIMIT_EXCEEDED, ) - password_bytes = form_data.password.encode("utf-8") + password_bytes = form_data.password.encode('utf-8') if len(password_bytes) > 72: # TODO: Implement other hashing algorithms that support longer passwords - log.info("Password too long, truncating to 72 bytes for bcrypt") + log.info('Password too long, truncating to 72 bytes for bcrypt') password_bytes = password_bytes[:72] # decode safely — ignore incomplete UTF-8 sequences - form_data.password = password_bytes.decode("utf-8", errors="ignore") + form_data.password = password_bytes.decode('utf-8', errors='ignore') user = Auths.authenticate_user( form_data.email.lower(), @@ -690,7 +633,7 @@ async def signup_handler( email: str, password: str, name: str, - profile_image_url: str = "/user.png", + profile_image_url: str = '/user.png', *, db: Session, ) -> UserModel: @@ -720,7 +663,7 @@ async def signup_handler( # Atomically check if this is the only user *after* the insert. # Only the single user present at this point should become admin. if Users.get_num_users(db=db) == 1: - Users.update_user_role_by_id(user.id, "admin", db=db) + Users.update_user_role_by_id(user.id, 'admin', db=db) user = Users.get_user_by_id(user.id, db=db) request.app.state.config.ENABLE_SIGNUP = False @@ -730,9 +673,9 @@ async def signup_handler( request.app.state.config.WEBHOOK_URL, WEBHOOK_MESSAGES.USER_SIGNUP(user.name), { - "action": "signup", - "message": WEBHOOK_MESSAGES.USER_SIGNUP(user.name), - "user": user.model_dump_json(exclude_none=True), + 'action': 'signup', + 'message': WEBHOOK_MESSAGES.USER_SIGNUP(user.name), + 'user': user.model_dump_json(exclude_none=True), }, ) @@ -745,7 +688,7 @@ async def signup_handler( return user -@router.post("/signup", response_model=SessionUserResponse) +@router.post('/signup', response_model=SessionUserResponse) async def signup( request: Request, response: Response, @@ -755,24 +698,15 @@ async def signup( has_users = Users.has_users(db=db) if WEBUI_AUTH: - if ( - not request.app.state.config.ENABLE_SIGNUP - or not request.app.state.config.ENABLE_LOGIN_FORM - ): + if not request.app.state.config.ENABLE_SIGNUP or not request.app.state.config.ENABLE_LOGIN_FORM: if has_users or not ENABLE_INITIAL_ADMIN_SIGNUP: - raise HTTPException( - status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.ACCESS_PROHIBITED - ) + raise HTTPException(status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.ACCESS_PROHIBITED) else: if has_users: - raise HTTPException( - status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.ACCESS_PROHIBITED - ) + raise HTTPException(status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.ACCESS_PROHIBITED) if not validate_email_format(form_data.email.lower()): - raise HTTPException( - status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.INVALID_EMAIL_FORMAT - ) + raise HTTPException(status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.INVALID_EMAIL_FORMAT) if Users.get_user_by_email(form_data.email.lower(), db=db): raise HTTPException(400, detail=ERROR_MESSAGES.EMAIL_TAKEN) @@ -795,34 +729,31 @@ async def signup( except HTTPException: raise except Exception as err: - log.error(f"Signup error: {str(err)}") - raise HTTPException(500, detail="An internal error occurred during signup.") + log.error(f'Signup error: {str(err)}') + raise HTTPException(500, detail='An internal error occurred during signup.') -@router.get("/signout") -async def signout( - request: Request, response: Response, db: Session = Depends(get_session) -): - +@router.get('/signout') +async def signout(request: Request, response: Response, db: Session = Depends(get_session)): # get auth token from headers or cookies token = None - auth_header = request.headers.get("Authorization") + auth_header = request.headers.get('Authorization') if auth_header: auth_cred = get_http_authorization_cred(auth_header) token = auth_cred.credentials else: - token = request.cookies.get("token") + token = request.cookies.get('token') if token: await invalidate_token(request, token) - response.delete_cookie("token") - response.delete_cookie("oui-session") - response.delete_cookie("oauth_id_token") + response.delete_cookie('token') + response.delete_cookie('oui-session') + response.delete_cookie('oauth_id_token') - oauth_session_id = request.cookies.get("oauth_session_id") + oauth_session_id = request.cookies.get('oauth_session_id') if oauth_session_id: - response.delete_cookie("oauth_session_id") + response.delete_cookie('oauth_session_id') session = OAuthSessions.get_session_by_id(oauth_session_id, db=db) @@ -832,49 +763,47 @@ async def signout( return JSONResponse( status_code=200, content={ - "status": True, - "redirect_url": OPENID_END_SESSION_ENDPOINT.value, + 'status': True, + 'redirect_url': OPENID_END_SESSION_ENDPOINT.value, }, headers=response.headers, ) oauth_server_metadata_url = ( - request.app.state.oauth_manager.get_server_metadata_url(session.provider) - if session - else None + request.app.state.oauth_manager.get_server_metadata_url(session.provider) if session else None ) or OPENID_PROVIDER_URL.value if session and oauth_server_metadata_url: - oauth_id_token = session.token.get("id_token") + oauth_id_token = session.token.get('id_token') try: async with ClientSession(trust_env=True) as session: async with session.get(oauth_server_metadata_url) as r: if r.status == 200: openid_data = await r.json() - logout_url = openid_data.get("end_session_endpoint") + logout_url = openid_data.get('end_session_endpoint') if logout_url: return JSONResponse( status_code=200, content={ - "status": True, - "redirect_url": f"{logout_url}?id_token_hint={oauth_id_token}" + 'status': True, + 'redirect_url': f'{logout_url}?id_token_hint={oauth_id_token}' + ( - f"&post_logout_redirect_uri={WEBUI_AUTH_SIGNOUT_REDIRECT_URL}" + f'&post_logout_redirect_uri={WEBUI_AUTH_SIGNOUT_REDIRECT_URL}' if WEBUI_AUTH_SIGNOUT_REDIRECT_URL - else "" + else '' ), }, headers=response.headers, ) else: - raise Exception("Failed to fetch OpenID configuration") + raise Exception('Failed to fetch OpenID configuration') except Exception as e: - log.error(f"OpenID signout error: {str(e)}") + log.error(f'OpenID signout error: {str(e)}') raise HTTPException( status_code=500, - detail="Failed to sign out from the OpenID provider.", + detail='Failed to sign out from the OpenID provider.', headers=response.headers, ) @@ -882,15 +811,13 @@ async def signout( return JSONResponse( status_code=200, content={ - "status": True, - "redirect_url": WEBUI_AUTH_SIGNOUT_REDIRECT_URL, + 'status': True, + 'redirect_url': WEBUI_AUTH_SIGNOUT_REDIRECT_URL, }, headers=response.headers, ) - return JSONResponse( - status_code=200, content={"status": True}, headers=response.headers - ) + return JSONResponse(status_code=200, content={'status': True}, headers=response.headers) ############################ @@ -898,7 +825,7 @@ async def signout( ############################ -@router.post("/add", response_model=SigninResponse) +@router.post('/add', response_model=SigninResponse) async def add_user( request: Request, form_data: AddUserForm, @@ -906,9 +833,7 @@ async def add_user( db: Session = Depends(get_session), ): if not validate_email_format(form_data.email.lower()): - raise HTTPException( - status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.INVALID_EMAIL_FORMAT - ) + raise HTTPException(status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.INVALID_EMAIL_FORMAT) if Users.get_user_by_email(form_data.email.lower(), db=db): raise HTTPException(400, detail=ERROR_MESSAGES.EMAIL_TAKEN) @@ -936,25 +861,23 @@ async def add_user( db=db, ) - token = create_token(data={"id": user.id}) + token = create_token(data={'id': user.id}) return { - "token": token, - "token_type": "Bearer", - "id": user.id, - "email": user.email, - "name": user.name, - "role": user.role, - "profile_image_url": f"/api/v1/users/{user.id}/profile/image", + 'token': token, + 'token_type': 'Bearer', + 'id': user.id, + 'email': user.email, + 'name': user.name, + 'role': user.role, + 'profile_image_url': f'/api/v1/users/{user.id}/profile/image', } else: raise HTTPException(500, detail=ERROR_MESSAGES.CREATE_USER_ERROR) except HTTPException: raise except Exception as err: - log.error(f"Add user error: {str(err)}") - raise HTTPException( - 500, detail="An internal error occurred while adding the user." - ) + log.error(f'Add user error: {str(err)}') + raise HTTPException(500, detail='An internal error occurred while adding the user.') ############################ @@ -962,15 +885,13 @@ async def add_user( ############################ -@router.get("/admin/details") -async def get_admin_details( - request: Request, user=Depends(get_current_user), db: Session = Depends(get_session) -): +@router.get('/admin/details') +async def get_admin_details(request: Request, user=Depends(get_current_user), db: Session = Depends(get_session)): if request.app.state.config.SHOW_ADMIN_DETAILS: admin_email = request.app.state.config.ADMIN_EMAIL admin_name = None - log.info(f"Admin details - Email: {admin_email}, Name: {admin_name}") + log.info(f'Admin details - Email: {admin_email}, Name: {admin_name}') if admin_email: admin = Users.get_user_by_email(admin_email, db=db) @@ -983,8 +904,8 @@ async def get_admin_details( admin_name = admin.name return { - "name": admin_name, - "email": admin_email, + 'name': admin_name, + 'email': admin_email, } else: raise HTTPException(400, detail=ERROR_MESSAGES.ACTION_PROHIBITED) @@ -995,31 +916,31 @@ async def get_admin_details( ############################ -@router.get("/admin/config") +@router.get('/admin/config') async def get_admin_config(request: Request, user=Depends(get_admin_user)): return { - "SHOW_ADMIN_DETAILS": request.app.state.config.SHOW_ADMIN_DETAILS, - "ADMIN_EMAIL": request.app.state.config.ADMIN_EMAIL, - "WEBUI_URL": request.app.state.config.WEBUI_URL, - "ENABLE_SIGNUP": request.app.state.config.ENABLE_SIGNUP, - "ENABLE_API_KEYS": request.app.state.config.ENABLE_API_KEYS, - "ENABLE_API_KEYS_ENDPOINT_RESTRICTIONS": request.app.state.config.ENABLE_API_KEYS_ENDPOINT_RESTRICTIONS, - "API_KEYS_ALLOWED_ENDPOINTS": request.app.state.config.API_KEYS_ALLOWED_ENDPOINTS, - "DEFAULT_USER_ROLE": request.app.state.config.DEFAULT_USER_ROLE, - "DEFAULT_GROUP_ID": request.app.state.config.DEFAULT_GROUP_ID, - "JWT_EXPIRES_IN": request.app.state.config.JWT_EXPIRES_IN, - "ENABLE_COMMUNITY_SHARING": request.app.state.config.ENABLE_COMMUNITY_SHARING, - "ENABLE_MESSAGE_RATING": request.app.state.config.ENABLE_MESSAGE_RATING, - "ENABLE_FOLDERS": request.app.state.config.ENABLE_FOLDERS, - "FOLDER_MAX_FILE_COUNT": request.app.state.config.FOLDER_MAX_FILE_COUNT, - "ENABLE_CHANNELS": request.app.state.config.ENABLE_CHANNELS, - "ENABLE_MEMORIES": request.app.state.config.ENABLE_MEMORIES, - "ENABLE_NOTES": request.app.state.config.ENABLE_NOTES, - "ENABLE_USER_WEBHOOKS": request.app.state.config.ENABLE_USER_WEBHOOKS, - "ENABLE_USER_STATUS": request.app.state.config.ENABLE_USER_STATUS, - "PENDING_USER_OVERLAY_TITLE": request.app.state.config.PENDING_USER_OVERLAY_TITLE, - "PENDING_USER_OVERLAY_CONTENT": request.app.state.config.PENDING_USER_OVERLAY_CONTENT, - "RESPONSE_WATERMARK": request.app.state.config.RESPONSE_WATERMARK, + 'SHOW_ADMIN_DETAILS': request.app.state.config.SHOW_ADMIN_DETAILS, + 'ADMIN_EMAIL': request.app.state.config.ADMIN_EMAIL, + 'WEBUI_URL': request.app.state.config.WEBUI_URL, + 'ENABLE_SIGNUP': request.app.state.config.ENABLE_SIGNUP, + 'ENABLE_API_KEYS': request.app.state.config.ENABLE_API_KEYS, + 'ENABLE_API_KEYS_ENDPOINT_RESTRICTIONS': request.app.state.config.ENABLE_API_KEYS_ENDPOINT_RESTRICTIONS, + 'API_KEYS_ALLOWED_ENDPOINTS': request.app.state.config.API_KEYS_ALLOWED_ENDPOINTS, + 'DEFAULT_USER_ROLE': request.app.state.config.DEFAULT_USER_ROLE, + 'DEFAULT_GROUP_ID': request.app.state.config.DEFAULT_GROUP_ID, + 'JWT_EXPIRES_IN': request.app.state.config.JWT_EXPIRES_IN, + 'ENABLE_COMMUNITY_SHARING': request.app.state.config.ENABLE_COMMUNITY_SHARING, + 'ENABLE_MESSAGE_RATING': request.app.state.config.ENABLE_MESSAGE_RATING, + 'ENABLE_FOLDERS': request.app.state.config.ENABLE_FOLDERS, + 'FOLDER_MAX_FILE_COUNT': request.app.state.config.FOLDER_MAX_FILE_COUNT, + 'ENABLE_CHANNELS': request.app.state.config.ENABLE_CHANNELS, + 'ENABLE_MEMORIES': request.app.state.config.ENABLE_MEMORIES, + 'ENABLE_NOTES': request.app.state.config.ENABLE_NOTES, + 'ENABLE_USER_WEBHOOKS': request.app.state.config.ENABLE_USER_WEBHOOKS, + 'ENABLE_USER_STATUS': request.app.state.config.ENABLE_USER_STATUS, + 'PENDING_USER_OVERLAY_TITLE': request.app.state.config.PENDING_USER_OVERLAY_TITLE, + 'PENDING_USER_OVERLAY_CONTENT': request.app.state.config.PENDING_USER_OVERLAY_CONTENT, + 'RESPONSE_WATERMARK': request.app.state.config.RESPONSE_WATERMARK, } @@ -1048,82 +969,70 @@ class AdminConfig(BaseModel): RESPONSE_WATERMARK: Optional[str] = None -@router.post("/admin/config") -async def update_admin_config( - request: Request, form_data: AdminConfig, user=Depends(get_admin_user) -): +@router.post('/admin/config') +async def update_admin_config(request: Request, form_data: AdminConfig, user=Depends(get_admin_user)): request.app.state.config.SHOW_ADMIN_DETAILS = form_data.SHOW_ADMIN_DETAILS request.app.state.config.ADMIN_EMAIL = form_data.ADMIN_EMAIL request.app.state.config.WEBUI_URL = form_data.WEBUI_URL request.app.state.config.ENABLE_SIGNUP = form_data.ENABLE_SIGNUP request.app.state.config.ENABLE_API_KEYS = form_data.ENABLE_API_KEYS - request.app.state.config.ENABLE_API_KEYS_ENDPOINT_RESTRICTIONS = ( - form_data.ENABLE_API_KEYS_ENDPOINT_RESTRICTIONS - ) - request.app.state.config.API_KEYS_ALLOWED_ENDPOINTS = ( - form_data.API_KEYS_ALLOWED_ENDPOINTS - ) + request.app.state.config.ENABLE_API_KEYS_ENDPOINT_RESTRICTIONS = form_data.ENABLE_API_KEYS_ENDPOINT_RESTRICTIONS + request.app.state.config.API_KEYS_ALLOWED_ENDPOINTS = form_data.API_KEYS_ALLOWED_ENDPOINTS request.app.state.config.ENABLE_FOLDERS = form_data.ENABLE_FOLDERS request.app.state.config.FOLDER_MAX_FILE_COUNT = ( - int(form_data.FOLDER_MAX_FILE_COUNT) if form_data.FOLDER_MAX_FILE_COUNT else "" + int(form_data.FOLDER_MAX_FILE_COUNT) if form_data.FOLDER_MAX_FILE_COUNT else '' ) request.app.state.config.ENABLE_CHANNELS = form_data.ENABLE_CHANNELS request.app.state.config.ENABLE_MEMORIES = form_data.ENABLE_MEMORIES request.app.state.config.ENABLE_NOTES = form_data.ENABLE_NOTES - if form_data.DEFAULT_USER_ROLE in ["pending", "user", "admin"]: + if form_data.DEFAULT_USER_ROLE in ['pending', 'user', 'admin']: request.app.state.config.DEFAULT_USER_ROLE = form_data.DEFAULT_USER_ROLE request.app.state.config.DEFAULT_GROUP_ID = form_data.DEFAULT_GROUP_ID - pattern = r"^(-1|0|(-?\d+(\.\d+)?)(ms|s|m|h|d|w))$" + pattern = r'^(-1|0|(-?\d+(\.\d+)?)(ms|s|m|h|d|w))$' # Check if the input string matches the pattern if re.match(pattern, form_data.JWT_EXPIRES_IN): request.app.state.config.JWT_EXPIRES_IN = form_data.JWT_EXPIRES_IN - request.app.state.config.ENABLE_COMMUNITY_SHARING = ( - form_data.ENABLE_COMMUNITY_SHARING - ) + request.app.state.config.ENABLE_COMMUNITY_SHARING = form_data.ENABLE_COMMUNITY_SHARING request.app.state.config.ENABLE_MESSAGE_RATING = form_data.ENABLE_MESSAGE_RATING request.app.state.config.ENABLE_USER_WEBHOOKS = form_data.ENABLE_USER_WEBHOOKS request.app.state.config.ENABLE_USER_STATUS = form_data.ENABLE_USER_STATUS - request.app.state.config.PENDING_USER_OVERLAY_TITLE = ( - form_data.PENDING_USER_OVERLAY_TITLE - ) - request.app.state.config.PENDING_USER_OVERLAY_CONTENT = ( - form_data.PENDING_USER_OVERLAY_CONTENT - ) + request.app.state.config.PENDING_USER_OVERLAY_TITLE = form_data.PENDING_USER_OVERLAY_TITLE + request.app.state.config.PENDING_USER_OVERLAY_CONTENT = form_data.PENDING_USER_OVERLAY_CONTENT request.app.state.config.RESPONSE_WATERMARK = form_data.RESPONSE_WATERMARK return { - "SHOW_ADMIN_DETAILS": request.app.state.config.SHOW_ADMIN_DETAILS, - "ADMIN_EMAIL": request.app.state.config.ADMIN_EMAIL, - "WEBUI_URL": request.app.state.config.WEBUI_URL, - "ENABLE_SIGNUP": request.app.state.config.ENABLE_SIGNUP, - "ENABLE_API_KEYS": request.app.state.config.ENABLE_API_KEYS, - "ENABLE_API_KEYS_ENDPOINT_RESTRICTIONS": request.app.state.config.ENABLE_API_KEYS_ENDPOINT_RESTRICTIONS, - "API_KEYS_ALLOWED_ENDPOINTS": request.app.state.config.API_KEYS_ALLOWED_ENDPOINTS, - "DEFAULT_USER_ROLE": request.app.state.config.DEFAULT_USER_ROLE, - "DEFAULT_GROUP_ID": request.app.state.config.DEFAULT_GROUP_ID, - "JWT_EXPIRES_IN": request.app.state.config.JWT_EXPIRES_IN, - "ENABLE_COMMUNITY_SHARING": request.app.state.config.ENABLE_COMMUNITY_SHARING, - "ENABLE_MESSAGE_RATING": request.app.state.config.ENABLE_MESSAGE_RATING, - "ENABLE_FOLDERS": request.app.state.config.ENABLE_FOLDERS, - "FOLDER_MAX_FILE_COUNT": request.app.state.config.FOLDER_MAX_FILE_COUNT, - "ENABLE_CHANNELS": request.app.state.config.ENABLE_CHANNELS, - "ENABLE_MEMORIES": request.app.state.config.ENABLE_MEMORIES, - "ENABLE_NOTES": request.app.state.config.ENABLE_NOTES, - "ENABLE_USER_WEBHOOKS": request.app.state.config.ENABLE_USER_WEBHOOKS, - "ENABLE_USER_STATUS": request.app.state.config.ENABLE_USER_STATUS, - "PENDING_USER_OVERLAY_TITLE": request.app.state.config.PENDING_USER_OVERLAY_TITLE, - "PENDING_USER_OVERLAY_CONTENT": request.app.state.config.PENDING_USER_OVERLAY_CONTENT, - "RESPONSE_WATERMARK": request.app.state.config.RESPONSE_WATERMARK, + 'SHOW_ADMIN_DETAILS': request.app.state.config.SHOW_ADMIN_DETAILS, + 'ADMIN_EMAIL': request.app.state.config.ADMIN_EMAIL, + 'WEBUI_URL': request.app.state.config.WEBUI_URL, + 'ENABLE_SIGNUP': request.app.state.config.ENABLE_SIGNUP, + 'ENABLE_API_KEYS': request.app.state.config.ENABLE_API_KEYS, + 'ENABLE_API_KEYS_ENDPOINT_RESTRICTIONS': request.app.state.config.ENABLE_API_KEYS_ENDPOINT_RESTRICTIONS, + 'API_KEYS_ALLOWED_ENDPOINTS': request.app.state.config.API_KEYS_ALLOWED_ENDPOINTS, + 'DEFAULT_USER_ROLE': request.app.state.config.DEFAULT_USER_ROLE, + 'DEFAULT_GROUP_ID': request.app.state.config.DEFAULT_GROUP_ID, + 'JWT_EXPIRES_IN': request.app.state.config.JWT_EXPIRES_IN, + 'ENABLE_COMMUNITY_SHARING': request.app.state.config.ENABLE_COMMUNITY_SHARING, + 'ENABLE_MESSAGE_RATING': request.app.state.config.ENABLE_MESSAGE_RATING, + 'ENABLE_FOLDERS': request.app.state.config.ENABLE_FOLDERS, + 'FOLDER_MAX_FILE_COUNT': request.app.state.config.FOLDER_MAX_FILE_COUNT, + 'ENABLE_CHANNELS': request.app.state.config.ENABLE_CHANNELS, + 'ENABLE_MEMORIES': request.app.state.config.ENABLE_MEMORIES, + 'ENABLE_NOTES': request.app.state.config.ENABLE_NOTES, + 'ENABLE_USER_WEBHOOKS': request.app.state.config.ENABLE_USER_WEBHOOKS, + 'ENABLE_USER_STATUS': request.app.state.config.ENABLE_USER_STATUS, + 'PENDING_USER_OVERLAY_TITLE': request.app.state.config.PENDING_USER_OVERLAY_TITLE, + 'PENDING_USER_OVERLAY_CONTENT': request.app.state.config.PENDING_USER_OVERLAY_CONTENT, + 'RESPONSE_WATERMARK': request.app.state.config.RESPONSE_WATERMARK, } @@ -1131,62 +1040,58 @@ class LdapServerConfig(BaseModel): label: str host: str port: Optional[int] = None - attribute_for_mail: str = "mail" - attribute_for_username: str = "uid" + attribute_for_mail: str = 'mail' + attribute_for_username: str = 'uid' app_dn: str app_dn_password: str search_base: str - search_filters: str = "" + search_filters: str = '' use_tls: bool = True certificate_path: Optional[str] = None validate_cert: bool = True - ciphers: Optional[str] = "ALL" + ciphers: Optional[str] = 'ALL' -@router.get("/admin/config/ldap/server", response_model=LdapServerConfig) +@router.get('/admin/config/ldap/server', response_model=LdapServerConfig) async def get_ldap_server(request: Request, user=Depends(get_admin_user)): return { - "label": request.app.state.config.LDAP_SERVER_LABEL, - "host": request.app.state.config.LDAP_SERVER_HOST, - "port": request.app.state.config.LDAP_SERVER_PORT, - "attribute_for_mail": request.app.state.config.LDAP_ATTRIBUTE_FOR_MAIL, - "attribute_for_username": request.app.state.config.LDAP_ATTRIBUTE_FOR_USERNAME, - "app_dn": request.app.state.config.LDAP_APP_DN, - "app_dn_password": request.app.state.config.LDAP_APP_PASSWORD, - "search_base": request.app.state.config.LDAP_SEARCH_BASE, - "search_filters": request.app.state.config.LDAP_SEARCH_FILTERS, - "use_tls": request.app.state.config.LDAP_USE_TLS, - "certificate_path": request.app.state.config.LDAP_CA_CERT_FILE, - "validate_cert": request.app.state.config.LDAP_VALIDATE_CERT, - "ciphers": request.app.state.config.LDAP_CIPHERS, + 'label': request.app.state.config.LDAP_SERVER_LABEL, + 'host': request.app.state.config.LDAP_SERVER_HOST, + 'port': request.app.state.config.LDAP_SERVER_PORT, + 'attribute_for_mail': request.app.state.config.LDAP_ATTRIBUTE_FOR_MAIL, + 'attribute_for_username': request.app.state.config.LDAP_ATTRIBUTE_FOR_USERNAME, + 'app_dn': request.app.state.config.LDAP_APP_DN, + 'app_dn_password': request.app.state.config.LDAP_APP_PASSWORD, + 'search_base': request.app.state.config.LDAP_SEARCH_BASE, + 'search_filters': request.app.state.config.LDAP_SEARCH_FILTERS, + 'use_tls': request.app.state.config.LDAP_USE_TLS, + 'certificate_path': request.app.state.config.LDAP_CA_CERT_FILE, + 'validate_cert': request.app.state.config.LDAP_VALIDATE_CERT, + 'ciphers': request.app.state.config.LDAP_CIPHERS, } -@router.post("/admin/config/ldap/server") -async def update_ldap_server( - request: Request, form_data: LdapServerConfig, user=Depends(get_admin_user) -): +@router.post('/admin/config/ldap/server') +async def update_ldap_server(request: Request, form_data: LdapServerConfig, user=Depends(get_admin_user)): required_fields = [ - "label", - "host", - "attribute_for_mail", - "attribute_for_username", - "search_base", + 'label', + 'host', + 'attribute_for_mail', + 'attribute_for_username', + 'search_base', ] for key in required_fields: value = getattr(form_data, key) if not value: - raise HTTPException(400, detail=f"Required field {key} is empty") + raise HTTPException(400, detail=f'Required field {key} is empty') request.app.state.config.LDAP_SERVER_LABEL = form_data.label request.app.state.config.LDAP_SERVER_HOST = form_data.host request.app.state.config.LDAP_SERVER_PORT = form_data.port request.app.state.config.LDAP_ATTRIBUTE_FOR_MAIL = form_data.attribute_for_mail - request.app.state.config.LDAP_ATTRIBUTE_FOR_USERNAME = ( - form_data.attribute_for_username - ) - request.app.state.config.LDAP_APP_DN = form_data.app_dn or "" - request.app.state.config.LDAP_APP_PASSWORD = form_data.app_dn_password or "" + request.app.state.config.LDAP_ATTRIBUTE_FOR_USERNAME = form_data.attribute_for_username + request.app.state.config.LDAP_APP_DN = form_data.app_dn or '' + request.app.state.config.LDAP_APP_PASSWORD = form_data.app_dn_password or '' request.app.state.config.LDAP_SEARCH_BASE = form_data.search_base request.app.state.config.LDAP_SEARCH_FILTERS = form_data.search_filters request.app.state.config.LDAP_USE_TLS = form_data.use_tls @@ -1195,37 +1100,35 @@ async def update_ldap_server( request.app.state.config.LDAP_CIPHERS = form_data.ciphers return { - "label": request.app.state.config.LDAP_SERVER_LABEL, - "host": request.app.state.config.LDAP_SERVER_HOST, - "port": request.app.state.config.LDAP_SERVER_PORT, - "attribute_for_mail": request.app.state.config.LDAP_ATTRIBUTE_FOR_MAIL, - "attribute_for_username": request.app.state.config.LDAP_ATTRIBUTE_FOR_USERNAME, - "app_dn": request.app.state.config.LDAP_APP_DN, - "app_dn_password": request.app.state.config.LDAP_APP_PASSWORD, - "search_base": request.app.state.config.LDAP_SEARCH_BASE, - "search_filters": request.app.state.config.LDAP_SEARCH_FILTERS, - "use_tls": request.app.state.config.LDAP_USE_TLS, - "certificate_path": request.app.state.config.LDAP_CA_CERT_FILE, - "validate_cert": request.app.state.config.LDAP_VALIDATE_CERT, - "ciphers": request.app.state.config.LDAP_CIPHERS, + 'label': request.app.state.config.LDAP_SERVER_LABEL, + 'host': request.app.state.config.LDAP_SERVER_HOST, + 'port': request.app.state.config.LDAP_SERVER_PORT, + 'attribute_for_mail': request.app.state.config.LDAP_ATTRIBUTE_FOR_MAIL, + 'attribute_for_username': request.app.state.config.LDAP_ATTRIBUTE_FOR_USERNAME, + 'app_dn': request.app.state.config.LDAP_APP_DN, + 'app_dn_password': request.app.state.config.LDAP_APP_PASSWORD, + 'search_base': request.app.state.config.LDAP_SEARCH_BASE, + 'search_filters': request.app.state.config.LDAP_SEARCH_FILTERS, + 'use_tls': request.app.state.config.LDAP_USE_TLS, + 'certificate_path': request.app.state.config.LDAP_CA_CERT_FILE, + 'validate_cert': request.app.state.config.LDAP_VALIDATE_CERT, + 'ciphers': request.app.state.config.LDAP_CIPHERS, } -@router.get("/admin/config/ldap") +@router.get('/admin/config/ldap') async def get_ldap_config(request: Request, user=Depends(get_admin_user)): - return {"ENABLE_LDAP": request.app.state.config.ENABLE_LDAP} + return {'ENABLE_LDAP': request.app.state.config.ENABLE_LDAP} class LdapConfigForm(BaseModel): enable_ldap: Optional[bool] = None -@router.post("/admin/config/ldap") -async def update_ldap_config( - request: Request, form_data: LdapConfigForm, user=Depends(get_admin_user) -): +@router.post('/admin/config/ldap') +async def update_ldap_config(request: Request, form_data: LdapConfigForm, user=Depends(get_admin_user)): request.app.state.config.ENABLE_LDAP = form_data.enable_ldap - return {"ENABLE_LDAP": request.app.state.config.ENABLE_LDAP} + return {'ENABLE_LDAP': request.app.state.config.ENABLE_LDAP} ############################ @@ -1234,12 +1137,10 @@ async def update_ldap_config( # create api key -@router.post("/api_key", response_model=ApiKey) -async def generate_api_key( - request: Request, user=Depends(get_current_user), db: Session = Depends(get_session) -): +@router.post('/api_key', response_model=ApiKey) +async def generate_api_key(request: Request, user=Depends(get_current_user), db: Session = Depends(get_session)): if not request.app.state.config.ENABLE_API_KEYS or not has_permission( - user.id, "features.api_keys", request.app.state.config.USER_PERMISSIONS + user.id, 'features.api_keys', request.app.state.config.USER_PERMISSIONS ): raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, @@ -1251,29 +1152,25 @@ async def generate_api_key( if success: return { - "api_key": api_key, + 'api_key': api_key, } else: raise HTTPException(500, detail=ERROR_MESSAGES.CREATE_API_KEY_ERROR) # delete api key -@router.delete("/api_key", response_model=bool) -async def delete_api_key( - user=Depends(get_current_user), db: Session = Depends(get_session) -): +@router.delete('/api_key', response_model=bool) +async def delete_api_key(user=Depends(get_current_user), db: Session = Depends(get_session)): return Users.delete_user_api_key_by_id(user.id, db=db) # get api key -@router.get("/api_key", response_model=ApiKey) -async def get_api_key( - user=Depends(get_current_user), db: Session = Depends(get_session) -): +@router.get('/api_key', response_model=ApiKey) +async def get_api_key(user=Depends(get_current_user), db: Session = Depends(get_session)): api_key = Users.get_user_api_key_by_id(user.id, db=db) if api_key: return { - "api_key": api_key, + 'api_key': api_key, } else: raise HTTPException(404, detail=ERROR_MESSAGES.API_KEY_NOT_FOUND) @@ -1288,7 +1185,7 @@ class TokenExchangeForm(BaseModel): token: str # OAuth access token from external provider -@router.post("/oauth/{provider}/token/exchange", response_model=SessionUserResponse) +@router.post('/oauth/{provider}/token/exchange', response_model=SessionUserResponse) async def token_exchange( request: Request, response: Response, @@ -1303,7 +1200,7 @@ async def token_exchange( if not ENABLE_OAUTH_TOKEN_EXCHANGE: raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, - detail="Token exchange is disabled", + detail='Token exchange is disabled', ) provider = provider.lower() @@ -1325,19 +1222,19 @@ async def token_exchange( # Validate the token by calling the userinfo endpoint try: - token_data = {"access_token": form_data.token, "token_type": "Bearer"} + token_data = {'access_token': form_data.token, 'token_type': 'Bearer'} user_data = await client.userinfo(token=token_data) if not user_data: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, - detail="Invalid token or unable to fetch user info", + detail='Invalid token or unable to fetch user info', ) except Exception as e: - log.warning(f"Token exchange failed for provider {provider}: {e}") + log.warning(f'Token exchange failed for provider {provider}: {e}') raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, - detail="Invalid token or unable to validate with provider", + detail='Invalid token or unable to validate with provider', ) # Extract user information from the token claims @@ -1345,23 +1242,20 @@ async def token_exchange( username_claim = request.app.state.config.OAUTH_USERNAME_CLAIM # Get sub claim - sub = user_data.get( - request.app.state.config.OAUTH_SUB_CLAIM - or OAUTH_PROVIDERS[provider].get("sub_claim", "sub") - ) + sub = user_data.get(request.app.state.config.OAUTH_SUB_CLAIM or OAUTH_PROVIDERS[provider].get('sub_claim', 'sub')) if not sub: - log.warning(f"Token exchange failed: sub claim missing from user data") + log.warning(f'Token exchange failed: sub claim missing from user data') raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="Token missing required 'sub' claim", ) - email = user_data.get(email_claim, "") + email = user_data.get(email_claim, '') if not email: - log.warning(f"Token exchange failed: email claim missing from user data") + log.warning(f'Token exchange failed: email claim missing from user data') raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, - detail="Token missing required email claim", + detail='Token missing required email claim', ) email = email.lower() @@ -1378,7 +1272,7 @@ async def token_exchange( if not user: raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, - detail="User not found. Please sign in via the web interface first.", + detail='User not found. Please sign in via the web interface first.', ) return create_session_response(request, user, db) diff --git a/backend/open_webui/routers/channels.py b/backend/open_webui/routers/channels.py index 55a6e6ebba..06d2d2a6bb 100644 --- a/backend/open_webui/routers/channels.py +++ b/backend/open_webui/routers/channels.py @@ -75,34 +75,28 @@ router = APIRouter() def channel_has_access( user_id: str, channel: ChannelModel, - permission: str = "read", + permission: str = 'read', strict: bool = True, db: Optional[Session] = None, ) -> bool: if AccessGrants.has_access( user_id=user_id, - resource_type="channel", + resource_type='channel', resource_id=channel.id, permission=permission, db=db, ): return True - if ( - not strict - and permission == "write" - and has_public_read_access_grant(channel.access_grants) - ): + if not strict and permission == 'write' and has_public_read_access_grant(channel.access_grants): return True return False -def get_channel_users_with_access( - channel: ChannelModel, permission: str = "read", db: Optional[Session] = None -): +def get_channel_users_with_access(channel: ChannelModel, permission: str = 'read', db: Optional[Session] = None): return AccessGrants.get_users_with_access( - resource_type="channel", + resource_type='channel', resource_id=channel.id, permission=permission, db=db, @@ -110,9 +104,9 @@ def get_channel_users_with_access( def get_channel_permitted_group_and_user_ids( - channel: ChannelModel, permission: str = "read" + channel: ChannelModel, permission: str = 'read' ) -> Optional[dict[str, list[str]]]: - if permission == "read" and has_public_read_access_grant(channel.access_grants): + if permission == 'read' and has_public_read_access_grant(channel.access_grants): return None user_ids = [] @@ -121,14 +115,14 @@ def get_channel_permitted_group_and_user_ids( for grant in channel.access_grants: if grant.permission != permission: continue - if grant.principal_type == "group": + if grant.principal_type == 'group': group_ids.append(grant.principal_id) - elif grant.principal_type == "user" and grant.principal_id != "*": + elif grant.principal_type == 'user' and grant.principal_id != '*': user_ids.append(grant.principal_id) return { - "user_ids": list(dict.fromkeys(user_ids)), - "group_ids": list(dict.fromkeys(group_ids)), + 'user_ids': list(dict.fromkeys(user_ids)), + 'group_ids': list(dict.fromkeys(group_ids)), } @@ -142,12 +136,12 @@ def check_channels_access(request: Request, user: Optional[UserModel] = None): if not request.app.state.config.ENABLE_CHANNELS: raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, - detail="Channels are not enabled", + detail='Channels are not enabled', ) if user: - if user.role != "admin" and not has_permission( - user.id, "features.channels", request.app.state.config.USER_PERMISSIONS + if user.role != 'admin' and not has_permission( + user.id, 'features.channels', request.app.state.config.USER_PERMISSIONS ): raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, @@ -168,7 +162,7 @@ class ChannelListItemResponse(ChannelModel): unread_count: int = 0 -@router.get("/", response_model=list[ChannelListItemResponse]) +@router.get('/', response_model=list[ChannelListItemResponse]) async def get_channels( request: Request, user=Depends(get_verified_user), @@ -182,29 +176,22 @@ async def get_channels( last_message = Messages.get_last_message_by_channel_id(channel.id, db=db) last_message_at = last_message.created_at if last_message else None - channel_member = Channels.get_member_by_channel_and_user_id( - channel.id, user.id, db=db - ) + channel_member = Channels.get_member_by_channel_and_user_id(channel.id, user.id, db=db) unread_count = ( - Messages.get_unread_message_count( - channel.id, user.id, channel_member.last_read_at, db=db - ) + Messages.get_unread_message_count(channel.id, user.id, channel_member.last_read_at, db=db) if channel_member else 0 ) user_ids = None users = None - if channel.type == "dm": - user_ids = [ - member.user_id - for member in Channels.get_members_by_channel_id(channel.id, db=db) - ] + if channel.type == 'dm': + user_ids = [member.user_id for member in Channels.get_members_by_channel_id(channel.id, db=db)] users = [ UserIdNameStatusResponse( **{ **user.model_dump(), - "is_active": Users.is_active(user), + 'is_active': Users.is_active(user), } ) for user in Users.get_users_by_user_ids(user_ids, db=db) @@ -223,14 +210,14 @@ async def get_channels( return channel_list -@router.get("/list", response_model=list[ChannelModel]) +@router.get('/list', response_model=list[ChannelModel]) async def get_all_channels( request: Request, user=Depends(get_verified_user), db: Session = Depends(get_session), ): check_channels_access(request) - if user.role == "admin": + if user.role == 'admin': return Channels.get_channels(db=db) return Channels.get_channels_by_user_id(user.id, db=db) @@ -240,7 +227,7 @@ async def get_all_channels( ############################ -@router.get("/users/{user_id}", response_model=Optional[ChannelModel]) +@router.get('/users/{user_id}', response_model=Optional[ChannelModel]) async def get_dm_channel_by_user_id( request: Request, user_id: str, @@ -249,35 +236,26 @@ async def get_dm_channel_by_user_id( ): check_channels_access(request, user) try: - existing_channel = Channels.get_dm_channel_by_user_ids( - [user.id, user_id], db=db - ) + existing_channel = Channels.get_dm_channel_by_user_ids([user.id, user_id], db=db) if existing_channel: participant_ids = [ - member.user_id - for member in Channels.get_members_by_channel_id( - existing_channel.id, db=db - ) + member.user_id for member in Channels.get_members_by_channel_id(existing_channel.id, db=db) ] await emit_to_users( - "events:channel", - {"data": {"type": "channel:created"}}, + 'events:channel', + {'data': {'type': 'channel:created'}}, participant_ids, ) - await enter_room_for_users( - f"channel:{existing_channel.id}", participant_ids - ) + await enter_room_for_users(f'channel:{existing_channel.id}', participant_ids) - Channels.update_member_active_status( - existing_channel.id, user.id, True, db=db - ) + Channels.update_member_active_status(existing_channel.id, user.id, True, db=db) return ChannelModel(**existing_channel.model_dump()) channel = Channels.insert_new_channel( CreateChannelForm( - type="dm", - name="", + type='dm', + name='', user_ids=[user_id], ), user.id, @@ -285,26 +263,21 @@ async def get_dm_channel_by_user_id( ) if channel: - participant_ids = [ - member.user_id - for member in Channels.get_members_by_channel_id(channel.id, db=db) - ] + participant_ids = [member.user_id for member in Channels.get_members_by_channel_id(channel.id, db=db)] await emit_to_users( - "events:channel", - {"data": {"type": "channel:created"}}, + 'events:channel', + {'data': {'type': 'channel:created'}}, participant_ids, ) - await enter_room_for_users(f"channel:{channel.id}", participant_ids) + await enter_room_for_users(f'channel:{channel.id}', participant_ids) return ChannelModel(**channel.model_dump()) else: - raise Exception("Error creating channel") + raise Exception('Error creating channel') except Exception as e: log.exception(e) - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT() - ) + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT()) ############################ @@ -312,7 +285,7 @@ async def get_dm_channel_by_user_id( ############################ -@router.post("/create", response_model=Optional[ChannelModel]) +@router.post('/create', response_model=Optional[ChannelModel]) async def create_new_channel( request: Request, form_data: CreateChannelForm, @@ -321,7 +294,7 @@ async def create_new_channel( ): check_channels_access(request, user) - if form_data.type not in ["group", "dm"] and user.role != "admin": + if form_data.type not in ['group', 'dm'] and user.role != 'admin': # Only admins can create standard channels (joined by default) raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, @@ -329,54 +302,40 @@ async def create_new_channel( ) try: - if form_data.type == "dm": - existing_channel = Channels.get_dm_channel_by_user_ids( - [user.id, *form_data.user_ids], db=db - ) + if form_data.type == 'dm': + existing_channel = Channels.get_dm_channel_by_user_ids([user.id, *form_data.user_ids], db=db) if existing_channel: participant_ids = [ - member.user_id - for member in Channels.get_members_by_channel_id( - existing_channel.id, db=db - ) + member.user_id for member in Channels.get_members_by_channel_id(existing_channel.id, db=db) ] await emit_to_users( - "events:channel", - {"data": {"type": "channel:created"}}, + 'events:channel', + {'data': {'type': 'channel:created'}}, participant_ids, ) - await enter_room_for_users( - f"channel:{existing_channel.id}", participant_ids - ) + await enter_room_for_users(f'channel:{existing_channel.id}', participant_ids) - Channels.update_member_active_status( - existing_channel.id, user.id, True, db=db - ) + Channels.update_member_active_status(existing_channel.id, user.id, True, db=db) return ChannelModel(**existing_channel.model_dump()) channel = Channels.insert_new_channel(form_data, user.id, db=db) if channel: - participant_ids = [ - member.user_id - for member in Channels.get_members_by_channel_id(channel.id, db=db) - ] + participant_ids = [member.user_id for member in Channels.get_members_by_channel_id(channel.id, db=db)] await emit_to_users( - "events:channel", - {"data": {"type": "channel:created"}}, + 'events:channel', + {'data': {'type': 'channel:created'}}, participant_ids, ) - await enter_room_for_users(f"channel:{channel.id}", participant_ids) + await enter_room_for_users(f'channel:{channel.id}', participant_ids) return ChannelModel(**channel.model_dump()) else: - raise Exception("Error creating channel") + raise Exception('Error creating channel') except Exception as e: log.exception(e) - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT() - ) + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT()) ############################ @@ -392,7 +351,7 @@ class ChannelFullResponse(ChannelResponse): unread_count: int = 0 -@router.get("/{id}", response_model=Optional[ChannelFullResponse]) +@router.get('/{id}', response_model=Optional[ChannelFullResponse]) async def get_channel_by_id( request: Request, id: str, @@ -402,37 +361,28 @@ async def get_channel_by_id( check_channels_access(request, user) channel = Channels.get_channel_by_id(id, db=db) if not channel: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND - ) + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND) user_ids = None users = None - if channel.type in ["group", "dm"]: + if channel.type in ['group', 'dm']: if not Channels.is_user_channel_member(channel.id, user.id, db=db): - raise HTTPException( - status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT() - ) + raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()) - user_ids = [ - member.user_id - for member in Channels.get_members_by_channel_id(channel.id, db=db) - ] + user_ids = [member.user_id for member in Channels.get_members_by_channel_id(channel.id, db=db)] users = [ UserIdNameStatusResponse( **{ **user.model_dump(), - "is_active": Users.is_active(user), + 'is_active': Users.is_active(user), } ) for user in Users.get_users_by_user_ids(user_ids, db=db) ] - channel_member = Channels.get_member_by_channel_and_user_id( - channel.id, user.id, db=db - ) + channel_member = Channels.get_member_by_channel_and_user_id(channel.id, user.id, db=db) unread_count = Messages.get_unread_message_count( channel.id, user.id, channel_member.last_read_at if channel_member else None ) @@ -440,38 +390,30 @@ async def get_channel_by_id( return ChannelFullResponse( **{ **channel.model_dump(), - "user_ids": user_ids, - "users": users, - "is_manager": Channels.is_user_channel_manager( - channel.id, user.id, db=db - ), - "write_access": True, - "user_count": len(user_ids), - "last_read_at": channel_member.last_read_at if channel_member else None, - "unread_count": unread_count, + 'user_ids': user_ids, + 'users': users, + 'is_manager': Channels.is_user_channel_manager(channel.id, user.id, db=db), + 'write_access': True, + 'user_count': len(user_ids), + 'last_read_at': channel_member.last_read_at if channel_member else None, + 'unread_count': unread_count, } ) else: - if user.role != "admin" and not channel_has_access( - user.id, channel, permission="read", db=db - ): - raise HTTPException( - status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT() - ) + if user.role != 'admin' and not channel_has_access(user.id, channel, permission='read', db=db): + raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()) write_access = channel_has_access( user.id, channel, - permission="write", + permission='write', strict=False, db=db, ) - user_count = len(get_channel_users_with_access(channel, "read", db=db)) + user_count = len(get_channel_users_with_access(channel, 'read', db=db)) - channel_member = Channels.get_member_by_channel_and_user_id( - channel.id, user.id, db=db - ) + channel_member = Channels.get_member_by_channel_and_user_id(channel.id, user.id, db=db) unread_count = Messages.get_unread_message_count( channel.id, user.id, channel_member.last_read_at if channel_member else None ) @@ -479,15 +421,13 @@ async def get_channel_by_id( return ChannelFullResponse( **{ **channel.model_dump(), - "user_ids": user_ids, - "users": users, - "is_manager": Channels.is_user_channel_manager( - channel.id, user.id, db=db - ), - "write_access": write_access or user.role == "admin", - "user_count": user_count, - "last_read_at": channel_member.last_read_at if channel_member else None, - "unread_count": unread_count, + 'user_ids': user_ids, + 'users': users, + 'is_manager': Channels.is_user_channel_manager(channel.id, user.id, db=db), + 'write_access': write_access or user.role == 'admin', + 'user_count': user_count, + 'last_read_at': channel_member.last_read_at if channel_member else None, + 'unread_count': unread_count, } ) @@ -500,7 +440,7 @@ async def get_channel_by_id( PAGE_ITEM_COUNT = 30 -@router.get("/{id}/members", response_model=UserListResponse) +@router.get('/{id}/members', response_model=UserListResponse) async def get_channel_members_by_id( request: Request, id: str, @@ -515,68 +455,53 @@ async def get_channel_members_by_id( channel = Channels.get_channel_by_id(id, db=db) if not channel: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND - ) + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND) limit = PAGE_ITEM_COUNT page = max(1, page) skip = (page - 1) * limit - if channel.type in ["group", "dm"]: + if channel.type in ['group', 'dm']: if not Channels.is_user_channel_member(channel.id, user.id, db=db): - raise HTTPException( - status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT() - ) + raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()) - if channel.type == "dm": - user_ids = [ - member.user_id - for member in Channels.get_members_by_channel_id(channel.id, db=db) - ] + if channel.type == 'dm': + user_ids = [member.user_id for member in Channels.get_members_by_channel_id(channel.id, db=db)] users = Users.get_users_by_user_ids(user_ids, db=db) total = len(users) return { - "users": [ - UserModelResponse(**user.model_dump(), is_active=Users.is_active(user)) - for user in users - ], - "total": total, + 'users': [UserModelResponse(**user.model_dump(), is_active=Users.is_active(user)) for user in users], + 'total': total, } else: filter = {} if query: - filter["query"] = query + filter['query'] = query if order_by: - filter["order_by"] = order_by + filter['order_by'] = order_by if direction: - filter["direction"] = direction + filter['direction'] = direction - if channel.type == "group": - filter["channel_id"] = channel.id + if channel.type == 'group': + filter['channel_id'] = channel.id else: - filter["roles"] = ["!pending"] - permitted_ids = get_channel_permitted_group_and_user_ids( - channel, permission="read" - ) + filter['roles'] = ['!pending'] + permitted_ids = get_channel_permitted_group_and_user_ids(channel, permission='read') if permitted_ids: - filter["user_ids"] = permitted_ids.get("user_ids") - filter["group_ids"] = permitted_ids.get("group_ids") + filter['user_ids'] = permitted_ids.get('user_ids') + filter['group_ids'] = permitted_ids.get('group_ids') result = Users.get_users(filter=filter, skip=skip, limit=limit, db=db) - users = result["users"] - total = result["total"] + users = result['users'] + total = result['total'] return { - "users": [ - UserModelResponse(**user.model_dump(), is_active=Users.is_active(user)) - for user in users - ], - "total": total, + 'users': [UserModelResponse(**user.model_dump(), is_active=Users.is_active(user)) for user in users], + 'total': total, } @@ -589,7 +514,7 @@ class UpdateActiveMemberForm(BaseModel): is_active: bool -@router.post("/{id}/members/active", response_model=bool) +@router.post('/{id}/members/active', response_model=bool) async def update_is_active_member_by_id_and_user_id( request: Request, id: str, @@ -600,18 +525,12 @@ async def update_is_active_member_by_id_and_user_id( check_channels_access(request) channel = Channels.get_channel_by_id(id, db=db) if not channel: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND - ) + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND) if not Channels.is_user_channel_member(channel.id, user.id, db=db): - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND - ) + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND) - Channels.update_member_active_status( - channel.id, user.id, form_data.is_active, db=db - ) + Channels.update_member_active_status(channel.id, user.id, form_data.is_active, db=db) return True @@ -625,7 +544,7 @@ class UpdateMembersForm(BaseModel): group_ids: list[str] = [] -@router.post("/{id}/update/members/add") +@router.post('/{id}/update/members/add') async def add_members_by_id( request: Request, id: str, @@ -636,14 +555,10 @@ async def add_members_by_id( check_channels_access(request, user) channel = Channels.get_channel_by_id(id, db=db) if not channel: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND - ) + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND) - if channel.user_id != user.id and user.role != "admin": - raise HTTPException( - status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT() - ) + if channel.user_id != user.id and user.role != 'admin': + raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()) try: memberships = Channels.add_members_to_channel( @@ -653,9 +568,7 @@ async def add_members_by_id( return memberships except Exception as e: log.exception(e) - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT() - ) + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT()) ################################################# @@ -667,7 +580,7 @@ class RemoveMembersForm(BaseModel): user_ids: list[str] = [] -@router.post("/{id}/update/members/remove") +@router.post('/{id}/update/members/remove') async def remove_members_by_id( request: Request, id: str, @@ -679,26 +592,18 @@ async def remove_members_by_id( channel = Channels.get_channel_by_id(id, db=db) if not channel: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND - ) + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND) - if channel.user_id != user.id and user.role != "admin": - raise HTTPException( - status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT() - ) + if channel.user_id != user.id and user.role != 'admin': + raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()) try: - deleted = Channels.remove_members_from_channel( - channel.id, form_data.user_ids, db=db - ) + deleted = Channels.remove_members_from_channel(channel.id, form_data.user_ids, db=db) return deleted except Exception as e: log.exception(e) - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT() - ) + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT()) ############################ @@ -706,7 +611,7 @@ async def remove_members_by_id( ############################ -@router.post("/{id}/update", response_model=Optional[ChannelModel]) +@router.post('/{id}/update', response_model=Optional[ChannelModel]) async def update_channel_by_id( request: Request, id: str, @@ -718,23 +623,17 @@ async def update_channel_by_id( channel = Channels.get_channel_by_id(id, db=db) if not channel: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND - ) + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND) - if channel.user_id != user.id and user.role != "admin": - raise HTTPException( - status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT() - ) + if channel.user_id != user.id and user.role != 'admin': + raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()) try: channel = Channels.update_channel_by_id(id, form_data, db=db) return ChannelModel(**channel.model_dump()) except Exception as e: log.exception(e) - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT() - ) + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT()) ############################ @@ -742,7 +641,7 @@ async def update_channel_by_id( ############################ -@router.delete("/{id}/delete", response_model=bool) +@router.delete('/{id}/delete', response_model=bool) async def delete_channel_by_id( request: Request, id: str, @@ -753,23 +652,17 @@ async def delete_channel_by_id( channel = Channels.get_channel_by_id(id, db=db) if not channel: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND - ) + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND) - if channel.user_id != user.id and user.role != "admin": - raise HTTPException( - status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT() - ) + if channel.user_id != user.id and user.role != 'admin': + raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()) try: Channels.delete_channel_by_id(id, db=db) return True except Exception as e: log.exception(e) - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT() - ) + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT()) ############################ @@ -780,7 +673,7 @@ async def delete_channel_by_id( class MessageUserResponse(MessageResponse): data: bool | None = None - @field_validator("data", mode="before") + @field_validator('data', mode='before') def convert_data_to_bool(cls, v): # No data or not a dict → False if not isinstance(v, dict): @@ -790,7 +683,7 @@ class MessageUserResponse(MessageResponse): return any(bool(val) for val in v.values()) -@router.get("/{id}/messages", response_model=list[MessageUserResponse]) +@router.get('/{id}/messages', response_model=list[MessageUserResponse]) async def get_channel_messages( request: Request, id: str, @@ -802,26 +695,16 @@ async def get_channel_messages( check_channels_access(request, user) channel = Channels.get_channel_by_id(id, db=db) if not channel: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND - ) + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND) - if channel.type in ["group", "dm"]: + if channel.type in ['group', 'dm']: if not Channels.is_user_channel_member(channel.id, user.id, db=db): - raise HTTPException( - status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT() - ) + raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()) else: - if user.role != "admin" and not channel_has_access( - user.id, channel, permission="read", db=db - ): - raise HTTPException( - status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT() - ) + if user.role != 'admin' and not channel_has_access(user.id, channel, permission='read', db=db): + raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()) - channel_member = Channels.join_channel( - id, user.id, db=db - ) # Ensure user is a member of the channel + channel_member = Channels.join_channel(id, user.id, db=db) # Ensure user is a member of the channel message_list = Messages.get_messages_by_channel_id(id, skip, limit, db=db) @@ -835,9 +718,7 @@ async def get_channel_messages( messages = [] for message in message_list: thread_replies = Messages.get_thread_replies_by_message_id(message.id, db=db) - latest_thread_reply_at = ( - thread_replies[0].created_at if thread_replies else None - ) + latest_thread_reply_at = thread_replies[0].created_at if thread_replies else None # Use message.user if present (for webhooks), otherwise look up by user_id user_info = message.user @@ -848,12 +729,10 @@ async def get_channel_messages( MessageUserResponse( **{ **message.model_dump(), - "reply_count": len(thread_replies), - "latest_reply_at": latest_thread_reply_at, - "reactions": Messages.get_reactions_by_message_id( - message.id, db=db - ), - "user": user_info, + 'reply_count': len(thread_replies), + 'latest_reply_at': latest_thread_reply_at, + 'reactions': Messages.get_reactions_by_message_id(message.id, db=db), + 'user': user_info, } ) ) @@ -868,7 +747,7 @@ async def get_channel_messages( PAGE_ITEM_COUNT_PINNED = 20 -@router.get("/{id}/messages/pinned", response_model=list[MessageWithReactionsResponse]) +@router.get('/{id}/messages/pinned', response_model=list[MessageWithReactionsResponse]) async def get_pinned_channel_messages( request: Request, id: str, @@ -879,22 +758,14 @@ async def get_pinned_channel_messages( check_channels_access(request) channel = Channels.get_channel_by_id(id, db=db) if not channel: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND - ) + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND) - if channel.type in ["group", "dm"]: + if channel.type in ['group', 'dm']: if not Channels.is_user_channel_member(channel.id, user.id, db=db): - raise HTTPException( - status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT() - ) + raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()) else: - if user.role != "admin" and not channel_has_access( - user.id, channel, permission="read", db=db - ): - raise HTTPException( - status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT() - ) + if user.role != 'admin' and not channel_has_access(user.id, channel, permission='read', db=db): + raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()) page = max(1, page) skip = (page - 1) * PAGE_ITEM_COUNT_PINNED @@ -912,12 +783,12 @@ async def get_pinned_channel_messages( messages = [] for message in message_list: # Check for webhook identity in meta - webhook_info = message.meta.get("webhook") if message.meta else None + webhook_info = message.meta.get('webhook') if message.meta else None if webhook_info: user_info = UserNameResponse( - id=webhook_info.get("id"), - name=webhook_info.get("name"), - role="webhook", + id=webhook_info.get('id'), + name=webhook_info.get('name'), + role='webhook', ) elif message.user_id in users: user_info = UserNameResponse(**users[message.user_id].model_dump()) @@ -928,10 +799,8 @@ async def get_pinned_channel_messages( MessageWithReactionsResponse( **{ **message.model_dump(), - "reactions": Messages.get_reactions_by_message_id( - message.id, db=db - ), - "user": user_info, + 'reactions': Messages.get_reactions_by_message_id(message.id, db=db), + 'user': user_info, } ) ) @@ -944,29 +813,23 @@ async def get_pinned_channel_messages( ############################ -async def send_notification( - name, webui_url, channel, message, active_user_ids, db=None -): - users = get_channel_users_with_access(channel, "read", db=db) +async def send_notification(name, webui_url, channel, message, active_user_ids, db=None): + users = get_channel_users_with_access(channel, 'read', db=db) for user in users: - if (user.id not in active_user_ids) and Channels.is_user_channel_member( - channel.id, user.id, db=db - ): + if (user.id not in active_user_ids) and Channels.is_user_channel_member(channel.id, user.id, db=db): if user.settings: - webhook_url = user.settings.ui.get("notifications", {}).get( - "webhook_url", None - ) + webhook_url = user.settings.ui.get('notifications', {}).get('webhook_url', None) if webhook_url: await post_webhook( name, webhook_url, - f"#{channel.name} - {webui_url}/channels/{channel.id}\n\n{message.content}", + f'#{channel.name} - {webui_url}/channels/{channel.id}\n\n{message.content}', { - "action": "channel", - "message": message.content, - "title": channel.name, - "url": f"{webui_url}/channels/{channel.id}", + 'action': 'channel', + 'message': message.content, + 'title': channel.name, + 'url': f'{webui_url}/channels/{channel.id}', }, ) @@ -974,10 +837,7 @@ async def send_notification( async def model_response_handler(request, channel, message, user, db=None): - MODELS = { - model["id"]: model - for model in get_filtered_models(await get_all_models(request, user=user), user) - } + MODELS = {model['id']: model for model in get_filtered_models(await get_all_models(request, user=user), user)} mentions = extract_mentions(message.content) message_content = replace_mentions(message.content) @@ -988,21 +848,21 @@ async def model_response_handler(request, channel, message, user, db=None): if ( message.reply_to_message and message.reply_to_message.meta - and message.reply_to_message.meta.get("model_id", None) + and message.reply_to_message.meta.get('model_id', None) ): - model_id = message.reply_to_message.meta.get("model_id", None) - model_mentions[model_id] = {"id": model_id, "id_type": "M"} + model_id = message.reply_to_message.meta.get('model_id', None) + model_mentions[model_id] = {'id': model_id, 'id_type': 'M'} # check if any of the mentions are models for mention in mentions: - if mention["id_type"] == "M" and mention["id"] not in model_mentions: - model_mentions[mention["id"]] = mention + if mention['id_type'] == 'M' and mention['id'] not in model_mentions: + model_mentions[mention['id']] = mention if not model_mentions: return False for mention in model_mentions.values(): - model_id = mention["id"] + model_id = mention['id'] model = MODELS.get(model_id, None) if model: @@ -1019,14 +879,12 @@ async def model_response_handler(request, channel, message, user, db=None): channel.id, MessageForm( **{ - "parent_id": ( - message.parent_id if message.parent_id else message.id - ), - "content": f"", - "data": {}, - "meta": { - "model_id": model_id, - "model_name": model.get("name", model_id), + 'parent_id': (message.parent_id if message.parent_id else message.id), + 'content': f'', + 'data': {}, + 'meta': { + 'model_id': model_id, + 'model_name': model.get('name', model_id), }, } ), @@ -1041,63 +899,53 @@ async def model_response_handler(request, channel, message, user, db=None): for thread_message in thread_messages: message_user = None if thread_message.user_id not in message_users: - message_user = Users.get_user_by_id( - thread_message.user_id, db=db - ) + message_user = Users.get_user_by_id(thread_message.user_id, db=db) message_users[thread_message.user_id] = message_user else: message_user = message_users[thread_message.user_id] - if thread_message.meta and thread_message.meta.get( - "model_id", None - ): + if thread_message.meta and thread_message.meta.get('model_id', None): # If the message was sent by a model, use the model name - message_model_id = thread_message.meta.get("model_id", None) + message_model_id = thread_message.meta.get('model_id', None) message_model = MODELS.get(message_model_id, None) - username = ( - message_model.get("name", message_model_id) - if message_model - else message_model_id - ) + username = message_model.get('name', message_model_id) if message_model else message_model_id else: - username = message_user.name if message_user else "Unknown" + username = message_user.name if message_user else 'Unknown' - thread_history.append( - f"{username}: {replace_mentions(thread_message.content)}" - ) + thread_history.append(f'{username}: {replace_mentions(thread_message.content)}') - thread_message_files = (thread_message.data or {}).get("files", []) + thread_message_files = (thread_message.data or {}).get('files', []) for file in thread_message_files: - if file.get("type", "") == "image": - images.append(file.get("url", "")) - elif file.get("content_type", "").startswith("image/"): - image = get_image_base64_from_file_id(file.get("id", "")) + if file.get('type', '') == 'image': + images.append(file.get('url', '')) + elif file.get('content_type', '').startswith('image/'): + image = get_image_base64_from_file_id(file.get('id', '')) if image: images.append(image) - thread_history_string = "\n\n".join(thread_history) + thread_history_string = '\n\n'.join(thread_history) system_message = { - "role": "system", - "content": f"You are {model.get('name', model_id)}, participating in a threaded conversation. Be concise and conversational." + 'role': 'system', + 'content': f'You are {model.get("name", model_id)}, participating in a threaded conversation. Be concise and conversational.' + ( f"Here's the thread history:\n\n\n{thread_history_string}\n\n\nContinue the conversation naturally as {model.get('name', model_id)}, addressing the most recent message while being aware of the full context." if thread_history - else "" + else '' ), } - content = f"{user.name if user else 'User'}: {message_content}" + content = f'{user.name if user else "User"}: {message_content}' if images: content = [ { - "type": "text", - "text": content, + 'type': 'text', + 'text': content, }, *[ { - "type": "image_url", - "image_url": { - "url": image, + 'type': 'image_url', + 'image_url': { + 'url': image, }, } for image in images @@ -1105,12 +953,12 @@ async def model_response_handler(request, channel, message, user, db=None): ] form_data = { - "model": model_id, - "messages": [ + 'model': model_id, + 'messages': [ system_message, - {"role": "user", "content": content}, + {'role': 'user', 'content': content}, ], - "stream": False, + 'stream': False, } res = await generate_chat_completion( @@ -1120,32 +968,32 @@ async def model_response_handler(request, channel, message, user, db=None): ) if res: - if res.get("choices", []) and len(res["choices"]) > 0: + if res.get('choices', []) and len(res['choices']) > 0: await update_message_by_id( request, channel.id, response_message.id, MessageForm( **{ - "content": res["choices"][0]["message"]["content"], - "meta": { - "done": True, + 'content': res['choices'][0]['message']['content'], + 'meta': { + 'done': True, }, } ), user, db, ) - elif res.get("error", None): + elif res.get('error', None): await update_message_by_id( request, channel.id, response_message.id, MessageForm( **{ - "content": f"Error: {res['error']}", - "meta": { - "done": True, + 'content': f'Error: {res["error"]}', + 'meta': { + 'done': True, }, } ), @@ -1159,59 +1007,49 @@ async def model_response_handler(request, channel, message, user, db=None): return True -async def new_message_handler( - request: Request, id: str, form_data: MessageForm, user, db -): +async def new_message_handler(request: Request, id: str, form_data: MessageForm, user, db): channel = Channels.get_channel_by_id(id, db=db) if not channel: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND - ) + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND) - if channel.type in ["group", "dm"]: + if channel.type in ['group', 'dm']: if not Channels.is_user_channel_member(channel.id, user.id, db=db): - raise HTTPException( - status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT() - ) + raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()) else: - if user.role != "admin" and not channel_has_access( + if user.role != 'admin' and not channel_has_access( user.id, channel, - permission="write", + permission='write', strict=False, db=db, ): - raise HTTPException( - status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT() - ) + raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()) try: message = Messages.insert_new_message(form_data, channel.id, user.id, db=db) if message: - if channel.type in ["group", "dm"]: + if channel.type in ['group', 'dm']: members = Channels.get_members_by_channel_id(channel.id, db=db) for member in members: if not member.is_active: - Channels.update_member_active_status( - channel.id, member.user_id, True, db=db - ) + Channels.update_member_active_status(channel.id, member.user_id, True, db=db) message = Messages.get_message_by_id(message.id, db=db) event_data = { - "channel_id": channel.id, - "message_id": message.id, - "data": { - "type": "message", - "data": {"temp_id": form_data.temp_id, **message.model_dump()}, + 'channel_id': channel.id, + 'message_id': message.id, + 'data': { + 'type': 'message', + 'data': {'temp_id': form_data.temp_id, **message.model_dump()}, }, - "user": UserNameResponse(**user.model_dump()).model_dump(), - "channel": channel.model_dump(), + 'user': UserNameResponse(**user.model_dump()).model_dump(), + 'channel': channel.model_dump(), } await sio.emit( - "events:channel", + 'events:channel', event_data, - to=f"channel:{channel.id}", + to=f'channel:{channel.id}', ) if message.parent_id: @@ -1220,30 +1058,28 @@ async def new_message_handler( if parent_message: await sio.emit( - "events:channel", + 'events:channel', { - "channel_id": channel.id, - "message_id": parent_message.id, - "data": { - "type": "message:reply", - "data": parent_message.model_dump(), + 'channel_id': channel.id, + 'message_id': parent_message.id, + 'data': { + 'type': 'message:reply', + 'data': parent_message.model_dump(), }, - "user": UserNameResponse(**user.model_dump()).model_dump(), - "channel": channel.model_dump(), + 'user': UserNameResponse(**user.model_dump()).model_dump(), + 'channel': channel.model_dump(), }, - to=f"channel:{channel.id}", + to=f'channel:{channel.id}', ) return message, channel else: - raise Exception("Error creating message") + raise Exception('Error creating message') except Exception as e: log.exception(e) - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT() - ) + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT()) -@router.post("/{id}/messages/post", response_model=Optional[MessageModel]) +@router.post('/{id}/messages/post', response_model=Optional[MessageModel]) async def post_new_message( request: Request, id: str, @@ -1257,15 +1093,13 @@ async def post_new_message( try: message, channel = await new_message_handler(request, id, form_data, user, db) try: - if files := message.data.get("files", []): + if files := message.data.get('files', []): for file in files: - Channels.set_file_message_id_in_channel_by_id( - channel.id, file.get("id", ""), message.id, db=db - ) + Channels.set_file_message_id_in_channel_by_id(channel.id, file.get('id', ''), message.id, db=db) except Exception as e: log.debug(e) - active_user_ids = get_user_ids_from_room(f"channel:{channel.id}") + active_user_ids = get_user_ids_from_room(f'channel:{channel.id}') # NOTE: We intentionally do NOT pass db to background_handler. # Background tasks should manage their own short-lived sessions to avoid @@ -1288,9 +1122,7 @@ async def post_new_message( raise e except Exception as e: log.exception(e) - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT() - ) + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT()) ############################ @@ -1298,7 +1130,7 @@ async def post_new_message( ############################ -@router.get("/{id}/messages/{message_id}", response_model=Optional[MessageResponse]) +@router.get('/{id}/messages/{message_id}', response_model=Optional[MessageResponse]) async def get_channel_message( request: Request, id: str, @@ -1309,40 +1141,26 @@ async def get_channel_message( check_channels_access(request) channel = Channels.get_channel_by_id(id, db=db) if not channel: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND - ) + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND) - if channel.type in ["group", "dm"]: + if channel.type in ['group', 'dm']: if not Channels.is_user_channel_member(channel.id, user.id, db=db): - raise HTTPException( - status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT() - ) + raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()) else: - if user.role != "admin" and not channel_has_access( - user.id, channel, permission="read", db=db - ): - raise HTTPException( - status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT() - ) + if user.role != 'admin' and not channel_has_access(user.id, channel, permission='read', db=db): + raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()) message = Messages.get_message_by_id(message_id, db=db) if not message: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND - ) + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND) if message.channel_id != id: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT() - ) + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT()) return MessageResponse( **{ **message.model_dump(), - "user": UserNameResponse( - **Users.get_user_by_id(message.user_id, db=db).model_dump() - ), + 'user': UserNameResponse(**Users.get_user_by_id(message.user_id, db=db).model_dump()), } ) @@ -1352,7 +1170,7 @@ async def get_channel_message( ############################ -@router.get("/{id}/messages/{message_id}/data", response_model=Optional[dict]) +@router.get('/{id}/messages/{message_id}/data', response_model=Optional[dict]) async def get_channel_message_data( request: Request, id: str, @@ -1363,33 +1181,21 @@ async def get_channel_message_data( check_channels_access(request) channel = Channels.get_channel_by_id(id, db=db) if not channel: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND - ) + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND) - if channel.type in ["group", "dm"]: + if channel.type in ['group', 'dm']: if not Channels.is_user_channel_member(channel.id, user.id, db=db): - raise HTTPException( - status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT() - ) + raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()) else: - if user.role != "admin" and not channel_has_access( - user.id, channel, permission="read", db=db - ): - raise HTTPException( - status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT() - ) + if user.role != 'admin' and not channel_has_access(user.id, channel, permission='read', db=db): + raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()) message = Messages.get_message_by_id(message_id, db=db) if not message: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND - ) + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND) if message.channel_id != id: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT() - ) + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT()) return message.data @@ -1403,9 +1209,7 @@ class PinMessageForm(BaseModel): is_pinned: bool -@router.post( - "/{id}/messages/{message_id}/pin", response_model=Optional[MessageUserResponse] -) +@router.post('/{id}/messages/{message_id}/pin', response_model=Optional[MessageUserResponse]) async def pin_channel_message( request: Request, id: str, @@ -1417,33 +1221,21 @@ async def pin_channel_message( check_channels_access(request) channel = Channels.get_channel_by_id(id, db=db) if not channel: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND - ) + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND) - if channel.type in ["group", "dm"]: + if channel.type in ['group', 'dm']: if not Channels.is_user_channel_member(channel.id, user.id, db=db): - raise HTTPException( - status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT() - ) + raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()) else: - if user.role != "admin" and not channel_has_access( - user.id, channel, permission="read", db=db - ): - raise HTTPException( - status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT() - ) + if user.role != 'admin' and not channel_has_access(user.id, channel, permission='read', db=db): + raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()) message = Messages.get_message_by_id(message_id, db=db) if not message: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND - ) + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND) if message.channel_id != id: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT() - ) + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT()) try: Messages.update_is_pinned_by_id(message_id, form_data.is_pinned, user.id, db=db) @@ -1451,16 +1243,12 @@ async def pin_channel_message( return MessageUserResponse( **{ **message.model_dump(), - "user": UserNameResponse( - **Users.get_user_by_id(message.user_id, db=db).model_dump() - ), + 'user': UserNameResponse(**Users.get_user_by_id(message.user_id, db=db).model_dump()), } ) except Exception as e: log.exception(e) - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT() - ) + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT()) ############################ @@ -1468,9 +1256,7 @@ async def pin_channel_message( ############################ -@router.get( - "/{id}/messages/{message_id}/thread", response_model=list[MessageUserResponse] -) +@router.get('/{id}/messages/{message_id}/thread', response_model=list[MessageUserResponse]) async def get_channel_thread_messages( request: Request, id: str, @@ -1483,26 +1269,16 @@ async def get_channel_thread_messages( check_channels_access(request) channel = Channels.get_channel_by_id(id, db=db) if not channel: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND - ) + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND) - if channel.type in ["group", "dm"]: + if channel.type in ['group', 'dm']: if not Channels.is_user_channel_member(channel.id, user.id, db=db): - raise HTTPException( - status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT() - ) + raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()) else: - if user.role != "admin" and not channel_has_access( - user.id, channel, permission="read", db=db - ): - raise HTTPException( - status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT() - ) + if user.role != 'admin' and not channel_has_access(user.id, channel, permission='read', db=db): + raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()) - message_list = Messages.get_messages_by_parent_id( - id, message_id, skip, limit, db=db - ) + message_list = Messages.get_messages_by_parent_id(id, message_id, skip, limit, db=db) if not message_list: return [] @@ -1522,12 +1298,10 @@ async def get_channel_thread_messages( MessageUserResponse( **{ **message.model_dump(), - "reply_count": 0, - "latest_reply_at": None, - "reactions": Messages.get_reactions_by_message_id( - message.id, db=db - ), - "user": user_info, + 'reply_count': 0, + 'latest_reply_at': None, + 'reactions': Messages.get_reactions_by_message_id(message.id, db=db), + 'user': user_info, } ) ) @@ -1540,9 +1314,7 @@ async def get_channel_thread_messages( ############################ -@router.post( - "/{id}/messages/{message_id}/update", response_model=Optional[MessageModel] -) +@router.post('/{id}/messages/{message_id}/update', response_model=Optional[MessageModel]) async def update_message_by_id( request: Request, id: str, @@ -1554,37 +1326,25 @@ async def update_message_by_id( check_channels_access(request) channel = Channels.get_channel_by_id(id, db=db) if not channel: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND - ) + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND) message = Messages.get_message_by_id(message_id, db=db) if not message: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND - ) + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND) if message.channel_id != id: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT() - ) + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT()) - if channel.type in ["group", "dm"]: + if channel.type in ['group', 'dm']: if not Channels.is_user_channel_member(channel.id, user.id, db=db): - raise HTTPException( - status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT() - ) + raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()) else: if ( - user.role != "admin" + user.role != 'admin' and message.user_id != user.id - and not channel_has_access( - user.id, channel, permission="write", strict=False, db=db - ) + and not channel_has_access(user.id, channel, permission='write', strict=False, db=db) ): - raise HTTPException( - status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT() - ) + raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()) try: message = Messages.update_message_by_id(message_id, form_data, db=db) @@ -1592,26 +1352,24 @@ async def update_message_by_id( if message: await sio.emit( - "events:channel", + 'events:channel', { - "channel_id": channel.id, - "message_id": message.id, - "data": { - "type": "message:update", - "data": message.model_dump(), + 'channel_id': channel.id, + 'message_id': message.id, + 'data': { + 'type': 'message:update', + 'data': message.model_dump(), }, - "user": UserNameResponse(**user.model_dump()).model_dump(), - "channel": channel.model_dump(), + 'user': UserNameResponse(**user.model_dump()).model_dump(), + 'channel': channel.model_dump(), }, - to=f"channel:{channel.id}", + to=f'channel:{channel.id}', ) return MessageModel(**message.model_dump()) except Exception as e: log.exception(e) - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT() - ) + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT()) ############################ @@ -1623,7 +1381,7 @@ class ReactionForm(BaseModel): name: str -@router.post("/{id}/messages/{message_id}/reactions/add", response_model=bool) +@router.post('/{id}/messages/{message_id}/reactions/add', response_model=bool) async def add_reaction_to_message( request: Request, id: str, @@ -1635,66 +1393,54 @@ async def add_reaction_to_message( check_channels_access(request) channel = Channels.get_channel_by_id(id, db=db) if not channel: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND - ) + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND) - if channel.type in ["group", "dm"]: + if channel.type in ['group', 'dm']: if not Channels.is_user_channel_member(channel.id, user.id, db=db): - raise HTTPException( - status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT() - ) + raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()) else: - if user.role != "admin" and not channel_has_access( + if user.role != 'admin' and not channel_has_access( user.id, channel, - permission="write", + permission='write', strict=False, db=db, ): - raise HTTPException( - status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT() - ) + raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()) message = Messages.get_message_by_id(message_id, db=db) if not message: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND - ) + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND) if message.channel_id != id: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT() - ) + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT()) try: Messages.add_reaction_to_message(message_id, user.id, form_data.name, db=db) message = Messages.get_message_by_id(message_id, db=db) await sio.emit( - "events:channel", + 'events:channel', { - "channel_id": channel.id, - "message_id": message.id, - "data": { - "type": "message:reaction:add", - "data": { + 'channel_id': channel.id, + 'message_id': message.id, + 'data': { + 'type': 'message:reaction:add', + 'data': { **message.model_dump(), - "name": form_data.name, + 'name': form_data.name, }, }, - "user": UserNameResponse(**user.model_dump()).model_dump(), - "channel": channel.model_dump(), + 'user': UserNameResponse(**user.model_dump()).model_dump(), + 'channel': channel.model_dump(), }, - to=f"channel:{channel.id}", + to=f'channel:{channel.id}', ) return True except Exception as e: log.exception(e) - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT() - ) + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT()) ############################ @@ -1702,7 +1448,7 @@ async def add_reaction_to_message( ############################ -@router.post("/{id}/messages/{message_id}/reactions/remove", response_model=bool) +@router.post('/{id}/messages/{message_id}/reactions/remove', response_model=bool) async def remove_reaction_by_id_and_user_id_and_name( request: Request, id: str, @@ -1714,69 +1460,55 @@ async def remove_reaction_by_id_and_user_id_and_name( check_channels_access(request) channel = Channels.get_channel_by_id(id, db=db) if not channel: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND - ) + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND) - if channel.type in ["group", "dm"]: + if channel.type in ['group', 'dm']: if not Channels.is_user_channel_member(channel.id, user.id, db=db): - raise HTTPException( - status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT() - ) + raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()) else: - if user.role != "admin" and not channel_has_access( + if user.role != 'admin' and not channel_has_access( user.id, channel, - permission="write", + permission='write', strict=False, db=db, ): - raise HTTPException( - status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT() - ) + raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()) message = Messages.get_message_by_id(message_id, db=db) if not message: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND - ) + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND) if message.channel_id != id: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT() - ) + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT()) try: - Messages.remove_reaction_by_id_and_user_id_and_name( - message_id, user.id, form_data.name, db=db - ) + Messages.remove_reaction_by_id_and_user_id_and_name(message_id, user.id, form_data.name, db=db) message = Messages.get_message_by_id(message_id, db=db) await sio.emit( - "events:channel", + 'events:channel', { - "channel_id": channel.id, - "message_id": message.id, - "data": { - "type": "message:reaction:remove", - "data": { + 'channel_id': channel.id, + 'message_id': message.id, + 'data': { + 'type': 'message:reaction:remove', + 'data': { **message.model_dump(), - "name": form_data.name, + 'name': form_data.name, }, }, - "user": UserNameResponse(**user.model_dump()).model_dump(), - "channel": channel.model_dump(), + 'user': UserNameResponse(**user.model_dump()).model_dump(), + 'channel': channel.model_dump(), }, - to=f"channel:{channel.id}", + to=f'channel:{channel.id}', ) return True except Exception as e: log.exception(e) - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT() - ) + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT()) ############################ @@ -1784,7 +1516,7 @@ async def remove_reaction_by_id_and_user_id_and_name( ############################ -@router.delete("/{id}/messages/{message_id}/delete", response_model=bool) +@router.delete('/{id}/messages/{message_id}/delete', response_model=bool) async def delete_message_by_id( request: Request, id: str, @@ -1795,60 +1527,50 @@ async def delete_message_by_id( check_channels_access(request) channel = Channels.get_channel_by_id(id, db=db) if not channel: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND - ) + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND) message = Messages.get_message_by_id(message_id, db=db) if not message: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND - ) + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND) if message.channel_id != id: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT() - ) + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT()) - if channel.type in ["group", "dm"]: + if channel.type in ['group', 'dm']: if not Channels.is_user_channel_member(channel.id, user.id, db=db): - raise HTTPException( - status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT() - ) + raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()) else: if ( - user.role != "admin" + user.role != 'admin' and message.user_id != user.id and not channel_has_access( user.id, channel, - permission="write", + permission='write', strict=False, db=db, ) ): - raise HTTPException( - status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT() - ) + raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()) try: Messages.delete_message_by_id(message_id, db=db) await sio.emit( - "events:channel", + 'events:channel', { - "channel_id": channel.id, - "message_id": message.id, - "data": { - "type": "message:delete", - "data": { + 'channel_id': channel.id, + 'message_id': message.id, + 'data': { + 'type': 'message:delete', + 'data': { **message.model_dump(), - "user": UserNameResponse(**user.model_dump()).model_dump(), + 'user': UserNameResponse(**user.model_dump()).model_dump(), }, }, - "user": UserNameResponse(**user.model_dump()).model_dump(), - "channel": channel.model_dump(), + 'user': UserNameResponse(**user.model_dump()).model_dump(), + 'channel': channel.model_dump(), }, - to=f"channel:{channel.id}", + to=f'channel:{channel.id}', ) if message.parent_id: @@ -1857,26 +1579,24 @@ async def delete_message_by_id( if parent_message: await sio.emit( - "events:channel", + 'events:channel', { - "channel_id": channel.id, - "message_id": parent_message.id, - "data": { - "type": "message:reply", - "data": parent_message.model_dump(), + 'channel_id': channel.id, + 'message_id': parent_message.id, + 'data': { + 'type': 'message:reply', + 'data': parent_message.model_dump(), }, - "user": UserNameResponse(**user.model_dump()).model_dump(), - "channel": channel.model_dump(), + 'user': UserNameResponse(**user.model_dump()).model_dump(), + 'channel': channel.model_dump(), }, - to=f"channel:{channel.id}", + to=f'channel:{channel.id}', ) return True except Exception as e: log.exception(e) - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT() - ) + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT()) ############################ @@ -1884,41 +1604,41 @@ async def delete_message_by_id( ############################ -@router.get("/webhooks/{webhook_id}/profile/image") +@router.get('/webhooks/{webhook_id}/profile/image') def get_webhook_profile_image(webhook_id: str, user=Depends(get_verified_user)): """Get webhook profile image by webhook ID.""" webhook = Channels.get_webhook_by_id(webhook_id) if not webhook: # Return default favicon if webhook not found - return FileResponse(f"{STATIC_DIR}/favicon.png") + return FileResponse(f'{STATIC_DIR}/favicon.png') if webhook.profile_image_url: # Check if it's url or base64 - if webhook.profile_image_url.startswith("http"): + if webhook.profile_image_url.startswith('http'): return Response( status_code=status.HTTP_302_FOUND, - headers={"Location": webhook.profile_image_url}, + headers={'Location': webhook.profile_image_url}, ) - elif webhook.profile_image_url.startswith("data:image"): + elif webhook.profile_image_url.startswith('data:image'): try: - header, base64_data = webhook.profile_image_url.split(",", 1) + header, base64_data = webhook.profile_image_url.split(',', 1) image_data = base64.b64decode(base64_data) image_buffer = io.BytesIO(image_data) - media_type = header.split(";")[0].lstrip("data:") + media_type = header.split(';')[0].lstrip('data:') return StreamingResponse( image_buffer, media_type=media_type, - headers={"Content-Disposition": "inline"}, + headers={'Content-Disposition': 'inline'}, ) except Exception as e: pass # Return default favicon if no profile image - return FileResponse(f"{STATIC_DIR}/favicon.png") + return FileResponse(f'{STATIC_DIR}/favicon.png') -@router.get("/{id}/webhooks", response_model=list[ChannelWebhookModel]) +@router.get('/{id}/webhooks', response_model=list[ChannelWebhookModel]) async def get_channel_webhooks( request: Request, id: str, @@ -1928,23 +1648,16 @@ async def get_channel_webhooks( check_channels_access(request) channel = Channels.get_channel_by_id(id, db=db) if not channel: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND - ) + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND) # Only channel managers can view webhooks - if ( - not Channels.is_user_channel_manager(channel.id, user.id, db=db) - and user.role != "admin" - ): - raise HTTPException( - status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.UNAUTHORIZED - ) + if not Channels.is_user_channel_manager(channel.id, user.id, db=db) and user.role != 'admin': + raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.UNAUTHORIZED) return Channels.get_webhooks_by_channel_id(id, db=db) -@router.post("/{id}/webhooks/create", response_model=ChannelWebhookModel) +@router.post('/{id}/webhooks/create', response_model=ChannelWebhookModel) async def create_channel_webhook( request: Request, id: str, @@ -1955,29 +1668,20 @@ async def create_channel_webhook( check_channels_access(request) channel = Channels.get_channel_by_id(id, db=db) if not channel: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND - ) + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND) # Only channel managers can create webhooks - if ( - not Channels.is_user_channel_manager(channel.id, user.id, db=db) - and user.role != "admin" - ): - raise HTTPException( - status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.UNAUTHORIZED - ) + if not Channels.is_user_channel_manager(channel.id, user.id, db=db) and user.role != 'admin': + raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.UNAUTHORIZED) webhook = Channels.insert_webhook(id, user.id, form_data, db=db) if not webhook: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT() - ) + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT()) return webhook -@router.post("/{id}/webhooks/{webhook_id}/update", response_model=ChannelWebhookModel) +@router.post('/{id}/webhooks/{webhook_id}/update', response_model=ChannelWebhookModel) async def update_channel_webhook( request: Request, id: str, @@ -1989,35 +1693,24 @@ async def update_channel_webhook( check_channels_access(request) channel = Channels.get_channel_by_id(id, db=db) if not channel: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND - ) + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND) # Only channel managers can update webhooks - if ( - not Channels.is_user_channel_manager(channel.id, user.id, db=db) - and user.role != "admin" - ): - raise HTTPException( - status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.UNAUTHORIZED - ) + if not Channels.is_user_channel_manager(channel.id, user.id, db=db) and user.role != 'admin': + raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.UNAUTHORIZED) webhook = Channels.get_webhook_by_id(webhook_id, db=db) if not webhook or webhook.channel_id != id: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND - ) + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND) updated = Channels.update_webhook_by_id(webhook_id, form_data, db=db) if not updated: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT() - ) + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT()) return updated -@router.delete("/{id}/webhooks/{webhook_id}/delete", response_model=bool) +@router.delete('/{id}/webhooks/{webhook_id}/delete', response_model=bool) async def delete_channel_webhook( request: Request, id: str, @@ -2028,24 +1721,15 @@ async def delete_channel_webhook( check_channels_access(request) channel = Channels.get_channel_by_id(id, db=db) if not channel: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND - ) + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND) # Only channel managers can delete webhooks - if ( - not Channels.is_user_channel_manager(channel.id, user.id, db=db) - and user.role != "admin" - ): - raise HTTPException( - status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.UNAUTHORIZED - ) + if not Channels.is_user_channel_manager(channel.id, user.id, db=db) and user.role != 'admin': + raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.UNAUTHORIZED) webhook = Channels.get_webhook_by_id(webhook_id, db=db) if not webhook or webhook.channel_id != id: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND - ) + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND) return Channels.delete_webhook_by_id(webhook_id, db=db) @@ -2059,7 +1743,7 @@ class WebhookMessageForm(BaseModel): content: str -@router.post("/webhooks/{webhook_id}/{token}") +@router.post('/webhooks/{webhook_id}/{token}') async def post_webhook_message( request: Request, webhook_id: str, @@ -2075,18 +1759,16 @@ async def post_webhook_message( if not webhook: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, - detail="Invalid webhook URL", + detail='Invalid webhook URL', ) channel = Channels.get_channel_by_id(webhook.channel_id, db=db) if not channel: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND - ) + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND) # Create message with webhook identity stored in meta message = Messages.insert_new_message( - MessageForm(content=form_data.content, meta={"webhook": {"id": webhook.id}}), + MessageForm(content=form_data.content, meta={'webhook': {'id': webhook.id}}), webhook.channel_id, webhook.user_id, # Required for DB but webhook info in meta takes precedence db=db, @@ -2095,7 +1777,7 @@ async def post_webhook_message( if not message: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, - detail="Failed to create message", + detail='Failed to create message', ) # Update last_used_at @@ -2105,31 +1787,31 @@ async def post_webhook_message( message = Messages.get_message_by_id(message.id, db=db) event_data = { - "channel_id": channel.id, - "message_id": message.id, - "data": { - "type": "message", - "data": { + 'channel_id': channel.id, + 'message_id': message.id, + 'data': { + 'type': 'message', + 'data': { **message.model_dump(), - "user": { - "id": webhook.id, - "name": webhook.name, - "role": "webhook", + 'user': { + 'id': webhook.id, + 'name': webhook.name, + 'role': 'webhook', }, }, }, - "user": { - "id": webhook.id, - "name": webhook.name, - "role": "webhook", + 'user': { + 'id': webhook.id, + 'name': webhook.name, + 'role': 'webhook', }, - "channel": channel.model_dump(), + 'channel': channel.model_dump(), } await sio.emit( - "events:channel", + 'events:channel', event_data, - to=f"channel:{channel.id}", + to=f'channel:{channel.id}', ) - return {"success": True, "message_id": message.id} + return {'success': True, 'message_id': message.id} diff --git a/backend/open_webui/routers/chats.py b/backend/open_webui/routers/chats.py index 837170c89d..79d0525698 100644 --- a/backend/open_webui/routers/chats.py +++ b/backend/open_webui/routers/chats.py @@ -45,8 +45,8 @@ router = APIRouter() ############################ -@router.get("/", response_model=list[ChatTitleIdResponse]) -@router.get("/list", response_model=list[ChatTitleIdResponse]) +@router.get('/', response_model=list[ChatTitleIdResponse]) +@router.get('/list', response_model=list[ChatTitleIdResponse]) def get_session_user_chat_list( user=Depends(get_verified_user), page: Optional[int] = None, @@ -76,9 +76,7 @@ def get_session_user_chat_list( ) except Exception as e: log.exception(e) - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT() - ) + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT()) ############################ @@ -87,7 +85,7 @@ def get_session_user_chat_list( ############################ -@router.get("/stats/usage", response_model=ChatUsageStatsListResponse) +@router.get('/stats/usage', response_model=ChatUsageStatsListResponse) def get_session_user_chat_usage_stats( items_per_page: Optional[int] = 50, page: Optional[int] = 1, @@ -105,8 +103,8 @@ def get_session_user_chat_usage_stats( chat_stats = [] for chat in chats: - messages_map = chat.chat.get("history", {}).get("messages", {}) - message_id = chat.chat.get("history", {}).get("currentId") + messages_map = chat.chat.get('history', {}).get('messages', {}) + message_id = chat.chat.get('history', {}).get('currentId') if messages_map and message_id: try: @@ -116,30 +114,24 @@ def get_session_user_chat_usage_stats( history_assistant_messages = [] for message in messages_map.values(): - if message.get("role", "") == "user": + if message.get('role', '') == 'user': history_user_messages.append(message) - elif message.get("role", "") == "assistant": + elif message.get('role', '') == 'assistant': history_assistant_messages.append(message) - model = message.get("model", None) + model = message.get('model', None) if model: if model not in history_models: history_models[model] = 0 history_models[model] += 1 average_user_message_content_length = ( - sum( - len(message.get("content", "")) - for message in history_user_messages - ) + sum(len(message.get('content', '')) for message in history_user_messages) / len(history_user_messages) if len(history_user_messages) > 0 else 0 ) average_assistant_message_content_length = ( - sum( - len(message.get("content", "")) - for message in history_assistant_messages - ) + sum(len(message.get('content', '')) for message in history_assistant_messages) / len(history_assistant_messages) if len(history_assistant_messages) > 0 else 0 @@ -147,53 +139,45 @@ def get_session_user_chat_usage_stats( response_times = [] for message in history_assistant_messages: - user_message_id = message.get("parentId", None) + user_message_id = message.get('parentId', None) if user_message_id and user_message_id in messages_map: user_message = messages_map[user_message_id] - response_time = message.get( - "timestamp", 0 - ) - user_message.get("timestamp", 0) + response_time = message.get('timestamp', 0) - user_message.get('timestamp', 0) response_times.append(response_time) - average_response_time = ( - sum(response_times) / len(response_times) - if len(response_times) > 0 - else 0 - ) + average_response_time = sum(response_times) / len(response_times) if len(response_times) > 0 else 0 message_list = get_message_list(messages_map, message_id) message_count = len(message_list) models = {} for message in reversed(message_list): - if message.get("role") == "assistant": - model = message.get("model", None) + if message.get('role') == 'assistant': + model = message.get('model', None) if model: if model not in models: models[model] = 0 models[model] += 1 - annotation = message.get("annotation", {}) + annotation = message.get('annotation', {}) chat_stats.append( { - "id": chat.id, - "models": models, - "message_count": message_count, - "history_models": history_models, - "history_message_count": history_message_count, - "history_user_message_count": len(history_user_messages), - "history_assistant_message_count": len( - history_assistant_messages - ), - "average_response_time": average_response_time, - "average_user_message_content_length": average_user_message_content_length, - "average_assistant_message_content_length": average_assistant_message_content_length, - "tags": chat.meta.get("tags", []), - "last_message_at": message_list[-1].get("timestamp", None), - "updated_at": chat.updated_at, - "created_at": chat.created_at, + 'id': chat.id, + 'models': models, + 'message_count': message_count, + 'history_models': history_models, + 'history_message_count': history_message_count, + 'history_user_message_count': len(history_user_messages), + 'history_assistant_message_count': len(history_assistant_messages), + 'average_response_time': average_response_time, + 'average_user_message_content_length': average_user_message_content_length, + 'average_assistant_message_content_length': average_assistant_message_content_length, + 'tags': chat.meta.get('tags', []), + 'last_message_at': message_list[-1].get('timestamp', None), + 'updated_at': chat.updated_at, + 'created_at': chat.created_at, } ) except Exception as e: @@ -203,9 +187,7 @@ def get_session_user_chat_usage_stats( except Exception as e: log.exception(e) - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT() - ) + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT()) ############################ @@ -217,7 +199,7 @@ CHAT_EXPORT_PAGE_ITEM_COUNT = 10 class ChatStatsExportList(BaseModel): - type: str = "chats" + type: str = 'chats' items: list[ChatStatsExport] total: int page: int @@ -227,19 +209,15 @@ def _process_chat_for_export(chat) -> Optional[ChatStatsExport]: try: def get_message_content_length(message): - content = message.get("content", "") + content = message.get('content', '') if isinstance(content, str): return len(content) elif isinstance(content, list): - return sum( - len(item.get("text", "")) - for item in content - if item.get("type") == "text" - ) + return sum(len(item.get('text', '')) for item in content if item.get('type') == 'text') return 0 - messages_map = chat.chat.get("history", {}).get("messages", {}) - message_id = chat.chat.get("history", {}).get("currentId") + messages_map = chat.chat.get('history', {}).get('messages', {}) + message_id = chat.chat.get('history', {}).get('currentId') history_models = {} history_message_count = len(messages_map) @@ -252,14 +230,14 @@ def _process_chat_for_export(chat) -> Optional[ChatStatsExport]: content_length = get_message_content_length(message) # Extract rating safely - rating = message.get("annotation", {}).get("rating") - tags = message.get("annotation", {}).get("tags") + rating = message.get('annotation', {}).get('rating') + tags = message.get('annotation', {}).get('tags') message_stat = MessageStats( - id=message.get("id"), - role=message.get("role"), - model=message.get("model"), - timestamp=message.get("timestamp"), + id=message.get('id'), + role=message.get('role'), + model=message.get('model'), + timestamp=message.get('timestamp'), content_length=content_length, token_count=None, # Populate if available, e.g. message.get("info", {}).get("token_count") rating=rating, @@ -269,31 +247,29 @@ def _process_chat_for_export(chat) -> Optional[ChatStatsExport]: export_messages[key] = message_stat # --- Aggregation Logic (copied/adapted from usage stats) --- - role = message.get("role", "") - if role == "user": + role = message.get('role', '') + if role == 'user': history_user_messages.append(message) - elif role == "assistant": + elif role == 'assistant': history_assistant_messages.append(message) - model = message.get("model") + model = message.get('model') if model: if model not in history_models: history_models[model] = 0 history_models[model] += 1 except Exception as e: - log.debug(f"Error processing message {key}: {e}") + log.debug(f'Error processing message {key}: {e}') continue # Calculate Averages average_user_message_content_length = ( - sum(get_message_content_length(m) for m in history_user_messages) - / len(history_user_messages) + sum(get_message_content_length(m) for m in history_user_messages) / len(history_user_messages) if history_user_messages else 0 ) average_assistant_message_content_length = ( - sum(get_message_content_length(m) for m in history_assistant_messages) - / len(history_assistant_messages) + sum(get_message_content_length(m) for m in history_assistant_messages) / len(history_assistant_messages) if history_assistant_messages else 0 ) @@ -301,26 +277,24 @@ def _process_chat_for_export(chat) -> Optional[ChatStatsExport]: # Response Times response_times = [] for message in history_assistant_messages: - user_message_id = message.get("parentId", None) + user_message_id = message.get('parentId', None) if user_message_id and user_message_id in messages_map: user_message = messages_map[user_message_id] # Ensure timestamps exist - t1 = message.get("timestamp") - t0 = user_message.get("timestamp") + t1 = message.get('timestamp') + t0 = user_message.get('timestamp') if t1 and t0: response_times.append(t1 - t0) - average_response_time = ( - sum(response_times) / len(response_times) if response_times else 0 - ) + average_response_time = sum(response_times) / len(response_times) if response_times else 0 # Current Message List Logic (Main path) message_list = get_message_list(messages_map, message_id) message_count = len(message_list) models = {} for message in reversed(message_list): - if message.get("role") == "assistant": - model = message.get("model") + if message.get('role') == 'assistant': + model = message.get('model') if model: if model not in models: models[model] = 0 @@ -340,21 +314,19 @@ def _process_chat_for_export(chat) -> Optional[ChatStatsExport]: ) # Construct Chat Body - chat_body = ChatBody( - history=ChatHistoryStats(messages=export_messages, currentId=message_id) - ) + chat_body = ChatBody(history=ChatHistoryStats(messages=export_messages, currentId=message_id)) return ChatStatsExport( id=chat.id, user_id=chat.user_id, created_at=chat.created_at, updated_at=chat.updated_at, - tags=chat.meta.get("tags", []), + tags=chat.meta.get('tags', []), stats=stats, chat=chat_body, ) except Exception as e: - log.exception(f"Error exporting stats for chat {chat.id}: {e}") + log.exception(f'Error exporting stats for chat {chat.id}: {e}') return None @@ -408,14 +380,14 @@ def generate_chat_stats_jsonl_generator(user_id, filter): try: chat_stat = _process_chat_for_export(chat) if chat_stat: - yield chat_stat.model_dump_json() + "\n" + yield chat_stat.model_dump_json() + '\n' except Exception as e: - log.exception(f"Error processing chat {chat.id}: {e}") + log.exception(f'Error processing chat {chat.id}: {e}') skip += limit -@router.get("/stats/export", response_model=ChatStatsExportList) +@router.get('/stats/export', response_model=ChatStatsExportList) async def export_chat_stats( request: Request, updated_at: Optional[int] = None, @@ -424,9 +396,7 @@ async def export_chat_stats( user=Depends(get_verified_user), ): # Check if the user has permission to share/export chats - if (user.role != "admin") and ( - not request.app.state.config.ENABLE_COMMUNITY_SHARING - ): + if (user.role != 'admin') and (not request.app.state.config.ENABLE_COMMUNITY_SHARING): raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.ACCESS_PROHIBITED, @@ -434,36 +404,28 @@ async def export_chat_stats( try: # Fetch chats with date filtering - filter = {"order_by": "updated_at", "direction": "asc"} + filter = {'order_by': 'updated_at', 'direction': 'asc'} if updated_at: - filter["updated_at"] = updated_at + filter['updated_at'] = updated_at if stream: return StreamingResponse( generate_chat_stats_jsonl_generator(user.id, filter), - media_type="application/x-ndjson", - headers={ - "Content-Disposition": f"attachment; filename=chat-stats-export-{user.id}.jsonl" - }, + media_type='application/x-ndjson', + headers={'Content-Disposition': f'attachment; filename=chat-stats-export-{user.id}.jsonl'}, ) else: limit = CHAT_EXPORT_PAGE_ITEM_COUNT skip = (page - 1) * limit - chat_stats_export_list, total = await asyncio.to_thread( - calculate_chat_stats, user.id, skip, limit, filter - ) + chat_stats_export_list, total = await asyncio.to_thread(calculate_chat_stats, user.id, skip, limit, filter) - return ChatStatsExportList( - items=chat_stats_export_list, total=total, page=page - ) + return ChatStatsExportList(items=chat_stats_export_list, total=total, page=page) except Exception as e: - log.debug(f"Error exporting chat stats: {e}") - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT() - ) + log.debug(f'Error exporting chat stats: {e}') + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT()) ############################ @@ -471,7 +433,7 @@ async def export_chat_stats( ############################ -@router.get("/stats/export/{chat_id}", response_model=Optional[ChatStatsExport]) +@router.get('/stats/export/{chat_id}', response_model=Optional[ChatStatsExport]) async def export_single_chat_stats( request: Request, chat_id: str, @@ -483,9 +445,7 @@ async def export_single_chat_stats( Returns ChatStatsExport for the specified chat. """ # Check if the user has permission to share/export chats - if (user.role != "admin") and ( - not request.app.state.config.ENABLE_COMMUNITY_SHARING - ): + if (user.role != 'admin') and (not request.app.state.config.ENABLE_COMMUNITY_SHARING): raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.ACCESS_PROHIBITED, @@ -501,7 +461,7 @@ async def export_single_chat_stats( ) # Verify the chat belongs to the user (unless admin) - if chat.user_id != user.id and user.role != "admin": + if chat.user_id != user.id and user.role != 'admin': raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.ACCESS_PROHIBITED, @@ -513,7 +473,7 @@ async def export_single_chat_stats( if not chat_stats: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, - detail="Failed to process chat stats", + detail='Failed to process chat stats', ) return chat_stats @@ -521,22 +481,17 @@ async def export_single_chat_stats( except HTTPException: raise except Exception as e: - log.debug(f"Error exporting single chat stats: {e}") - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT() - ) + log.debug(f'Error exporting single chat stats: {e}') + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT()) -@router.delete("/", response_model=bool) +@router.delete('/', response_model=bool) async def delete_all_user_chats( request: Request, user=Depends(get_verified_user), db: Session = Depends(get_session), ): - - if user.role == "user" and not has_permission( - user.id, "chat.delete", request.app.state.config.USER_PERMISSIONS - ): + if user.role == 'user' and not has_permission(user.id, 'chat.delete', request.app.state.config.USER_PERMISSIONS): raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.ACCESS_PROHIBITED, @@ -551,7 +506,7 @@ async def delete_all_user_chats( ############################ -@router.get("/list/user/{user_id}", response_model=list[ChatTitleIdResponse]) +@router.get('/list/user/{user_id}', response_model=list[ChatTitleIdResponse]) async def get_user_chat_list_by_user_id( user_id: str, page: Optional[int] = None, @@ -575,15 +530,13 @@ async def get_user_chat_list_by_user_id( filter = {} if query: - filter["query"] = query + filter['query'] = query if order_by: - filter["order_by"] = order_by + filter['order_by'] = order_by if direction: - filter["direction"] = direction + filter['direction'] = direction - return Chats.get_chat_list_by_user_id( - user_id, include_archived=True, filter=filter, skip=skip, limit=limit, db=db - ) + return Chats.get_chat_list_by_user_id(user_id, include_archived=True, filter=filter, skip=skip, limit=limit, db=db) ############################ @@ -591,7 +544,7 @@ async def get_user_chat_list_by_user_id( ############################ -@router.post("/new", response_model=Optional[ChatResponse]) +@router.post('/new', response_model=Optional[ChatResponse]) async def create_new_chat( form_data: ChatForm, user=Depends(get_verified_user), @@ -602,9 +555,7 @@ async def create_new_chat( return ChatResponse(**chat.model_dump()) except Exception as e: log.exception(e) - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT() - ) + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT()) ############################ @@ -612,7 +563,7 @@ async def create_new_chat( ############################ -@router.post("/import", response_model=list[ChatResponse]) +@router.post('/import', response_model=list[ChatResponse]) async def import_chats( form_data: ChatsImportForm, user=Depends(get_verified_user), @@ -623,9 +574,7 @@ async def import_chats( return chats except Exception as e: log.exception(e) - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT() - ) + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT()) ############################ @@ -633,7 +582,7 @@ async def import_chats( ############################ -@router.get("/search", response_model=list[ChatTitleIdResponse]) +@router.get('/search', response_model=list[ChatTitleIdResponse]) def search_user_chats( text: str, page: Optional[int] = None, @@ -648,18 +597,16 @@ def search_user_chats( chat_list = [ ChatTitleIdResponse(**chat.model_dump()) - for chat in Chats.get_chats_by_user_id_and_search_text( - user.id, text, skip=skip, limit=limit, db=db - ) + for chat in Chats.get_chats_by_user_id_and_search_text(user.id, text, skip=skip, limit=limit, db=db) ] # Delete tag if no chat is found - words = text.strip().split(" ") - if page == 1 and len(words) == 1 and words[0].startswith("tag:"): - tag_id = words[0].replace("tag:", "") + words = text.strip().split(' ') + if page == 1 and len(words) == 1 and words[0].startswith('tag:'): + tag_id = words[0].replace('tag:', '') if len(chat_list) == 0: if Tags.get_tag_by_name_and_user_id(tag_id, user.id, db=db): - log.debug(f"deleting tag: {tag_id}") + log.debug(f'deleting tag: {tag_id}') Tags.delete_tag_by_name_and_user_id(tag_id, user.id, db=db) return chat_list @@ -670,26 +617,20 @@ def search_user_chats( ############################ -@router.get("/folder/{folder_id}", response_model=list[ChatResponse]) -async def get_chats_by_folder_id( - folder_id: str, user=Depends(get_verified_user), db: Session = Depends(get_session) -): +@router.get('/folder/{folder_id}', response_model=list[ChatResponse]) +async def get_chats_by_folder_id(folder_id: str, user=Depends(get_verified_user), db: Session = Depends(get_session)): folder_ids = [folder_id] - children_folders = Folders.get_children_folders_by_id_and_user_id( - folder_id, user.id, db=db - ) + children_folders = Folders.get_children_folders_by_id_and_user_id(folder_id, user.id, db=db) if children_folders: folder_ids.extend([folder.id for folder in children_folders]) return [ ChatResponse(**chat.model_dump()) - for chat in Chats.get_chats_by_folder_ids_and_user_id( - folder_ids, user.id, db=db - ) + for chat in Chats.get_chats_by_folder_ids_and_user_id(folder_ids, user.id, db=db) ] -@router.get("/folder/{folder_id}/list") +@router.get('/folder/{folder_id}/list') async def get_chat_list_by_folder_id( folder_id: str, page: Optional[int] = 1, @@ -701,17 +642,13 @@ async def get_chat_list_by_folder_id( skip = (page - 1) * limit return [ - {"title": chat.title, "id": chat.id, "updated_at": chat.updated_at} - for chat in Chats.get_chats_by_folder_id_and_user_id( - folder_id, user.id, skip=skip, limit=limit, db=db - ) + {'title': chat.title, 'id': chat.id, 'updated_at': chat.updated_at} + for chat in Chats.get_chats_by_folder_id_and_user_id(folder_id, user.id, skip=skip, limit=limit, db=db) ] except Exception as e: log.exception(e) - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT() - ) + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT()) ############################ @@ -719,10 +656,8 @@ async def get_chat_list_by_folder_id( ############################ -@router.get("/pinned", response_model=list[ChatTitleIdResponse]) -async def get_user_pinned_chats( - user=Depends(get_verified_user), db: Session = Depends(get_session) -): +@router.get('/pinned', response_model=list[ChatTitleIdResponse]) +async def get_user_pinned_chats(user=Depends(get_verified_user), db: Session = Depends(get_session)): return Chats.get_pinned_chats_by_user_id(user.id, db=db) @@ -731,10 +666,8 @@ async def get_user_pinned_chats( ############################ -@router.get("/all", response_model=list[ChatResponse]) -async def get_user_chats( - user=Depends(get_verified_user), db: Session = Depends(get_session) -): +@router.get('/all', response_model=list[ChatResponse]) +async def get_user_chats(user=Depends(get_verified_user), db: Session = Depends(get_session)): result = Chats.get_chats_by_user_id(user.id, db=db) return [ChatResponse(**chat.model_dump()) for chat in result.items] @@ -744,14 +677,9 @@ async def get_user_chats( ############################ -@router.get("/all/archived", response_model=list[ChatResponse]) -async def get_user_archived_chats( - user=Depends(get_verified_user), db: Session = Depends(get_session) -): - return [ - ChatResponse(**chat.model_dump()) - for chat in Chats.get_archived_chats_by_user_id(user.id, db=db) - ] +@router.get('/all/archived', response_model=list[ChatResponse]) +async def get_user_archived_chats(user=Depends(get_verified_user), db: Session = Depends(get_session)): + return [ChatResponse(**chat.model_dump()) for chat in Chats.get_archived_chats_by_user_id(user.id, db=db)] ############################ @@ -759,18 +687,14 @@ async def get_user_archived_chats( ############################ -@router.get("/all/tags", response_model=list[TagModel]) -async def get_all_user_tags( - user=Depends(get_verified_user), db: Session = Depends(get_session) -): +@router.get('/all/tags', response_model=list[TagModel]) +async def get_all_user_tags(user=Depends(get_verified_user), db: Session = Depends(get_session)): try: tags = Tags.get_tags_by_user_id(user.id, db=db) return tags except Exception as e: log.exception(e) - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT() - ) + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT()) ############################ @@ -778,10 +702,8 @@ async def get_all_user_tags( ############################ -@router.get("/all/db", response_model=list[ChatResponse]) -async def get_all_user_chats_in_db( - user=Depends(get_admin_user), db: Session = Depends(get_session) -): +@router.get('/all/db', response_model=list[ChatResponse]) +async def get_all_user_chats_in_db(user=Depends(get_admin_user), db: Session = Depends(get_session)): if not ENABLE_ADMIN_EXPORT: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, @@ -795,7 +717,7 @@ async def get_all_user_chats_in_db( ############################ -@router.get("/archived", response_model=list[ChatTitleIdResponse]) +@router.get('/archived', response_model=list[ChatTitleIdResponse]) async def get_archived_session_user_chat_list( page: Optional[int] = None, query: Optional[str] = None, @@ -812,11 +734,11 @@ async def get_archived_session_user_chat_list( filter = {} if query: - filter["query"] = query + filter['query'] = query if order_by: - filter["order_by"] = order_by + filter['order_by'] = order_by if direction: - filter["direction"] = direction + filter['direction'] = direction return Chats.get_archived_chat_list_by_user_id( user.id, @@ -832,10 +754,8 @@ async def get_archived_session_user_chat_list( ############################ -@router.post("/archive/all", response_model=bool) -async def archive_all_chats( - user=Depends(get_verified_user), db: Session = Depends(get_session) -): +@router.post('/archive/all', response_model=bool) +async def archive_all_chats(user=Depends(get_verified_user), db: Session = Depends(get_session)): return Chats.archive_all_chats_by_user_id(user.id, db=db) @@ -844,10 +764,8 @@ async def archive_all_chats( ############################ -@router.post("/unarchive/all", response_model=bool) -async def unarchive_all_chats( - user=Depends(get_verified_user), db: Session = Depends(get_session) -): +@router.post('/unarchive/all', response_model=bool) +async def unarchive_all_chats(user=Depends(get_verified_user), db: Session = Depends(get_session)): return Chats.unarchive_all_chats_by_user_id(user.id, db=db) @@ -856,7 +774,7 @@ async def unarchive_all_chats( ############################ -@router.get("/shared", response_model=list[SharedChatResponse]) +@router.get('/shared', response_model=list[SharedChatResponse]) async def get_shared_session_user_chat_list( page: Optional[int] = None, query: Optional[str] = None, @@ -873,11 +791,11 @@ async def get_shared_session_user_chat_list( filter = {} if query: - filter["query"] = query + filter['query'] = query if order_by: - filter["order_by"] = order_by + filter['order_by'] = order_by if direction: - filter["direction"] = direction + filter['direction'] = direction return Chats.get_shared_chat_list_by_user_id( user.id, @@ -893,27 +811,21 @@ async def get_shared_session_user_chat_list( ############################ -@router.get("/share/{share_id}", response_model=Optional[ChatResponse]) -async def get_shared_chat_by_id( - share_id: str, user=Depends(get_verified_user), db: Session = Depends(get_session) -): - if user.role == "pending": - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND - ) +@router.get('/share/{share_id}', response_model=Optional[ChatResponse]) +async def get_shared_chat_by_id(share_id: str, user=Depends(get_verified_user), db: Session = Depends(get_session)): + if user.role == 'pending': + raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND) - if user.role == "user" or (user.role == "admin" and not ENABLE_ADMIN_CHAT_ACCESS): + if user.role == 'user' or (user.role == 'admin' and not ENABLE_ADMIN_CHAT_ACCESS): chat = Chats.get_chat_by_share_id(share_id, db=db) - elif user.role == "admin" and ENABLE_ADMIN_CHAT_ACCESS: + elif user.role == 'admin' and ENABLE_ADMIN_CHAT_ACCESS: chat = Chats.get_chat_by_id(share_id, db=db) if chat: return ChatResponse(**chat.model_dump()) else: - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND - ) + raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND) ############################ @@ -930,15 +842,13 @@ class TagFilterForm(TagForm): limit: Optional[int] = 50 -@router.post("/tags", response_model=list[ChatTitleIdResponse]) +@router.post('/tags', response_model=list[ChatTitleIdResponse]) async def get_user_chat_list_by_tag_name( form_data: TagFilterForm, user=Depends(get_verified_user), db: Session = Depends(get_session), ): - chats = Chats.get_chat_list_by_user_id_and_tag_name( - user.id, form_data.name, form_data.skip, form_data.limit, db=db - ) + chats = Chats.get_chat_list_by_user_id_and_tag_name(user.id, form_data.name, form_data.skip, form_data.limit, db=db) if len(chats) == 0: Tags.delete_tag_by_name_and_user_id(form_data.name, user.id, db=db) @@ -950,19 +860,15 @@ async def get_user_chat_list_by_tag_name( ############################ -@router.get("/{id}", response_model=Optional[ChatResponse]) -async def get_chat_by_id( - id: str, user=Depends(get_verified_user), db: Session = Depends(get_session) -): +@router.get('/{id}', response_model=Optional[ChatResponse]) +async def get_chat_by_id(id: str, user=Depends(get_verified_user), db: Session = Depends(get_session)): chat = Chats.get_chat_by_id_and_user_id(id, user.id, db=db) if chat: return ChatResponse(**chat.model_dump()) else: - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND - ) + raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND) ############################ @@ -970,7 +876,7 @@ async def get_chat_by_id( ############################ -@router.post("/{id}", response_model=Optional[ChatResponse]) +@router.post('/{id}', response_model=Optional[ChatResponse]) async def update_chat_by_id( id: str, form_data: ChatForm, @@ -996,7 +902,7 @@ class MessageForm(BaseModel): content: str -@router.post("/{id}/messages/{message_id}", response_model=Optional[ChatResponse]) +@router.post('/{id}/messages/{message_id}', response_model=Optional[ChatResponse]) async def update_chat_message_by_id( id: str, message_id: str, @@ -1012,7 +918,7 @@ async def update_chat_message_by_id( detail=ERROR_MESSAGES.ACCESS_PROHIBITED, ) - if chat.user_id != user.id and user.role != "admin": + if chat.user_id != user.id and user.role != 'admin': raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.ACCESS_PROHIBITED, @@ -1022,16 +928,16 @@ async def update_chat_message_by_id( id, message_id, { - "content": form_data.content, + 'content': form_data.content, }, db=db, ) event_emitter = get_event_emitter( { - "user_id": user.id, - "chat_id": id, - "message_id": message_id, + 'user_id': user.id, + 'chat_id': id, + 'message_id': message_id, }, False, ) @@ -1039,11 +945,11 @@ async def update_chat_message_by_id( if event_emitter: await event_emitter( { - "type": "chat:message", - "data": { - "chat_id": id, - "message_id": message_id, - "content": form_data.content, + 'type': 'chat:message', + 'data': { + 'chat_id': id, + 'message_id': message_id, + 'content': form_data.content, }, } ) @@ -1059,7 +965,7 @@ class EventForm(BaseModel): data: dict -@router.post("/{id}/messages/{message_id}/event", response_model=Optional[bool]) +@router.post('/{id}/messages/{message_id}/event', response_model=Optional[bool]) async def send_chat_message_event_by_id( id: str, message_id: str, @@ -1075,7 +981,7 @@ async def send_chat_message_event_by_id( detail=ERROR_MESSAGES.ACCESS_PROHIBITED, ) - if chat.user_id != user.id and user.role != "admin": + if chat.user_id != user.id and user.role != 'admin': raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.ACCESS_PROHIBITED, @@ -1083,9 +989,9 @@ async def send_chat_message_event_by_id( event_emitter = get_event_emitter( { - "user_id": user.id, - "chat_id": id, - "message_id": message_id, + 'user_id': user.id, + 'chat_id': id, + 'message_id': message_id, } ) @@ -1104,31 +1010,27 @@ async def send_chat_message_event_by_id( ############################ -@router.delete("/{id}", response_model=bool) +@router.delete('/{id}', response_model=bool) async def delete_chat_by_id( request: Request, id: str, user=Depends(get_verified_user), db: Session = Depends(get_session), ): - if user.role == "admin": + if user.role == 'admin': chat = Chats.get_chat_by_id(id, db=db) if not chat: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND, ) - Chats.delete_orphan_tags_for_user( - chat.meta.get("tags", []), user.id, threshold=1, db=db - ) + Chats.delete_orphan_tags_for_user(chat.meta.get('tags', []), user.id, threshold=1, db=db) result = Chats.delete_chat_by_id(id, db=db) return result else: - if not has_permission( - user.id, "chat.delete", request.app.state.config.USER_PERMISSIONS - ): + if not has_permission(user.id, 'chat.delete', request.app.state.config.USER_PERMISSIONS): raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.ACCESS_PROHIBITED, @@ -1140,9 +1042,7 @@ async def delete_chat_by_id( status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND, ) - Chats.delete_orphan_tags_for_user( - chat.meta.get("tags", []), user.id, threshold=1, db=db - ) + Chats.delete_orphan_tags_for_user(chat.meta.get('tags', []), user.id, threshold=1, db=db) result = Chats.delete_chat_by_id_and_user_id(id, user.id, db=db) return result @@ -1153,17 +1053,13 @@ async def delete_chat_by_id( ############################ -@router.get("/{id}/pinned", response_model=Optional[bool]) -async def get_pinned_status_by_id( - id: str, user=Depends(get_verified_user), db: Session = Depends(get_session) -): +@router.get('/{id}/pinned', response_model=Optional[bool]) +async def get_pinned_status_by_id(id: str, user=Depends(get_verified_user), db: Session = Depends(get_session)): chat = Chats.get_chat_by_id_and_user_id(id, user.id, db=db) if chat: return chat.pinned else: - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.DEFAULT() - ) + raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.DEFAULT()) ############################ @@ -1171,18 +1067,14 @@ async def get_pinned_status_by_id( ############################ -@router.post("/{id}/pin", response_model=Optional[ChatResponse]) -async def pin_chat_by_id( - id: str, user=Depends(get_verified_user), db: Session = Depends(get_session) -): +@router.post('/{id}/pin', response_model=Optional[ChatResponse]) +async def pin_chat_by_id(id: str, user=Depends(get_verified_user), db: Session = Depends(get_session)): chat = Chats.get_chat_by_id_and_user_id(id, user.id, db=db) if chat: chat = Chats.toggle_chat_pinned_by_id(id, db=db) return chat else: - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.DEFAULT() - ) + raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.DEFAULT()) ############################ @@ -1194,7 +1086,7 @@ class CloneForm(BaseModel): title: Optional[str] = None -@router.post("/{id}/clone", response_model=Optional[ChatResponse]) +@router.post('/{id}/clone', response_model=Optional[ChatResponse]) async def clone_chat_by_id( form_data: CloneForm, id: str, @@ -1205,9 +1097,9 @@ async def clone_chat_by_id( if chat: updated_chat = { **chat.chat, - "originalChatId": chat.id, - "branchPointMessageId": chat.chat["history"]["currentId"], - "title": form_data.title if form_data.title else f"Clone of {chat.title}", + 'originalChatId': chat.id, + 'branchPointMessageId': chat.chat['history']['currentId'], + 'title': form_data.title if form_data.title else f'Clone of {chat.title}', } chats = Chats.import_chats( @@ -1215,10 +1107,10 @@ async def clone_chat_by_id( [ ChatImportForm( **{ - "chat": updated_chat, - "meta": chat.meta, - "pinned": chat.pinned, - "folder_id": chat.folder_id, + 'chat': updated_chat, + 'meta': chat.meta, + 'pinned': chat.pinned, + 'folder_id': chat.folder_id, } ) ], @@ -1234,9 +1126,7 @@ async def clone_chat_by_id( detail=ERROR_MESSAGES.DEFAULT(), ) else: - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.DEFAULT() - ) + raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.DEFAULT()) ############################ @@ -1244,12 +1134,9 @@ async def clone_chat_by_id( ############################ -@router.post("/{id}/clone/shared", response_model=Optional[ChatResponse]) -async def clone_shared_chat_by_id( - id: str, user=Depends(get_verified_user), db: Session = Depends(get_session) -): - - if user.role == "admin": +@router.post('/{id}/clone/shared', response_model=Optional[ChatResponse]) +async def clone_shared_chat_by_id(id: str, user=Depends(get_verified_user), db: Session = Depends(get_session)): + if user.role == 'admin': chat = Chats.get_chat_by_id(id, db=db) else: chat = Chats.get_chat_by_share_id(id, db=db) @@ -1257,9 +1144,9 @@ async def clone_shared_chat_by_id( if chat: updated_chat = { **chat.chat, - "originalChatId": chat.id, - "branchPointMessageId": chat.chat["history"]["currentId"], - "title": f"Clone of {chat.title}", + 'originalChatId': chat.id, + 'branchPointMessageId': chat.chat['history']['currentId'], + 'title': f'Clone of {chat.title}', } chats = Chats.import_chats( @@ -1267,10 +1154,10 @@ async def clone_shared_chat_by_id( [ ChatImportForm( **{ - "chat": updated_chat, - "meta": chat.meta, - "pinned": chat.pinned, - "folder_id": chat.folder_id, + 'chat': updated_chat, + 'meta': chat.meta, + 'pinned': chat.pinned, + 'folder_id': chat.folder_id, } ) ], @@ -1286,9 +1173,7 @@ async def clone_shared_chat_by_id( detail=ERROR_MESSAGES.DEFAULT(), ) else: - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.DEFAULT() - ) + raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.DEFAULT()) ############################ @@ -1296,15 +1181,13 @@ async def clone_shared_chat_by_id( ############################ -@router.post("/{id}/archive", response_model=Optional[ChatResponse]) -async def archive_chat_by_id( - id: str, user=Depends(get_verified_user), db: Session = Depends(get_session) -): +@router.post('/{id}/archive', response_model=Optional[ChatResponse]) +async def archive_chat_by_id(id: str, user=Depends(get_verified_user), db: Session = Depends(get_session)): chat = Chats.get_chat_by_id_and_user_id(id, user.id, db=db) if chat: chat = Chats.toggle_chat_archive_by_id(id, db=db) - tag_ids = chat.meta.get("tags", []) + tag_ids = chat.meta.get('tags', []) if chat.archived: # Archived chats are excluded from count — clean up orphans Chats.delete_orphan_tags_for_user(tag_ids, user.id, db=db) @@ -1314,9 +1197,7 @@ async def archive_chat_by_id( return ChatResponse(**chat.model_dump()) else: - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.DEFAULT() - ) + raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.DEFAULT()) ############################ @@ -1324,17 +1205,15 @@ async def archive_chat_by_id( ############################ -@router.post("/{id}/share", response_model=Optional[ChatResponse]) +@router.post('/{id}/share', response_model=Optional[ChatResponse]) async def share_chat_by_id( request: Request, id: str, user=Depends(get_verified_user), db: Session = Depends(get_session), ): - if (user.role != "admin") and ( - not has_permission( - user.id, "chat.share", request.app.state.config.USER_PERMISSIONS - ) + if (user.role != 'admin') and ( + not has_permission(user.id, 'chat.share', request.app.state.config.USER_PERMISSIONS) ): raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, @@ -1368,10 +1247,8 @@ async def share_chat_by_id( ############################ -@router.delete("/{id}/share", response_model=Optional[bool]) -async def delete_shared_chat_by_id( - id: str, user=Depends(get_verified_user), db: Session = Depends(get_session) -): +@router.delete('/{id}/share', response_model=Optional[bool]) +async def delete_shared_chat_by_id(id: str, user=Depends(get_verified_user), db: Session = Depends(get_session)): chat = Chats.get_chat_by_id_and_user_id(id, user.id, db=db) if chat: if not chat.share_id: @@ -1397,7 +1274,7 @@ class ChatFolderIdForm(BaseModel): folder_id: Optional[str] = None -@router.post("/{id}/folder", response_model=Optional[ChatResponse]) +@router.post('/{id}/folder', response_model=Optional[ChatResponse]) async def update_chat_folder_id_by_id( id: str, form_data: ChatFolderIdForm, @@ -1406,14 +1283,10 @@ async def update_chat_folder_id_by_id( ): chat = Chats.get_chat_by_id_and_user_id(id, user.id, db=db) if chat: - chat = Chats.update_chat_folder_id_by_id_and_user_id( - id, user.id, form_data.folder_id, db=db - ) + chat = Chats.update_chat_folder_id_by_id_and_user_id(id, user.id, form_data.folder_id, db=db) return ChatResponse(**chat.model_dump()) else: - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.DEFAULT() - ) + raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.DEFAULT()) ############################ @@ -1421,18 +1294,14 @@ async def update_chat_folder_id_by_id( ############################ -@router.get("/{id}/tags", response_model=list[TagModel]) -async def get_chat_tags_by_id( - id: str, user=Depends(get_verified_user), db: Session = Depends(get_session) -): +@router.get('/{id}/tags', response_model=list[TagModel]) +async def get_chat_tags_by_id(id: str, user=Depends(get_verified_user), db: Session = Depends(get_session)): chat = Chats.get_chat_by_id_and_user_id(id, user.id, db=db) if chat: - tags = chat.meta.get("tags", []) + tags = chat.meta.get('tags', []) return Tags.get_tags_by_ids_and_user_id(tags, user.id, db=db) else: - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND - ) + raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND) ############################ @@ -1440,7 +1309,7 @@ async def get_chat_tags_by_id( ############################ -@router.post("/{id}/tags", response_model=list[TagModel]) +@router.post('/{id}/tags', response_model=list[TagModel]) async def add_tag_by_id_and_tag_name( id: str, form_data: TagForm, @@ -1449,27 +1318,23 @@ async def add_tag_by_id_and_tag_name( ): chat = Chats.get_chat_by_id_and_user_id(id, user.id, db=db) if chat: - tags = chat.meta.get("tags", []) - tag_id = form_data.name.replace(" ", "_").lower() + tags = chat.meta.get('tags', []) + tag_id = form_data.name.replace(' ', '_').lower() - if tag_id == "none": + if tag_id == 'none': raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT("Tag name cannot be 'None'"), ) if tag_id not in tags: - Chats.add_chat_tag_by_id_and_user_id_and_tag_name( - id, user.id, form_data.name, db=db - ) + Chats.add_chat_tag_by_id_and_user_id_and_tag_name(id, user.id, form_data.name, db=db) chat = Chats.get_chat_by_id_and_user_id(id, user.id, db=db) - tags = chat.meta.get("tags", []) + tags = chat.meta.get('tags', []) return Tags.get_tags_by_ids_and_user_id(tags, user.id, db=db) else: - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.DEFAULT() - ) + raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.DEFAULT()) ############################ @@ -1477,7 +1342,7 @@ async def add_tag_by_id_and_tag_name( ############################ -@router.delete("/{id}/tags", response_model=list[TagModel]) +@router.delete('/{id}/tags', response_model=list[TagModel]) async def delete_tag_by_id_and_tag_name( id: str, form_data: TagForm, @@ -1486,23 +1351,16 @@ async def delete_tag_by_id_and_tag_name( ): chat = Chats.get_chat_by_id_and_user_id(id, user.id, db=db) if chat: - Chats.delete_tag_by_id_and_user_id_and_tag_name( - id, user.id, form_data.name, db=db - ) + Chats.delete_tag_by_id_and_user_id_and_tag_name(id, user.id, form_data.name, db=db) - if ( - Chats.count_chats_by_tag_name_and_user_id(form_data.name, user.id, db=db) - == 0 - ): + if Chats.count_chats_by_tag_name_and_user_id(form_data.name, user.id, db=db) == 0: Tags.delete_tag_by_name_and_user_id(form_data.name, user.id, db=db) chat = Chats.get_chat_by_id_and_user_id(id, user.id, db=db) - tags = chat.meta.get("tags", []) + tags = chat.meta.get('tags', []) return Tags.get_tags_by_ids_and_user_id(tags, user.id, db=db) else: - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND - ) + raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND) ############################ @@ -1510,18 +1368,14 @@ async def delete_tag_by_id_and_tag_name( ############################ -@router.delete("/{id}/tags/all", response_model=Optional[bool]) -async def delete_all_tags_by_id( - id: str, user=Depends(get_verified_user), db: Session = Depends(get_session) -): +@router.delete('/{id}/tags/all', response_model=Optional[bool]) +async def delete_all_tags_by_id(id: str, user=Depends(get_verified_user), db: Session = Depends(get_session)): chat = Chats.get_chat_by_id_and_user_id(id, user.id, db=db) if chat: - old_tags = chat.meta.get("tags", []) + old_tags = chat.meta.get('tags', []) Chats.delete_all_tags_by_id_and_user_id(id, user.id, db=db) Chats.delete_orphan_tags_for_user(old_tags, user.id, db=db) return True else: - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND - ) + raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND) diff --git a/backend/open_webui/routers/configs.py b/backend/open_webui/routers/configs.py index 8d145d8cef..e0fb4bb610 100644 --- a/backend/open_webui/routers/configs.py +++ b/backend/open_webui/routers/configs.py @@ -44,7 +44,7 @@ class ImportConfigForm(BaseModel): config: dict -@router.post("/import", response_model=dict) +@router.post('/import', response_model=dict) async def import_config(form_data: ImportConfigForm, user=Depends(get_admin_user)): save_config(form_data.config) return get_config() @@ -55,7 +55,7 @@ async def import_config(form_data: ImportConfigForm, user=Depends(get_admin_user ############################ -@router.get("/export", response_model=dict) +@router.get('/export', response_model=dict) async def export_config(user=Depends(get_admin_user)): return get_config() @@ -70,30 +70,26 @@ class ConnectionsConfigForm(BaseModel): ENABLE_BASE_MODELS_CACHE: bool -@router.get("/connections", response_model=ConnectionsConfigForm) +@router.get('/connections', response_model=ConnectionsConfigForm) async def get_connections_config(request: Request, user=Depends(get_admin_user)): return { - "ENABLE_DIRECT_CONNECTIONS": request.app.state.config.ENABLE_DIRECT_CONNECTIONS, - "ENABLE_BASE_MODELS_CACHE": request.app.state.config.ENABLE_BASE_MODELS_CACHE, + 'ENABLE_DIRECT_CONNECTIONS': request.app.state.config.ENABLE_DIRECT_CONNECTIONS, + 'ENABLE_BASE_MODELS_CACHE': request.app.state.config.ENABLE_BASE_MODELS_CACHE, } -@router.post("/connections", response_model=ConnectionsConfigForm) +@router.post('/connections', response_model=ConnectionsConfigForm) async def set_connections_config( request: Request, form_data: ConnectionsConfigForm, user=Depends(get_admin_user), ): - request.app.state.config.ENABLE_DIRECT_CONNECTIONS = ( - form_data.ENABLE_DIRECT_CONNECTIONS - ) - request.app.state.config.ENABLE_BASE_MODELS_CACHE = ( - form_data.ENABLE_BASE_MODELS_CACHE - ) + request.app.state.config.ENABLE_DIRECT_CONNECTIONS = form_data.ENABLE_DIRECT_CONNECTIONS + request.app.state.config.ENABLE_BASE_MODELS_CACHE = form_data.ENABLE_BASE_MODELS_CACHE return { - "ENABLE_DIRECT_CONNECTIONS": request.app.state.config.ENABLE_DIRECT_CONNECTIONS, - "ENABLE_BASE_MODELS_CACHE": request.app.state.config.ENABLE_BASE_MODELS_CACHE, + 'ENABLE_DIRECT_CONNECTIONS': request.app.state.config.ENABLE_DIRECT_CONNECTIONS, + 'ENABLE_BASE_MODELS_CACHE': request.app.state.config.ENABLE_BASE_MODELS_CACHE, } @@ -103,7 +99,7 @@ class OAuthClientRegistrationForm(BaseModel): client_name: Optional[str] = None -@router.post("/oauth/clients/register") +@router.post('/oauth/clients/register') async def register_oauth_client( request: Request, form_data: OAuthClientRegistrationForm, @@ -113,24 +109,20 @@ async def register_oauth_client( try: oauth_client_id = form_data.client_id if type: - oauth_client_id = f"{type}:{form_data.client_id}" + oauth_client_id = f'{type}:{form_data.client_id}' - oauth_client_info = ( - await get_oauth_client_info_with_dynamic_client_registration( - request, oauth_client_id, form_data.url - ) + oauth_client_info = await get_oauth_client_info_with_dynamic_client_registration( + request, oauth_client_id, form_data.url ) return { - "status": True, - "oauth_client_info": encrypt_data( - oauth_client_info.model_dump(mode="json") - ), + 'status': True, + 'oauth_client_info': encrypt_data(oauth_client_info.model_dump(mode='json')), } except Exception as e: - log.debug(f"Failed to register OAuth client: {e}") + log.debug(f'Failed to register OAuth client: {e}') raise HTTPException( status_code=400, - detail=f"Failed to register OAuth client", + detail=f'Failed to register OAuth client', ) @@ -142,40 +134,40 @@ async def register_oauth_client( class ToolServerConnection(BaseModel): url: str path: str - type: Optional[str] = "openapi" # openapi, mcp + type: Optional[str] = 'openapi' # openapi, mcp auth_type: Optional[str] headers: Optional[dict | str] = None key: Optional[str] config: Optional[dict] - model_config = ConfigDict(extra="allow") + model_config = ConfigDict(extra='allow') class ToolServersConfigForm(BaseModel): TOOL_SERVER_CONNECTIONS: list[ToolServerConnection] -@router.get("/tool_servers", response_model=ToolServersConfigForm) +@router.get('/tool_servers', response_model=ToolServersConfigForm) async def get_tool_servers_config(request: Request, user=Depends(get_admin_user)): return { - "TOOL_SERVER_CONNECTIONS": request.app.state.config.TOOL_SERVER_CONNECTIONS, + 'TOOL_SERVER_CONNECTIONS': request.app.state.config.TOOL_SERVER_CONNECTIONS, } -@router.post("/tool_servers", response_model=ToolServersConfigForm) +@router.post('/tool_servers', response_model=ToolServersConfigForm) async def set_tool_servers_config( request: Request, form_data: ToolServersConfigForm, user=Depends(get_admin_user), ): for connection in request.app.state.config.TOOL_SERVER_CONNECTIONS: - server_type = connection.get("type", "openapi") - auth_type = connection.get("auth_type", "none") + server_type = connection.get('type', 'openapi') + auth_type = connection.get('auth_type', 'none') - if auth_type == "oauth_2.1": + if auth_type == 'oauth_2.1': # Remove existing OAuth clients for tool servers - server_id = connection.get("info", {}).get("id") - client_key = f"{server_type}:{server_id}" + server_id = connection.get('info', {}).get('id') + client_key = f'{server_type}:{server_id}' try: request.app.state.oauth_client_manager.remove_client(client_key) @@ -190,42 +182,40 @@ async def set_tool_servers_config( await set_tool_servers(request) for connection in request.app.state.config.TOOL_SERVER_CONNECTIONS: - server_type = connection.get("type", "openapi") - if server_type == "mcp": - server_id = connection.get("info", {}).get("id") - auth_type = connection.get("auth_type", "none") + server_type = connection.get('type', 'openapi') + if server_type == 'mcp': + server_id = connection.get('info', {}).get('id') + auth_type = connection.get('auth_type', 'none') - if auth_type == "oauth_2.1" and server_id: + if auth_type == 'oauth_2.1' and server_id: try: - oauth_client_info = connection.get("info", {}).get( - "oauth_client_info", "" - ) + oauth_client_info = connection.get('info', {}).get('oauth_client_info', '') oauth_client_info = decrypt_data(oauth_client_info) request.app.state.oauth_client_manager.add_client( - f"{server_type}:{server_id}", + f'{server_type}:{server_id}', OAuthClientInformationFull(**oauth_client_info), ) except Exception as e: - log.debug(f"Failed to add OAuth client for MCP tool server: {e}") + log.debug(f'Failed to add OAuth client for MCP tool server: {e}') continue return { - "TOOL_SERVER_CONNECTIONS": request.app.state.config.TOOL_SERVER_CONNECTIONS, + 'TOOL_SERVER_CONNECTIONS': request.app.state.config.TOOL_SERVER_CONNECTIONS, } class TerminalServerConnection(BaseModel): - id: Optional[str] = "" - name: Optional[str] = "" + id: Optional[str] = '' + name: Optional[str] = '' enabled: Optional[bool] = True url: str - path: Optional[str] = "/openapi.json" + path: Optional[str] = '/openapi.json' - key: Optional[str] = "" - auth_type: Optional[str] = "bearer" + key: Optional[str] = '' + auth_type: Optional[str] = 'bearer' config: Optional[dict] = None @@ -234,21 +224,21 @@ class TerminalServerConnection(BaseModel): policy_id: Optional[str] = None policy: Optional[dict] = None # cached policy data - model_config = ConfigDict(extra="allow") + model_config = ConfigDict(extra='allow') class TerminalServersConfigForm(BaseModel): TERMINAL_SERVER_CONNECTIONS: list[TerminalServerConnection] -@router.get("/terminal_servers") +@router.get('/terminal_servers') async def get_terminal_servers_config(request: Request, user=Depends(get_admin_user)): return { - "TERMINAL_SERVER_CONNECTIONS": request.app.state.config.TERMINAL_SERVER_CONNECTIONS, + 'TERMINAL_SERVER_CONNECTIONS': request.app.state.config.TERMINAL_SERVER_CONNECTIONS, } -@router.post("/terminal_servers") +@router.post('/terminal_servers') async def set_terminal_servers_config( request: Request, form_data: TerminalServersConfigForm, @@ -261,57 +251,45 @@ async def set_terminal_servers_config( await set_terminal_servers(request) return { - "TERMINAL_SERVER_CONNECTIONS": request.app.state.config.TERMINAL_SERVER_CONNECTIONS, + 'TERMINAL_SERVER_CONNECTIONS': request.app.state.config.TERMINAL_SERVER_CONNECTIONS, } -@router.post("/tool_servers/verify") -async def verify_tool_servers_config( - request: Request, form_data: ToolServerConnection, user=Depends(get_admin_user) -): +@router.post('/tool_servers/verify') +async def verify_tool_servers_config(request: Request, form_data: ToolServerConnection, user=Depends(get_admin_user)): """ Verify the connection to the tool server. """ try: - if form_data.type == "mcp": - if form_data.auth_type == "oauth_2.1": + if form_data.type == 'mcp': + if form_data.auth_type == 'oauth_2.1': discovery_urls = await get_discovery_urls(form_data.url) for discovery_url in discovery_urls: - log.debug( - f"Trying to fetch OAuth 2.1 discovery document from {discovery_url}" - ) + log.debug(f'Trying to fetch OAuth 2.1 discovery document from {discovery_url}') async with aiohttp.ClientSession( trust_env=True, timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT), ) as session: - async with session.get( - discovery_url - ) as oauth_server_metadata_response: + async with session.get(discovery_url) as oauth_server_metadata_response: if oauth_server_metadata_response.status == 200: try: - oauth_server_metadata = ( - OAuthMetadata.model_validate( - await oauth_server_metadata_response.json() - ) + oauth_server_metadata = OAuthMetadata.model_validate( + await oauth_server_metadata_response.json() ) return { - "status": True, - "oauth_server_metadata": oauth_server_metadata.model_dump( - mode="json" - ), + 'status': True, + 'oauth_server_metadata': oauth_server_metadata.model_dump(mode='json'), } except Exception as e: - log.info( - f"Failed to parse OAuth 2.1 discovery document: {e}" - ) + log.info(f'Failed to parse OAuth 2.1 discovery document: {e}') raise HTTPException( status_code=400, - detail=f"Failed to parse OAuth 2.1 discovery document from {discovery_url}", + detail=f'Failed to parse OAuth 2.1 discovery document from {discovery_url}', ) raise HTTPException( status_code=400, - detail=f"Failed to fetch OAuth 2.1 discovery document from {discovery_urls}", + detail=f'Failed to fetch OAuth 2.1 discovery document from {discovery_urls}', ) else: try: @@ -319,25 +297,25 @@ async def verify_tool_servers_config( headers = None token = None - if form_data.auth_type == "bearer": + if form_data.auth_type == 'bearer': token = form_data.key - elif form_data.auth_type == "session": + elif form_data.auth_type == 'session': token = request.state.token.credentials - elif form_data.auth_type == "system_oauth": + elif form_data.auth_type == 'system_oauth': oauth_token = None try: - if request.cookies.get("oauth_session_id", None): + if request.cookies.get('oauth_session_id', None): oauth_token = await request.app.state.oauth_manager.get_oauth_token( user.id, - request.cookies.get("oauth_session_id", None), + request.cookies.get('oauth_session_id', None), ) if oauth_token: - token = oauth_token.get("access_token", "") + token = oauth_token.get('access_token', '') except Exception as e: pass if token: - headers = {"Authorization": f"Bearer {token}"} + headers = {'Authorization': f'Bearer {token}'} if form_data.headers and isinstance(form_data.headers, dict): if headers is None: @@ -347,14 +325,14 @@ async def verify_tool_servers_config( await client.connect(form_data.url, headers=headers) specs = await client.list_tool_specs() return { - "status": True, - "specs": specs, + 'status': True, + 'specs': specs, } except Exception as e: - log.debug(f"Failed to create MCP client: {e}") + log.debug(f'Failed to create MCP client: {e}') raise HTTPException( status_code=400, - detail=f"Failed to create MCP client", + detail=f'Failed to create MCP client', ) finally: if client: @@ -362,28 +340,26 @@ async def verify_tool_servers_config( else: # openapi token = None headers = None - if form_data.auth_type == "bearer": + if form_data.auth_type == 'bearer': token = form_data.key - elif form_data.auth_type == "session": + elif form_data.auth_type == 'session': token = request.state.token.credentials - elif form_data.auth_type == "system_oauth": + elif form_data.auth_type == 'system_oauth': try: - if request.cookies.get("oauth_session_id", None): - oauth_token = ( - await request.app.state.oauth_manager.get_oauth_token( - user.id, - request.cookies.get("oauth_session_id", None), - ) + if request.cookies.get('oauth_session_id', None): + oauth_token = await request.app.state.oauth_manager.get_oauth_token( + user.id, + request.cookies.get('oauth_session_id', None), ) if oauth_token: - token = oauth_token.get("access_token", "") + token = oauth_token.get('access_token', '') except Exception as e: pass if token: - headers = {"Authorization": f"Bearer {token}"} + headers = {'Authorization': f'Bearer {token}'} if form_data.headers and isinstance(form_data.headers, dict): if headers is None: @@ -395,10 +371,10 @@ async def verify_tool_servers_config( except HTTPException as e: raise e except Exception as e: - log.debug(f"Failed to connect to the tool server: {e}") + log.debug(f'Failed to connect to the tool server: {e}') raise HTTPException( status_code=400, - detail=f"Failed to connect to the tool server", + detail=f'Failed to connect to the tool server', ) @@ -423,91 +399,68 @@ class CodeInterpreterConfigForm(BaseModel): CODE_INTERPRETER_JUPYTER_TIMEOUT: Optional[int] -@router.get("/code_execution", response_model=CodeInterpreterConfigForm) +@router.get('/code_execution', response_model=CodeInterpreterConfigForm) async def get_code_execution_config(request: Request, user=Depends(get_admin_user)): return { - "ENABLE_CODE_EXECUTION": request.app.state.config.ENABLE_CODE_EXECUTION, - "CODE_EXECUTION_ENGINE": request.app.state.config.CODE_EXECUTION_ENGINE, - "CODE_EXECUTION_JUPYTER_URL": request.app.state.config.CODE_EXECUTION_JUPYTER_URL, - "CODE_EXECUTION_JUPYTER_AUTH": request.app.state.config.CODE_EXECUTION_JUPYTER_AUTH, - "CODE_EXECUTION_JUPYTER_AUTH_TOKEN": request.app.state.config.CODE_EXECUTION_JUPYTER_AUTH_TOKEN, - "CODE_EXECUTION_JUPYTER_AUTH_PASSWORD": request.app.state.config.CODE_EXECUTION_JUPYTER_AUTH_PASSWORD, - "CODE_EXECUTION_JUPYTER_TIMEOUT": request.app.state.config.CODE_EXECUTION_JUPYTER_TIMEOUT, - "ENABLE_CODE_INTERPRETER": request.app.state.config.ENABLE_CODE_INTERPRETER, - "CODE_INTERPRETER_ENGINE": request.app.state.config.CODE_INTERPRETER_ENGINE, - "CODE_INTERPRETER_PROMPT_TEMPLATE": request.app.state.config.CODE_INTERPRETER_PROMPT_TEMPLATE, - "CODE_INTERPRETER_JUPYTER_URL": request.app.state.config.CODE_INTERPRETER_JUPYTER_URL, - "CODE_INTERPRETER_JUPYTER_AUTH": request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH, - "CODE_INTERPRETER_JUPYTER_AUTH_TOKEN": request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH_TOKEN, - "CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD": request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD, - "CODE_INTERPRETER_JUPYTER_TIMEOUT": request.app.state.config.CODE_INTERPRETER_JUPYTER_TIMEOUT, + 'ENABLE_CODE_EXECUTION': request.app.state.config.ENABLE_CODE_EXECUTION, + 'CODE_EXECUTION_ENGINE': request.app.state.config.CODE_EXECUTION_ENGINE, + 'CODE_EXECUTION_JUPYTER_URL': request.app.state.config.CODE_EXECUTION_JUPYTER_URL, + 'CODE_EXECUTION_JUPYTER_AUTH': request.app.state.config.CODE_EXECUTION_JUPYTER_AUTH, + 'CODE_EXECUTION_JUPYTER_AUTH_TOKEN': request.app.state.config.CODE_EXECUTION_JUPYTER_AUTH_TOKEN, + 'CODE_EXECUTION_JUPYTER_AUTH_PASSWORD': request.app.state.config.CODE_EXECUTION_JUPYTER_AUTH_PASSWORD, + 'CODE_EXECUTION_JUPYTER_TIMEOUT': request.app.state.config.CODE_EXECUTION_JUPYTER_TIMEOUT, + 'ENABLE_CODE_INTERPRETER': request.app.state.config.ENABLE_CODE_INTERPRETER, + 'CODE_INTERPRETER_ENGINE': request.app.state.config.CODE_INTERPRETER_ENGINE, + 'CODE_INTERPRETER_PROMPT_TEMPLATE': request.app.state.config.CODE_INTERPRETER_PROMPT_TEMPLATE, + 'CODE_INTERPRETER_JUPYTER_URL': request.app.state.config.CODE_INTERPRETER_JUPYTER_URL, + 'CODE_INTERPRETER_JUPYTER_AUTH': request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH, + 'CODE_INTERPRETER_JUPYTER_AUTH_TOKEN': request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH_TOKEN, + 'CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD': request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD, + 'CODE_INTERPRETER_JUPYTER_TIMEOUT': request.app.state.config.CODE_INTERPRETER_JUPYTER_TIMEOUT, } -@router.post("/code_execution", response_model=CodeInterpreterConfigForm) +@router.post('/code_execution', response_model=CodeInterpreterConfigForm) async def set_code_execution_config( request: Request, form_data: CodeInterpreterConfigForm, user=Depends(get_admin_user) ): - request.app.state.config.ENABLE_CODE_EXECUTION = form_data.ENABLE_CODE_EXECUTION request.app.state.config.CODE_EXECUTION_ENGINE = form_data.CODE_EXECUTION_ENGINE - request.app.state.config.CODE_EXECUTION_JUPYTER_URL = ( - form_data.CODE_EXECUTION_JUPYTER_URL - ) - request.app.state.config.CODE_EXECUTION_JUPYTER_AUTH = ( - form_data.CODE_EXECUTION_JUPYTER_AUTH - ) - request.app.state.config.CODE_EXECUTION_JUPYTER_AUTH_TOKEN = ( - form_data.CODE_EXECUTION_JUPYTER_AUTH_TOKEN - ) - request.app.state.config.CODE_EXECUTION_JUPYTER_AUTH_PASSWORD = ( - form_data.CODE_EXECUTION_JUPYTER_AUTH_PASSWORD - ) - request.app.state.config.CODE_EXECUTION_JUPYTER_TIMEOUT = ( - form_data.CODE_EXECUTION_JUPYTER_TIMEOUT - ) + request.app.state.config.CODE_EXECUTION_JUPYTER_URL = form_data.CODE_EXECUTION_JUPYTER_URL + request.app.state.config.CODE_EXECUTION_JUPYTER_AUTH = form_data.CODE_EXECUTION_JUPYTER_AUTH + request.app.state.config.CODE_EXECUTION_JUPYTER_AUTH_TOKEN = form_data.CODE_EXECUTION_JUPYTER_AUTH_TOKEN + request.app.state.config.CODE_EXECUTION_JUPYTER_AUTH_PASSWORD = form_data.CODE_EXECUTION_JUPYTER_AUTH_PASSWORD + request.app.state.config.CODE_EXECUTION_JUPYTER_TIMEOUT = form_data.CODE_EXECUTION_JUPYTER_TIMEOUT request.app.state.config.ENABLE_CODE_INTERPRETER = form_data.ENABLE_CODE_INTERPRETER request.app.state.config.CODE_INTERPRETER_ENGINE = form_data.CODE_INTERPRETER_ENGINE - request.app.state.config.CODE_INTERPRETER_PROMPT_TEMPLATE = ( - form_data.CODE_INTERPRETER_PROMPT_TEMPLATE - ) + request.app.state.config.CODE_INTERPRETER_PROMPT_TEMPLATE = form_data.CODE_INTERPRETER_PROMPT_TEMPLATE - request.app.state.config.CODE_INTERPRETER_JUPYTER_URL = ( - form_data.CODE_INTERPRETER_JUPYTER_URL - ) + request.app.state.config.CODE_INTERPRETER_JUPYTER_URL = form_data.CODE_INTERPRETER_JUPYTER_URL - request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH = ( - form_data.CODE_INTERPRETER_JUPYTER_AUTH - ) + request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH = form_data.CODE_INTERPRETER_JUPYTER_AUTH - request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH_TOKEN = ( - form_data.CODE_INTERPRETER_JUPYTER_AUTH_TOKEN - ) - request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD = ( - form_data.CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD - ) - request.app.state.config.CODE_INTERPRETER_JUPYTER_TIMEOUT = ( - form_data.CODE_INTERPRETER_JUPYTER_TIMEOUT - ) + request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH_TOKEN = form_data.CODE_INTERPRETER_JUPYTER_AUTH_TOKEN + request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD = form_data.CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD + request.app.state.config.CODE_INTERPRETER_JUPYTER_TIMEOUT = form_data.CODE_INTERPRETER_JUPYTER_TIMEOUT return { - "ENABLE_CODE_EXECUTION": request.app.state.config.ENABLE_CODE_EXECUTION, - "CODE_EXECUTION_ENGINE": request.app.state.config.CODE_EXECUTION_ENGINE, - "CODE_EXECUTION_JUPYTER_URL": request.app.state.config.CODE_EXECUTION_JUPYTER_URL, - "CODE_EXECUTION_JUPYTER_AUTH": request.app.state.config.CODE_EXECUTION_JUPYTER_AUTH, - "CODE_EXECUTION_JUPYTER_AUTH_TOKEN": request.app.state.config.CODE_EXECUTION_JUPYTER_AUTH_TOKEN, - "CODE_EXECUTION_JUPYTER_AUTH_PASSWORD": request.app.state.config.CODE_EXECUTION_JUPYTER_AUTH_PASSWORD, - "CODE_EXECUTION_JUPYTER_TIMEOUT": request.app.state.config.CODE_EXECUTION_JUPYTER_TIMEOUT, - "ENABLE_CODE_INTERPRETER": request.app.state.config.ENABLE_CODE_INTERPRETER, - "CODE_INTERPRETER_ENGINE": request.app.state.config.CODE_INTERPRETER_ENGINE, - "CODE_INTERPRETER_PROMPT_TEMPLATE": request.app.state.config.CODE_INTERPRETER_PROMPT_TEMPLATE, - "CODE_INTERPRETER_JUPYTER_URL": request.app.state.config.CODE_INTERPRETER_JUPYTER_URL, - "CODE_INTERPRETER_JUPYTER_AUTH": request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH, - "CODE_INTERPRETER_JUPYTER_AUTH_TOKEN": request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH_TOKEN, - "CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD": request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD, - "CODE_INTERPRETER_JUPYTER_TIMEOUT": request.app.state.config.CODE_INTERPRETER_JUPYTER_TIMEOUT, + 'ENABLE_CODE_EXECUTION': request.app.state.config.ENABLE_CODE_EXECUTION, + 'CODE_EXECUTION_ENGINE': request.app.state.config.CODE_EXECUTION_ENGINE, + 'CODE_EXECUTION_JUPYTER_URL': request.app.state.config.CODE_EXECUTION_JUPYTER_URL, + 'CODE_EXECUTION_JUPYTER_AUTH': request.app.state.config.CODE_EXECUTION_JUPYTER_AUTH, + 'CODE_EXECUTION_JUPYTER_AUTH_TOKEN': request.app.state.config.CODE_EXECUTION_JUPYTER_AUTH_TOKEN, + 'CODE_EXECUTION_JUPYTER_AUTH_PASSWORD': request.app.state.config.CODE_EXECUTION_JUPYTER_AUTH_PASSWORD, + 'CODE_EXECUTION_JUPYTER_TIMEOUT': request.app.state.config.CODE_EXECUTION_JUPYTER_TIMEOUT, + 'ENABLE_CODE_INTERPRETER': request.app.state.config.ENABLE_CODE_INTERPRETER, + 'CODE_INTERPRETER_ENGINE': request.app.state.config.CODE_INTERPRETER_ENGINE, + 'CODE_INTERPRETER_PROMPT_TEMPLATE': request.app.state.config.CODE_INTERPRETER_PROMPT_TEMPLATE, + 'CODE_INTERPRETER_JUPYTER_URL': request.app.state.config.CODE_INTERPRETER_JUPYTER_URL, + 'CODE_INTERPRETER_JUPYTER_AUTH': request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH, + 'CODE_INTERPRETER_JUPYTER_AUTH_TOKEN': request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH_TOKEN, + 'CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD': request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD, + 'CODE_INTERPRETER_JUPYTER_TIMEOUT': request.app.state.config.CODE_INTERPRETER_JUPYTER_TIMEOUT, } @@ -522,32 +475,30 @@ class ModelsConfigForm(BaseModel): DEFAULT_MODEL_PARAMS: Optional[dict] = None -@router.get("/models", response_model=ModelsConfigForm) +@router.get('/models', response_model=ModelsConfigForm) async def get_models_config(request: Request, user=Depends(get_admin_user)): return { - "DEFAULT_MODELS": request.app.state.config.DEFAULT_MODELS, - "DEFAULT_PINNED_MODELS": request.app.state.config.DEFAULT_PINNED_MODELS, - "MODEL_ORDER_LIST": request.app.state.config.MODEL_ORDER_LIST, - "DEFAULT_MODEL_METADATA": request.app.state.config.DEFAULT_MODEL_METADATA, - "DEFAULT_MODEL_PARAMS": request.app.state.config.DEFAULT_MODEL_PARAMS, + 'DEFAULT_MODELS': request.app.state.config.DEFAULT_MODELS, + 'DEFAULT_PINNED_MODELS': request.app.state.config.DEFAULT_PINNED_MODELS, + 'MODEL_ORDER_LIST': request.app.state.config.MODEL_ORDER_LIST, + 'DEFAULT_MODEL_METADATA': request.app.state.config.DEFAULT_MODEL_METADATA, + 'DEFAULT_MODEL_PARAMS': request.app.state.config.DEFAULT_MODEL_PARAMS, } -@router.post("/models", response_model=ModelsConfigForm) -async def set_models_config( - request: Request, form_data: ModelsConfigForm, user=Depends(get_admin_user) -): +@router.post('/models', response_model=ModelsConfigForm) +async def set_models_config(request: Request, form_data: ModelsConfigForm, user=Depends(get_admin_user)): request.app.state.config.DEFAULT_MODELS = form_data.DEFAULT_MODELS request.app.state.config.DEFAULT_PINNED_MODELS = form_data.DEFAULT_PINNED_MODELS request.app.state.config.MODEL_ORDER_LIST = form_data.MODEL_ORDER_LIST request.app.state.config.DEFAULT_MODEL_METADATA = form_data.DEFAULT_MODEL_METADATA request.app.state.config.DEFAULT_MODEL_PARAMS = form_data.DEFAULT_MODEL_PARAMS return { - "DEFAULT_MODELS": request.app.state.config.DEFAULT_MODELS, - "DEFAULT_PINNED_MODELS": request.app.state.config.DEFAULT_PINNED_MODELS, - "MODEL_ORDER_LIST": request.app.state.config.MODEL_ORDER_LIST, - "DEFAULT_MODEL_METADATA": request.app.state.config.DEFAULT_MODEL_METADATA, - "DEFAULT_MODEL_PARAMS": request.app.state.config.DEFAULT_MODEL_PARAMS, + 'DEFAULT_MODELS': request.app.state.config.DEFAULT_MODELS, + 'DEFAULT_PINNED_MODELS': request.app.state.config.DEFAULT_PINNED_MODELS, + 'MODEL_ORDER_LIST': request.app.state.config.MODEL_ORDER_LIST, + 'DEFAULT_MODEL_METADATA': request.app.state.config.DEFAULT_MODEL_METADATA, + 'DEFAULT_MODEL_PARAMS': request.app.state.config.DEFAULT_MODEL_PARAMS, } @@ -560,14 +511,14 @@ class SetDefaultSuggestionsForm(BaseModel): suggestions: list[PromptSuggestion] -@router.post("/suggestions", response_model=list[PromptSuggestion]) +@router.post('/suggestions', response_model=list[PromptSuggestion]) async def set_default_suggestions( request: Request, form_data: SetDefaultSuggestionsForm, user=Depends(get_admin_user), ): data = form_data.model_dump() - request.app.state.config.DEFAULT_PROMPT_SUGGESTIONS = data["suggestions"] + request.app.state.config.DEFAULT_PROMPT_SUGGESTIONS = data['suggestions'] return request.app.state.config.DEFAULT_PROMPT_SUGGESTIONS @@ -580,18 +531,18 @@ class SetBannersForm(BaseModel): banners: list[BannerModel] -@router.post("/banners", response_model=list[BannerModel]) +@router.post('/banners', response_model=list[BannerModel]) async def set_banners( request: Request, form_data: SetBannersForm, user=Depends(get_admin_user), ): data = form_data.model_dump() - request.app.state.config.BANNERS = data["banners"] + request.app.state.config.BANNERS = data['banners'] return request.app.state.config.BANNERS -@router.get("/banners", response_model=list[BannerModel]) +@router.get('/banners', response_model=list[BannerModel]) async def get_banners( request: Request, user=Depends(get_verified_user), diff --git a/backend/open_webui/routers/evaluations.py b/backend/open_webui/routers/evaluations.py index 22bb20df9b..f301613286 100644 --- a/backend/open_webui/routers/evaluations.py +++ b/backend/open_webui/routers/evaluations.py @@ -51,9 +51,7 @@ router = APIRouter() import os -EMBEDDING_MODEL_NAME = os.environ.get( - "AUXILIARY_EMBEDDING_MODEL", "TaylorAI/bge-micro-v2" -) +EMBEDDING_MODEL_NAME = os.environ.get('AUXILIARY_EMBEDDING_MODEL', 'TaylorAI/bge-micro-v2') _embedding_model = None @@ -65,13 +63,11 @@ def _get_embedding_model(): _embedding_model = SentenceTransformer(EMBEDDING_MODEL_NAME) except Exception as e: - log.error(f"Embedding model load failed: {e}") + log.error(f'Embedding model load failed: {e}') return _embedding_model -def _calculate_elo( - feedbacks: list[LeaderboardFeedbackData], similarities: dict = None -) -> dict: +def _calculate_elo(feedbacks: list[LeaderboardFeedbackData], similarities: dict = None) -> dict: """ Calculate Elo ratings for models based on user feedback. @@ -90,35 +86,33 @@ def _calculate_elo( def get_or_create_stats(model_id): if model_id not in model_stats: - model_stats[model_id] = {"rating": 1000.0, "won": 0, "lost": 0} + model_stats[model_id] = {'rating': 1000.0, 'won': 0, 'lost': 0} return model_stats[model_id] for feedback in feedbacks: data = feedback.data or {} - winner_id = data.get("model_id") - rating_value = str(data.get("rating", "")) - if not winner_id or rating_value not in ("1", "-1"): + winner_id = data.get('model_id') + rating_value = str(data.get('rating', '')) + if not winner_id or rating_value not in ('1', '-1'): continue - won = rating_value == "1" + won = rating_value == '1' weight = similarities.get(feedback.id, 1.0) if similarities else 1.0 - for opponent_id in data.get("sibling_model_ids") or []: + for opponent_id in data.get('sibling_model_ids') or []: winner = get_or_create_stats(winner_id) opponent = get_or_create_stats(opponent_id) - expected = 1 / (1 + 10 ** ((opponent["rating"] - winner["rating"]) / 400)) + expected = 1 / (1 + 10 ** ((opponent['rating'] - winner['rating']) / 400)) - winner["rating"] += K_FACTOR * ((1 if won else 0) - expected) * weight - opponent["rating"] += ( - K_FACTOR * ((0 if won else 1) - (1 - expected)) * weight - ) + winner['rating'] += K_FACTOR * ((1 if won else 0) - expected) * weight + opponent['rating'] += K_FACTOR * ((0 if won else 1) - (1 - expected)) * weight if won: - winner["won"] += 1 - opponent["lost"] += 1 + winner['won'] += 1 + opponent['lost'] += 1 else: - winner["lost"] += 1 - opponent["won"] += 1 + winner['lost'] += 1 + opponent['won'] += 1 return model_stats @@ -139,16 +133,13 @@ def _get_top_tags(feedbacks: list[LeaderboardFeedbackData], limit: int = 5) -> d for feedback in feedbacks: data = feedback.data or {} - model_id = data.get("model_id") + model_id = data.get('model_id') if model_id: - for tag in data.get("tags", []): + for tag in data.get('tags', []): tag_counts[model_id][tag] += 1 return { - model_id: [ - {"tag": tag, "count": count} - for tag, count in sorted(tags.items(), key=lambda x: -x[1])[:limit] - ] + model_id: [{'tag': tag, 'count': count} for tag, count in sorted(tags.items(), key=lambda x: -x[1])[:limit]] for model_id, tags in tag_counts.items() } @@ -172,14 +163,7 @@ def _compute_similarities(feedbacks: list[LeaderboardFeedbackData], query: str) if not embedding_model: return {} - all_tags = list( - { - tag - for feedback in feedbacks - if feedback.data - for tag in feedback.data.get("tags", []) - } - ) + all_tags = list({tag for feedback in feedbacks if feedback.data for tag in feedback.data.get('tags', [])}) if not all_tags: return {} @@ -187,23 +171,18 @@ def _compute_similarities(feedbacks: list[LeaderboardFeedbackData], query: str) tag_embeddings = embedding_model.encode(all_tags) query_embedding = embedding_model.encode([query])[0] except Exception as e: - log.error(f"Embedding error: {e}") + log.error(f'Embedding error: {e}') return {} # Vectorized cosine similarity tag_norms = np.linalg.norm(tag_embeddings, axis=1) query_norm = np.linalg.norm(query_embedding) - similarities = np.dot(tag_embeddings, query_embedding) / ( - tag_norms * query_norm + 1e-9 - ) + similarities = np.dot(tag_embeddings, query_embedding) / (tag_norms * query_norm + 1e-9) tag_similarity_map = dict(zip(all_tags, similarities.tolist())) return { feedback.id: max( - ( - tag_similarity_map.get(tag, 0) - for tag in (feedback.data or {}).get("tags", []) - ), + (tag_similarity_map.get(tag, 0) for tag in (feedback.data or {}).get('tags', [])), default=0, ) for feedback in feedbacks @@ -223,7 +202,7 @@ class LeaderboardResponse(BaseModel): entries: list[LeaderboardEntry] -@router.get("/leaderboard", response_model=LeaderboardResponse) +@router.get('/leaderboard', response_model=LeaderboardResponse) async def get_leaderboard( query: Optional[str] = None, user=Depends(get_admin_user), @@ -234,9 +213,7 @@ async def get_leaderboard( similarities = None if query and query.strip(): - similarities = await run_in_threadpool( - _compute_similarities, feedbacks, query.strip() - ) + similarities = await run_in_threadpool(_compute_similarities, feedbacks, query.strip()) elo_stats = _calculate_elo(feedbacks, similarities) tags_by_model = _get_top_tags(feedbacks) @@ -245,10 +222,10 @@ async def get_leaderboard( [ LeaderboardEntry( model_id=mid, - rating=round(s["rating"]), - won=s["won"], - lost=s["lost"], - count=s["won"] + s["lost"], + rating=round(s['rating']), + won=s['won'], + lost=s['lost'], + count=s['won'] + s['lost'], top_tags=tags_by_model.get(mid, []), ) for mid, s in elo_stats.items() @@ -260,7 +237,7 @@ async def get_leaderboard( return LeaderboardResponse(entries=entries) -@router.get("/leaderboard/{model_id}/history", response_model=ModelHistoryResponse) +@router.get('/leaderboard/{model_id}/history', response_model=ModelHistoryResponse) async def get_model_history( model_id: str, days: int = 30, @@ -268,9 +245,7 @@ async def get_model_history( db: Session = Depends(get_session), ): """Get daily win/loss history for a specific model.""" - history = Feedbacks.get_model_evaluation_history( - model_id=model_id, days=days, db=db - ) + history = Feedbacks.get_model_evaluation_history(model_id=model_id, days=days, db=db) return ModelHistoryResponse(model_id=model_id, history=history) @@ -279,11 +254,11 @@ async def get_model_history( ############################ -@router.get("/config") +@router.get('/config') async def get_config(request: Request, user=Depends(get_admin_user)): return { - "ENABLE_EVALUATION_ARENA_MODELS": request.app.state.config.ENABLE_EVALUATION_ARENA_MODELS, - "EVALUATION_ARENA_MODELS": request.app.state.config.EVALUATION_ARENA_MODELS, + 'ENABLE_EVALUATION_ARENA_MODELS': request.app.state.config.ENABLE_EVALUATION_ARENA_MODELS, + 'EVALUATION_ARENA_MODELS': request.app.state.config.EVALUATION_ARENA_MODELS, } @@ -297,7 +272,7 @@ class UpdateConfigForm(BaseModel): EVALUATION_ARENA_MODELS: Optional[list[dict]] = None -@router.post("/config") +@router.post('/config') async def update_config( request: Request, form_data: UpdateConfigForm, @@ -309,54 +284,42 @@ async def update_config( if form_data.EVALUATION_ARENA_MODELS is not None: config.EVALUATION_ARENA_MODELS = form_data.EVALUATION_ARENA_MODELS return { - "ENABLE_EVALUATION_ARENA_MODELS": config.ENABLE_EVALUATION_ARENA_MODELS, - "EVALUATION_ARENA_MODELS": config.EVALUATION_ARENA_MODELS, + 'ENABLE_EVALUATION_ARENA_MODELS': config.ENABLE_EVALUATION_ARENA_MODELS, + 'EVALUATION_ARENA_MODELS': config.EVALUATION_ARENA_MODELS, } -@router.get("/feedbacks/all", response_model=list[FeedbackResponse]) -async def get_all_feedbacks( - user=Depends(get_admin_user), db: Session = Depends(get_session) -): +@router.get('/feedbacks/all', response_model=list[FeedbackResponse]) +async def get_all_feedbacks(user=Depends(get_admin_user), db: Session = Depends(get_session)): feedbacks = Feedbacks.get_all_feedbacks(db=db) return feedbacks -@router.get("/feedbacks/all/ids", response_model=list[FeedbackIdResponse]) -async def get_all_feedback_ids( - user=Depends(get_admin_user), db: Session = Depends(get_session) -): +@router.get('/feedbacks/all/ids', response_model=list[FeedbackIdResponse]) +async def get_all_feedback_ids(user=Depends(get_admin_user), db: Session = Depends(get_session)): return Feedbacks.get_all_feedback_ids(db=db) -@router.delete("/feedbacks/all") -async def delete_all_feedbacks( - user=Depends(get_admin_user), db: Session = Depends(get_session) -): +@router.delete('/feedbacks/all') +async def delete_all_feedbacks(user=Depends(get_admin_user), db: Session = Depends(get_session)): success = Feedbacks.delete_all_feedbacks(db=db) return success -@router.get("/feedbacks/all/export", response_model=list[FeedbackModel]) -async def export_all_feedbacks( - user=Depends(get_admin_user), db: Session = Depends(get_session) -): +@router.get('/feedbacks/all/export', response_model=list[FeedbackModel]) +async def export_all_feedbacks(user=Depends(get_admin_user), db: Session = Depends(get_session)): feedbacks = Feedbacks.get_all_feedbacks(db=db) return feedbacks -@router.get("/feedbacks/user", response_model=list[FeedbackUserResponse]) -async def get_feedbacks( - user=Depends(get_verified_user), db: Session = Depends(get_session) -): +@router.get('/feedbacks/user', response_model=list[FeedbackUserResponse]) +async def get_feedbacks(user=Depends(get_verified_user), db: Session = Depends(get_session)): feedbacks = Feedbacks.get_feedbacks_by_user_id(user.id, db=db) return feedbacks -@router.delete("/feedbacks", response_model=bool) -async def delete_feedbacks( - user=Depends(get_verified_user), db: Session = Depends(get_session) -): +@router.delete('/feedbacks', response_model=bool) +async def delete_feedbacks(user=Depends(get_verified_user), db: Session = Depends(get_session)): success = Feedbacks.delete_feedbacks_by_user_id(user.id, db=db) return success @@ -364,7 +327,7 @@ async def delete_feedbacks( PAGE_ITEM_COUNT = 30 -@router.get("/feedbacks/list", response_model=FeedbackListResponse) +@router.get('/feedbacks/list', response_model=FeedbackListResponse) async def get_feedbacks( order_by: Optional[str] = None, direction: Optional[str] = None, @@ -379,24 +342,22 @@ async def get_feedbacks( filter = {} if order_by: - filter["order_by"] = order_by + filter['order_by'] = order_by if direction: - filter["direction"] = direction + filter['direction'] = direction result = Feedbacks.get_feedback_items(filter=filter, skip=skip, limit=limit, db=db) return result -@router.post("/feedback", response_model=FeedbackModel) +@router.post('/feedback', response_model=FeedbackModel) async def create_feedback( request: Request, form_data: FeedbackForm, user=Depends(get_verified_user), db: Session = Depends(get_session), ): - feedback = Feedbacks.insert_new_feedback( - user_id=user.id, form_data=form_data, db=db - ) + feedback = Feedbacks.insert_new_feedback(user_id=user.id, form_data=form_data, db=db) if not feedback: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, @@ -406,61 +367,45 @@ async def create_feedback( return feedback -@router.get("/feedback/{id}", response_model=FeedbackModel) -async def get_feedback_by_id( - id: str, user=Depends(get_verified_user), db: Session = Depends(get_session) -): - if user.role == "admin": +@router.get('/feedback/{id}', response_model=FeedbackModel) +async def get_feedback_by_id(id: str, user=Depends(get_verified_user), db: Session = Depends(get_session)): + if user.role == 'admin': feedback = Feedbacks.get_feedback_by_id(id=id, db=db) else: - feedback = Feedbacks.get_feedback_by_id_and_user_id( - id=id, user_id=user.id, db=db - ) + feedback = Feedbacks.get_feedback_by_id_and_user_id(id=id, user_id=user.id, db=db) if not feedback: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND - ) + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND) return feedback -@router.post("/feedback/{id}", response_model=FeedbackModel) +@router.post('/feedback/{id}', response_model=FeedbackModel) async def update_feedback_by_id( id: str, form_data: FeedbackForm, user=Depends(get_verified_user), db: Session = Depends(get_session), ): - if user.role == "admin": + if user.role == 'admin': feedback = Feedbacks.update_feedback_by_id(id=id, form_data=form_data, db=db) else: - feedback = Feedbacks.update_feedback_by_id_and_user_id( - id=id, user_id=user.id, form_data=form_data, db=db - ) + feedback = Feedbacks.update_feedback_by_id_and_user_id(id=id, user_id=user.id, form_data=form_data, db=db) if not feedback: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND - ) + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND) return feedback -@router.delete("/feedback/{id}") -async def delete_feedback_by_id( - id: str, user=Depends(get_verified_user), db: Session = Depends(get_session) -): - if user.role == "admin": +@router.delete('/feedback/{id}') +async def delete_feedback_by_id(id: str, user=Depends(get_verified_user), db: Session = Depends(get_session)): + if user.role == 'admin': success = Feedbacks.delete_feedback_by_id(id=id, db=db) else: - success = Feedbacks.delete_feedback_by_id_and_user_id( - id=id, user_id=user.id, db=db - ) + success = Feedbacks.delete_feedback_by_id_and_user_id(id=id, user_id=user.id, db=db) if not success: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND - ) + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND) return success diff --git a/backend/open_webui/routers/files.py b/backend/open_webui/routers/files.py index 0ec16b6f50..b6d081085f 100644 --- a/backend/open_webui/routers/files.py +++ b/backend/open_webui/routers/files.py @@ -73,14 +73,14 @@ def _is_text_file(file_path: str, chunk_size: int = 8192) -> bool: """ try: resolved = Storage.get_file(file_path) - with open(resolved, "rb") as f: + with open(resolved, 'rb') as f: chunk = f.read(chunk_size) if not chunk: return False # Null bytes are a strong indicator of binary content - if b"\x00" in chunk: + if b'\x00' in chunk: return False - chunk.decode("utf-8") + chunk.decode('utf-8') return True except (UnicodeDecodeError, Exception): return False @@ -100,31 +100,25 @@ def process_uploaded_file( content_type = file.content_type # Detect mis-labeled text files (e.g. .ts → video/mp2t) - if content_type and content_type.startswith(("image/", "video/")): + if content_type and content_type.startswith(('image/', 'video/')): if _is_text_file(file_path): - content_type = "text/plain" + content_type = 'text/plain' if content_type: - stt_supported_content_types = getattr( - request.app.state.config, "STT_SUPPORTED_CONTENT_TYPES", [] - ) + stt_supported_content_types = getattr(request.app.state.config, 'STT_SUPPORTED_CONTENT_TYPES', []) if strict_match_mime_type(stt_supported_content_types, content_type): file_path_processed = Storage.get_file(file_path) - result = transcribe( - request, file_path_processed, file_metadata, user - ) + result = transcribe(request, file_path_processed, file_metadata, user) process_file( request, - ProcessFileForm( - file_id=file_item.id, content=result.get("text", "") - ), + ProcessFileForm(file_id=file_item.id, content=result.get('text', '')), user=user, db=db_session, ) - elif (not content_type.startswith(("image/", "video/"))) or ( - request.app.state.config.CONTENT_EXTRACTION_ENGINE == "external" + elif (not content_type.startswith(('image/', 'video/'))) or ( + request.app.state.config.CONTENT_EXTRACTION_ENGINE == 'external' ): process_file( request, @@ -133,13 +127,9 @@ def process_uploaded_file( db=db_session, ) else: - raise Exception( - f"File type {content_type} is not supported for processing" - ) + raise Exception(f'File type {content_type} is not supported for processing') else: - log.info( - f"File type {file.content_type} is not provided, but trying to process anyway" - ) + log.info(f'File type {file.content_type} is not provided, but trying to process anyway') process_file( request, ProcessFileForm(file_id=file_item.id), @@ -148,12 +138,12 @@ def process_uploaded_file( ) except Exception as e: - log.error(f"Error processing file: {file_item.id}") + log.error(f'Error processing file: {file_item.id}') Files.update_file_data_by_id( file_item.id, { - "status": "failed", - "error": str(e.detail) if hasattr(e, "detail") else str(e), + 'status': 'failed', + 'error': str(e.detail) if hasattr(e, 'detail') else str(e), }, db=db_session, ) @@ -165,7 +155,7 @@ def process_uploaded_file( _process_handler(db_session) -@router.post("/", response_model=FileModelResponse) +@router.post('/', response_model=FileModelResponse) def upload_file( request: Request, background_tasks: BackgroundTasks, @@ -198,7 +188,7 @@ def upload_file_handler( background_tasks: Optional[BackgroundTasks] = None, db: Optional[Session] = None, ): - log.info(f"file.content_type: {file.content_type} {process}") + log.info(f'file.content_type: {file.content_type} {process}') if isinstance(metadata, str): try: @@ -206,7 +196,7 @@ def upload_file_handler( except json.JSONDecodeError: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, - detail=ERROR_MESSAGES.DEFAULT("Invalid metadata format"), + detail=ERROR_MESSAGES.DEFAULT('Invalid metadata format'), ) file_metadata = metadata if metadata else {} @@ -216,7 +206,7 @@ def upload_file_handler( file_extension = os.path.splitext(filename)[1] # Remove the leading dot from the file extension - file_extension = file_extension[1:] if file_extension else "" + file_extension = file_extension[1:] if file_extension else '' if process and request.app.state.config.ALLOWED_FILE_EXTENSIONS: request.app.state.config.ALLOWED_FILE_EXTENSIONS = [ @@ -226,23 +216,21 @@ def upload_file_handler( if file_extension not in request.app.state.config.ALLOWED_FILE_EXTENSIONS: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, - detail=ERROR_MESSAGES.DEFAULT( - f"File type {file_extension} is not allowed" - ), + detail=ERROR_MESSAGES.DEFAULT(f'File type {file_extension} is not allowed'), ) # replace filename with uuid id = str(uuid.uuid4()) name = filename - filename = f"{id}_{filename}" + filename = f'{id}_{filename}' contents, file_path = Storage.upload_file( file.file, filename, { - "OpenWebUI-User-Email": user.email, - "OpenWebUI-User-Id": user.id, - "OpenWebUI-User-Name": user.name, - "OpenWebUI-File-Id": id, + 'OpenWebUI-User-Email': user.email, + 'OpenWebUI-User-Id': user.id, + 'OpenWebUI-User-Name': user.name, + 'OpenWebUI-File-Id': id, }, ) @@ -250,35 +238,27 @@ def upload_file_handler( user.id, FileForm( **{ - "id": id, - "filename": name, - "path": file_path, - "data": { - **({"status": "pending"} if process else {}), + 'id': id, + 'filename': name, + 'path': file_path, + 'data': { + **({'status': 'pending'} if process else {}), }, - "meta": { - "name": name, - "content_type": ( - file.content_type - if isinstance(file.content_type, str) - else None - ), - "size": len(contents), - "data": file_metadata, + 'meta': { + 'name': name, + 'content_type': (file.content_type if isinstance(file.content_type, str) else None), + 'size': len(contents), + 'data': file_metadata, }, } ), db=db, ) - if "channel_id" in file_metadata: - channel = Channels.get_channel_by_id_and_user_id( - file_metadata["channel_id"], user.id, db=db - ) + if 'channel_id' in file_metadata: + channel = Channels.get_channel_by_id_and_user_id(file_metadata['channel_id'], user.id, db=db) if channel: - Channels.add_file_to_channel_by_id( - channel.id, file_item.id, user.id, db=db - ) + Channels.add_file_to_channel_by_id(channel.id, file_item.id, user.id, db=db) if process: if background_tasks and process_in_background: @@ -291,7 +271,7 @@ def upload_file_handler( file_metadata, user, ) - return {"status": True, **file_item.model_dump()} + return {'status': True, **file_item.model_dump()} else: process_uploaded_file( request, @@ -302,14 +282,14 @@ def upload_file_handler( user, db=db, ) - return {"status": True, **file_item.model_dump()} + return {'status': True, **file_item.model_dump()} else: if file_item: return file_item else: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, - detail=ERROR_MESSAGES.DEFAULT("Error uploading file"), + detail=ERROR_MESSAGES.DEFAULT('Error uploading file'), ) except HTTPException as e: @@ -318,7 +298,7 @@ def upload_file_handler( log.exception(e) raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, - detail=ERROR_MESSAGES.DEFAULT("Error uploading file"), + detail=ERROR_MESSAGES.DEFAULT('Error uploading file'), ) @@ -330,26 +310,22 @@ def upload_file_handler( PAGE_SIZE = 50 -@router.get("/", response_model=FileListResponse) +@router.get('/', response_model=FileListResponse) async def list_files( user=Depends(get_verified_user), - page: int = Query(1, ge=1, description="Page number (1-indexed)"), + page: int = Query(1, ge=1, description='Page number (1-indexed)'), content: bool = Query(True), db: Session = Depends(get_session), ): skip = (page - 1) * PAGE_SIZE - user_id = ( - None if (user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL) else user.id - ) + user_id = None if (user.role == 'admin' and BYPASS_ADMIN_ACCESS_CONTROL) else user.id - result = Files.get_file_list( - user_id=user_id, skip=skip, limit=PAGE_SIZE, db=db - ) + result = Files.get_file_list(user_id=user_id, skip=skip, limit=PAGE_SIZE, db=db) if not content: for file in result.items: - if file.data and "content" in file.data: - del file.data["content"] + if file.data and 'content' in file.data: + del file.data['content'] return result @@ -359,17 +335,15 @@ async def list_files( ############################ -@router.get("/search", response_model=list[FileModelResponse]) +@router.get('/search', response_model=list[FileModelResponse]) async def search_files( filename: str = Query( ..., description="Filename pattern to search for. Supports wildcards such as '*.txt'", ), content: bool = Query(True), - skip: int = Query(0, ge=0, description="Number of files to skip"), - limit: int = Query( - 100, ge=1, le=1000, description="Maximum number of files to return" - ), + skip: int = Query(0, ge=0, description='Number of files to skip'), + limit: int = Query(100, ge=1, le=1000, description='Maximum number of files to return'), user=Depends(get_verified_user), db: Session = Depends(get_session), ): @@ -378,9 +352,7 @@ async def search_files( Uses SQL-based filtering with pagination for better performance. """ # Determine user_id: null for admin with bypass (search all), user.id otherwise - user_id = ( - None if (user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL) else user.id - ) + user_id = None if (user.role == 'admin' and BYPASS_ADMIN_ACCESS_CONTROL) else user.id # Use optimized database query with pagination files = Files.search_files( @@ -394,13 +366,13 @@ async def search_files( if not files: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, - detail="No files found matching the pattern.", + detail='No files found matching the pattern.', ) if not content: for file in files: - if file.data and "content" in file.data: - del file.data["content"] + if file.data and 'content' in file.data: + del file.data['content'] return files @@ -410,10 +382,8 @@ async def search_files( ############################ -@router.delete("/all") -async def delete_all_files( - user=Depends(get_admin_user), db: Session = Depends(get_session) -): +@router.delete('/all') +async def delete_all_files(user=Depends(get_admin_user), db: Session = Depends(get_session)): result = Files.delete_all_files(db=db) if result: try: @@ -421,16 +391,16 @@ async def delete_all_files( VECTOR_DB_CLIENT.reset() except Exception as e: log.exception(e) - log.error("Error deleting files") + log.error('Error deleting files') raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, - detail=ERROR_MESSAGES.DEFAULT("Error deleting files"), + detail=ERROR_MESSAGES.DEFAULT('Error deleting files'), ) - return {"message": "All files deleted successfully"} + return {'message': 'All files deleted successfully'} else: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, - detail=ERROR_MESSAGES.DEFAULT("Error deleting files"), + detail=ERROR_MESSAGES.DEFAULT('Error deleting files'), ) @@ -439,10 +409,8 @@ async def delete_all_files( ############################ -@router.get("/{id}", response_model=Optional[FileModel]) -async def get_file_by_id( - id: str, user=Depends(get_verified_user), db: Session = Depends(get_session) -): +@router.get('/{id}', response_model=Optional[FileModel]) +async def get_file_by_id(id: str, user=Depends(get_verified_user), db: Session = Depends(get_session)): file = Files.get_file_by_id(id, db=db) if not file: @@ -451,11 +419,7 @@ async def get_file_by_id( detail=ERROR_MESSAGES.NOT_FOUND, ) - if ( - file.user_id == user.id - or user.role == "admin" - or has_access_to_file(id, "read", user, db=db) - ): + if file.user_id == user.id or user.role == 'admin' or has_access_to_file(id, 'read', user, db=db): return file else: raise HTTPException( @@ -464,7 +428,7 @@ async def get_file_by_id( ) -@router.get("/{id}/process/status") +@router.get('/{id}/process/status') async def get_file_process_status( id: str, stream: bool = Query(False), @@ -479,11 +443,7 @@ async def get_file_process_status( detail=ERROR_MESSAGES.NOT_FOUND, ) - if ( - file.user_id == user.id - or user.role == "admin" - or has_access_to_file(id, "read", user, db=db) - ): + if file.user_id == user.id or user.role == 'admin' or has_access_to_file(id, 'read', user, db=db): if stream: MAX_FILE_PROCESSING_DURATION = 3600 * 2 @@ -494,32 +454,32 @@ async def get_file_process_status( for _ in range(MAX_FILE_PROCESSING_DURATION): file_item = Files.get_file_by_id(file_id) # Creates own session if file_item: - data = file_item.model_dump().get("data", {}) - status = data.get("status") + data = file_item.model_dump().get('data', {}) + status = data.get('status') if status: - event = {"status": status} - if status == "failed": - event["error"] = data.get("error") + event = {'status': status} + if status == 'failed': + event['error'] = data.get('error') - yield f"data: {json.dumps(event)}\n\n" - if status in ("completed", "failed"): + yield f'data: {json.dumps(event)}\n\n' + if status in ('completed', 'failed'): break else: # Legacy break else: - yield f"data: {json.dumps({'status': 'not_found'})}\n\n" + yield f'data: {json.dumps({"status": "not_found"})}\n\n' break await asyncio.sleep(1) return StreamingResponse( event_stream(file.id), - media_type="text/event-stream", + media_type='text/event-stream', ) else: - return {"status": file.data.get("status", "pending")} + return {'status': file.data.get('status', 'pending')} else: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, @@ -532,10 +492,8 @@ async def get_file_process_status( ############################ -@router.get("/{id}/data/content") -async def get_file_data_content_by_id( - id: str, user=Depends(get_verified_user), db: Session = Depends(get_session) -): +@router.get('/{id}/data/content') +async def get_file_data_content_by_id(id: str, user=Depends(get_verified_user), db: Session = Depends(get_session)): file = Files.get_file_by_id(id, db=db) if not file: @@ -544,12 +502,8 @@ async def get_file_data_content_by_id( detail=ERROR_MESSAGES.NOT_FOUND, ) - if ( - file.user_id == user.id - or user.role == "admin" - or has_access_to_file(id, "read", user, db=db) - ): - return {"content": file.data.get("content", "")} + if file.user_id == user.id or user.role == 'admin' or has_access_to_file(id, 'read', user, db=db): + return {'content': file.data.get('content', '')} else: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, @@ -566,7 +520,7 @@ class ContentForm(BaseModel): content: str -@router.post("/{id}/data/content/update") +@router.post('/{id}/data/content/update') def update_file_data_content_by_id( request: Request, id: str, @@ -582,11 +536,7 @@ def update_file_data_content_by_id( detail=ERROR_MESSAGES.NOT_FOUND, ) - if ( - file.user_id == user.id - or user.role == "admin" - or has_access_to_file(id, "write", user, db=db) - ): + if file.user_id == user.id or user.role == 'admin' or has_access_to_file(id, 'write', user, db=db): try: process_file( request, @@ -597,7 +547,7 @@ def update_file_data_content_by_id( file = Files.get_file_by_id(id=id, db=db) except Exception as e: log.exception(e) - log.error(f"Error processing file: {file.id}") + log.error(f'Error processing file: {file.id}') # Propagate content change to all knowledge collections referencing # this file. Without this the old embeddings remain in the knowledge @@ -606,9 +556,7 @@ def update_file_data_content_by_id( for knowledge in knowledges: try: # Remove old embeddings for this file from the KB collection - VECTOR_DB_CLIENT.delete( - collection_name=knowledge.id, filter={"file_id": id} - ) + VECTOR_DB_CLIENT.delete(collection_name=knowledge.id, filter={'file_id': id}) # Re-add from the now-updated file-{file_id} collection process_file( request, @@ -617,12 +565,9 @@ def update_file_data_content_by_id( db=db, ) except Exception as e: - log.warning( - f"Failed to update knowledge {knowledge.id} after " - f"content change for file {id}: {e}" - ) + log.warning(f'Failed to update knowledge {knowledge.id} after content change for file {id}: {e}') - return {"content": file.data.get("content", "")} + return {'content': file.data.get('content', '')} else: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, @@ -635,7 +580,7 @@ def update_file_data_content_by_id( ############################ -@router.get("/{id}/content") +@router.get('/{id}/content') async def get_file_content_by_id( id: str, user=Depends(get_verified_user), @@ -650,11 +595,7 @@ async def get_file_content_by_id( detail=ERROR_MESSAGES.NOT_FOUND, ) - if ( - file.user_id == user.id - or user.role == "admin" - or has_access_to_file(id, "read", user, db=db) - ): + if file.user_id == user.id or user.role == 'admin' or has_access_to_file(id, 'read', user, db=db): try: file_path = Storage.get_file(file.path) file_path = Path(file_path) @@ -662,30 +603,22 @@ async def get_file_content_by_id( # Check if the file already exists in the cache if file_path.is_file(): # Handle Unicode filenames - filename = file.meta.get("name", file.filename) + filename = file.meta.get('name', file.filename) encoded_filename = quote(filename) # RFC5987 encoding - content_type = file.meta.get("content_type") - filename = file.meta.get("name", file.filename) + content_type = file.meta.get('content_type') + filename = file.meta.get('name', file.filename) encoded_filename = quote(filename) headers = {} if attachment: - headers["Content-Disposition"] = ( - f"attachment; filename*=UTF-8''{encoded_filename}" - ) + headers['Content-Disposition'] = f"attachment; filename*=UTF-8''{encoded_filename}" else: - if content_type == "application/pdf" or filename.lower().endswith( - ".pdf" - ): - headers["Content-Disposition"] = ( - f"inline; filename*=UTF-8''{encoded_filename}" - ) - content_type = "application/pdf" - elif content_type != "text/plain": - headers["Content-Disposition"] = ( - f"attachment; filename*=UTF-8''{encoded_filename}" - ) + if content_type == 'application/pdf' or filename.lower().endswith('.pdf'): + headers['Content-Disposition'] = f"inline; filename*=UTF-8''{encoded_filename}" + content_type = 'application/pdf' + elif content_type != 'text/plain': + headers['Content-Disposition'] = f"attachment; filename*=UTF-8''{encoded_filename}" return FileResponse(file_path, headers=headers, media_type=content_type) @@ -698,10 +631,10 @@ async def get_file_content_by_id( raise e except Exception as e: log.exception(e) - log.error("Error getting file content") + log.error('Error getting file content') raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, - detail=ERROR_MESSAGES.DEFAULT("Error getting file content"), + detail=ERROR_MESSAGES.DEFAULT('Error getting file content'), ) else: raise HTTPException( @@ -710,10 +643,8 @@ async def get_file_content_by_id( ) -@router.get("/{id}/content/html") -async def get_html_file_content_by_id( - id: str, user=Depends(get_verified_user), db: Session = Depends(get_session) -): +@router.get('/{id}/content/html') +async def get_html_file_content_by_id(id: str, user=Depends(get_verified_user), db: Session = Depends(get_session)): file = Files.get_file_by_id(id, db=db) if not file: @@ -723,24 +654,20 @@ async def get_html_file_content_by_id( ) file_user = Users.get_user_by_id(file.user_id, db=db) - if not file_user.role == "admin": + if not file_user.role == 'admin': raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND, ) - if ( - file.user_id == user.id - or user.role == "admin" - or has_access_to_file(id, "read", user, db=db) - ): + if file.user_id == user.id or user.role == 'admin' or has_access_to_file(id, 'read', user, db=db): try: file_path = Storage.get_file(file.path) file_path = Path(file_path) # Check if the file already exists in the cache if file_path.is_file(): - log.info(f"file_path: {file_path}") + log.info(f'file_path: {file_path}') return FileResponse(file_path) else: raise HTTPException( @@ -751,10 +678,10 @@ async def get_html_file_content_by_id( raise e except Exception as e: log.exception(e) - log.error("Error getting file content") + log.error('Error getting file content') raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, - detail=ERROR_MESSAGES.DEFAULT("Error getting file content"), + detail=ERROR_MESSAGES.DEFAULT('Error getting file content'), ) else: raise HTTPException( @@ -763,10 +690,8 @@ async def get_html_file_content_by_id( ) -@router.get("/{id}/content/{file_name}") -async def get_file_content_by_id( - id: str, user=Depends(get_verified_user), db: Session = Depends(get_session) -): +@router.get('/{id}/content/{file_name}') +async def get_file_content_by_id(id: str, user=Depends(get_verified_user), db: Session = Depends(get_session)): file = Files.get_file_by_id(id, db=db) if not file: @@ -775,19 +700,13 @@ async def get_file_content_by_id( detail=ERROR_MESSAGES.NOT_FOUND, ) - if ( - file.user_id == user.id - or user.role == "admin" - or has_access_to_file(id, "read", user, db=db) - ): + if file.user_id == user.id or user.role == 'admin' or has_access_to_file(id, 'read', user, db=db): file_path = file.path # Handle Unicode filenames - filename = file.meta.get("name", file.filename) + filename = file.meta.get('name', file.filename) encoded_filename = quote(filename) # RFC5987 encoding - headers = { - "Content-Disposition": f"attachment; filename*=UTF-8''{encoded_filename}" - } + headers = {'Content-Disposition': f"attachment; filename*=UTF-8''{encoded_filename}"} if file_path: file_path = Storage.get_file(file_path) @@ -803,16 +722,16 @@ async def get_file_content_by_id( ) else: # File path doesn’t exist, return the content as .txt if possible - file_content = file.content.get("content", "") + file_content = file.content.get('content', '') file_name = file.filename # Create a generator that encodes the file content def generator(): - yield file_content.encode("utf-8") + yield file_content.encode('utf-8') return StreamingResponse( generator(), - media_type="text/plain", + media_type='text/plain', headers=headers, ) else: @@ -827,10 +746,8 @@ async def get_file_content_by_id( ############################ -@router.delete("/{id}") -async def delete_file_by_id( - id: str, user=Depends(get_verified_user), db: Session = Depends(get_session) -): +@router.delete('/{id}') +async def delete_file_by_id(id: str, user=Depends(get_verified_user), db: Session = Depends(get_session)): file = Files.get_file_by_id(id, db=db) if not file: @@ -839,12 +756,7 @@ async def delete_file_by_id( detail=ERROR_MESSAGES.NOT_FOUND, ) - if ( - file.user_id == user.id - or user.role == "admin" - or has_access_to_file(id, "write", user, db=db) - ): - + if file.user_id == user.id or user.role == 'admin' or has_access_to_file(id, 'write', user, db=db): # Clean up KB associations and embeddings before deleting knowledges = Knowledges.get_knowledges_by_file_id(id, db=db) for knowledge in knowledges: @@ -852,33 +764,29 @@ async def delete_file_by_id( Knowledges.remove_file_from_knowledge_by_id(knowledge.id, id, db=db) # Clean KB embeddings (same logic as /knowledge/{id}/file/remove) try: - VECTOR_DB_CLIENT.delete( - collection_name=knowledge.id, filter={"file_id": id} - ) + VECTOR_DB_CLIENT.delete(collection_name=knowledge.id, filter={'file_id': id}) if file.hash: - VECTOR_DB_CLIENT.delete( - collection_name=knowledge.id, filter={"hash": file.hash} - ) + VECTOR_DB_CLIENT.delete(collection_name=knowledge.id, filter={'hash': file.hash}) except Exception as e: - log.debug(f"KB embedding cleanup for {knowledge.id}: {e}") + log.debug(f'KB embedding cleanup for {knowledge.id}: {e}') result = Files.delete_file_by_id(id, db=db) if result: try: Storage.delete_file(file.path) - VECTOR_DB_CLIENT.delete(collection_name=f"file-{id}") + VECTOR_DB_CLIENT.delete(collection_name=f'file-{id}') except Exception as e: log.exception(e) - log.error("Error deleting files") + log.error('Error deleting files') raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, - detail=ERROR_MESSAGES.DEFAULT("Error deleting files"), + detail=ERROR_MESSAGES.DEFAULT('Error deleting files'), ) - return {"message": "File deleted successfully"} + return {'message': 'File deleted successfully'} else: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, - detail=ERROR_MESSAGES.DEFAULT("Error deleting file"), + detail=ERROR_MESSAGES.DEFAULT('Error deleting file'), ) else: raise HTTPException( diff --git a/backend/open_webui/routers/folders.py b/backend/open_webui/routers/folders.py index f6269e7b72..0bf5a87f1e 100644 --- a/backend/open_webui/routers/folders.py +++ b/backend/open_webui/routers/folders.py @@ -44,7 +44,7 @@ router = APIRouter() ############################ -@router.get("/", response_model=list[FolderNameIdResponse]) +@router.get('/', response_model=list[FolderNameIdResponse]) async def get_folders( request: Request, user=Depends(get_verified_user), @@ -56,9 +56,9 @@ async def get_folders( detail=ERROR_MESSAGES.ACCESS_PROHIBITED, ) - if user.role != "admin" and not has_permission( + if user.role != 'admin' and not has_permission( user.id, - "features.folders", + 'features.folders', request.app.state.config.USER_PERMISSIONS, db=db, ): @@ -72,35 +72,24 @@ async def get_folders( # Verify folder data integrity folder_list = [] for folder in folders: - if folder.parent_id and not Folders.get_folder_by_id_and_user_id( - folder.parent_id, user.id, db=db - ): - folder = Folders.update_folder_parent_id_by_id_and_user_id( - folder.id, user.id, None, db=db - ) + if folder.parent_id and not Folders.get_folder_by_id_and_user_id(folder.parent_id, user.id, db=db): + folder = Folders.update_folder_parent_id_by_id_and_user_id(folder.id, user.id, None, db=db) if folder.data: - if "files" in folder.data: + if 'files' in folder.data: valid_files = [] - for file in folder.data["files"]: - - if file.get("type") == "file": - if Files.check_access_by_user_id( - file.get("id"), user.id, "read", db=db - ): + for file in folder.data['files']: + if file.get('type') == 'file': + if Files.check_access_by_user_id(file.get('id'), user.id, 'read', db=db): valid_files.append(file) - elif file.get("type") == "collection": - if Knowledges.check_access_by_user_id( - file.get("id"), user.id, "read", db=db - ): + elif file.get('type') == 'collection': + if Knowledges.check_access_by_user_id(file.get('id'), user.id, 'read', db=db): valid_files.append(file) else: valid_files.append(file) - folder.data["files"] = valid_files - Folders.update_folder_by_id_and_user_id( - folder.id, user.id, FolderUpdateForm(data=folder.data), db=db - ) + folder.data['files'] = valid_files + Folders.update_folder_by_id_and_user_id(folder.id, user.id, FolderUpdateForm(data=folder.data), db=db) folder_list.append(FolderNameIdResponse(**folder.model_dump())) @@ -112,33 +101,29 @@ async def get_folders( ############################ -@router.post("/") +@router.post('/') def create_folder( form_data: FolderForm, user=Depends(get_verified_user), db: Session = Depends(get_session), ): - folder = Folders.get_folder_by_parent_id_and_user_id_and_name( - form_data.parent_id, user.id, form_data.name, db=db - ) + folder = Folders.get_folder_by_parent_id_and_user_id_and_name(form_data.parent_id, user.id, form_data.name, db=db) if folder: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, - detail=ERROR_MESSAGES.DEFAULT("Folder already exists"), + detail=ERROR_MESSAGES.DEFAULT('Folder already exists'), ) try: - folder = Folders.insert_new_folder( - user.id, form_data, form_data.parent_id, db=db - ) + folder = Folders.insert_new_folder(user.id, form_data, form_data.parent_id, db=db) return folder except Exception as e: log.exception(e) - log.error("Error creating folder") + log.error('Error creating folder') raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, - detail=ERROR_MESSAGES.DEFAULT("Error creating folder"), + detail=ERROR_MESSAGES.DEFAULT('Error creating folder'), ) @@ -147,10 +132,8 @@ def create_folder( ############################ -@router.get("/{id}", response_model=Optional[FolderModel]) -async def get_folder_by_id( - id: str, user=Depends(get_verified_user), db: Session = Depends(get_session) -): +@router.get('/{id}', response_model=Optional[FolderModel]) +async def get_folder_by_id(id: str, user=Depends(get_verified_user), db: Session = Depends(get_session)): folder = Folders.get_folder_by_id_and_user_id(id, user.id, db=db) if folder: return folder @@ -166,7 +149,7 @@ async def get_folder_by_id( ############################ -@router.post("/{id}/update") +@router.post('/{id}/update') async def update_folder_name_by_id( id: str, form_data: FolderUpdateForm, @@ -175,7 +158,6 @@ async def update_folder_name_by_id( ): folder = Folders.get_folder_by_id_and_user_id(id, user.id, db=db) if folder: - if form_data.name is not None: # Check if folder with same name exists existing_folder = Folders.get_folder_by_parent_id_and_user_id_and_name( @@ -184,20 +166,18 @@ async def update_folder_name_by_id( if existing_folder and existing_folder.id != id: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, - detail=ERROR_MESSAGES.DEFAULT("Folder already exists"), + detail=ERROR_MESSAGES.DEFAULT('Folder already exists'), ) try: - folder = Folders.update_folder_by_id_and_user_id( - id, user.id, form_data, db=db - ) + folder = Folders.update_folder_by_id_and_user_id(id, user.id, form_data, db=db) return folder except Exception as e: log.exception(e) - log.error(f"Error updating folder: {id}") + log.error(f'Error updating folder: {id}') raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, - detail=ERROR_MESSAGES.DEFAULT("Error updating folder"), + detail=ERROR_MESSAGES.DEFAULT('Error updating folder'), ) else: raise HTTPException( @@ -215,7 +195,7 @@ class FolderParentIdForm(BaseModel): parent_id: Optional[str] = None -@router.post("/{id}/update/parent") +@router.post('/{id}/update/parent') async def update_folder_parent_id_by_id( id: str, form_data: FolderParentIdForm, @@ -231,20 +211,18 @@ async def update_folder_parent_id_by_id( if existing_folder: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, - detail=ERROR_MESSAGES.DEFAULT("Folder already exists"), + detail=ERROR_MESSAGES.DEFAULT('Folder already exists'), ) try: - folder = Folders.update_folder_parent_id_by_id_and_user_id( - id, user.id, form_data.parent_id, db=db - ) + folder = Folders.update_folder_parent_id_by_id_and_user_id(id, user.id, form_data.parent_id, db=db) return folder except Exception as e: log.exception(e) - log.error(f"Error updating folder: {id}") + log.error(f'Error updating folder: {id}') raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, - detail=ERROR_MESSAGES.DEFAULT("Error updating folder"), + detail=ERROR_MESSAGES.DEFAULT('Error updating folder'), ) else: raise HTTPException( @@ -262,7 +240,7 @@ class FolderIsExpandedForm(BaseModel): is_expanded: bool -@router.post("/{id}/update/expanded") +@router.post('/{id}/update/expanded') async def update_folder_is_expanded_by_id( id: str, form_data: FolderIsExpandedForm, @@ -272,16 +250,14 @@ async def update_folder_is_expanded_by_id( folder = Folders.get_folder_by_id_and_user_id(id, user.id, db=db) if folder: try: - folder = Folders.update_folder_is_expanded_by_id_and_user_id( - id, user.id, form_data.is_expanded, db=db - ) + folder = Folders.update_folder_is_expanded_by_id_and_user_id(id, user.id, form_data.is_expanded, db=db) return folder except Exception as e: log.exception(e) - log.error(f"Error updating folder: {id}") + log.error(f'Error updating folder: {id}') raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, - detail=ERROR_MESSAGES.DEFAULT("Error updating folder"), + detail=ERROR_MESSAGES.DEFAULT('Error updating folder'), ) else: raise HTTPException( @@ -295,7 +271,7 @@ async def update_folder_is_expanded_by_id( ############################ -@router.delete("/{id}") +@router.delete('/{id}') async def delete_folder_by_id( request: Request, id: str, @@ -305,9 +281,9 @@ async def delete_folder_by_id( ): if Chats.count_chats_by_folder_id_and_user_id(id, user.id, db=db): chat_delete_permission = has_permission( - user.id, "chat.delete", request.app.state.config.USER_PERMISSIONS, db=db + user.id, 'chat.delete', request.app.state.config.USER_PERMISSIONS, db=db ) - if user.role != "admin" and not chat_delete_permission: + if user.role != 'admin' and not chat_delete_permission: raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.ACCESS_PROHIBITED, @@ -319,33 +295,25 @@ async def delete_folder_by_id( folder = folders.pop() if folder: try: - folder_ids = Folders.delete_folder_by_id_and_user_id( - folder.id, user.id, db=db - ) + folder_ids = Folders.delete_folder_by_id_and_user_id(folder.id, user.id, db=db) for folder_id in folder_ids: if delete_contents: - Chats.delete_chats_by_user_id_and_folder_id( - user.id, folder_id, db=db - ) + Chats.delete_chats_by_user_id_and_folder_id(user.id, folder_id, db=db) else: - Chats.move_chats_by_user_id_and_folder_id( - user.id, folder_id, None, db=db - ) + Chats.move_chats_by_user_id_and_folder_id(user.id, folder_id, None, db=db) return True except Exception as e: log.exception(e) - log.error(f"Error deleting folder: {id}") + log.error(f'Error deleting folder: {id}') raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, - detail=ERROR_MESSAGES.DEFAULT("Error deleting folder"), + detail=ERROR_MESSAGES.DEFAULT('Error deleting folder'), ) finally: # Get all subfolders - subfolders = Folders.get_folders_by_parent_id_and_user_id( - folder.id, user.id, db=db - ) + subfolders = Folders.get_folders_by_parent_id_and_user_id(folder.id, user.id, db=db) folders.extend(subfolders) else: diff --git a/backend/open_webui/routers/functions.py b/backend/open_webui/routers/functions.py index 3af3b1664a..44f139dc07 100644 --- a/backend/open_webui/routers/functions.py +++ b/backend/open_webui/routers/functions.py @@ -39,17 +39,13 @@ router = APIRouter() ############################ -@router.get("/", response_model=list[FunctionResponse]) -async def get_functions( - user=Depends(get_verified_user), db: Session = Depends(get_session) -): +@router.get('/', response_model=list[FunctionResponse]) +async def get_functions(user=Depends(get_verified_user), db: Session = Depends(get_session)): return Functions.get_functions(db=db) -@router.get("/list", response_model=list[FunctionUserResponse]) -async def get_function_list( - user=Depends(get_admin_user), db: Session = Depends(get_session) -): +@router.get('/list', response_model=list[FunctionUserResponse]) +async def get_function_list(user=Depends(get_admin_user), db: Session = Depends(get_session)): return Functions.get_function_list(db=db) @@ -58,7 +54,7 @@ async def get_function_list( ############################ -@router.get("/export", response_model=list[FunctionModel | FunctionWithValvesModel]) +@router.get('/export', response_model=list[FunctionModel | FunctionWithValvesModel]) async def get_functions( include_valves: bool = False, user=Depends(get_admin_user), @@ -78,70 +74,59 @@ class LoadUrlForm(BaseModel): def github_url_to_raw_url(url: str) -> str: # Handle 'tree' (folder) URLs (add main.py at the end) - m1 = re.match(r"https://github\.com/([^/]+)/([^/]+)/tree/([^/]+)/(.*)", url) + m1 = re.match(r'https://github\.com/([^/]+)/([^/]+)/tree/([^/]+)/(.*)', url) if m1: org, repo, branch, path = m1.groups() - return f"https://raw.githubusercontent.com/{org}/{repo}/refs/heads/{branch}/{path.rstrip('/')}/main.py" + return f'https://raw.githubusercontent.com/{org}/{repo}/refs/heads/{branch}/{path.rstrip("/")}/main.py' # Handle 'blob' (file) URLs - m2 = re.match(r"https://github\.com/([^/]+)/([^/]+)/blob/([^/]+)/(.*)", url) + m2 = re.match(r'https://github\.com/([^/]+)/([^/]+)/blob/([^/]+)/(.*)', url) if m2: org, repo, branch, path = m2.groups() - return ( - f"https://raw.githubusercontent.com/{org}/{repo}/refs/heads/{branch}/{path}" - ) + return f'https://raw.githubusercontent.com/{org}/{repo}/refs/heads/{branch}/{path}' # No match; return as-is return url -@router.post("/load/url", response_model=Optional[dict]) -async def load_function_from_url( - request: Request, form_data: LoadUrlForm, user=Depends(get_admin_user) -): +@router.post('/load/url', response_model=Optional[dict]) +async def load_function_from_url(request: Request, form_data: LoadUrlForm, user=Depends(get_admin_user)): # NOTE: This is NOT a SSRF vulnerability: # This endpoint is admin-only (see get_admin_user), meant for *trusted* internal use, # and does NOT accept untrusted user input. Access is enforced by authentication. url = str(form_data.url) if not url: - raise HTTPException(status_code=400, detail="Please enter a valid URL") + raise HTTPException(status_code=400, detail='Please enter a valid URL') url = github_url_to_raw_url(url) - url_parts = url.rstrip("/").split("/") + url_parts = url.rstrip('/').split('/') file_name = url_parts[-1] function_name = ( file_name[:-3] - if ( - file_name.endswith(".py") - and (not file_name.startswith(("main.py", "index.py", "__init__.py"))) - ) - else url_parts[-2] if len(url_parts) > 1 else "function" + if (file_name.endswith('.py') and (not file_name.startswith(('main.py', 'index.py', '__init__.py')))) + else url_parts[-2] + if len(url_parts) > 1 + else 'function' ) try: async with aiohttp.ClientSession( trust_env=True, timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT) ) as session: - async with session.get( - url, headers={"Content-Type": "application/json"} - ) as resp: + async with session.get(url, headers={'Content-Type': 'application/json'}) as resp: if resp.status != 200: - raise HTTPException( - status_code=resp.status, detail="Failed to fetch the function" - ) + raise HTTPException(status_code=resp.status, detail='Failed to fetch the function') data = await resp.text() if not data: - raise HTTPException( - status_code=400, detail="No data received from the URL" - ) + raise HTTPException(status_code=400, detail='No data received from the URL') return { - "name": function_name, - "content": data, + 'name': function_name, + 'content': data, } except Exception as e: - raise HTTPException(status_code=500, detail=f"Error importing function: {e}") + raise HTTPException(status_code=500, detail=f'Error importing function: {e}') ############################ @@ -153,7 +138,7 @@ class SyncFunctionsForm(BaseModel): functions: list[FunctionWithValvesModel] = [] -@router.post("/sync", response_model=list[FunctionWithValvesModel]) +@router.post('/sync', response_model=list[FunctionWithValvesModel]) async def sync_functions( request: Request, form_data: SyncFunctionsForm, @@ -168,21 +153,17 @@ async def sync_functions( content=function.content, ) - if hasattr(function_module, "Valves") and function.valves: + if hasattr(function_module, 'Valves') and function.valves: Valves = function_module.Valves try: - Valves( - **{k: v for k, v in function.valves.items() if v is not None} - ) + Valves(**{k: v for k, v in function.valves.items() if v is not None}) except Exception as e: - log.exception( - f"Error validating valves for function {function.id}: {e}" - ) + log.exception(f'Error validating valves for function {function.id}: {e}') raise e return Functions.sync_functions(user.id, form_data.functions, db=db) except Exception as e: - log.exception(f"Failed to load a function: {e}") + log.exception(f'Failed to load a function: {e}') raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT(e), @@ -194,7 +175,7 @@ async def sync_functions( ############################ -@router.post("/create", response_model=Optional[FunctionResponse]) +@router.post('/create', response_model=Optional[FunctionResponse]) async def create_new_function( request: Request, form_data: FunctionForm, @@ -204,7 +185,7 @@ async def create_new_function( if not form_data.id.isidentifier(): raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, - detail="Only alphanumeric characters and underscores are allowed in the id", + detail='Only alphanumeric characters and underscores are allowed in the id', ) form_data.id = form_data.id.lower() @@ -222,27 +203,23 @@ async def create_new_function( FUNCTIONS = request.app.state.FUNCTIONS FUNCTIONS[form_data.id] = function_module - function = Functions.insert_new_function( - user.id, function_type, form_data, db=db - ) + function = Functions.insert_new_function(user.id, function_type, form_data, db=db) - function_cache_dir = CACHE_DIR / "functions" / form_data.id + function_cache_dir = CACHE_DIR / 'functions' / form_data.id function_cache_dir.mkdir(parents=True, exist_ok=True) - if function_type == "filter" and getattr(function_module, "toggle", None): - Functions.update_function_metadata_by_id( - form_data.id, {"toggle": True}, db=db - ) + if function_type == 'filter' and getattr(function_module, 'toggle', None): + Functions.update_function_metadata_by_id(form_data.id, {'toggle': True}, db=db) if function: return function else: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, - detail=ERROR_MESSAGES.DEFAULT("Error creating function"), + detail=ERROR_MESSAGES.DEFAULT('Error creating function'), ) except Exception as e: - log.exception(f"Failed to create a new function: {e}") + log.exception(f'Failed to create a new function: {e}') raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT(e), @@ -259,10 +236,8 @@ async def create_new_function( ############################ -@router.get("/id/{id}", response_model=Optional[FunctionModel]) -async def get_function_by_id( - id: str, user=Depends(get_admin_user), db: Session = Depends(get_session) -): +@router.get('/id/{id}', response_model=Optional[FunctionModel]) +async def get_function_by_id(id: str, user=Depends(get_admin_user), db: Session = Depends(get_session)): function = Functions.get_function_by_id(id, db=db) if function: @@ -279,22 +254,18 @@ async def get_function_by_id( ############################ -@router.post("/id/{id}/toggle", response_model=Optional[FunctionModel]) -async def toggle_function_by_id( - id: str, user=Depends(get_admin_user), db: Session = Depends(get_session) -): +@router.post('/id/{id}/toggle', response_model=Optional[FunctionModel]) +async def toggle_function_by_id(id: str, user=Depends(get_admin_user), db: Session = Depends(get_session)): function = Functions.get_function_by_id(id, db=db) if function: - function = Functions.update_function_by_id( - id, {"is_active": not function.is_active}, db=db - ) + function = Functions.update_function_by_id(id, {'is_active': not function.is_active}, db=db) if function: return function else: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, - detail=ERROR_MESSAGES.DEFAULT("Error updating function"), + detail=ERROR_MESSAGES.DEFAULT('Error updating function'), ) else: raise HTTPException( @@ -308,22 +279,18 @@ async def toggle_function_by_id( ############################ -@router.post("/id/{id}/toggle/global", response_model=Optional[FunctionModel]) -async def toggle_global_by_id( - id: str, user=Depends(get_admin_user), db: Session = Depends(get_session) -): +@router.post('/id/{id}/toggle/global', response_model=Optional[FunctionModel]) +async def toggle_global_by_id(id: str, user=Depends(get_admin_user), db: Session = Depends(get_session)): function = Functions.get_function_by_id(id, db=db) if function: - function = Functions.update_function_by_id( - id, {"is_global": not function.is_global}, db=db - ) + function = Functions.update_function_by_id(id, {'is_global': not function.is_global}, db=db) if function: return function else: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, - detail=ERROR_MESSAGES.DEFAULT("Error updating function"), + detail=ERROR_MESSAGES.DEFAULT('Error updating function'), ) else: raise HTTPException( @@ -337,7 +304,7 @@ async def toggle_global_by_id( ############################ -@router.post("/id/{id}/update", response_model=Optional[FunctionModel]) +@router.post('/id/{id}/update', response_model=Optional[FunctionModel]) async def update_function_by_id( request: Request, id: str, @@ -347,28 +314,26 @@ async def update_function_by_id( ): try: form_data.content = replace_imports(form_data.content) - function_module, function_type, frontmatter = load_function_module_by_id( - id, content=form_data.content - ) + function_module, function_type, frontmatter = load_function_module_by_id(id, content=form_data.content) form_data.meta.manifest = frontmatter FUNCTIONS = request.app.state.FUNCTIONS FUNCTIONS[id] = function_module - updated = {**form_data.model_dump(exclude={"id"}), "type": function_type} + updated = {**form_data.model_dump(exclude={'id'}), 'type': function_type} log.debug(updated) function = Functions.update_function_by_id(id, updated, db=db) - if function_type == "filter" and getattr(function_module, "toggle", None): - Functions.update_function_metadata_by_id(id, {"toggle": True}, db=db) + if function_type == 'filter' and getattr(function_module, 'toggle', None): + Functions.update_function_metadata_by_id(id, {'toggle': True}, db=db) if function: return function else: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, - detail=ERROR_MESSAGES.DEFAULT("Error updating function"), + detail=ERROR_MESSAGES.DEFAULT('Error updating function'), ) except Exception as e: @@ -383,7 +348,7 @@ async def update_function_by_id( ############################ -@router.delete("/id/{id}/delete", response_model=bool) +@router.delete('/id/{id}/delete', response_model=bool) async def delete_function_by_id( request: Request, id: str, @@ -405,10 +370,8 @@ async def delete_function_by_id( ############################ -@router.get("/id/{id}/valves", response_model=Optional[dict]) -async def get_function_valves_by_id( - id: str, user=Depends(get_admin_user), db: Session = Depends(get_session) -): +@router.get('/id/{id}/valves', response_model=Optional[dict]) +async def get_function_valves_by_id(id: str, user=Depends(get_admin_user), db: Session = Depends(get_session)): function = Functions.get_function_by_id(id, db=db) if function: try: @@ -431,7 +394,7 @@ async def get_function_valves_by_id( ############################ -@router.get("/id/{id}/valves/spec", response_model=Optional[dict]) +@router.get('/id/{id}/valves/spec', response_model=Optional[dict]) async def get_function_valves_spec_by_id( request: Request, id: str, @@ -440,11 +403,9 @@ async def get_function_valves_spec_by_id( ): function = Functions.get_function_by_id(id, db=db) if function: - function_module, function_type, frontmatter = get_function_module_from_cache( - request, id - ) + function_module, function_type, frontmatter = get_function_module_from_cache(request, id) - if hasattr(function_module, "Valves"): + if hasattr(function_module, 'Valves'): Valves = function_module.Valves schema = Valves.schema() # Resolve dynamic options for select dropdowns @@ -463,7 +424,7 @@ async def get_function_valves_spec_by_id( ############################ -@router.post("/id/{id}/valves/update", response_model=Optional[dict]) +@router.post('/id/{id}/valves/update', response_model=Optional[dict]) async def update_function_valves_by_id( request: Request, id: str, @@ -473,11 +434,9 @@ async def update_function_valves_by_id( ): function = Functions.get_function_by_id(id, db=db) if function: - function_module, function_type, frontmatter = get_function_module_from_cache( - request, id - ) + function_module, function_type, frontmatter = get_function_module_from_cache(request, id) - if hasattr(function_module, "Valves"): + if hasattr(function_module, 'Valves'): Valves = function_module.Valves try: @@ -488,7 +447,7 @@ async def update_function_valves_by_id( Functions.update_function_valves_by_id(id, valves_dict, db=db) return valves_dict except Exception as e: - log.exception(f"Error updating function values by id {id}: {e}") + log.exception(f'Error updating function values by id {id}: {e}') raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT(e), @@ -511,16 +470,12 @@ async def update_function_valves_by_id( ############################ -@router.get("/id/{id}/valves/user", response_model=Optional[dict]) -async def get_function_user_valves_by_id( - id: str, user=Depends(get_verified_user), db: Session = Depends(get_session) -): +@router.get('/id/{id}/valves/user', response_model=Optional[dict]) +async def get_function_user_valves_by_id(id: str, user=Depends(get_verified_user), db: Session = Depends(get_session)): function = Functions.get_function_by_id(id, db=db) if function: try: - user_valves = Functions.get_user_valves_by_id_and_user_id( - id, user.id, db=db - ) + user_valves = Functions.get_user_valves_by_id_and_user_id(id, user.id, db=db) return user_valves except Exception as e: raise HTTPException( @@ -534,7 +489,7 @@ async def get_function_user_valves_by_id( ) -@router.get("/id/{id}/valves/user/spec", response_model=Optional[dict]) +@router.get('/id/{id}/valves/user/spec', response_model=Optional[dict]) async def get_function_user_valves_spec_by_id( request: Request, id: str, @@ -543,11 +498,9 @@ async def get_function_user_valves_spec_by_id( ): function = Functions.get_function_by_id(id, db=db) if function: - function_module, function_type, frontmatter = get_function_module_from_cache( - request, id - ) + function_module, function_type, frontmatter = get_function_module_from_cache(request, id) - if hasattr(function_module, "UserValves"): + if hasattr(function_module, 'UserValves'): UserValves = function_module.UserValves schema = UserValves.schema() # Resolve dynamic options for select dropdowns @@ -561,7 +514,7 @@ async def get_function_user_valves_spec_by_id( ) -@router.post("/id/{id}/valves/user/update", response_model=Optional[dict]) +@router.post('/id/{id}/valves/user/update', response_model=Optional[dict]) async def update_function_user_valves_by_id( request: Request, id: str, @@ -572,23 +525,19 @@ async def update_function_user_valves_by_id( function = Functions.get_function_by_id(id, db=db) if function: - function_module, function_type, frontmatter = get_function_module_from_cache( - request, id - ) + function_module, function_type, frontmatter = get_function_module_from_cache(request, id) - if hasattr(function_module, "UserValves"): + if hasattr(function_module, 'UserValves'): UserValves = function_module.UserValves try: form_data = {k: v for k, v in form_data.items() if v is not None} user_valves = UserValves(**form_data) user_valves_dict = user_valves.model_dump(exclude_unset=True) - Functions.update_user_valves_by_id_and_user_id( - id, user.id, user_valves_dict, db=db - ) + Functions.update_user_valves_by_id_and_user_id(id, user.id, user_valves_dict, db=db) return user_valves_dict except Exception as e: - log.exception(f"Error updating function user valves by id {id}: {e}") + log.exception(f'Error updating function user valves by id {id}: {e}') raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT(e), diff --git a/backend/open_webui/routers/groups.py b/backend/open_webui/routers/groups.py index 3711a52ab4..4e9688c3d8 100755 --- a/backend/open_webui/routers/groups.py +++ b/backend/open_webui/routers/groups.py @@ -31,20 +31,19 @@ router = APIRouter() ############################ -@router.get("/", response_model=list[GroupResponse]) +@router.get('/', response_model=list[GroupResponse]) async def get_groups( share: Optional[bool] = None, user=Depends(get_verified_user), db: Session = Depends(get_session), ): - filter = {} # Admins can share to all groups regardless of share setting - if user.role != "admin": - filter["member_id"] = user.id + if user.role != 'admin': + filter['member_id'] = user.id if share is not None: - filter["share"] = share + filter['share'] = share groups = Groups.get_groups(filter=filter, db=db) @@ -56,7 +55,7 @@ async def get_groups( ############################ -@router.post("/create", response_model=Optional[GroupResponse]) +@router.post('/create', response_model=Optional[GroupResponse]) async def create_new_group( form_data: GroupForm, user=Depends(get_admin_user), @@ -72,10 +71,10 @@ async def create_new_group( else: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, - detail=ERROR_MESSAGES.DEFAULT("Error creating group"), + detail=ERROR_MESSAGES.DEFAULT('Error creating group'), ) except Exception as e: - log.exception(f"Error creating a new group: {e}") + log.exception(f'Error creating a new group: {e}') raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT(e), @@ -87,10 +86,8 @@ async def create_new_group( ############################ -@router.get("/id/{id}", response_model=Optional[GroupResponse]) -async def get_group_by_id( - id: str, user=Depends(get_admin_user), db: Session = Depends(get_session) -): +@router.get('/id/{id}', response_model=Optional[GroupResponse]) +async def get_group_by_id(id: str, user=Depends(get_admin_user), db: Session = Depends(get_session)): group = Groups.get_group_by_id(id, db=db) if group: return GroupResponse( @@ -104,10 +101,8 @@ async def get_group_by_id( ) -@router.get("/id/{id}/info", response_model=Optional[GroupInfoResponse]) -async def get_group_info_by_id( - id: str, user=Depends(get_verified_user), db: Session = Depends(get_session) -): +@router.get('/id/{id}/info', response_model=Optional[GroupInfoResponse]) +async def get_group_info_by_id(id: str, user=Depends(get_verified_user), db: Session = Depends(get_session)): group = Groups.get_group_by_id(id, db=db) if group: return GroupInfoResponse( @@ -131,10 +126,8 @@ class GroupExportResponse(GroupResponse): pass -@router.get("/id/{id}/export", response_model=Optional[GroupExportResponse]) -async def export_group_by_id( - id: str, user=Depends(get_admin_user), db: Session = Depends(get_session) -): +@router.get('/id/{id}/export', response_model=Optional[GroupExportResponse]) +async def export_group_by_id(id: str, user=Depends(get_admin_user), db: Session = Depends(get_session)): group = Groups.get_group_by_id(id, db=db) if group: return GroupExportResponse( @@ -154,15 +147,13 @@ async def export_group_by_id( ############################ -@router.post("/id/{id}/users", response_model=list[UserInfoResponse]) -async def get_users_in_group( - id: str, user=Depends(get_admin_user), db: Session = Depends(get_session) -): +@router.post('/id/{id}/users', response_model=list[UserInfoResponse]) +async def get_users_in_group(id: str, user=Depends(get_admin_user), db: Session = Depends(get_session)): try: users = Users.get_users_by_group_id(id, db=db) return users except Exception as e: - log.exception(f"Error adding users to group {id}: {e}") + log.exception(f'Error adding users to group {id}: {e}') raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT(e), @@ -174,7 +165,7 @@ async def get_users_in_group( ############################ -@router.post("/id/{id}/update", response_model=Optional[GroupResponse]) +@router.post('/id/{id}/update', response_model=Optional[GroupResponse]) async def update_group_by_id( id: str, form_data: GroupUpdateForm, @@ -191,10 +182,10 @@ async def update_group_by_id( else: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, - detail=ERROR_MESSAGES.DEFAULT("Error updating group"), + detail=ERROR_MESSAGES.DEFAULT('Error updating group'), ) except Exception as e: - log.exception(f"Error updating group {id}: {e}") + log.exception(f'Error updating group {id}: {e}') raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT(e), @@ -206,7 +197,7 @@ async def update_group_by_id( ############################ -@router.post("/id/{id}/users/add", response_model=Optional[GroupResponse]) +@router.post('/id/{id}/users/add', response_model=Optional[GroupResponse]) async def add_user_to_group( id: str, form_data: UserIdsForm, @@ -226,17 +217,17 @@ async def add_user_to_group( else: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, - detail=ERROR_MESSAGES.DEFAULT("Error adding users to group"), + detail=ERROR_MESSAGES.DEFAULT('Error adding users to group'), ) except Exception as e: - log.exception(f"Error adding users to group {id}: {e}") + log.exception(f'Error adding users to group {id}: {e}') raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT(e), ) -@router.post("/id/{id}/users/remove", response_model=Optional[GroupResponse]) +@router.post('/id/{id}/users/remove', response_model=Optional[GroupResponse]) async def remove_users_from_group( id: str, form_data: UserIdsForm, @@ -253,10 +244,10 @@ async def remove_users_from_group( else: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, - detail=ERROR_MESSAGES.DEFAULT("Error removing users from group"), + detail=ERROR_MESSAGES.DEFAULT('Error removing users from group'), ) except Exception as e: - log.exception(f"Error removing users from group {id}: {e}") + log.exception(f'Error removing users from group {id}: {e}') raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT(e), @@ -268,10 +259,8 @@ async def remove_users_from_group( ############################ -@router.delete("/id/{id}/delete", response_model=bool) -async def delete_group_by_id( - id: str, user=Depends(get_admin_user), db: Session = Depends(get_session) -): +@router.delete('/id/{id}/delete', response_model=bool) +async def delete_group_by_id(id: str, user=Depends(get_admin_user), db: Session = Depends(get_session)): try: result = Groups.delete_group_by_id(id, db=db) if result: @@ -279,10 +268,10 @@ async def delete_group_by_id( else: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, - detail=ERROR_MESSAGES.DEFAULT("Error deleting group"), + detail=ERROR_MESSAGES.DEFAULT('Error deleting group'), ) except Exception as e: - log.exception(f"Error deleting group {id}: {e}") + log.exception(f'Error deleting group {id}: {e}') raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT(e), diff --git a/backend/open_webui/routers/images.py b/backend/open_webui/routers/images.py index 48209fc05c..060461f2b7 100644 --- a/backend/open_webui/routers/images.py +++ b/backend/open_webui/routers/images.py @@ -42,67 +42,65 @@ from pydantic import BaseModel log = logging.getLogger(__name__) -IMAGE_CACHE_DIR = CACHE_DIR / "image" / "generations" +IMAGE_CACHE_DIR = CACHE_DIR / 'image' / 'generations' IMAGE_CACHE_DIR.mkdir(parents=True, exist_ok=True) router = APIRouter() def set_image_model(request: Request, model: str): - log.info(f"Setting image model to {model}") + log.info(f'Setting image model to {model}') request.app.state.config.IMAGE_GENERATION_MODEL = model - if request.app.state.config.IMAGE_GENERATION_ENGINE in ["", "automatic1111"]: + if request.app.state.config.IMAGE_GENERATION_ENGINE in ['', 'automatic1111']: api_auth = get_automatic1111_api_auth(request) try: r = requests.get( - url=f"{request.app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options", - headers={"authorization": api_auth}, + url=f'{request.app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options', + headers={'authorization': api_auth}, ) options = r.json() - if model != options["sd_model_checkpoint"]: - options["sd_model_checkpoint"] = model + if model != options['sd_model_checkpoint']: + options['sd_model_checkpoint'] = model r = requests.post( - url=f"{request.app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options", + url=f'{request.app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options', json=options, - headers={"authorization": api_auth}, + headers={'authorization': api_auth}, ) except Exception as e: - log.debug(f"{e}") + log.debug(f'{e}') return request.app.state.config.IMAGE_GENERATION_MODEL def get_image_model(request): - if request.app.state.config.IMAGE_GENERATION_ENGINE == "openai": + if request.app.state.config.IMAGE_GENERATION_ENGINE == 'openai': return ( request.app.state.config.IMAGE_GENERATION_MODEL if request.app.state.config.IMAGE_GENERATION_MODEL - else "dall-e-2" + else 'dall-e-2' ) - elif request.app.state.config.IMAGE_GENERATION_ENGINE == "gemini": + elif request.app.state.config.IMAGE_GENERATION_ENGINE == 'gemini': return ( request.app.state.config.IMAGE_GENERATION_MODEL if request.app.state.config.IMAGE_GENERATION_MODEL - else "imagen-3.0-generate-002" + else 'imagen-3.0-generate-002' ) - elif request.app.state.config.IMAGE_GENERATION_ENGINE == "comfyui": + elif request.app.state.config.IMAGE_GENERATION_ENGINE == 'comfyui': return ( - request.app.state.config.IMAGE_GENERATION_MODEL - if request.app.state.config.IMAGE_GENERATION_MODEL - else "" + request.app.state.config.IMAGE_GENERATION_MODEL if request.app.state.config.IMAGE_GENERATION_MODEL else '' ) elif ( - request.app.state.config.IMAGE_GENERATION_ENGINE == "automatic1111" - or request.app.state.config.IMAGE_GENERATION_ENGINE == "" + request.app.state.config.IMAGE_GENERATION_ENGINE == 'automatic1111' + or request.app.state.config.IMAGE_GENERATION_ENGINE == '' ): try: r = requests.get( - url=f"{request.app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options", - headers={"authorization": get_automatic1111_api_auth(request)}, + url=f'{request.app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options', + headers={'authorization': get_automatic1111_api_auth(request)}, ) options = r.json() - return options["sd_model_checkpoint"] + return options['sd_model_checkpoint'] except Exception as e: request.app.state.config.ENABLE_IMAGE_GENERATION = False raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e)) @@ -151,79 +149,71 @@ class ImagesConfig(BaseModel): IMAGES_EDIT_COMFYUI_WORKFLOW_NODES: list[dict] -@router.get("/config", response_model=ImagesConfig) +@router.get('/config', response_model=ImagesConfig) async def get_config(request: Request, user=Depends(get_admin_user)): return { - "ENABLE_IMAGE_GENERATION": request.app.state.config.ENABLE_IMAGE_GENERATION, - "ENABLE_IMAGE_PROMPT_GENERATION": request.app.state.config.ENABLE_IMAGE_PROMPT_GENERATION, - "IMAGE_GENERATION_ENGINE": request.app.state.config.IMAGE_GENERATION_ENGINE, - "IMAGE_GENERATION_MODEL": request.app.state.config.IMAGE_GENERATION_MODEL, - "IMAGE_SIZE": request.app.state.config.IMAGE_SIZE, - "IMAGE_STEPS": request.app.state.config.IMAGE_STEPS, - "IMAGES_OPENAI_API_BASE_URL": request.app.state.config.IMAGES_OPENAI_API_BASE_URL, - "IMAGES_OPENAI_API_KEY": request.app.state.config.IMAGES_OPENAI_API_KEY, - "IMAGES_OPENAI_API_VERSION": request.app.state.config.IMAGES_OPENAI_API_VERSION, - "IMAGES_OPENAI_API_PARAMS": request.app.state.config.IMAGES_OPENAI_API_PARAMS, - "AUTOMATIC1111_BASE_URL": request.app.state.config.AUTOMATIC1111_BASE_URL, - "AUTOMATIC1111_API_AUTH": request.app.state.config.AUTOMATIC1111_API_AUTH, - "AUTOMATIC1111_PARAMS": request.app.state.config.AUTOMATIC1111_PARAMS, - "COMFYUI_BASE_URL": request.app.state.config.COMFYUI_BASE_URL, - "COMFYUI_API_KEY": request.app.state.config.COMFYUI_API_KEY, - "COMFYUI_WORKFLOW": request.app.state.config.COMFYUI_WORKFLOW, - "COMFYUI_WORKFLOW_NODES": request.app.state.config.COMFYUI_WORKFLOW_NODES, - "IMAGES_GEMINI_API_BASE_URL": request.app.state.config.IMAGES_GEMINI_API_BASE_URL, - "IMAGES_GEMINI_API_KEY": request.app.state.config.IMAGES_GEMINI_API_KEY, - "IMAGES_GEMINI_ENDPOINT_METHOD": request.app.state.config.IMAGES_GEMINI_ENDPOINT_METHOD, - "ENABLE_IMAGE_EDIT": request.app.state.config.ENABLE_IMAGE_EDIT, - "IMAGE_EDIT_ENGINE": request.app.state.config.IMAGE_EDIT_ENGINE, - "IMAGE_EDIT_MODEL": request.app.state.config.IMAGE_EDIT_MODEL, - "IMAGE_EDIT_SIZE": request.app.state.config.IMAGE_EDIT_SIZE, - "IMAGES_EDIT_OPENAI_API_BASE_URL": request.app.state.config.IMAGES_EDIT_OPENAI_API_BASE_URL, - "IMAGES_EDIT_OPENAI_API_KEY": request.app.state.config.IMAGES_EDIT_OPENAI_API_KEY, - "IMAGES_EDIT_OPENAI_API_VERSION": request.app.state.config.IMAGES_EDIT_OPENAI_API_VERSION, - "IMAGES_EDIT_GEMINI_API_BASE_URL": request.app.state.config.IMAGES_EDIT_GEMINI_API_BASE_URL, - "IMAGES_EDIT_GEMINI_API_KEY": request.app.state.config.IMAGES_EDIT_GEMINI_API_KEY, - "IMAGES_EDIT_COMFYUI_BASE_URL": request.app.state.config.IMAGES_EDIT_COMFYUI_BASE_URL, - "IMAGES_EDIT_COMFYUI_API_KEY": request.app.state.config.IMAGES_EDIT_COMFYUI_API_KEY, - "IMAGES_EDIT_COMFYUI_WORKFLOW": request.app.state.config.IMAGES_EDIT_COMFYUI_WORKFLOW, - "IMAGES_EDIT_COMFYUI_WORKFLOW_NODES": request.app.state.config.IMAGES_EDIT_COMFYUI_WORKFLOW_NODES, + 'ENABLE_IMAGE_GENERATION': request.app.state.config.ENABLE_IMAGE_GENERATION, + 'ENABLE_IMAGE_PROMPT_GENERATION': request.app.state.config.ENABLE_IMAGE_PROMPT_GENERATION, + 'IMAGE_GENERATION_ENGINE': request.app.state.config.IMAGE_GENERATION_ENGINE, + 'IMAGE_GENERATION_MODEL': request.app.state.config.IMAGE_GENERATION_MODEL, + 'IMAGE_SIZE': request.app.state.config.IMAGE_SIZE, + 'IMAGE_STEPS': request.app.state.config.IMAGE_STEPS, + 'IMAGES_OPENAI_API_BASE_URL': request.app.state.config.IMAGES_OPENAI_API_BASE_URL, + 'IMAGES_OPENAI_API_KEY': request.app.state.config.IMAGES_OPENAI_API_KEY, + 'IMAGES_OPENAI_API_VERSION': request.app.state.config.IMAGES_OPENAI_API_VERSION, + 'IMAGES_OPENAI_API_PARAMS': request.app.state.config.IMAGES_OPENAI_API_PARAMS, + 'AUTOMATIC1111_BASE_URL': request.app.state.config.AUTOMATIC1111_BASE_URL, + 'AUTOMATIC1111_API_AUTH': request.app.state.config.AUTOMATIC1111_API_AUTH, + 'AUTOMATIC1111_PARAMS': request.app.state.config.AUTOMATIC1111_PARAMS, + 'COMFYUI_BASE_URL': request.app.state.config.COMFYUI_BASE_URL, + 'COMFYUI_API_KEY': request.app.state.config.COMFYUI_API_KEY, + 'COMFYUI_WORKFLOW': request.app.state.config.COMFYUI_WORKFLOW, + 'COMFYUI_WORKFLOW_NODES': request.app.state.config.COMFYUI_WORKFLOW_NODES, + 'IMAGES_GEMINI_API_BASE_URL': request.app.state.config.IMAGES_GEMINI_API_BASE_URL, + 'IMAGES_GEMINI_API_KEY': request.app.state.config.IMAGES_GEMINI_API_KEY, + 'IMAGES_GEMINI_ENDPOINT_METHOD': request.app.state.config.IMAGES_GEMINI_ENDPOINT_METHOD, + 'ENABLE_IMAGE_EDIT': request.app.state.config.ENABLE_IMAGE_EDIT, + 'IMAGE_EDIT_ENGINE': request.app.state.config.IMAGE_EDIT_ENGINE, + 'IMAGE_EDIT_MODEL': request.app.state.config.IMAGE_EDIT_MODEL, + 'IMAGE_EDIT_SIZE': request.app.state.config.IMAGE_EDIT_SIZE, + 'IMAGES_EDIT_OPENAI_API_BASE_URL': request.app.state.config.IMAGES_EDIT_OPENAI_API_BASE_URL, + 'IMAGES_EDIT_OPENAI_API_KEY': request.app.state.config.IMAGES_EDIT_OPENAI_API_KEY, + 'IMAGES_EDIT_OPENAI_API_VERSION': request.app.state.config.IMAGES_EDIT_OPENAI_API_VERSION, + 'IMAGES_EDIT_GEMINI_API_BASE_URL': request.app.state.config.IMAGES_EDIT_GEMINI_API_BASE_URL, + 'IMAGES_EDIT_GEMINI_API_KEY': request.app.state.config.IMAGES_EDIT_GEMINI_API_KEY, + 'IMAGES_EDIT_COMFYUI_BASE_URL': request.app.state.config.IMAGES_EDIT_COMFYUI_BASE_URL, + 'IMAGES_EDIT_COMFYUI_API_KEY': request.app.state.config.IMAGES_EDIT_COMFYUI_API_KEY, + 'IMAGES_EDIT_COMFYUI_WORKFLOW': request.app.state.config.IMAGES_EDIT_COMFYUI_WORKFLOW, + 'IMAGES_EDIT_COMFYUI_WORKFLOW_NODES': request.app.state.config.IMAGES_EDIT_COMFYUI_WORKFLOW_NODES, } -@router.post("/config/update") -async def update_config( - request: Request, form_data: ImagesConfig, user=Depends(get_admin_user) -): +@router.post('/config/update') +async def update_config(request: Request, form_data: ImagesConfig, user=Depends(get_admin_user)): request.app.state.config.ENABLE_IMAGE_GENERATION = form_data.ENABLE_IMAGE_GENERATION # Create Image - request.app.state.config.ENABLE_IMAGE_PROMPT_GENERATION = ( - form_data.ENABLE_IMAGE_PROMPT_GENERATION - ) + request.app.state.config.ENABLE_IMAGE_PROMPT_GENERATION = form_data.ENABLE_IMAGE_PROMPT_GENERATION request.app.state.config.IMAGE_GENERATION_ENGINE = form_data.IMAGE_GENERATION_ENGINE set_image_model(request, form_data.IMAGE_GENERATION_MODEL) - if form_data.IMAGE_SIZE == "auto" and not re.match( + if form_data.IMAGE_SIZE == 'auto' and not re.match( IMAGE_AUTO_SIZE_MODELS_REGEX_PATTERN, form_data.IMAGE_GENERATION_MODEL ): raise HTTPException( status_code=400, detail=ERROR_MESSAGES.INCORRECT_FORMAT( - f" (auto is only allowed with models matching {IMAGE_AUTO_SIZE_MODELS_REGEX_PATTERN})." + f' (auto is only allowed with models matching {IMAGE_AUTO_SIZE_MODELS_REGEX_PATTERN}).' ), ) - pattern = r"^\d+x\d+$" - if ( - form_data.IMAGE_SIZE == "auto" - or form_data.IMAGE_SIZE == "" - or re.match(pattern, form_data.IMAGE_SIZE) - ): + pattern = r'^\d+x\d+$' + if form_data.IMAGE_SIZE == 'auto' or form_data.IMAGE_SIZE == '' or re.match(pattern, form_data.IMAGE_SIZE): request.app.state.config.IMAGE_SIZE = form_data.IMAGE_SIZE else: raise HTTPException( status_code=400, - detail=ERROR_MESSAGES.INCORRECT_FORMAT(" (e.g., 512x512)."), + detail=ERROR_MESSAGES.INCORRECT_FORMAT(' (e.g., 512x512).'), ) if form_data.IMAGE_STEPS >= 0: @@ -231,36 +221,26 @@ async def update_config( else: raise HTTPException( status_code=400, - detail=ERROR_MESSAGES.INCORRECT_FORMAT(" (e.g., 50)."), + detail=ERROR_MESSAGES.INCORRECT_FORMAT(' (e.g., 50).'), ) - request.app.state.config.IMAGES_OPENAI_API_BASE_URL = ( - form_data.IMAGES_OPENAI_API_BASE_URL - ) + request.app.state.config.IMAGES_OPENAI_API_BASE_URL = form_data.IMAGES_OPENAI_API_BASE_URL request.app.state.config.IMAGES_OPENAI_API_KEY = form_data.IMAGES_OPENAI_API_KEY - request.app.state.config.IMAGES_OPENAI_API_VERSION = ( - form_data.IMAGES_OPENAI_API_VERSION - ) - request.app.state.config.IMAGES_OPENAI_API_PARAMS = ( - form_data.IMAGES_OPENAI_API_PARAMS - ) + request.app.state.config.IMAGES_OPENAI_API_VERSION = form_data.IMAGES_OPENAI_API_VERSION + request.app.state.config.IMAGES_OPENAI_API_PARAMS = form_data.IMAGES_OPENAI_API_PARAMS request.app.state.config.AUTOMATIC1111_BASE_URL = form_data.AUTOMATIC1111_BASE_URL request.app.state.config.AUTOMATIC1111_API_AUTH = form_data.AUTOMATIC1111_API_AUTH request.app.state.config.AUTOMATIC1111_PARAMS = form_data.AUTOMATIC1111_PARAMS - request.app.state.config.COMFYUI_BASE_URL = form_data.COMFYUI_BASE_URL.strip("/") + request.app.state.config.COMFYUI_BASE_URL = form_data.COMFYUI_BASE_URL.strip('/') request.app.state.config.COMFYUI_API_KEY = form_data.COMFYUI_API_KEY request.app.state.config.COMFYUI_WORKFLOW = form_data.COMFYUI_WORKFLOW request.app.state.config.COMFYUI_WORKFLOW_NODES = form_data.COMFYUI_WORKFLOW_NODES - request.app.state.config.IMAGES_GEMINI_API_BASE_URL = ( - form_data.IMAGES_GEMINI_API_BASE_URL - ) + request.app.state.config.IMAGES_GEMINI_API_BASE_URL = form_data.IMAGES_GEMINI_API_BASE_URL request.app.state.config.IMAGES_GEMINI_API_KEY = form_data.IMAGES_GEMINI_API_KEY - request.app.state.config.IMAGES_GEMINI_ENDPOINT_METHOD = ( - form_data.IMAGES_GEMINI_ENDPOINT_METHOD - ) + request.app.state.config.IMAGES_GEMINI_ENDPOINT_METHOD = form_data.IMAGES_GEMINI_ENDPOINT_METHOD # Edit Image request.app.state.config.ENABLE_IMAGE_EDIT = form_data.ENABLE_IMAGE_EDIT @@ -268,107 +248,85 @@ async def update_config( request.app.state.config.IMAGE_EDIT_MODEL = form_data.IMAGE_EDIT_MODEL request.app.state.config.IMAGE_EDIT_SIZE = form_data.IMAGE_EDIT_SIZE - request.app.state.config.IMAGES_EDIT_OPENAI_API_BASE_URL = ( - form_data.IMAGES_EDIT_OPENAI_API_BASE_URL - ) - request.app.state.config.IMAGES_EDIT_OPENAI_API_KEY = ( - form_data.IMAGES_EDIT_OPENAI_API_KEY - ) - request.app.state.config.IMAGES_EDIT_OPENAI_API_VERSION = ( - form_data.IMAGES_EDIT_OPENAI_API_VERSION - ) + request.app.state.config.IMAGES_EDIT_OPENAI_API_BASE_URL = form_data.IMAGES_EDIT_OPENAI_API_BASE_URL + request.app.state.config.IMAGES_EDIT_OPENAI_API_KEY = form_data.IMAGES_EDIT_OPENAI_API_KEY + request.app.state.config.IMAGES_EDIT_OPENAI_API_VERSION = form_data.IMAGES_EDIT_OPENAI_API_VERSION - request.app.state.config.IMAGES_EDIT_GEMINI_API_BASE_URL = ( - form_data.IMAGES_EDIT_GEMINI_API_BASE_URL - ) - request.app.state.config.IMAGES_EDIT_GEMINI_API_KEY = ( - form_data.IMAGES_EDIT_GEMINI_API_KEY - ) + request.app.state.config.IMAGES_EDIT_GEMINI_API_BASE_URL = form_data.IMAGES_EDIT_GEMINI_API_BASE_URL + request.app.state.config.IMAGES_EDIT_GEMINI_API_KEY = form_data.IMAGES_EDIT_GEMINI_API_KEY - request.app.state.config.IMAGES_EDIT_COMFYUI_BASE_URL = ( - form_data.IMAGES_EDIT_COMFYUI_BASE_URL.strip("/") - ) - request.app.state.config.IMAGES_EDIT_COMFYUI_API_KEY = ( - form_data.IMAGES_EDIT_COMFYUI_API_KEY - ) - request.app.state.config.IMAGES_EDIT_COMFYUI_WORKFLOW = ( - form_data.IMAGES_EDIT_COMFYUI_WORKFLOW - ) - request.app.state.config.IMAGES_EDIT_COMFYUI_WORKFLOW_NODES = ( - form_data.IMAGES_EDIT_COMFYUI_WORKFLOW_NODES - ) + request.app.state.config.IMAGES_EDIT_COMFYUI_BASE_URL = form_data.IMAGES_EDIT_COMFYUI_BASE_URL.strip('/') + request.app.state.config.IMAGES_EDIT_COMFYUI_API_KEY = form_data.IMAGES_EDIT_COMFYUI_API_KEY + request.app.state.config.IMAGES_EDIT_COMFYUI_WORKFLOW = form_data.IMAGES_EDIT_COMFYUI_WORKFLOW + request.app.state.config.IMAGES_EDIT_COMFYUI_WORKFLOW_NODES = form_data.IMAGES_EDIT_COMFYUI_WORKFLOW_NODES return { - "ENABLE_IMAGE_GENERATION": request.app.state.config.ENABLE_IMAGE_GENERATION, - "ENABLE_IMAGE_PROMPT_GENERATION": request.app.state.config.ENABLE_IMAGE_PROMPT_GENERATION, - "IMAGE_GENERATION_ENGINE": request.app.state.config.IMAGE_GENERATION_ENGINE, - "IMAGE_GENERATION_MODEL": request.app.state.config.IMAGE_GENERATION_MODEL, - "IMAGE_SIZE": request.app.state.config.IMAGE_SIZE, - "IMAGE_STEPS": request.app.state.config.IMAGE_STEPS, - "IMAGES_OPENAI_API_BASE_URL": request.app.state.config.IMAGES_OPENAI_API_BASE_URL, - "IMAGES_OPENAI_API_KEY": request.app.state.config.IMAGES_OPENAI_API_KEY, - "IMAGES_OPENAI_API_VERSION": request.app.state.config.IMAGES_OPENAI_API_VERSION, - "IMAGES_OPENAI_API_PARAMS": request.app.state.config.IMAGES_OPENAI_API_PARAMS, - "AUTOMATIC1111_BASE_URL": request.app.state.config.AUTOMATIC1111_BASE_URL, - "AUTOMATIC1111_API_AUTH": request.app.state.config.AUTOMATIC1111_API_AUTH, - "AUTOMATIC1111_PARAMS": request.app.state.config.AUTOMATIC1111_PARAMS, - "COMFYUI_BASE_URL": request.app.state.config.COMFYUI_BASE_URL, - "COMFYUI_API_KEY": request.app.state.config.COMFYUI_API_KEY, - "COMFYUI_WORKFLOW": request.app.state.config.COMFYUI_WORKFLOW, - "COMFYUI_WORKFLOW_NODES": request.app.state.config.COMFYUI_WORKFLOW_NODES, - "IMAGES_GEMINI_API_BASE_URL": request.app.state.config.IMAGES_GEMINI_API_BASE_URL, - "IMAGES_GEMINI_API_KEY": request.app.state.config.IMAGES_GEMINI_API_KEY, - "IMAGES_GEMINI_ENDPOINT_METHOD": request.app.state.config.IMAGES_GEMINI_ENDPOINT_METHOD, - "ENABLE_IMAGE_EDIT": request.app.state.config.ENABLE_IMAGE_EDIT, - "IMAGE_EDIT_ENGINE": request.app.state.config.IMAGE_EDIT_ENGINE, - "IMAGE_EDIT_MODEL": request.app.state.config.IMAGE_EDIT_MODEL, - "IMAGE_EDIT_SIZE": request.app.state.config.IMAGE_EDIT_SIZE, - "IMAGES_EDIT_OPENAI_API_BASE_URL": request.app.state.config.IMAGES_EDIT_OPENAI_API_BASE_URL, - "IMAGES_EDIT_OPENAI_API_KEY": request.app.state.config.IMAGES_EDIT_OPENAI_API_KEY, - "IMAGES_EDIT_OPENAI_API_VERSION": request.app.state.config.IMAGES_EDIT_OPENAI_API_VERSION, - "IMAGES_EDIT_GEMINI_API_BASE_URL": request.app.state.config.IMAGES_EDIT_GEMINI_API_BASE_URL, - "IMAGES_EDIT_GEMINI_API_KEY": request.app.state.config.IMAGES_EDIT_GEMINI_API_KEY, - "IMAGES_EDIT_COMFYUI_BASE_URL": request.app.state.config.IMAGES_EDIT_COMFYUI_BASE_URL, - "IMAGES_EDIT_COMFYUI_API_KEY": request.app.state.config.IMAGES_EDIT_COMFYUI_API_KEY, - "IMAGES_EDIT_COMFYUI_WORKFLOW": request.app.state.config.IMAGES_EDIT_COMFYUI_WORKFLOW, - "IMAGES_EDIT_COMFYUI_WORKFLOW_NODES": request.app.state.config.IMAGES_EDIT_COMFYUI_WORKFLOW_NODES, + 'ENABLE_IMAGE_GENERATION': request.app.state.config.ENABLE_IMAGE_GENERATION, + 'ENABLE_IMAGE_PROMPT_GENERATION': request.app.state.config.ENABLE_IMAGE_PROMPT_GENERATION, + 'IMAGE_GENERATION_ENGINE': request.app.state.config.IMAGE_GENERATION_ENGINE, + 'IMAGE_GENERATION_MODEL': request.app.state.config.IMAGE_GENERATION_MODEL, + 'IMAGE_SIZE': request.app.state.config.IMAGE_SIZE, + 'IMAGE_STEPS': request.app.state.config.IMAGE_STEPS, + 'IMAGES_OPENAI_API_BASE_URL': request.app.state.config.IMAGES_OPENAI_API_BASE_URL, + 'IMAGES_OPENAI_API_KEY': request.app.state.config.IMAGES_OPENAI_API_KEY, + 'IMAGES_OPENAI_API_VERSION': request.app.state.config.IMAGES_OPENAI_API_VERSION, + 'IMAGES_OPENAI_API_PARAMS': request.app.state.config.IMAGES_OPENAI_API_PARAMS, + 'AUTOMATIC1111_BASE_URL': request.app.state.config.AUTOMATIC1111_BASE_URL, + 'AUTOMATIC1111_API_AUTH': request.app.state.config.AUTOMATIC1111_API_AUTH, + 'AUTOMATIC1111_PARAMS': request.app.state.config.AUTOMATIC1111_PARAMS, + 'COMFYUI_BASE_URL': request.app.state.config.COMFYUI_BASE_URL, + 'COMFYUI_API_KEY': request.app.state.config.COMFYUI_API_KEY, + 'COMFYUI_WORKFLOW': request.app.state.config.COMFYUI_WORKFLOW, + 'COMFYUI_WORKFLOW_NODES': request.app.state.config.COMFYUI_WORKFLOW_NODES, + 'IMAGES_GEMINI_API_BASE_URL': request.app.state.config.IMAGES_GEMINI_API_BASE_URL, + 'IMAGES_GEMINI_API_KEY': request.app.state.config.IMAGES_GEMINI_API_KEY, + 'IMAGES_GEMINI_ENDPOINT_METHOD': request.app.state.config.IMAGES_GEMINI_ENDPOINT_METHOD, + 'ENABLE_IMAGE_EDIT': request.app.state.config.ENABLE_IMAGE_EDIT, + 'IMAGE_EDIT_ENGINE': request.app.state.config.IMAGE_EDIT_ENGINE, + 'IMAGE_EDIT_MODEL': request.app.state.config.IMAGE_EDIT_MODEL, + 'IMAGE_EDIT_SIZE': request.app.state.config.IMAGE_EDIT_SIZE, + 'IMAGES_EDIT_OPENAI_API_BASE_URL': request.app.state.config.IMAGES_EDIT_OPENAI_API_BASE_URL, + 'IMAGES_EDIT_OPENAI_API_KEY': request.app.state.config.IMAGES_EDIT_OPENAI_API_KEY, + 'IMAGES_EDIT_OPENAI_API_VERSION': request.app.state.config.IMAGES_EDIT_OPENAI_API_VERSION, + 'IMAGES_EDIT_GEMINI_API_BASE_URL': request.app.state.config.IMAGES_EDIT_GEMINI_API_BASE_URL, + 'IMAGES_EDIT_GEMINI_API_KEY': request.app.state.config.IMAGES_EDIT_GEMINI_API_KEY, + 'IMAGES_EDIT_COMFYUI_BASE_URL': request.app.state.config.IMAGES_EDIT_COMFYUI_BASE_URL, + 'IMAGES_EDIT_COMFYUI_API_KEY': request.app.state.config.IMAGES_EDIT_COMFYUI_API_KEY, + 'IMAGES_EDIT_COMFYUI_WORKFLOW': request.app.state.config.IMAGES_EDIT_COMFYUI_WORKFLOW, + 'IMAGES_EDIT_COMFYUI_WORKFLOW_NODES': request.app.state.config.IMAGES_EDIT_COMFYUI_WORKFLOW_NODES, } def get_automatic1111_api_auth(request: Request): if request.app.state.config.AUTOMATIC1111_API_AUTH is None: - return "" + return '' else: - auth1111_byte_string = request.app.state.config.AUTOMATIC1111_API_AUTH.encode( - "utf-8" - ) + auth1111_byte_string = request.app.state.config.AUTOMATIC1111_API_AUTH.encode('utf-8') auth1111_base64_encoded_bytes = base64.b64encode(auth1111_byte_string) - auth1111_base64_encoded_string = auth1111_base64_encoded_bytes.decode("utf-8") - return f"Basic {auth1111_base64_encoded_string}" + auth1111_base64_encoded_string = auth1111_base64_encoded_bytes.decode('utf-8') + return f'Basic {auth1111_base64_encoded_string}' -@router.get("/config/url/verify") +@router.get('/config/url/verify') async def verify_url(request: Request, user=Depends(get_admin_user)): - if request.app.state.config.IMAGE_GENERATION_ENGINE == "automatic1111": + if request.app.state.config.IMAGE_GENERATION_ENGINE == 'automatic1111': try: r = requests.get( - url=f"{request.app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options", - headers={"authorization": get_automatic1111_api_auth(request)}, + url=f'{request.app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options', + headers={'authorization': get_automatic1111_api_auth(request)}, ) r.raise_for_status() return True except Exception: request.app.state.config.ENABLE_IMAGE_GENERATION = False raise HTTPException(status_code=400, detail=ERROR_MESSAGES.INVALID_URL) - elif request.app.state.config.IMAGE_GENERATION_ENGINE == "comfyui": + elif request.app.state.config.IMAGE_GENERATION_ENGINE == 'comfyui': headers = None if request.app.state.config.COMFYUI_API_KEY: - headers = { - "Authorization": f"Bearer {request.app.state.config.COMFYUI_API_KEY}" - } + headers = {'Authorization': f'Bearer {request.app.state.config.COMFYUI_API_KEY}'} try: r = requests.get( - url=f"{request.app.state.config.COMFYUI_BASE_URL}/object_info", + url=f'{request.app.state.config.COMFYUI_BASE_URL}/object_info', headers=headers, ) r.raise_for_status() @@ -380,27 +338,25 @@ async def verify_url(request: Request, user=Depends(get_admin_user)): return True -@router.get("/models") +@router.get('/models') def get_models(request: Request, user=Depends(get_verified_user)): try: - if request.app.state.config.IMAGE_GENERATION_ENGINE == "openai": + if request.app.state.config.IMAGE_GENERATION_ENGINE == 'openai': return [ - {"id": "dall-e-2", "name": "DALL·E 2"}, - {"id": "dall-e-3", "name": "DALL·E 3"}, - {"id": "gpt-image-1", "name": "GPT-IMAGE 1"}, - {"id": "gpt-image-1.5", "name": "GPT-IMAGE 1.5"}, + {'id': 'dall-e-2', 'name': 'DALL·E 2'}, + {'id': 'dall-e-3', 'name': 'DALL·E 3'}, + {'id': 'gpt-image-1', 'name': 'GPT-IMAGE 1'}, + {'id': 'gpt-image-1.5', 'name': 'GPT-IMAGE 1.5'}, ] - elif request.app.state.config.IMAGE_GENERATION_ENGINE == "gemini": + elif request.app.state.config.IMAGE_GENERATION_ENGINE == 'gemini': return [ - {"id": "imagen-3.0-generate-002", "name": "imagen-3.0 generate-002"}, + {'id': 'imagen-3.0-generate-002', 'name': 'imagen-3.0 generate-002'}, ] - elif request.app.state.config.IMAGE_GENERATION_ENGINE == "comfyui": + elif request.app.state.config.IMAGE_GENERATION_ENGINE == 'comfyui': # TODO - get models from comfyui - headers = { - "Authorization": f"Bearer {request.app.state.config.COMFYUI_API_KEY}" - } + headers = {'Authorization': f'Bearer {request.app.state.config.COMFYUI_API_KEY}'} r = requests.get( - url=f"{request.app.state.config.COMFYUI_BASE_URL}/object_info", + url=f'{request.app.state.config.COMFYUI_BASE_URL}/object_info', headers=headers, ) info = r.json() @@ -409,52 +365,46 @@ def get_models(request: Request, user=Depends(get_verified_user)): model_node_id = None for node in request.app.state.config.COMFYUI_WORKFLOW_NODES: - if node["type"] == "model": - if node["node_ids"]: - model_node_id = node["node_ids"][0] + if node['type'] == 'model': + if node['node_ids']: + model_node_id = node['node_ids'][0] break if model_node_id: model_list_key = None - log.info(workflow[model_node_id]["class_type"]) - for key in info[workflow[model_node_id]["class_type"]]["input"][ - "required" - ]: - if "_name" in key: + log.info(workflow[model_node_id]['class_type']) + for key in info[workflow[model_node_id]['class_type']]['input']['required']: + if '_name' in key: model_list_key = key break if model_list_key: return list( map( - lambda model: {"id": model, "name": model}, - info[workflow[model_node_id]["class_type"]]["input"][ - "required" - ][model_list_key][0], + lambda model: {'id': model, 'name': model}, + info[workflow[model_node_id]['class_type']]['input']['required'][model_list_key][0], ) ) else: return list( map( - lambda model: {"id": model, "name": model}, - info["CheckpointLoaderSimple"]["input"]["required"][ - "ckpt_name" - ][0], + lambda model: {'id': model, 'name': model}, + info['CheckpointLoaderSimple']['input']['required']['ckpt_name'][0], ) ) elif ( - request.app.state.config.IMAGE_GENERATION_ENGINE == "automatic1111" - or request.app.state.config.IMAGE_GENERATION_ENGINE == "" + request.app.state.config.IMAGE_GENERATION_ENGINE == 'automatic1111' + or request.app.state.config.IMAGE_GENERATION_ENGINE == '' ): r = requests.get( - url=f"{request.app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/sd-models", - headers={"authorization": get_automatic1111_api_auth(request)}, + url=f'{request.app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/sd-models', + headers={'authorization': get_automatic1111_api_auth(request)}, ) models = r.json() return list( map( - lambda model: {"id": model["title"], "name": model["model_name"]}, + lambda model: {'id': model['title'], 'name': model['model_name']}, models, ) ) @@ -477,30 +427,30 @@ GenerateImageForm = CreateImageForm # Alias for backward compatibility def get_image_data(data: str, headers=None): try: - if data.startswith("http://") or data.startswith("https://"): + if data.startswith('http://') or data.startswith('https://'): if headers: r = requests.get(data, headers=headers) else: r = requests.get(data) r.raise_for_status() - if r.headers["content-type"].split("/")[0] == "image": - mime_type = r.headers["content-type"] + if r.headers['content-type'].split('/')[0] == 'image': + mime_type = r.headers['content-type'] return r.content, mime_type else: - log.error("Url does not point to an image.") + log.error('Url does not point to an image.') return None else: - if "," in data: - header, encoded = data.split(",", 1) - mime_type = header.split(";")[0].lstrip("data:") + if ',' in data: + header, encoded = data.split(',', 1) + mime_type = header.split(';')[0].lstrip('data:') img_data = base64.b64decode(encoded) else: - mime_type = "image/png" + mime_type = 'image/png' img_data = base64.b64decode(data) return img_data, mime_type except Exception as e: - log.exception(f"Error loading image data: {e}") + log.exception(f'Error loading image data: {e}') return None, None @@ -508,9 +458,9 @@ def upload_image(request, image_data, content_type, metadata, user, db=None): image_format = mimetypes.guess_extension(content_type) file = UploadFile( file=io.BytesIO(image_data), - filename=f"generated-image{image_format}", # will be converted to a unique ID on upload_file + filename=f'generated-image{image_format}', # will be converted to a unique ID on upload_file headers={ - "content-type": content_type, + 'content-type': content_type, }, ) file_item = upload_file_handler( @@ -523,8 +473,8 @@ def upload_image(request, image_data, content_type, metadata, user, db=None): if file_item and file_item.id: # If chat_id and message_id are provided in metadata, link the file to the chat message - chat_id = metadata.get("chat_id") - message_id = metadata.get("message_id") + chat_id = metadata.get('chat_id') + message_id = metadata.get('message_id') if chat_id and message_id: Chats.insert_chat_files( @@ -535,22 +485,20 @@ def upload_image(request, image_data, content_type, metadata, user, db=None): db=db, ) - url = request.app.url_path_for("get_file_content_by_id", id=file_item.id) + url = request.app.url_path_for('get_file_content_by_id', id=file_item.id) return file_item, url -@router.post("/generations") -async def generate_images( - request: Request, form_data: CreateImageForm, user=Depends(get_verified_user) -): +@router.post('/generations') +async def generate_images(request: Request, form_data: CreateImageForm, user=Depends(get_verified_user)): if not request.app.state.config.ENABLE_IMAGE_GENERATION: raise HTTPException( status_code=403, detail=ERROR_MESSAGES.ACCESS_PROHIBITED, ) - if user.role != "admin" and not has_permission( - user.id, "features.image_generation", request.app.state.config.USER_PERMISSIONS + if user.role != 'admin' and not has_permission( + user.id, 'features.image_generation', request.app.state.config.USER_PERMISSIONS ): raise HTTPException( status_code=403, @@ -570,17 +518,14 @@ async def image_generations( # This is only relevant when the user has set IMAGE_SIZE to 'auto' with an # image model other than gpt-image-1, which is warned about on settings save - size = "512x512" - if ( - request.app.state.config.IMAGE_SIZE - and "x" in request.app.state.config.IMAGE_SIZE - ): + size = '512x512' + if request.app.state.config.IMAGE_SIZE and 'x' in request.app.state.config.IMAGE_SIZE: size = request.app.state.config.IMAGE_SIZE - if form_data.size and "x" in form_data.size: + if form_data.size and 'x' in form_data.size: size = form_data.size - width, height = tuple(map(int, size.split("x"))) + width, height = tuple(map(int, size.split('x'))) metadata = metadata or {} @@ -588,36 +533,31 @@ async def image_generations( r = None try: - if request.app.state.config.IMAGE_GENERATION_ENGINE == "openai": - + if request.app.state.config.IMAGE_GENERATION_ENGINE == 'openai': headers = { - "Authorization": f"Bearer {request.app.state.config.IMAGES_OPENAI_API_KEY}", - "Content-Type": "application/json", + 'Authorization': f'Bearer {request.app.state.config.IMAGES_OPENAI_API_KEY}', + 'Content-Type': 'application/json', } if ENABLE_FORWARD_USER_INFO_HEADERS: headers = include_user_info_headers(headers, user) - url = f"{request.app.state.config.IMAGES_OPENAI_API_BASE_URL}/images/generations" + url = f'{request.app.state.config.IMAGES_OPENAI_API_BASE_URL}/images/generations' if request.app.state.config.IMAGES_OPENAI_API_VERSION: - url = f"{url}?api-version={request.app.state.config.IMAGES_OPENAI_API_VERSION}" + url = f'{url}?api-version={request.app.state.config.IMAGES_OPENAI_API_VERSION}' data = { - "model": model, - "prompt": form_data.prompt, - "n": form_data.n, - "size": ( - form_data.size - if form_data.size - else request.app.state.config.IMAGE_SIZE - ), + 'model': model, + 'prompt': form_data.prompt, + 'n': form_data.n, + 'size': (form_data.size if form_data.size else request.app.state.config.IMAGE_SIZE), **( {} if re.match( IMAGE_URL_RESPONSE_MODELS_REGEX_PATTERN, request.app.state.config.IMAGE_GENERATION_MODEL, ) - else {"response_format": "b64_json"} + else {'response_format': 'b64_json'} ), **( {} @@ -639,53 +579,48 @@ async def image_generations( images = [] - for image in res["data"]: - if image_url := image.get("url", None): + for image in res['data']: + if image_url := image.get('url', None): image_data, content_type = get_image_data( image_url, - {k: v for k, v in headers.items() if k != "Content-Type"}, + {k: v for k, v in headers.items() if k != 'Content-Type'}, ) else: - image_data, content_type = get_image_data(image["b64_json"]) + image_data, content_type = get_image_data(image['b64_json']) - _, url = upload_image( - request, image_data, content_type, {**data, **metadata}, user - ) - images.append({"url": url}) + _, url = upload_image(request, image_data, content_type, {**data, **metadata}, user) + images.append({'url': url}) return images - elif request.app.state.config.IMAGE_GENERATION_ENGINE == "gemini": + elif request.app.state.config.IMAGE_GENERATION_ENGINE == 'gemini': headers = { - "Content-Type": "application/json", - "x-goog-api-key": request.app.state.config.IMAGES_GEMINI_API_KEY, + 'Content-Type': 'application/json', + 'x-goog-api-key': request.app.state.config.IMAGES_GEMINI_API_KEY, } data = {} if ( - request.app.state.config.IMAGES_GEMINI_ENDPOINT_METHOD == "" - or request.app.state.config.IMAGES_GEMINI_ENDPOINT_METHOD == "predict" + request.app.state.config.IMAGES_GEMINI_ENDPOINT_METHOD == '' + or request.app.state.config.IMAGES_GEMINI_ENDPOINT_METHOD == 'predict' ): - model = f"{model}:predict" + model = f'{model}:predict' data = { - "instances": {"prompt": form_data.prompt}, - "parameters": { - "sampleCount": form_data.n, - "outputOptions": {"mimeType": "image/png"}, + 'instances': {'prompt': form_data.prompt}, + 'parameters': { + 'sampleCount': form_data.n, + 'outputOptions': {'mimeType': 'image/png'}, }, } - elif ( - request.app.state.config.IMAGES_GEMINI_ENDPOINT_METHOD - == "generateContent" - ): - model = f"{model}:generateContent" - data = {"contents": [{"parts": [{"text": form_data.prompt}]}]} + elif request.app.state.config.IMAGES_GEMINI_ENDPOINT_METHOD == 'generateContent': + model = f'{model}:generateContent' + data = {'contents': [{'parts': [{'text': form_data.prompt}]}]} # Use asyncio.to_thread for the requests.post call r = await asyncio.to_thread( requests.post, - url=f"{request.app.state.config.IMAGES_GEMINI_API_BASE_URL}/models/{model}", + url=f'{request.app.state.config.IMAGES_GEMINI_API_BASE_URL}/models/{model}', json=data, headers=headers, ) @@ -695,22 +630,16 @@ async def image_generations( images = [] - if model.endswith(":predict"): - for image in res["predictions"]: - image_data, content_type = get_image_data( - image["bytesBase64Encoded"] - ) - _, url = upload_image( - request, image_data, content_type, {**data, **metadata}, user - ) - images.append({"url": url}) - elif model.endswith(":generateContent"): - for image in res["candidates"]: - for part in image["content"]["parts"]: - if part.get("inlineData", {}).get("data"): - image_data, content_type = get_image_data( - part["inlineData"]["data"] - ) + if model.endswith(':predict'): + for image in res['predictions']: + image_data, content_type = get_image_data(image['bytesBase64Encoded']) + _, url = upload_image(request, image_data, content_type, {**data, **metadata}, user) + images.append({'url': url}) + elif model.endswith(':generateContent'): + for image in res['candidates']: + for part in image['content']['parts']: + if part.get('inlineData', {}).get('data'): + image_data, content_type = get_image_data(part['inlineData']['data']) _, url = upload_image( request, image_data, @@ -718,37 +647,30 @@ async def image_generations( {**data, **metadata}, user, ) - images.append({"url": url}) + images.append({'url': url}) return images - elif request.app.state.config.IMAGE_GENERATION_ENGINE == "comfyui": + elif request.app.state.config.IMAGE_GENERATION_ENGINE == 'comfyui': data = { - "prompt": form_data.prompt, - "width": width, - "height": height, - "n": form_data.n, + 'prompt': form_data.prompt, + 'width': width, + 'height': height, + 'n': form_data.n, } - if ( - request.app.state.config.IMAGE_STEPS is not None - or form_data.steps is not None - ): - data["steps"] = ( - form_data.steps - if form_data.steps is not None - else request.app.state.config.IMAGE_STEPS - ) + if request.app.state.config.IMAGE_STEPS is not None or form_data.steps is not None: + data['steps'] = form_data.steps if form_data.steps is not None else request.app.state.config.IMAGE_STEPS if form_data.negative_prompt is not None: - data["negative_prompt"] = form_data.negative_prompt + data['negative_prompt'] = form_data.negative_prompt form_data = ComfyUICreateImageForm( **{ - "workflow": ComfyUIWorkflow( + 'workflow': ComfyUIWorkflow( **{ - "workflow": request.app.state.config.COMFYUI_WORKFLOW, - "nodes": request.app.state.config.COMFYUI_WORKFLOW_NODES, + 'workflow': request.app.state.config.COMFYUI_WORKFLOW, + 'nodes': request.app.state.config.COMFYUI_WORKFLOW_NODES, } ), **data, @@ -761,18 +683,16 @@ async def image_generations( request.app.state.config.COMFYUI_BASE_URL, request.app.state.config.COMFYUI_API_KEY, ) - log.debug(f"res: {res}") + log.debug(f'res: {res}') images = [] - for image in res["data"]: + for image in res['data']: headers = None if request.app.state.config.COMFYUI_API_KEY: - headers = { - "Authorization": f"Bearer {request.app.state.config.COMFYUI_API_KEY}" - } + headers = {'Authorization': f'Bearer {request.app.state.config.COMFYUI_API_KEY}'} - image_data, content_type = get_image_data(image["url"], headers) + image_data, content_type = get_image_data(image['url'], headers) _, url = upload_image( request, image_data, @@ -780,34 +700,27 @@ async def image_generations( {**form_data.model_dump(exclude_none=True), **metadata}, user, ) - images.append({"url": url}) + images.append({'url': url}) return images elif ( - request.app.state.config.IMAGE_GENERATION_ENGINE == "automatic1111" - or request.app.state.config.IMAGE_GENERATION_ENGINE == "" + request.app.state.config.IMAGE_GENERATION_ENGINE == 'automatic1111' + or request.app.state.config.IMAGE_GENERATION_ENGINE == '' ): if form_data.model: set_image_model(request, form_data.model) data = { - "prompt": form_data.prompt, - "batch_size": form_data.n, - "width": width, - "height": height, + 'prompt': form_data.prompt, + 'batch_size': form_data.n, + 'width': width, + 'height': height, } - if ( - request.app.state.config.IMAGE_STEPS is not None - or form_data.steps is not None - ): - data["steps"] = ( - form_data.steps - if form_data.steps is not None - else request.app.state.config.IMAGE_STEPS - ) + if request.app.state.config.IMAGE_STEPS is not None or form_data.steps is not None: + data['steps'] = form_data.steps if form_data.steps is not None else request.app.state.config.IMAGE_STEPS if form_data.negative_prompt is not None: - data["negative_prompt"] = form_data.negative_prompt + data['negative_prompt'] = form_data.negative_prompt if request.app.state.config.AUTOMATIC1111_PARAMS: data = {**data, **request.app.state.config.AUTOMATIC1111_PARAMS} @@ -815,33 +728,33 @@ async def image_generations( # Use asyncio.to_thread for the requests.post call r = await asyncio.to_thread( requests.post, - url=f"{request.app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/txt2img", + url=f'{request.app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/txt2img', json=data, - headers={"authorization": get_automatic1111_api_auth(request)}, + headers={'authorization': get_automatic1111_api_auth(request)}, ) res = r.json() - log.debug(f"res: {res}") + log.debug(f'res: {res}') images = [] - for image in res["images"]: + for image in res['images']: image_data, content_type = get_image_data(image) _, url = upload_image( request, image_data, content_type, - {**data, "info": res["info"], **metadata}, + {**data, 'info': res['info'], **metadata}, user, ) - images.append({"url": url}) + images.append({'url': url}) return images except Exception as e: error = e if r != None: data = r.json() - if "error" in data: - error = data["error"]["message"] + if 'error' in data: + error = data['error']['message'] raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(error)) @@ -855,7 +768,7 @@ class EditImageForm(BaseModel): background: Optional[str] = None -@router.post("/edit") +@router.post('/edit') async def image_edits( request: Request, form_data: EditImageForm, @@ -866,42 +779,33 @@ async def image_edits( width, height = None, None metadata = metadata or {} - if ( - request.app.state.config.IMAGE_EDIT_SIZE - and "x" in request.app.state.config.IMAGE_EDIT_SIZE - ) or (form_data.size and "x" in form_data.size): - size = ( - form_data.size - if form_data.size - else request.app.state.config.IMAGE_EDIT_SIZE - ) - width, height = tuple(map(int, size.split("x"))) + if (request.app.state.config.IMAGE_EDIT_SIZE and 'x' in request.app.state.config.IMAGE_EDIT_SIZE) or ( + form_data.size and 'x' in form_data.size + ): + size = form_data.size if form_data.size else request.app.state.config.IMAGE_EDIT_SIZE + width, height = tuple(map(int, size.split('x'))) - model = ( - request.app.state.config.IMAGE_EDIT_MODEL - if form_data.model is None - else form_data.model - ) + model = request.app.state.config.IMAGE_EDIT_MODEL if form_data.model is None else form_data.model try: async def load_url_image(data): - if data.startswith("data:"): + if data.startswith('data:'): return data - if data.startswith("http://") or data.startswith("https://"): + if data.startswith('http://') or data.startswith('https://'): # Validate URL to prevent SSRF attacks against local/private networks validate_url(data) r = await asyncio.to_thread(requests.get, data) r.raise_for_status() - image_data = base64.b64encode(r.content).decode("utf-8") - return f"data:{r.headers['content-type']};base64,{image_data}" + image_data = base64.b64encode(r.content).decode('utf-8') + return f'data:{r.headers["content-type"]};base64,{image_data}' else: file_id = None - if data.startswith("/api/v1/files"): - file_id = data.split("/api/v1/files/")[1].split("/content")[0] + if data.startswith('/api/v1/files'): + file_id = data.split('/api/v1/files/')[1].split('/content')[0] else: file_id = data @@ -909,12 +813,12 @@ async def image_edits( if isinstance(file_response, FileResponse): file_path = file_response.path - with open(file_path, "rb") as f: + with open(file_path, 'rb') as f: file_bytes = f.read() - image_data = base64.b64encode(file_bytes).decode("utf-8") + image_data = base64.b64encode(file_bytes).decode('utf-8') mime_type, _ = mimetypes.guess_type(file_path) - return f"data:{mime_type};base64,{image_data}" + return f'data:{mime_type};base64,{image_data}' return data # Load image(s) from URL(s) if necessary @@ -922,51 +826,47 @@ async def image_edits( form_data.image = await load_url_image(form_data.image) elif isinstance(form_data.image, list): # Load all images in parallel for better performance - form_data.image = list( - await asyncio.gather(*[load_url_image(img) for img in form_data.image]) - ) + form_data.image = list(await asyncio.gather(*[load_url_image(img) for img in form_data.image])) except Exception as e: raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e)) - def get_image_file_item(base64_string, param_name="image"): + def get_image_file_item(base64_string, param_name='image'): data = base64_string - header, encoded = data.split(",", 1) - mime_type = header.split(";")[0].lstrip("data:") + header, encoded = data.split(',', 1) + mime_type = header.split(';')[0].lstrip('data:') image_data = base64.b64decode(encoded) return ( param_name, ( - f"{uuid.uuid4()}.png", + f'{uuid.uuid4()}.png', io.BytesIO(image_data), - mime_type if mime_type else "image/png", + mime_type if mime_type else 'image/png', ), ) r = None try: - if request.app.state.config.IMAGE_EDIT_ENGINE == "openai": + if request.app.state.config.IMAGE_EDIT_ENGINE == 'openai': headers = { - "Authorization": f"Bearer {request.app.state.config.IMAGES_EDIT_OPENAI_API_KEY}", + 'Authorization': f'Bearer {request.app.state.config.IMAGES_EDIT_OPENAI_API_KEY}', } if ENABLE_FORWARD_USER_INFO_HEADERS: headers = include_user_info_headers(headers, user) data = { - "model": model, - "prompt": form_data.prompt, - **({"n": form_data.n} if form_data.n else {}), - **({"size": size} if size else {}), - **( - {"background": form_data.background} if form_data.background else {} - ), + 'model': model, + 'prompt': form_data.prompt, + **({'n': form_data.n} if form_data.n else {}), + **({'size': size} if size else {}), + **({'background': form_data.background} if form_data.background else {}), **( {} if re.match( IMAGE_URL_RESPONSE_MODELS_REGEX_PATTERN, request.app.state.config.IMAGE_EDIT_MODEL, ) - else {"response_format": "b64_json"} + else {'response_format': 'b64_json'} ), } @@ -975,16 +875,16 @@ async def image_edits( files = [get_image_file_item(form_data.image)] elif isinstance(form_data.image, list): for img in form_data.image: - files.append(get_image_file_item(img, "image[]")) + files.append(get_image_file_item(img, 'image[]')) - url_search_params = "" + url_search_params = '' if request.app.state.config.IMAGES_EDIT_OPENAI_API_VERSION: - url_search_params += f"?api-version={request.app.state.config.IMAGES_EDIT_OPENAI_API_VERSION}" + url_search_params += f'?api-version={request.app.state.config.IMAGES_EDIT_OPENAI_API_VERSION}' # Use asyncio.to_thread for the requests.post call r = await asyncio.to_thread( requests.post, - url=f"{request.app.state.config.IMAGES_EDIT_OPENAI_API_BASE_URL}/images/edits{url_search_params}", + url=f'{request.app.state.config.IMAGES_EDIT_OPENAI_API_BASE_URL}/images/edits{url_search_params}', headers=headers, files=files, data=data, @@ -994,46 +894,44 @@ async def image_edits( res = r.json() images = [] - for image in res["data"]: - if image_url := image.get("url", None): + for image in res['data']: + if image_url := image.get('url', None): image_data, content_type = get_image_data( image_url, - {k: v for k, v in headers.items() if k != "Content-Type"}, + {k: v for k, v in headers.items() if k != 'Content-Type'}, ) else: - image_data, content_type = get_image_data(image["b64_json"]) + image_data, content_type = get_image_data(image['b64_json']) - _, url = upload_image( - request, image_data, content_type, {**data, **metadata}, user - ) - images.append({"url": url}) + _, url = upload_image(request, image_data, content_type, {**data, **metadata}, user) + images.append({'url': url}) return images - elif request.app.state.config.IMAGE_EDIT_ENGINE == "gemini": + elif request.app.state.config.IMAGE_EDIT_ENGINE == 'gemini': headers = { - "Content-Type": "application/json", - "x-goog-api-key": request.app.state.config.IMAGES_EDIT_GEMINI_API_KEY, + 'Content-Type': 'application/json', + 'x-goog-api-key': request.app.state.config.IMAGES_EDIT_GEMINI_API_KEY, } - model = f"{model}:generateContent" - data = {"contents": [{"parts": [{"text": form_data.prompt}]}]} + model = f'{model}:generateContent' + data = {'contents': [{'parts': [{'text': form_data.prompt}]}]} if isinstance(form_data.image, str): - data["contents"][0]["parts"].append( + data['contents'][0]['parts'].append( { - "inline_data": { - "mime_type": "image/png", - "data": form_data.image.split(",", 1)[1], + 'inline_data': { + 'mime_type': 'image/png', + 'data': form_data.image.split(',', 1)[1], } } ) elif isinstance(form_data.image, list): - data["contents"][0]["parts"].extend( + data['contents'][0]['parts'].extend( [ { - "inline_data": { - "mime_type": "image/png", - "data": image.split(",", 1)[1], + 'inline_data': { + 'mime_type': 'image/png', + 'data': image.split(',', 1)[1], } } for image in form_data.image @@ -1043,7 +941,7 @@ async def image_edits( # Use asyncio.to_thread for the requests.post call r = await asyncio.to_thread( requests.post, - url=f"{request.app.state.config.IMAGES_EDIT_GEMINI_API_BASE_URL}/models/{model}", + url=f'{request.app.state.config.IMAGES_EDIT_GEMINI_API_BASE_URL}/models/{model}', json=data, headers=headers, ) @@ -1052,12 +950,10 @@ async def image_edits( res = r.json() images = [] - for image in res["candidates"]: - for part in image["content"]["parts"]: - if part.get("inlineData", {}).get("data"): - image_data, content_type = get_image_data( - part["inlineData"]["data"] - ) + for image in res['candidates']: + for part in image['content']['parts']: + if part.get('inlineData', {}).get('data'): + image_data, content_type = get_image_data(part['inlineData']['data']) _, url = upload_image( request, image_data, @@ -1065,11 +961,11 @@ async def image_edits( {**data, **metadata}, user, ) - images.append({"url": url}) + images.append({'url': url}) return images - elif request.app.state.config.IMAGE_EDIT_ENGINE == "comfyui": + elif request.app.state.config.IMAGE_EDIT_ENGINE == 'comfyui': try: files = [] if isinstance(form_data.image, str): @@ -1086,25 +982,25 @@ async def image_edits( request.app.state.config.IMAGES_EDIT_COMFYUI_BASE_URL, request.app.state.config.IMAGES_EDIT_COMFYUI_API_KEY, ) - comfyui_images.append(res.get("name", file_item[1][0])) + comfyui_images.append(res.get('name', file_item[1][0])) except Exception as e: - log.debug(f"Error uploading images to ComfyUI: {e}") - raise Exception("Failed to upload images to ComfyUI.") + log.debug(f'Error uploading images to ComfyUI: {e}') + raise Exception('Failed to upload images to ComfyUI.') data = { - "image": comfyui_images, - "prompt": form_data.prompt, - **({"width": width} if width is not None else {}), - **({"height": height} if height is not None else {}), - **({"n": form_data.n} if form_data.n else {}), + 'image': comfyui_images, + 'prompt': form_data.prompt, + **({'width': width} if width is not None else {}), + **({'height': height} if height is not None else {}), + **({'n': form_data.n} if form_data.n else {}), } form_data = ComfyUIEditImageForm( **{ - "workflow": ComfyUIWorkflow( + 'workflow': ComfyUIWorkflow( **{ - "workflow": request.app.state.config.IMAGES_EDIT_COMFYUI_WORKFLOW, - "nodes": request.app.state.config.IMAGES_EDIT_COMFYUI_WORKFLOW_NODES, + 'workflow': request.app.state.config.IMAGES_EDIT_COMFYUI_WORKFLOW, + 'nodes': request.app.state.config.IMAGES_EDIT_COMFYUI_WORKFLOW_NODES, } ), **data, @@ -1117,27 +1013,25 @@ async def image_edits( request.app.state.config.IMAGES_EDIT_COMFYUI_BASE_URL, request.app.state.config.IMAGES_EDIT_COMFYUI_API_KEY, ) - log.debug(f"res: {res}") + log.debug(f'res: {res}') image_urls = set() - for image in res["data"]: - image_urls.add(image["url"]) + for image in res['data']: + image_urls.add(image['url']) image_urls = list(image_urls) # Prioritize output type URLs if available - output_type_urls = [url for url in image_urls if "type=output" in url] + output_type_urls = [url for url in image_urls if 'type=output' in url] if output_type_urls: image_urls = output_type_urls - log.debug(f"Image URLs: {image_urls}") + log.debug(f'Image URLs: {image_urls}') images = [] for image_url in image_urls: headers = None if request.app.state.config.IMAGES_EDIT_COMFYUI_API_KEY: - headers = { - "Authorization": f"Bearer {request.app.state.config.IMAGES_EDIT_COMFYUI_API_KEY}" - } + headers = {'Authorization': f'Bearer {request.app.state.config.IMAGES_EDIT_COMFYUI_API_KEY}'} image_data, content_type = get_image_data(image_url, headers) _, url = upload_image( @@ -1147,7 +1041,7 @@ async def image_edits( {**form_data.model_dump(exclude_none=True), **metadata}, user, ) - images.append({"url": url}) + images.append({'url': url}) return images except Exception as e: @@ -1156,8 +1050,8 @@ async def image_edits( data = r.text try: data = json.loads(data) - if "error" in data: - error = data["error"]["message"] + if 'error' in data: + error = data['error']['message'] except Exception: error = data diff --git a/backend/open_webui/routers/knowledge.py b/backend/open_webui/routers/knowledge.py index 8bc58b4371..199ea110e7 100644 --- a/backend/open_webui/routers/knowledge.py +++ b/backend/open_webui/routers/knowledge.py @@ -50,7 +50,7 @@ PAGE_ITEM_COUNT = 30 # Knowledge Base Embedding ############################ -KNOWLEDGE_BASES_COLLECTION = "knowledge-bases" +KNOWLEDGE_BASES_COLLECTION = 'knowledge-bases' async def embed_knowledge_base_metadata( @@ -61,24 +61,24 @@ async def embed_knowledge_base_metadata( ) -> bool: """Generate and store embedding for knowledge base.""" try: - content = f"{name}\n\n{description}" if description else name + content = f'{name}\n\n{description}' if description else name embedding = await request.app.state.EMBEDDING_FUNCTION(content) VECTOR_DB_CLIENT.upsert( collection_name=KNOWLEDGE_BASES_COLLECTION, items=[ { - "id": knowledge_base_id, - "text": content, - "vector": embedding, - "metadata": { - "knowledge_base_id": knowledge_base_id, + 'id': knowledge_base_id, + 'text': content, + 'vector': embedding, + 'metadata': { + 'knowledge_base_id': knowledge_base_id, }, } ], ) return True except Exception as e: - log.error(f"Failed to embed knowledge base {knowledge_base_id}: {e}") + log.error(f'Failed to embed knowledge base {knowledge_base_id}: {e}') return False @@ -91,7 +91,7 @@ def remove_knowledge_base_metadata_embedding(knowledge_base_id: str) -> bool: ) return True except Exception as e: - log.debug(f"Failed to remove embedding for {knowledge_base_id}: {e}") + log.debug(f'Failed to remove embedding for {knowledge_base_id}: {e}') return False @@ -104,7 +104,7 @@ class KnowledgeAccessListResponse(BaseModel): total: int -@router.get("/", response_model=KnowledgeAccessListResponse) +@router.get('/', response_model=KnowledgeAccessListResponse) async def get_knowledge_bases( page: Optional[int] = 1, user=Depends(get_verified_user), @@ -118,23 +118,21 @@ async def get_knowledge_bases( groups = Groups.get_groups_by_member_id(user.id, db=db) user_group_ids = {group.id for group in groups} - if not user.role == "admin" or not BYPASS_ADMIN_ACCESS_CONTROL: + if not user.role == 'admin' or not BYPASS_ADMIN_ACCESS_CONTROL: if groups: - filter["group_ids"] = [group.id for group in groups] + filter['group_ids'] = [group.id for group in groups] - filter["user_id"] = user.id + filter['user_id'] = user.id - result = Knowledges.search_knowledge_bases( - user.id, filter=filter, skip=skip, limit=limit, db=db - ) + result = Knowledges.search_knowledge_bases(user.id, filter=filter, skip=skip, limit=limit, db=db) # Batch-fetch writable knowledge IDs in a single query instead of N has_access calls knowledge_base_ids = [knowledge_base.id for knowledge_base in result.items] writable_knowledge_base_ids = AccessGrants.get_accessible_resource_ids( user_id=user.id, - resource_type="knowledge", + resource_type='knowledge', resource_ids=knowledge_base_ids, - permission="write", + permission='write', user_group_ids=user_group_ids, db=db, ) @@ -145,7 +143,7 @@ async def get_knowledge_bases( **knowledge_base.model_dump(), write_access=( user.id == knowledge_base.user_id - or (user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL) + or (user.role == 'admin' and BYPASS_ADMIN_ACCESS_CONTROL) or knowledge_base.id in writable_knowledge_base_ids ), ) @@ -155,7 +153,7 @@ async def get_knowledge_bases( ) -@router.get("/search", response_model=KnowledgeAccessListResponse) +@router.get('/search', response_model=KnowledgeAccessListResponse) async def search_knowledge_bases( query: Optional[str] = None, view_option: Optional[str] = None, @@ -169,30 +167,28 @@ async def search_knowledge_bases( filter = {} if query: - filter["query"] = query + filter['query'] = query if view_option: - filter["view_option"] = view_option + filter['view_option'] = view_option groups = Groups.get_groups_by_member_id(user.id, db=db) user_group_ids = {group.id for group in groups} - if not user.role == "admin" or not BYPASS_ADMIN_ACCESS_CONTROL: + if not user.role == 'admin' or not BYPASS_ADMIN_ACCESS_CONTROL: if groups: - filter["group_ids"] = [group.id for group in groups] + filter['group_ids'] = [group.id for group in groups] - filter["user_id"] = user.id + filter['user_id'] = user.id - result = Knowledges.search_knowledge_bases( - user.id, filter=filter, skip=skip, limit=limit, db=db - ) + result = Knowledges.search_knowledge_bases(user.id, filter=filter, skip=skip, limit=limit, db=db) # Batch-fetch writable knowledge IDs in a single query instead of N has_access calls knowledge_base_ids = [knowledge_base.id for knowledge_base in result.items] writable_knowledge_base_ids = AccessGrants.get_accessible_resource_ids( user_id=user.id, - resource_type="knowledge", + resource_type='knowledge', resource_ids=knowledge_base_ids, - permission="write", + permission='write', user_group_ids=user_group_ids, db=db, ) @@ -203,7 +199,7 @@ async def search_knowledge_bases( **knowledge_base.model_dump(), write_access=( user.id == knowledge_base.user_id - or (user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL) + or (user.role == 'admin' and BYPASS_ADMIN_ACCESS_CONTROL) or knowledge_base.id in writable_knowledge_base_ids ), ) @@ -213,7 +209,7 @@ async def search_knowledge_bases( ) -@router.get("/search/files", response_model=KnowledgeFileListResponse) +@router.get('/search/files', response_model=KnowledgeFileListResponse) async def search_knowledge_files( query: Optional[str] = None, page: Optional[int] = 1, @@ -226,17 +222,15 @@ async def search_knowledge_files( filter = {} if query: - filter["query"] = query + filter['query'] = query groups = Groups.get_groups_by_member_id(user.id, db=db) if groups: - filter["group_ids"] = [group.id for group in groups] + filter['group_ids'] = [group.id for group in groups] - filter["user_id"] = user.id + filter['user_id'] = user.id - return Knowledges.search_knowledge_files( - filter=filter, skip=skip, limit=limit, db=db - ) + return Knowledges.search_knowledge_files(filter=filter, skip=skip, limit=limit, db=db) ############################ @@ -244,7 +238,7 @@ async def search_knowledge_files( ############################ -@router.post("/create", response_model=Optional[KnowledgeResponse]) +@router.post('/create', response_model=Optional[KnowledgeResponse]) async def create_new_knowledge( request: Request, form_data: KnowledgeForm, @@ -254,8 +248,8 @@ async def create_new_knowledge( # Database operations (has_permission, filter_allowed_access_grants, insert_new_knowledge) manage their own sessions. # This prevents holding a connection during embed_knowledge_base_metadata() # which makes external embedding API calls (1-5+ seconds). - if user.role != "admin" and not has_permission( - user.id, "workspace.knowledge", request.app.state.config.USER_PERMISSIONS + if user.role != 'admin' and not has_permission( + user.id, 'workspace.knowledge', request.app.state.config.USER_PERMISSIONS ): raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, @@ -267,7 +261,7 @@ async def create_new_knowledge( user.id, user.role, form_data.access_grants, - "sharing.public_knowledge", + 'sharing.public_knowledge', ) knowledge = Knowledges.insert_new_knowledge(user.id, form_data) @@ -293,13 +287,13 @@ async def create_new_knowledge( ############################ -@router.post("/reindex", response_model=bool) +@router.post('/reindex', response_model=bool) async def reindex_knowledge_files( request: Request, user=Depends(get_verified_user), db: Session = Depends(get_session), ): - if user.role != "admin": + if user.role != 'admin': raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.UNAUTHORIZED, @@ -307,18 +301,16 @@ async def reindex_knowledge_files( knowledge_bases = Knowledges.get_knowledge_bases(db=db) - log.info(f"Starting reindexing for {len(knowledge_bases)} knowledge bases") + log.info(f'Starting reindexing for {len(knowledge_bases)} knowledge bases') for knowledge_base in knowledge_bases: try: files = Knowledges.get_files_by_id(knowledge_base.id, db=db) try: if VECTOR_DB_CLIENT.has_collection(collection_name=knowledge_base.id): - VECTOR_DB_CLIENT.delete_collection( - collection_name=knowledge_base.id - ) + VECTOR_DB_CLIENT.delete_collection(collection_name=knowledge_base.id) except Exception as e: - log.error(f"Error deleting collection {knowledge_base.id}: {str(e)}") + log.error(f'Error deleting collection {knowledge_base.id}: {str(e)}') continue # Skip, don't raise failed_files = [] @@ -327,32 +319,26 @@ async def reindex_knowledge_files( await run_in_threadpool( process_file, request, - ProcessFileForm( - file_id=file.id, collection_name=knowledge_base.id - ), + ProcessFileForm(file_id=file.id, collection_name=knowledge_base.id), user=user, db=db, ) except Exception as e: - log.error( - f"Error processing file {file.filename} (ID: {file.id}): {str(e)}" - ) - failed_files.append({"file_id": file.id, "error": str(e)}) + log.error(f'Error processing file {file.filename} (ID: {file.id}): {str(e)}') + failed_files.append({'file_id': file.id, 'error': str(e)}) continue except Exception as e: - log.error(f"Error processing knowledge base {knowledge_base.id}: {str(e)}") + log.error(f'Error processing knowledge base {knowledge_base.id}: {str(e)}') # Don't raise, just continue continue if failed_files: - log.warning( - f"Failed to process {len(failed_files)} files in knowledge base {knowledge_base.id}" - ) + log.warning(f'Failed to process {len(failed_files)} files in knowledge base {knowledge_base.id}') for failed in failed_files: - log.warning(f"File ID: {failed['file_id']}, Error: {failed['error']}") + log.warning(f'File ID: {failed["file_id"]}, Error: {failed["error"]}') - log.info(f"Reindexing completed.") + log.info(f'Reindexing completed.') return True @@ -361,7 +347,7 @@ async def reindex_knowledge_files( ############################ -@router.post("/metadata/reindex", response_model=dict) +@router.post('/metadata/reindex', response_model=dict) async def reindex_knowledge_base_metadata_embeddings( request: Request, user=Depends(get_admin_user), @@ -374,15 +360,15 @@ async def reindex_knowledge_base_metadata_embeddings( this entire operation would exhaust the connection pool. """ knowledge_bases = Knowledges.get_knowledge_bases() - log.info(f"Reindexing embeddings for {len(knowledge_bases)} knowledge bases") + log.info(f'Reindexing embeddings for {len(knowledge_bases)} knowledge bases') success_count = 0 for kb in knowledge_bases: if await embed_knowledge_base_metadata(request, kb.id, kb.name, kb.description): success_count += 1 - log.info(f"Embedding reindex complete: {success_count}/{len(knowledge_bases)}") - return {"total": len(knowledge_bases), "success": success_count} + log.info(f'Embedding reindex complete: {success_count}/{len(knowledge_bases)}') + return {'total': len(knowledge_bases), 'success': success_count} ############################ @@ -395,35 +381,32 @@ class KnowledgeFilesResponse(KnowledgeResponse): write_access: Optional[bool] = False -@router.get("/{id}", response_model=Optional[KnowledgeFilesResponse]) -async def get_knowledge_by_id( - id: str, user=Depends(get_verified_user), db: Session = Depends(get_session) -): +@router.get('/{id}', response_model=Optional[KnowledgeFilesResponse]) +async def get_knowledge_by_id(id: str, user=Depends(get_verified_user), db: Session = Depends(get_session)): knowledge = Knowledges.get_knowledge_by_id(id=id, db=db) if knowledge: if ( - user.role == "admin" + user.role == 'admin' or knowledge.user_id == user.id or AccessGrants.has_access( user_id=user.id, - resource_type="knowledge", + resource_type='knowledge', resource_id=knowledge.id, - permission="read", + permission='read', db=db, ) ): - return KnowledgeFilesResponse( **knowledge.model_dump(), write_access=( user.id == knowledge.user_id - or (user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL) + or (user.role == 'admin' and BYPASS_ADMIN_ACCESS_CONTROL) or AccessGrants.has_access( user_id=user.id, - resource_type="knowledge", + resource_type='knowledge', resource_id=knowledge.id, - permission="write", + permission='write', db=db, ) ), @@ -445,7 +428,7 @@ async def get_knowledge_by_id( ############################ -@router.post("/{id}/update", response_model=Optional[KnowledgeFilesResponse]) +@router.post('/{id}/update', response_model=Optional[KnowledgeFilesResponse]) async def update_knowledge_by_id( request: Request, id: str, @@ -467,11 +450,11 @@ async def update_knowledge_by_id( knowledge.user_id != user.id and not AccessGrants.has_access( user_id=user.id, - resource_type="knowledge", + resource_type='knowledge', resource_id=knowledge.id, - permission="write", + permission='write', ) - and user.role != "admin" + and user.role != 'admin' ): raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, @@ -483,7 +466,7 @@ async def update_knowledge_by_id( user.id, user.role, form_data.access_grants, - "sharing.public_knowledge", + 'sharing.public_knowledge', ) knowledge = Knowledges.update_knowledge_by_id(id=id, form_data=form_data) @@ -515,7 +498,7 @@ class KnowledgeAccessGrantsForm(BaseModel): access_grants: list[dict] -@router.post("/{id}/access/update", response_model=Optional[KnowledgeFilesResponse]) +@router.post('/{id}/access/update', response_model=Optional[KnowledgeFilesResponse]) async def update_knowledge_access_by_id( request: Request, id: str, @@ -534,12 +517,12 @@ async def update_knowledge_access_by_id( knowledge.user_id != user.id and not AccessGrants.has_access( user_id=user.id, - resource_type="knowledge", + resource_type='knowledge', resource_id=knowledge.id, - permission="write", + permission='write', db=db, ) - and user.role != "admin" + and user.role != 'admin' ): raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, @@ -551,10 +534,10 @@ async def update_knowledge_access_by_id( user.id, user.role, form_data.access_grants, - "sharing.public_knowledge", + 'sharing.public_knowledge', ) - AccessGrants.set_access_grants("knowledge", id, form_data.access_grants, db=db) + AccessGrants.set_access_grants('knowledge', id, form_data.access_grants, db=db) return KnowledgeFilesResponse( **Knowledges.get_knowledge_by_id(id=id, db=db).model_dump(), @@ -567,7 +550,7 @@ async def update_knowledge_access_by_id( ############################ -@router.get("/{id}/files", response_model=KnowledgeFileListResponse) +@router.get('/{id}/files', response_model=KnowledgeFileListResponse) async def get_knowledge_files_by_id( id: str, query: Optional[str] = None, @@ -578,7 +561,6 @@ async def get_knowledge_files_by_id( user=Depends(get_verified_user), db: Session = Depends(get_session), ): - knowledge = Knowledges.get_knowledge_by_id(id=id, db=db) if not knowledge: raise HTTPException( @@ -587,13 +569,13 @@ async def get_knowledge_files_by_id( ) if not ( - user.role == "admin" + user.role == 'admin' or knowledge.user_id == user.id or AccessGrants.has_access( user_id=user.id, - resource_type="knowledge", + resource_type='knowledge', resource_id=knowledge.id, - permission="read", + permission='read', db=db, ) ): @@ -609,17 +591,15 @@ async def get_knowledge_files_by_id( filter = {} if query: - filter["query"] = query + filter['query'] = query if view_option: - filter["view_option"] = view_option + filter['view_option'] = view_option if order_by: - filter["order_by"] = order_by + filter['order_by'] = order_by if direction: - filter["direction"] = direction + filter['direction'] = direction - return Knowledges.search_files_by_id( - id, user.id, filter=filter, skip=skip, limit=limit, db=db - ) + return Knowledges.search_files_by_id(id, user.id, filter=filter, skip=skip, limit=limit, db=db) ############################ @@ -631,7 +611,7 @@ class KnowledgeFileIdForm(BaseModel): file_id: str -@router.post("/{id}/file/add", response_model=Optional[KnowledgeFilesResponse]) +@router.post('/{id}/file/add', response_model=Optional[KnowledgeFilesResponse]) def add_file_to_knowledge_by_id( request: Request, id: str, @@ -650,12 +630,12 @@ def add_file_to_knowledge_by_id( knowledge.user_id != user.id and not AccessGrants.has_access( user_id=user.id, - resource_type="knowledge", + resource_type='knowledge', resource_id=knowledge.id, - permission="write", + permission='write', db=db, ) - and user.role != "admin" + and user.role != 'admin' ): raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, @@ -684,9 +664,7 @@ def add_file_to_knowledge_by_id( ) # Add file to knowledge base - Knowledges.add_file_to_knowledge_by_id( - knowledge_id=id, file_id=form_data.file_id, user_id=user.id, db=db - ) + Knowledges.add_file_to_knowledge_by_id(knowledge_id=id, file_id=form_data.file_id, user_id=user.id, db=db) except Exception as e: log.debug(e) raise HTTPException( @@ -706,7 +684,7 @@ def add_file_to_knowledge_by_id( ) -@router.post("/{id}/file/update", response_model=Optional[KnowledgeFilesResponse]) +@router.post('/{id}/file/update', response_model=Optional[KnowledgeFilesResponse]) def update_file_from_knowledge_by_id( request: Request, id: str, @@ -725,14 +703,13 @@ def update_file_from_knowledge_by_id( knowledge.user_id != user.id and not AccessGrants.has_access( user_id=user.id, - resource_type="knowledge", + resource_type='knowledge', resource_id=knowledge.id, - permission="write", + permission='write', db=db, ) - and user.role != "admin" + and user.role != 'admin' ): - raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.ACCESS_PROHIBITED, @@ -753,9 +730,7 @@ def update_file_from_knowledge_by_id( ) # Remove content from the vector database - VECTOR_DB_CLIENT.delete( - collection_name=knowledge.id, filter={"file_id": form_data.file_id} - ) + VECTOR_DB_CLIENT.delete(collection_name=knowledge.id, filter={'file_id': form_data.file_id}) # Add content to the vector database try: @@ -788,7 +763,7 @@ def update_file_from_knowledge_by_id( ############################ -@router.post("/{id}/file/remove", response_model=Optional[KnowledgeFilesResponse]) +@router.post('/{id}/file/remove', response_model=Optional[KnowledgeFilesResponse]) def remove_file_from_knowledge_by_id( id: str, form_data: KnowledgeFileIdForm, @@ -807,12 +782,12 @@ def remove_file_from_knowledge_by_id( knowledge.user_id != user.id and not AccessGrants.has_access( user_id=user.id, - resource_type="knowledge", + resource_type='knowledge', resource_id=knowledge.id, - permission="write", + permission='write', db=db, ) - and user.role != "admin" + and user.role != 'admin' ): raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, @@ -833,32 +808,30 @@ def remove_file_from_knowledge_by_id( detail=ERROR_MESSAGES.NOT_FOUND, ) - Knowledges.remove_file_from_knowledge_by_id( - knowledge_id=id, file_id=form_data.file_id, db=db - ) + Knowledges.remove_file_from_knowledge_by_id(knowledge_id=id, file_id=form_data.file_id, db=db) # Remove content from the vector database try: VECTOR_DB_CLIENT.delete( - collection_name=knowledge.id, filter={"file_id": form_data.file_id} + collection_name=knowledge.id, filter={'file_id': form_data.file_id} ) # Remove by file_id first VECTOR_DB_CLIENT.delete( - collection_name=knowledge.id, filter={"hash": file.hash} + collection_name=knowledge.id, filter={'hash': file.hash} ) # Remove by hash as well in case of duplicates except Exception as e: - log.debug("This was most likely caused by bypassing embedding processing") + log.debug('This was most likely caused by bypassing embedding processing') log.debug(e) pass if delete_file: try: # Remove the file's collection from vector database - file_collection = f"file-{form_data.file_id}" + file_collection = f'file-{form_data.file_id}' if VECTOR_DB_CLIENT.has_collection(collection_name=file_collection): VECTOR_DB_CLIENT.delete_collection(collection_name=file_collection) except Exception as e: - log.debug("This was most likely caused by bypassing embedding processing") + log.debug('This was most likely caused by bypassing embedding processing') log.debug(e) pass @@ -882,10 +855,8 @@ def remove_file_from_knowledge_by_id( ############################ -@router.delete("/{id}/delete", response_model=bool) -async def delete_knowledge_by_id( - id: str, user=Depends(get_verified_user), db: Session = Depends(get_session) -): +@router.delete('/{id}/delete', response_model=bool) +async def delete_knowledge_by_id(id: str, user=Depends(get_verified_user), db: Session = Depends(get_session)): knowledge = Knowledges.get_knowledge_by_id(id=id, db=db) if not knowledge: raise HTTPException( @@ -897,34 +868,34 @@ async def delete_knowledge_by_id( knowledge.user_id != user.id and not AccessGrants.has_access( user_id=user.id, - resource_type="knowledge", + resource_type='knowledge', resource_id=knowledge.id, - permission="write", + permission='write', db=db, ) - and user.role != "admin" + and user.role != 'admin' ): raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.ACCESS_PROHIBITED, ) - log.info(f"Deleting knowledge base: {id} (name: {knowledge.name})") + log.info(f'Deleting knowledge base: {id} (name: {knowledge.name})') # Get all models models = Models.get_all_models(db=db) - log.info(f"Found {len(models)} models to check for knowledge base {id}") + log.info(f'Found {len(models)} models to check for knowledge base {id}') # Update models that reference this knowledge base for model in models: - if model.meta and hasattr(model.meta, "knowledge"): + if model.meta and hasattr(model.meta, 'knowledge'): knowledge_list = model.meta.knowledge or [] # Filter out the deleted knowledge base - updated_knowledge = [k for k in knowledge_list if k.get("id") != id] + updated_knowledge = [k for k in knowledge_list if k.get('id') != id] # If the knowledge list changed, update the model if len(updated_knowledge) != len(knowledge_list): - log.info(f"Updating model {model.id} to remove knowledge base {id}") + log.info(f'Updating model {model.id} to remove knowledge base {id}') model.meta.knowledge = updated_knowledge # Create a ModelForm for the update model_form = ModelForm( @@ -957,10 +928,8 @@ async def delete_knowledge_by_id( ############################ -@router.post("/{id}/reset", response_model=Optional[KnowledgeResponse]) -async def reset_knowledge_by_id( - id: str, user=Depends(get_verified_user), db: Session = Depends(get_session) -): +@router.post('/{id}/reset', response_model=Optional[KnowledgeResponse]) +async def reset_knowledge_by_id(id: str, user=Depends(get_verified_user), db: Session = Depends(get_session)): knowledge = Knowledges.get_knowledge_by_id(id=id, db=db) if not knowledge: raise HTTPException( @@ -972,12 +941,12 @@ async def reset_knowledge_by_id( knowledge.user_id != user.id and not AccessGrants.has_access( user_id=user.id, - resource_type="knowledge", + resource_type='knowledge', resource_id=knowledge.id, - permission="write", + permission='write', db=db, ) - and user.role != "admin" + and user.role != 'admin' ): raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, @@ -999,7 +968,7 @@ async def reset_knowledge_by_id( ############################ -@router.post("/{id}/files/batch/add", response_model=Optional[KnowledgeFilesResponse]) +@router.post('/{id}/files/batch/add', response_model=Optional[KnowledgeFilesResponse]) async def add_files_to_knowledge_batch( request: Request, id: str, @@ -1021,12 +990,12 @@ async def add_files_to_knowledge_batch( knowledge.user_id != user.id and not AccessGrants.has_access( user_id=user.id, - resource_type="knowledge", + resource_type='knowledge', resource_id=knowledge.id, - permission="write", + permission='write', db=db, ) - and user.role != "admin" + and user.role != 'admin' ): raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, @@ -1034,7 +1003,7 @@ async def add_files_to_knowledge_batch( ) # Batch-fetch all files to avoid N+1 queries - log.info(f"files/batch/add - {len(form_data)} files") + log.info(f'files/batch/add - {len(form_data)} files') file_ids = [form.file_id for form in form_data] files = Files.get_files_by_ids(file_ids, db=db) @@ -1044,7 +1013,7 @@ async def add_files_to_knowledge_batch( if missing_ids: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, - detail=f"File {missing_ids[0]} not found", + detail=f'File {missing_ids[0]} not found', ) # Process files @@ -1056,27 +1025,23 @@ async def add_files_to_knowledge_batch( db=db, ) except Exception as e: - log.error( - f"add_files_to_knowledge_batch: Exception occurred: {e}", exc_info=True - ) + log.error(f'add_files_to_knowledge_batch: Exception occurred: {e}', exc_info=True) raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e)) # Only add files that were successfully processed - successful_file_ids = [r.file_id for r in result.results if r.status == "completed"] + successful_file_ids = [r.file_id for r in result.results if r.status == 'completed'] for file_id in successful_file_ids: - Knowledges.add_file_to_knowledge_by_id( - knowledge_id=id, file_id=file_id, user_id=user.id, db=db - ) + Knowledges.add_file_to_knowledge_by_id(knowledge_id=id, file_id=file_id, user_id=user.id, db=db) # If there were any errors, include them in the response if result.errors: - error_details = [f"{err.file_id}: {err.error}" for err in result.errors] + error_details = [f'{err.file_id}: {err.error}' for err in result.errors] return KnowledgeFilesResponse( **knowledge.model_dump(), files=Knowledges.get_file_metadatas_by_id(knowledge.id, db=db), warnings={ - "message": "Some files failed to process", - "errors": error_details, + 'message': 'Some files failed to process', + 'errors': error_details, }, ) @@ -1091,10 +1056,8 @@ async def add_files_to_knowledge_batch( ############################ -@router.get("/{id}/export") -async def export_knowledge_by_id( - id: str, user=Depends(get_admin_user), db: Session = Depends(get_session) -): +@router.get('/{id}/export') +async def export_knowledge_by_id(id: str, user=Depends(get_admin_user), db: Session = Depends(get_session)): """ Export a knowledge base as a zip file containing .txt files. Admin only. @@ -1111,24 +1074,24 @@ async def export_knowledge_by_id( # Create zip file in memory zip_buffer = io.BytesIO() - with zipfile.ZipFile(zip_buffer, "w", zipfile.ZIP_DEFLATED) as zf: + with zipfile.ZipFile(zip_buffer, 'w', zipfile.ZIP_DEFLATED) as zf: for file in files: - content = file.data.get("content", "") if file.data else "" + content = file.data.get('content', '') if file.data else '' if content: # Use original filename with .txt extension filename = file.filename - if not filename.endswith(".txt"): - filename = f"{filename}.txt" + if not filename.endswith('.txt'): + filename = f'{filename}.txt' zf.writestr(filename, content) zip_buffer.seek(0) # Sanitize knowledge name for filename - safe_name = "".join(c if c.isalnum() or c in " -_" else "_" for c in knowledge.name) - zip_filename = f"{safe_name}.zip" + safe_name = ''.join(c if c.isalnum() or c in ' -_' else '_' for c in knowledge.name) + zip_filename = f'{safe_name}.zip' return StreamingResponse( zip_buffer, - media_type="application/zip", - headers={"Content-Disposition": f"attachment; filename={zip_filename}"}, + media_type='application/zip', + headers={'Content-Disposition': f'attachment; filename={zip_filename}'}, ) diff --git a/backend/open_webui/routers/memories.py b/backend/open_webui/routers/memories.py index db2bf5e238..82af3a580c 100644 --- a/backend/open_webui/routers/memories.py +++ b/backend/open_webui/routers/memories.py @@ -23,7 +23,7 @@ router = APIRouter() ############################ -@router.get("/", response_model=list[MemoryModel]) +@router.get('/', response_model=list[MemoryModel]) async def get_memories( request: Request, user=Depends(get_verified_user), @@ -35,9 +35,7 @@ async def get_memories( detail=ERROR_MESSAGES.NOT_FOUND, ) - if not has_permission( - user.id, "features.memories", request.app.state.config.USER_PERMISSIONS - ): + if not has_permission(user.id, 'features.memories', request.app.state.config.USER_PERMISSIONS): raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.ACCESS_PROHIBITED, @@ -59,7 +57,7 @@ class MemoryUpdateModel(BaseModel): content: Optional[str] = None -@router.post("/add", response_model=Optional[MemoryModel]) +@router.post('/add', response_model=Optional[MemoryModel]) async def add_memory( request: Request, form_data: AddMemoryForm, @@ -75,9 +73,7 @@ async def add_memory( detail=ERROR_MESSAGES.NOT_FOUND, ) - if not has_permission( - user.id, "features.memories", request.app.state.config.USER_PERMISSIONS - ): + if not has_permission(user.id, 'features.memories', request.app.state.config.USER_PERMISSIONS): raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.ACCESS_PROHIBITED, @@ -88,13 +84,13 @@ async def add_memory( vector = await request.app.state.EMBEDDING_FUNCTION(memory.content, user=user) VECTOR_DB_CLIENT.upsert( - collection_name=f"user-memory-{user.id}", + collection_name=f'user-memory-{user.id}', items=[ { - "id": memory.id, - "text": memory.content, - "vector": vector, - "metadata": {"created_at": memory.created_at}, + 'id': memory.id, + 'text': memory.content, + 'vector': vector, + 'metadata': {'created_at': memory.created_at}, } ], ) @@ -112,7 +108,7 @@ class QueryMemoryForm(BaseModel): k: Optional[int] = 1 -@router.post("/query") +@router.post('/query') async def query_memory( request: Request, form_data: QueryMemoryForm, @@ -128,9 +124,7 @@ async def query_memory( detail=ERROR_MESSAGES.NOT_FOUND, ) - if not has_permission( - user.id, "features.memories", request.app.state.config.USER_PERMISSIONS - ): + if not has_permission(user.id, 'features.memories', request.app.state.config.USER_PERMISSIONS): raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.ACCESS_PROHIBITED, @@ -138,12 +132,12 @@ async def query_memory( memories = Memories.get_memories_by_user_id(user.id) if not memories: - raise HTTPException(status_code=404, detail="No memories found for user") + raise HTTPException(status_code=404, detail='No memories found for user') vector = await request.app.state.EMBEDDING_FUNCTION(form_data.content, user=user) results = VECTOR_DB_CLIENT.search( - collection_name=f"user-memory-{user.id}", + collection_name=f'user-memory-{user.id}', vectors=[vector], limit=form_data.k, ) @@ -154,7 +148,7 @@ async def query_memory( ############################ # ResetMemoryFromVectorDB ############################ -@router.post("/reset", response_model=bool) +@router.post('/reset', response_model=bool) async def reset_memory_from_vector_db( request: Request, user=Depends(get_verified_user), @@ -173,36 +167,31 @@ async def reset_memory_from_vector_db( detail=ERROR_MESSAGES.NOT_FOUND, ) - if not has_permission( - user.id, "features.memories", request.app.state.config.USER_PERMISSIONS - ): + if not has_permission(user.id, 'features.memories', request.app.state.config.USER_PERMISSIONS): raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.ACCESS_PROHIBITED, ) - VECTOR_DB_CLIENT.delete_collection(f"user-memory-{user.id}") + VECTOR_DB_CLIENT.delete_collection(f'user-memory-{user.id}') memories = Memories.get_memories_by_user_id(user.id) # Generate vectors in parallel vectors = await asyncio.gather( - *[ - request.app.state.EMBEDDING_FUNCTION(memory.content, user=user) - for memory in memories - ] + *[request.app.state.EMBEDDING_FUNCTION(memory.content, user=user) for memory in memories] ) VECTOR_DB_CLIENT.upsert( - collection_name=f"user-memory-{user.id}", + collection_name=f'user-memory-{user.id}', items=[ { - "id": memory.id, - "text": memory.content, - "vector": vectors[idx], - "metadata": { - "created_at": memory.created_at, - "updated_at": memory.updated_at, + 'id': memory.id, + 'text': memory.content, + 'vector': vectors[idx], + 'metadata': { + 'created_at': memory.created_at, + 'updated_at': memory.updated_at, }, } for idx, memory in enumerate(memories) @@ -217,7 +206,7 @@ async def reset_memory_from_vector_db( ############################ -@router.delete("/delete/user", response_model=bool) +@router.delete('/delete/user', response_model=bool) async def delete_memory_by_user_id( request: Request, user=Depends(get_verified_user), @@ -229,9 +218,7 @@ async def delete_memory_by_user_id( detail=ERROR_MESSAGES.NOT_FOUND, ) - if not has_permission( - user.id, "features.memories", request.app.state.config.USER_PERMISSIONS - ): + if not has_permission(user.id, 'features.memories', request.app.state.config.USER_PERMISSIONS): raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.ACCESS_PROHIBITED, @@ -241,7 +228,7 @@ async def delete_memory_by_user_id( if result: try: - VECTOR_DB_CLIENT.delete_collection(f"user-memory-{user.id}") + VECTOR_DB_CLIENT.delete_collection(f'user-memory-{user.id}') except Exception as e: log.error(e) return True @@ -254,7 +241,7 @@ async def delete_memory_by_user_id( ############################ -@router.post("/{memory_id}/update", response_model=Optional[MemoryModel]) +@router.post('/{memory_id}/update', response_model=Optional[MemoryModel]) async def update_memory_by_id( memory_id: str, request: Request, @@ -271,33 +258,29 @@ async def update_memory_by_id( detail=ERROR_MESSAGES.NOT_FOUND, ) - if not has_permission( - user.id, "features.memories", request.app.state.config.USER_PERMISSIONS - ): + if not has_permission(user.id, 'features.memories', request.app.state.config.USER_PERMISSIONS): raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.ACCESS_PROHIBITED, ) - memory = Memories.update_memory_by_id_and_user_id( - memory_id, user.id, form_data.content - ) + memory = Memories.update_memory_by_id_and_user_id(memory_id, user.id, form_data.content) if memory is None: - raise HTTPException(status_code=404, detail="Memory not found") + raise HTTPException(status_code=404, detail='Memory not found') if form_data.content is not None: vector = await request.app.state.EMBEDDING_FUNCTION(memory.content, user=user) VECTOR_DB_CLIENT.upsert( - collection_name=f"user-memory-{user.id}", + collection_name=f'user-memory-{user.id}', items=[ { - "id": memory.id, - "text": memory.content, - "vector": vector, - "metadata": { - "created_at": memory.created_at, - "updated_at": memory.updated_at, + 'id': memory.id, + 'text': memory.content, + 'vector': vector, + 'metadata': { + 'created_at': memory.created_at, + 'updated_at': memory.updated_at, }, } ], @@ -311,7 +294,7 @@ async def update_memory_by_id( ############################ -@router.delete("/{memory_id}", response_model=bool) +@router.delete('/{memory_id}', response_model=bool) async def delete_memory_by_id( memory_id: str, request: Request, @@ -324,9 +307,7 @@ async def delete_memory_by_id( detail=ERROR_MESSAGES.NOT_FOUND, ) - if not has_permission( - user.id, "features.memories", request.app.state.config.USER_PERMISSIONS - ): + if not has_permission(user.id, 'features.memories', request.app.state.config.USER_PERMISSIONS): raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.ACCESS_PROHIBITED, @@ -335,9 +316,7 @@ async def delete_memory_by_id( result = Memories.delete_memory_by_id_and_user_id(memory_id, user.id, db=db) if result: - VECTOR_DB_CLIENT.delete( - collection_name=f"user-memory-{user.id}", ids=[memory_id] - ) + VECTOR_DB_CLIENT.delete(collection_name=f'user-memory-{user.id}', ids=[memory_id]) return True return False diff --git a/backend/open_webui/routers/models.py b/backend/open_webui/routers/models.py index 95279bc378..9dc602dc0e 100644 --- a/backend/open_webui/routers/models.py +++ b/backend/open_webui/routers/models.py @@ -55,9 +55,7 @@ def is_valid_model_id(model_id: str) -> bool: PAGE_ITEM_COUNT = 30 -@router.get( - "/list", response_model=ModelAccessListResponse -) # do NOT use "/" as path, conflicts with main.py +@router.get('/list', response_model=ModelAccessListResponse) # do NOT use "/" as path, conflicts with main.py async def get_models( query: Optional[str] = None, view_option: Optional[str] = None, @@ -68,7 +66,6 @@ async def get_models( user=Depends(get_verified_user), db: Session = Depends(get_session), ): - limit = PAGE_ITEM_COUNT page = max(1, page) @@ -76,25 +73,25 @@ async def get_models( filter = {} if query: - filter["query"] = query + filter['query'] = query if view_option: - filter["view_option"] = view_option + filter['view_option'] = view_option if tag: - filter["tag"] = tag + filter['tag'] = tag if order_by: - filter["order_by"] = order_by + filter['order_by'] = order_by if direction: - filter["direction"] = direction + filter['direction'] = direction # Pre-fetch user group IDs once - used for both filter and write_access check groups = Groups.get_groups_by_member_id(user.id, db=db) user_group_ids = {group.id for group in groups} - if not user.role == "admin" or not BYPASS_ADMIN_ACCESS_CONTROL: + if not user.role == 'admin' or not BYPASS_ADMIN_ACCESS_CONTROL: if groups: - filter["group_ids"] = [group.id for group in groups] + filter['group_ids'] = [group.id for group in groups] - filter["user_id"] = user.id + filter['user_id'] = user.id result = Models.search_models(user.id, filter=filter, skip=skip, limit=limit, db=db) @@ -102,9 +99,9 @@ async def get_models( model_ids = [model.id for model in result.items] writable_model_ids = AccessGrants.get_accessible_resource_ids( user_id=user.id, - resource_type="model", + resource_type='model', resource_ids=model_ids, - permission="write", + permission='write', user_group_ids=user_group_ids, db=db, ) @@ -114,7 +111,7 @@ async def get_models( ModelAccessResponse( **model.model_dump(), write_access=( - (user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL) + (user.role == 'admin' and BYPASS_ADMIN_ACCESS_CONTROL) or user.id == model.user_id or model.id in writable_model_ids ), @@ -130,10 +127,8 @@ async def get_models( ########################### -@router.get("/base", response_model=list[ModelResponse]) -async def get_base_models( - user=Depends(get_admin_user), db: Session = Depends(get_session) -): +@router.get('/base', response_model=list[ModelResponse]) +async def get_base_models(user=Depends(get_admin_user), db: Session = Depends(get_session)): return Models.get_base_models(db=db) @@ -142,11 +137,9 @@ async def get_base_models( ########################### -@router.get("/tags", response_model=list[str]) -async def get_model_tags( - user=Depends(get_verified_user), db: Session = Depends(get_session) -): - if user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL: +@router.get('/tags', response_model=list[str]) +async def get_model_tags(user=Depends(get_verified_user), db: Session = Depends(get_session)): + if user.role == 'admin' and BYPASS_ADMIN_ACCESS_CONTROL: models = Models.get_models(db=db) else: models = Models.get_models_by_user_id(user.id, db=db) @@ -155,8 +148,8 @@ async def get_model_tags( for model in models: if model.meta: meta = model.meta.model_dump() - for tag in meta.get("tags", []): - tags_set.add((tag.get("name"))) + for tag in meta.get('tags', []): + tags_set.add((tag.get('name'))) tags = [tag for tag in tags_set] tags.sort() @@ -168,15 +161,15 @@ async def get_model_tags( ############################ -@router.post("/create", response_model=Optional[ModelModel]) +@router.post('/create', response_model=Optional[ModelModel]) async def create_new_model( request: Request, form_data: ModelForm, user=Depends(get_verified_user), db: Session = Depends(get_session), ): - if user.role != "admin" and not has_permission( - user.id, "workspace.models", request.app.state.config.USER_PERMISSIONS, db=db + if user.role != 'admin' and not has_permission( + user.id, 'workspace.models', request.app.state.config.USER_PERMISSIONS, db=db ): raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, @@ -212,15 +205,15 @@ async def create_new_model( ############################ -@router.get("/export", response_model=list[ModelModel]) +@router.get('/export', response_model=list[ModelModel]) async def export_models( request: Request, user=Depends(get_verified_user), db: Session = Depends(get_session), ): - if user.role != "admin" and not has_permission( + if user.role != 'admin' and not has_permission( user.id, - "workspace.models_export", + 'workspace.models_export', request.app.state.config.USER_PERMISSIONS, db=db, ): @@ -229,7 +222,7 @@ async def export_models( detail=ERROR_MESSAGES.UNAUTHORIZED, ) - if user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL: + if user.role == 'admin' and BYPASS_ADMIN_ACCESS_CONTROL: return Models.get_models(db=db) else: return Models.get_models_by_user_id(user.id, db=db) @@ -244,16 +237,16 @@ class ModelsImportForm(BaseModel): models: list[dict] -@router.post("/import", response_model=bool) +@router.post('/import', response_model=bool) async def import_models( request: Request, user=Depends(get_verified_user), form_data: ModelsImportForm = (...), db: Session = Depends(get_session), ): - if user.role != "admin" and not has_permission( + if user.role != 'admin' and not has_permission( user.id, - "workspace.models_import", + 'workspace.models_import', request.app.state.config.USER_PERMISSIONS, db=db, ): @@ -266,43 +259,36 @@ async def import_models( if isinstance(data, list): # Batch-fetch all existing models in one query to avoid N+1 model_ids = [ - model_data.get("id") + model_data.get('id') for model_data in data - if model_data.get("id") and is_valid_model_id(model_data.get("id")) + if model_data.get('id') and is_valid_model_id(model_data.get('id')) ] existing_models = { - model.id: model - for model in ( - Models.get_models_by_ids(model_ids, db=db) if model_ids else [] - ) + model.id: model for model in (Models.get_models_by_ids(model_ids, db=db) if model_ids else []) } for model_data in data: # Here, you can add logic to validate model_data if needed - model_id = model_data.get("id") + model_id = model_data.get('id') if model_id and is_valid_model_id(model_id): existing_model = existing_models.get(model_id) if existing_model: # Update existing model - model_data["meta"] = model_data.get("meta", {}) - model_data["params"] = model_data.get("params", {}) + model_data['meta'] = model_data.get('meta', {}) + model_data['params'] = model_data.get('params', {}) - updated_model = ModelForm( - **{**existing_model.model_dump(), **model_data} - ) + updated_model = ModelForm(**{**existing_model.model_dump(), **model_data}) Models.update_model_by_id(model_id, updated_model, db=db) else: # Insert new model - model_data["meta"] = model_data.get("meta", {}) - model_data["params"] = model_data.get("params", {}) + model_data['meta'] = model_data.get('meta', {}) + model_data['params'] = model_data.get('params', {}) new_model = ModelForm(**model_data) - Models.insert_new_model( - user_id=user.id, form_data=new_model, db=db - ) + Models.insert_new_model(user_id=user.id, form_data=new_model, db=db) return True else: - raise HTTPException(status_code=400, detail="Invalid JSON format") + raise HTTPException(status_code=400, detail='Invalid JSON format') except Exception as e: log.exception(e) raise HTTPException(status_code=500, detail=str(e)) @@ -317,7 +303,7 @@ class SyncModelsForm(BaseModel): models: list[ModelModel] = [] -@router.post("/sync", response_model=list[ModelModel]) +@router.post('/sync', response_model=list[ModelModel]) async def sync_models( request: Request, form_data: SyncModelsForm, @@ -337,33 +323,31 @@ class ModelIdForm(BaseModel): # Note: We're not using the typical url path param here, but instead using a query parameter to allow '/' in the id -@router.get("/model", response_model=Optional[ModelAccessResponse]) -async def get_model_by_id( - id: str, user=Depends(get_verified_user), db: Session = Depends(get_session) -): +@router.get('/model', response_model=Optional[ModelAccessResponse]) +async def get_model_by_id(id: str, user=Depends(get_verified_user), db: Session = Depends(get_session)): model = Models.get_model_by_id(id, db=db) if model: if ( - (user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL) + (user.role == 'admin' and BYPASS_ADMIN_ACCESS_CONTROL) or model.user_id == user.id or AccessGrants.has_access( user_id=user.id, - resource_type="model", + resource_type='model', resource_id=model.id, - permission="read", + permission='read', db=db, ) ): return ModelAccessResponse( **model.model_dump(), write_access=( - (user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL) + (user.role == 'admin' and BYPASS_ADMIN_ACCESS_CONTROL) or user.id == model.user_id or AccessGrants.has_access( user_id=user.id, - resource_type="model", + resource_type='model', resource_id=model.id, - permission="write", + permission='write', db=db, ) ), @@ -385,7 +369,7 @@ async def get_model_by_id( ########################### -@router.get("/model/profile/image") +@router.get('/model/profile/image') def get_model_profile_image(id: str, user=Depends(get_verified_user)): model = Models.get_model_by_id(id) @@ -393,21 +377,21 @@ def get_model_profile_image(id: str, user=Depends(get_verified_user)): etag = f'"{model.updated_at}"' if model.updated_at else None if model.meta.profile_image_url: - if model.meta.profile_image_url.startswith("http"): + if model.meta.profile_image_url.startswith('http'): return Response( status_code=status.HTTP_302_FOUND, - headers={"Location": model.meta.profile_image_url}, + headers={'Location': model.meta.profile_image_url}, ) - elif model.meta.profile_image_url.startswith("data:image"): + elif model.meta.profile_image_url.startswith('data:image'): try: - header, base64_data = model.meta.profile_image_url.split(",", 1) + header, base64_data = model.meta.profile_image_url.split(',', 1) image_data = base64.b64decode(base64_data) image_buffer = io.BytesIO(image_data) - media_type = header.split(";")[0].lstrip("data:") + media_type = header.split(';')[0].lstrip('data:') - headers = {"Content-Disposition": "inline"} + headers = {'Content-Disposition': 'inline'} if etag: - headers["ETag"] = etag + headers['ETag'] = etag return StreamingResponse( image_buffer, @@ -417,9 +401,9 @@ def get_model_profile_image(id: str, user=Depends(get_verified_user)): except Exception as e: pass - return FileResponse(f"{STATIC_DIR}/favicon.png") + return FileResponse(f'{STATIC_DIR}/favicon.png') else: - return FileResponse(f"{STATIC_DIR}/favicon.png") + return FileResponse(f'{STATIC_DIR}/favicon.png') ############################ @@ -427,20 +411,18 @@ def get_model_profile_image(id: str, user=Depends(get_verified_user)): ############################ -@router.post("/model/toggle", response_model=Optional[ModelResponse]) -async def toggle_model_by_id( - id: str, user=Depends(get_verified_user), db: Session = Depends(get_session) -): +@router.post('/model/toggle', response_model=Optional[ModelResponse]) +async def toggle_model_by_id(id: str, user=Depends(get_verified_user), db: Session = Depends(get_session)): model = Models.get_model_by_id(id, db=db) if model: if ( - user.role == "admin" + user.role == 'admin' or model.user_id == user.id or AccessGrants.has_access( user_id=user.id, - resource_type="model", + resource_type='model', resource_id=model.id, - permission="write", + permission='write', db=db, ) ): @@ -451,7 +433,7 @@ async def toggle_model_by_id( else: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, - detail=ERROR_MESSAGES.DEFAULT("Error updating function"), + detail=ERROR_MESSAGES.DEFAULT('Error updating function'), ) else: raise HTTPException( @@ -470,7 +452,7 @@ async def toggle_model_by_id( ############################ -@router.post("/model/update", response_model=Optional[ModelModel]) +@router.post('/model/update', response_model=Optional[ModelModel]) async def update_model_by_id( form_data: ModelForm, user=Depends(get_verified_user), @@ -487,21 +469,19 @@ async def update_model_by_id( model.user_id != user.id and not AccessGrants.has_access( user_id=user.id, - resource_type="model", + resource_type='model', resource_id=model.id, - permission="write", + permission='write', db=db, ) - and user.role != "admin" + and user.role != 'admin' ): raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.ACCESS_PROHIBITED, ) - model = Models.update_model_by_id( - form_data.id, ModelForm(**form_data.model_dump()), db=db - ) + model = Models.update_model_by_id(form_data.id, ModelForm(**form_data.model_dump()), db=db) return model @@ -516,7 +496,7 @@ class ModelAccessGrantsForm(BaseModel): access_grants: list[dict] -@router.post("/model/access/update", response_model=Optional[ModelModel]) +@router.post('/model/access/update', response_model=Optional[ModelModel]) async def update_model_access_by_id( request: Request, form_data: ModelAccessGrantsForm, @@ -528,7 +508,7 @@ async def update_model_access_by_id( # Non-preset models (e.g. direct Ollama/OpenAI models) may not have a DB # entry yet. Create a minimal one so access grants can be stored. if not model: - if user.role != "admin": + if user.role != 'admin': raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.ACCESS_PROHIBITED, @@ -546,19 +526,19 @@ async def update_model_access_by_id( if not model: raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=ERROR_MESSAGES.DEFAULT("Error creating model entry"), + detail=ERROR_MESSAGES.DEFAULT('Error creating model entry'), ) if ( model.user_id != user.id and not AccessGrants.has_access( user_id=user.id, - resource_type="model", + resource_type='model', resource_id=model.id, - permission="write", + permission='write', db=db, ) - and user.role != "admin" + and user.role != 'admin' ): raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, @@ -570,12 +550,10 @@ async def update_model_access_by_id( user.id, user.role, form_data.access_grants, - "sharing.public_models", + 'sharing.public_models', ) - AccessGrants.set_access_grants( - "model", form_data.id, form_data.access_grants, db=db - ) + AccessGrants.set_access_grants('model', form_data.id, form_data.access_grants, db=db) return Models.get_model_by_id(form_data.id, db=db) @@ -585,7 +563,7 @@ async def update_model_access_by_id( ############################ -@router.post("/model/delete", response_model=bool) +@router.post('/model/delete', response_model=bool) async def delete_model_by_id( form_data: ModelIdForm, user=Depends(get_verified_user), @@ -599,13 +577,13 @@ async def delete_model_by_id( ) if ( - user.role != "admin" + user.role != 'admin' and model.user_id != user.id and not AccessGrants.has_access( user_id=user.id, - resource_type="model", + resource_type='model', resource_id=model.id, - permission="write", + permission='write', db=db, ) ): @@ -618,9 +596,7 @@ async def delete_model_by_id( return result -@router.delete("/delete/all", response_model=bool) -async def delete_all_models( - user=Depends(get_admin_user), db: Session = Depends(get_session) -): +@router.delete('/delete/all', response_model=bool) +async def delete_all_models(user=Depends(get_admin_user), db: Session = Depends(get_session)): result = Models.delete_all_models(db=db) return result diff --git a/backend/open_webui/routers/notes.py b/backend/open_webui/routers/notes.py index 8a93b54d77..dd826053c8 100644 --- a/backend/open_webui/routers/notes.py +++ b/backend/open_webui/routers/notes.py @@ -44,8 +44,8 @@ router = APIRouter() def _truncate_note_data(data: Optional[dict], max_length: int = 1000) -> Optional[dict]: if not data: return data - md = (data.get("content") or {}).get("md") or "" - return {"content": {"md": md[:max_length]}} + md = (data.get('content') or {}).get('md') or '' + return {'content': {'md': md[:max_length]}} ############################ @@ -62,15 +62,15 @@ class NoteItemResponse(BaseModel): user: Optional[UserResponse] = None -@router.get("/", response_model=list[NoteItemResponse]) +@router.get('/', response_model=list[NoteItemResponse]) async def get_notes( request: Request, page: Optional[int] = None, user=Depends(get_verified_user), db: Session = Depends(get_session), ): - if user.role != "admin" and not has_permission( - user.id, "features.notes", request.app.state.config.USER_PERMISSIONS, db=db + if user.role != 'admin' and not has_permission( + user.id, 'features.notes', request.app.state.config.USER_PERMISSIONS, db=db ): raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, @@ -83,7 +83,7 @@ async def get_notes( limit = 60 skip = (page - 1) * limit - notes = Notes.get_notes_by_user_id(user.id, "read", skip=skip, limit=limit, db=db) + notes = Notes.get_notes_by_user_id(user.id, 'read', skip=skip, limit=limit, db=db) if not notes: return [] @@ -94,8 +94,8 @@ async def get_notes( NoteUserResponse( **{ **note.model_dump(), - "data": _truncate_note_data(note.data), - "user": UserResponse(**users[note.user_id].model_dump()), + 'data': _truncate_note_data(note.data), + 'user': UserResponse(**users[note.user_id].model_dump()), } ) for note in notes @@ -103,7 +103,7 @@ async def get_notes( ] -@router.get("/search", response_model=NoteListResponse) +@router.get('/search', response_model=NoteListResponse) async def search_notes( request: Request, query: Optional[str] = None, @@ -115,8 +115,8 @@ async def search_notes( user=Depends(get_verified_user), db: Session = Depends(get_session), ): - if user.role != "admin" and not has_permission( - user.id, "features.notes", request.app.state.config.USER_PERMISSIONS, db=db + if user.role != 'admin' and not has_permission( + user.id, 'features.notes', request.app.state.config.USER_PERMISSIONS, db=db ): raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, @@ -131,22 +131,22 @@ async def search_notes( filter = {} if query: - filter["query"] = query + filter['query'] = query if view_option: - filter["view_option"] = view_option + filter['view_option'] = view_option if permission: - filter["permission"] = permission + filter['permission'] = permission if order_by: - filter["order_by"] = order_by + filter['order_by'] = order_by if direction: - filter["direction"] = direction + filter['direction'] = direction - if not user.role == "admin" or not BYPASS_ADMIN_ACCESS_CONTROL: + if not user.role == 'admin' or not BYPASS_ADMIN_ACCESS_CONTROL: groups = Groups.get_groups_by_member_id(user.id, db=db) if groups: - filter["group_ids"] = [group.id for group in groups] + filter['group_ids'] = [group.id for group in groups] - filter["user_id"] = user.id + filter['user_id'] = user.id result = Notes.search_notes(user.id, filter, skip=skip, limit=limit, db=db) for note in result.items: @@ -159,15 +159,15 @@ async def search_notes( ############################ -@router.post("/create", response_model=Optional[NoteModel]) +@router.post('/create', response_model=Optional[NoteModel]) async def create_new_note( request: Request, form_data: NoteForm, user=Depends(get_verified_user), db: Session = Depends(get_session), ): - if user.role != "admin" and not has_permission( - user.id, "features.notes", request.app.state.config.USER_PERMISSIONS, db=db + if user.role != 'admin' and not has_permission( + user.id, 'features.notes', request.app.state.config.USER_PERMISSIONS, db=db ): raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, @@ -179,9 +179,7 @@ async def create_new_note( return note except Exception as e: log.exception(e) - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT() - ) + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT()) ############################ @@ -193,15 +191,15 @@ class NoteResponse(NoteModel): write_access: bool = False -@router.get("/{id}", response_model=Optional[NoteResponse]) +@router.get('/{id}', response_model=Optional[NoteResponse]) async def get_note_by_id( request: Request, id: str, user=Depends(get_verified_user), db: Session = Depends(get_session), ): - if user.role != "admin" and not has_permission( - user.id, "features.notes", request.app.state.config.USER_PERMISSIONS, db=db + if user.role != 'admin' and not has_permission( + user.id, 'features.notes', request.app.state.config.USER_PERMISSIONS, db=db ): raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, @@ -210,34 +208,30 @@ async def get_note_by_id( note = Notes.get_note_by_id(id, db=db) if not note: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND - ) + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND) - if user.role != "admin" and ( + if user.role != 'admin' and ( user.id != note.user_id and ( not AccessGrants.has_access( user_id=user.id, - resource_type="note", + resource_type='note', resource_id=note.id, - permission="read", + permission='read', db=db, ) ) ): - raise HTTPException( - status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT() - ) + raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()) write_access = ( - user.role == "admin" + user.role == 'admin' or (user.id == note.user_id) or AccessGrants.has_access( user_id=user.id, - resource_type="note", + resource_type='note', resource_id=note.id, - permission="write", + permission='write', db=db, ) or has_public_read_access_grant(note.access_grants) @@ -251,7 +245,7 @@ async def get_note_by_id( ############################ -@router.post("/{id}/update", response_model=Optional[NoteModel]) +@router.post('/{id}/update', response_model=Optional[NoteModel]) async def update_note_by_id( request: Request, id: str, @@ -259,8 +253,8 @@ async def update_note_by_id( user=Depends(get_verified_user), db: Session = Depends(get_session), ): - if user.role != "admin" and not has_permission( - user.id, "features.notes", request.app.state.config.USER_PERMISSIONS, db=db + if user.role != 'admin' and not has_permission( + user.id, 'features.notes', request.app.state.config.USER_PERMISSIONS, db=db ): raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, @@ -269,47 +263,41 @@ async def update_note_by_id( note = Notes.get_note_by_id(id, db=db) if not note: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND - ) + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND) - if user.role != "admin" and ( + if user.role != 'admin' and ( user.id != note.user_id and not AccessGrants.has_access( user_id=user.id, - resource_type="note", + resource_type='note', resource_id=note.id, - permission="write", + permission='write', db=db, ) ): - raise HTTPException( - status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT() - ) + raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()) form_data.access_grants = filter_allowed_access_grants( request.app.state.config.USER_PERMISSIONS, user.id, user.role, form_data.access_grants, - "sharing.public_notes", + 'sharing.public_notes', db=db, ) try: note = Notes.update_note_by_id(id, form_data, db=db) await sio.emit( - "note-events", + 'note-events', note.model_dump(), - to=f"note:{note.id}", + to=f'note:{note.id}', ) return note except Exception as e: log.exception(e) - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT() - ) + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT()) ############################ @@ -321,7 +309,7 @@ class NoteAccessGrantsForm(BaseModel): access_grants: list[dict] -@router.post("/{id}/access/update", response_model=Optional[NoteModel]) +@router.post('/{id}/access/update', response_model=Optional[NoteModel]) async def update_note_access_by_id( request: Request, id: str, @@ -329,8 +317,8 @@ async def update_note_access_by_id( user=Depends(get_verified_user), db: Session = Depends(get_session), ): - if user.role != "admin" and not has_permission( - user.id, "features.notes", request.app.state.config.USER_PERMISSIONS, db=db + if user.role != 'admin' and not has_permission( + user.id, 'features.notes', request.app.state.config.USER_PERMISSIONS, db=db ): raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, @@ -339,33 +327,29 @@ async def update_note_access_by_id( note = Notes.get_note_by_id(id, db=db) if not note: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND - ) + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND) - if user.role != "admin" and ( + if user.role != 'admin' and ( user.id != note.user_id and not AccessGrants.has_access( user_id=user.id, - resource_type="note", + resource_type='note', resource_id=note.id, - permission="write", + permission='write', db=db, ) ): - raise HTTPException( - status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT() - ) + raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()) form_data.access_grants = filter_allowed_access_grants( request.app.state.config.USER_PERMISSIONS, user.id, user.role, form_data.access_grants, - "sharing.public_notes", + 'sharing.public_notes', ) - AccessGrants.set_access_grants("note", id, form_data.access_grants, db=db) + AccessGrants.set_access_grants('note', id, form_data.access_grants, db=db) return Notes.get_note_by_id(id, db=db) @@ -375,15 +359,15 @@ async def update_note_access_by_id( ############################ -@router.delete("/{id}/delete", response_model=bool) +@router.delete('/{id}/delete', response_model=bool) async def delete_note_by_id( request: Request, id: str, user=Depends(get_verified_user), db: Session = Depends(get_session), ): - if user.role != "admin" and not has_permission( - user.id, "features.notes", request.app.state.config.USER_PERMISSIONS, db=db + if user.role != 'admin' and not has_permission( + user.id, 'features.notes', request.app.state.config.USER_PERMISSIONS, db=db ): raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, @@ -392,29 +376,23 @@ async def delete_note_by_id( note = Notes.get_note_by_id(id, db=db) if not note: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND - ) + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND) - if user.role != "admin" and ( + if user.role != 'admin' and ( user.id != note.user_id and not AccessGrants.has_access( user_id=user.id, - resource_type="note", + resource_type='note', resource_id=note.id, - permission="write", + permission='write', db=db, ) ): - raise HTTPException( - status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT() - ) + raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()) try: note = Notes.delete_note_by_id(id, db=db) return True except Exception as e: log.exception(e) - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT() - ) + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT()) diff --git a/backend/open_webui/routers/ollama.py b/backend/open_webui/routers/ollama.py index f72cbff5aa..25a18e3e17 100644 --- a/backend/open_webui/routers/ollama.py +++ b/backend/open_webui/routers/ollama.py @@ -86,8 +86,8 @@ async def send_get_request(url, key=None, user: UserModel = None): try: async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session: headers = { - "Content-Type": "application/json", - **({"Authorization": f"Bearer {key}"} if key else {}), + 'Content-Type': 'application/json', + **({'Authorization': f'Bearer {key}'} if key else {}), } if ENABLE_FORWARD_USER_INFO_HEADERS and user: @@ -101,7 +101,7 @@ async def send_get_request(url, key=None, user: UserModel = None): return await response.json() except Exception as e: # Handle connection error here - log.error(f"Connection error: {e}") + log.error(f'Connection error: {e}') return None @@ -114,23 +114,20 @@ async def send_post_request( user: UserModel = None, metadata: Optional[dict] = None, ): - r = None streaming = False try: - session = aiohttp.ClientSession( - trust_env=True, timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT) - ) + session = aiohttp.ClientSession(trust_env=True, timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT)) headers = { - "Content-Type": "application/json", - **({"Authorization": f"Bearer {key}"} if key else {}), + 'Content-Type': 'application/json', + **({'Authorization': f'Bearer {key}'} if key else {}), } if ENABLE_FORWARD_USER_INFO_HEADERS and user: headers = include_user_info_headers(headers, user) - if metadata and metadata.get("chat_id"): - headers[FORWARD_SESSION_INFO_HEADER_CHAT_ID] = metadata.get("chat_id") + if metadata and metadata.get('chat_id'): + headers[FORWARD_SESSION_INFO_HEADER_CHAT_ID] = metadata.get('chat_id') r = await session.post( url, @@ -143,15 +140,15 @@ async def send_post_request( try: res = await r.json() await cleanup_response(r, session) - if "error" in res: - raise HTTPException(status_code=r.status, detail=res["error"]) + if 'error' in res: + raise HTTPException(status_code=r.status, detail=res['error']) except HTTPException as e: raise e # Re-raise HTTPException to be handled by FastAPI except Exception as e: - log.error(f"Failed to parse error response: {e}") + log.error(f'Failed to parse error response: {e}') raise HTTPException( status_code=r.status, - detail=f"Open WebUI: Server Connection Error", + detail=f'Open WebUI: Server Connection Error', ) r.raise_for_status() # Raises an error for bad responses (4xx, 5xx) @@ -159,7 +156,7 @@ async def send_post_request( response_headers = dict(r.headers) if content_type: - response_headers["Content-Type"] = content_type + response_headers['Content-Type'] = content_type streaming = True return StreamingResponse( @@ -174,11 +171,11 @@ async def send_post_request( except HTTPException as e: raise e # Re-raise HTTPException to be handled by FastAPI except Exception as e: - detail = f"Ollama: {e}" + detail = f'Ollama: {e}' raise HTTPException( status_code=r.status if r else 500, - detail=detail if e else "Open WebUI: Server Connection Error", + detail=detail if e else 'Open WebUI: Server Connection Error', ) finally: if not streaming: @@ -187,10 +184,8 @@ async def send_post_request( def get_api_key(idx, url, configs): parsed_url = urlparse(url) - base_url = f"{parsed_url.scheme}://{parsed_url.netloc}" - return configs.get(str(idx), configs.get(base_url, {})).get( - "key", None - ) # Legacy support + base_url = f'{parsed_url.scheme}://{parsed_url.netloc}' + return configs.get(str(idx), configs.get(base_url, {})).get('key', None) # Legacy support ########################################## @@ -202,10 +197,10 @@ def get_api_key(idx, url, configs): router = APIRouter() -@router.head("/") -@router.get("/") +@router.head('/') +@router.get('/') async def get_status(): - return {"status": True} + return {'status': True} class ConnectionVerificationForm(BaseModel): @@ -213,10 +208,8 @@ class ConnectionVerificationForm(BaseModel): key: Optional[str] = None -@router.post("/verify") -async def verify_connection( - form_data: ConnectionVerificationForm, user=Depends(get_admin_user) -): +@router.post('/verify') +async def verify_connection(form_data: ConnectionVerificationForm, user=Depends(get_admin_user)): url = form_data.url key = form_data.key @@ -226,44 +219,42 @@ async def verify_connection( ) as session: try: headers = { - **({"Authorization": f"Bearer {key}"} if key else {}), + **({'Authorization': f'Bearer {key}'} if key else {}), } if ENABLE_FORWARD_USER_INFO_HEADERS and user: headers = include_user_info_headers(headers, user) async with session.get( - f"{url}/api/version", + f'{url}/api/version', headers=headers, ssl=AIOHTTP_CLIENT_SESSION_SSL, ) as r: if r.status != 200: - detail = f"HTTP Error: {r.status}" + detail = f'HTTP Error: {r.status}' res = await r.json() - if "error" in res: - detail = f"External Error: {res['error']}" + if 'error' in res: + detail = f'External Error: {res["error"]}' raise Exception(detail) data = await r.json() return data except aiohttp.ClientError as e: - log.exception(f"Client error: {str(e)}") - raise HTTPException( - status_code=500, detail="Open WebUI: Server Connection Error" - ) + log.exception(f'Client error: {str(e)}') + raise HTTPException(status_code=500, detail='Open WebUI: Server Connection Error') except Exception as e: - log.exception(f"Unexpected error: {e}") - error_detail = f"Unexpected error: {str(e)}" + log.exception(f'Unexpected error: {e}') + error_detail = f'Unexpected error: {str(e)}' raise HTTPException(status_code=500, detail=error_detail) -@router.get("/config") +@router.get('/config') async def get_config(request: Request, user=Depends(get_admin_user)): return { - "ENABLE_OLLAMA_API": request.app.state.config.ENABLE_OLLAMA_API, - "OLLAMA_BASE_URLS": request.app.state.config.OLLAMA_BASE_URLS, - "OLLAMA_API_CONFIGS": request.app.state.config.OLLAMA_API_CONFIGS, + 'ENABLE_OLLAMA_API': request.app.state.config.ENABLE_OLLAMA_API, + 'OLLAMA_BASE_URLS': request.app.state.config.OLLAMA_BASE_URLS, + 'OLLAMA_API_CONFIGS': request.app.state.config.OLLAMA_API_CONFIGS, } @@ -273,10 +264,8 @@ class OllamaConfigForm(BaseModel): OLLAMA_API_CONFIGS: dict -@router.post("/config/update") -async def update_config( - request: Request, form_data: OllamaConfigForm, user=Depends(get_admin_user) -): +@router.post('/config/update') +async def update_config(request: Request, form_data: OllamaConfigForm, user=Depends(get_admin_user)): request.app.state.config.ENABLE_OLLAMA_API = form_data.ENABLE_OLLAMA_API request.app.state.config.OLLAMA_BASE_URLS = form_data.OLLAMA_BASE_URLS @@ -285,15 +274,13 @@ async def update_config( # Remove the API configs that are not in the API URLS keys = list(map(str, range(len(request.app.state.config.OLLAMA_BASE_URLS)))) request.app.state.config.OLLAMA_API_CONFIGS = { - key: value - for key, value in request.app.state.config.OLLAMA_API_CONFIGS.items() - if key in keys + key: value for key, value in request.app.state.config.OLLAMA_API_CONFIGS.items() if key in keys } return { - "ENABLE_OLLAMA_API": request.app.state.config.ENABLE_OLLAMA_API, - "OLLAMA_BASE_URLS": request.app.state.config.OLLAMA_BASE_URLS, - "OLLAMA_API_CONFIGS": request.app.state.config.OLLAMA_API_CONFIGS, + 'ENABLE_OLLAMA_API': request.app.state.config.ENABLE_OLLAMA_API, + 'OLLAMA_BASE_URLS': request.app.state.config.OLLAMA_BASE_URLS, + 'OLLAMA_API_CONFIGS': request.app.state.config.OLLAMA_API_CONFIGS, } @@ -303,45 +290,41 @@ def merge_ollama_models_lists(model_lists): for idx, model_list in enumerate(model_lists): if model_list is not None: for model in model_list: - id = model.get("model") + id = model.get('model') if id is not None: if id not in merged_models: - model["urls"] = [idx] + model['urls'] = [idx] merged_models[id] = model else: - merged_models[id]["urls"].append(idx) + merged_models[id]['urls'].append(idx) return list(merged_models.values()) @cached( ttl=MODELS_CACHE_TTL, - key=lambda _, user: f"ollama_all_models_{user.id}" if user else "ollama_all_models", + key=lambda _, user: f'ollama_all_models_{user.id}' if user else 'ollama_all_models', ) async def get_all_models(request: Request, user: UserModel = None): - log.info("get_all_models()") + log.info('get_all_models()') if request.app.state.config.ENABLE_OLLAMA_API: request_tasks = [] for idx, url in enumerate(request.app.state.config.OLLAMA_BASE_URLS): if (str(idx) not in request.app.state.config.OLLAMA_API_CONFIGS) and ( url not in request.app.state.config.OLLAMA_API_CONFIGS # Legacy support ): - request_tasks.append(send_get_request(f"{url}/api/tags", user=user)) + request_tasks.append(send_get_request(f'{url}/api/tags', user=user)) else: api_config = request.app.state.config.OLLAMA_API_CONFIGS.get( str(idx), - request.app.state.config.OLLAMA_API_CONFIGS.get( - url, {} - ), # Legacy support + request.app.state.config.OLLAMA_API_CONFIGS.get(url, {}), # Legacy support ) - enable = api_config.get("enable", True) - key = api_config.get("key", None) + enable = api_config.get('enable', True) + key = api_config.get('key', None) if enable: - request_tasks.append( - send_get_request(f"{url}/api/tags", key, user=user) - ) + request_tasks.append(send_get_request(f'{url}/api/tags', key, user=user)) else: request_tasks.append(asyncio.ensure_future(asyncio.sleep(0, None))) @@ -352,39 +335,37 @@ async def get_all_models(request: Request, user: UserModel = None): url = request.app.state.config.OLLAMA_BASE_URLS[idx] api_config = request.app.state.config.OLLAMA_API_CONFIGS.get( str(idx), - request.app.state.config.OLLAMA_API_CONFIGS.get( - url, {} - ), # Legacy support + request.app.state.config.OLLAMA_API_CONFIGS.get(url, {}), # Legacy support ) - connection_type = api_config.get("connection_type", "local") + connection_type = api_config.get('connection_type', 'local') - prefix_id = api_config.get("prefix_id", None) - tags = api_config.get("tags", []) - model_ids = api_config.get("model_ids", []) + prefix_id = api_config.get('prefix_id', None) + tags = api_config.get('tags', []) + model_ids = api_config.get('model_ids', []) - if len(model_ids) != 0 and "models" in response: - response["models"] = list( + if len(model_ids) != 0 and 'models' in response: + response['models'] = list( filter( - lambda model: model["model"] in model_ids, - response["models"], + lambda model: model['model'] in model_ids, + response['models'], ) ) - for model in response.get("models", []): + for model in response.get('models', []): if prefix_id: - model["model"] = f"{prefix_id}.{model['model']}" + model['model'] = f'{prefix_id}.{model["model"]}' if tags: - model["tags"] = tags + model['tags'] = tags if connection_type: - model["connection_type"] = connection_type + model['connection_type'] = connection_type models = { - "models": merge_ollama_models_lists( + 'models': merge_ollama_models_lists( map( - lambda response: response.get("models", []) if response else None, + lambda response: response.get('models', []) if response else None, responses, ) ) @@ -392,66 +373,53 @@ async def get_all_models(request: Request, user: UserModel = None): try: loaded_models = await get_ollama_loaded_models(request, user=user) - expires_map = { - m["model"]: m["expires_at"] - for m in loaded_models["models"] - if "expires_at" in m - } + expires_map = {m['model']: m['expires_at'] for m in loaded_models['models'] if 'expires_at' in m} - for m in models["models"]: - if m["model"] in expires_map: + for m in models['models']: + if m['model'] in expires_map: # Parse ISO8601 datetime with offset, get unix timestamp as int - dt = datetime.fromisoformat(expires_map[m["model"]]) - m["expires_at"] = int(dt.timestamp()) + dt = datetime.fromisoformat(expires_map[m['model']]) + m['expires_at'] = int(dt.timestamp()) except Exception as e: - log.debug(f"Failed to get loaded models: {e}") + log.debug(f'Failed to get loaded models: {e}') else: - models = {"models": []} + models = {'models': []} - request.app.state.OLLAMA_MODELS = { - model["model"]: model for model in models["models"] - } + request.app.state.OLLAMA_MODELS = {model['model']: model for model in models['models']} return models async def get_filtered_models(models, user, db=None): # Filter models based on user access control - model_ids = [model["model"] for model in models.get("models", [])] - model_infos = { - model_info.id: model_info - for model_info in Models.get_models_by_ids(model_ids, db=db) - } - user_group_ids = { - group.id for group in Groups.get_groups_by_member_id(user.id, db=db) - } + model_ids = [model['model'] for model in models.get('models', [])] + model_infos = {model_info.id: model_info for model_info in Models.get_models_by_ids(model_ids, db=db)} + user_group_ids = {group.id for group in Groups.get_groups_by_member_id(user.id, db=db)} # Batch-fetch accessible resource IDs in a single query instead of N has_access calls accessible_model_ids = AccessGrants.get_accessible_resource_ids( user_id=user.id, - resource_type="model", + resource_type='model', resource_ids=list(model_infos.keys()), - permission="read", + permission='read', user_group_ids=user_group_ids, db=db, ) filtered_models = [] - for model in models.get("models", []): - model_info = model_infos.get(model["model"]) + for model in models.get('models', []): + model_info = model_infos.get(model['model']) if model_info: if user.id == model_info.user_id or model_info.id in accessible_model_ids: filtered_models.append(model) return filtered_models -@router.get("/api/tags") -@router.get("/api/tags/{url_idx}") -async def get_ollama_tags( - request: Request, url_idx: Optional[int] = None, user=Depends(get_verified_user) -): +@router.get('/api/tags') +@router.get('/api/tags/{url_idx}') +async def get_ollama_tags(request: Request, url_idx: Optional[int] = None, user=Depends(get_verified_user)): if not request.app.state.config.ENABLE_OLLAMA_API: - raise HTTPException(status_code=503, detail="Ollama API is disabled") + raise HTTPException(status_code=503, detail='Ollama API is disabled') models = [] @@ -464,15 +432,15 @@ async def get_ollama_tags( r = None try: headers = { - **({"Authorization": f"Bearer {key}"} if key else {}), + **({'Authorization': f'Bearer {key}'} if key else {}), } if ENABLE_FORWARD_USER_INFO_HEADERS and user: headers = include_user_info_headers(headers, user) r = requests.request( - method="GET", - url=f"{url}/api/tags", + method='GET', + url=f'{url}/api/tags', headers=headers, ) r.raise_for_status() @@ -485,23 +453,23 @@ async def get_ollama_tags( if r is not None: try: res = r.json() - if "error" in res: - detail = f"Ollama: {res['error']}" + if 'error' in res: + detail = f'Ollama: {res["error"]}' except Exception: - detail = f"Ollama: {e}" + detail = f'Ollama: {e}' raise HTTPException( status_code=r.status_code if r else 500, - detail=detail if detail else "Open WebUI: Server Connection Error", + detail=detail if detail else 'Open WebUI: Server Connection Error', ) - if user.role == "user" and not BYPASS_MODEL_ACCESS_CONTROL: - models["models"] = await get_filtered_models(models, user) + if user.role == 'user' and not BYPASS_MODEL_ACCESS_CONTROL: + models['models'] = await get_filtered_models(models, user) return models -@router.get("/api/ps") +@router.get('/api/ps') async def get_ollama_loaded_models(request: Request, user=Depends(get_admin_user)): """ List models that are currently loaded into Ollama memory, and which node they are loaded on. @@ -512,22 +480,18 @@ async def get_ollama_loaded_models(request: Request, user=Depends(get_admin_user if (str(idx) not in request.app.state.config.OLLAMA_API_CONFIGS) and ( url not in request.app.state.config.OLLAMA_API_CONFIGS # Legacy support ): - request_tasks.append(send_get_request(f"{url}/api/ps", user=user)) + request_tasks.append(send_get_request(f'{url}/api/ps', user=user)) else: api_config = request.app.state.config.OLLAMA_API_CONFIGS.get( str(idx), - request.app.state.config.OLLAMA_API_CONFIGS.get( - url, {} - ), # Legacy support + request.app.state.config.OLLAMA_API_CONFIGS.get(url, {}), # Legacy support ) - enable = api_config.get("enable", True) - key = api_config.get("key", None) + enable = api_config.get('enable', True) + key = api_config.get('key', None) if enable: - request_tasks.append( - send_get_request(f"{url}/api/ps", key, user=user) - ) + request_tasks.append(send_get_request(f'{url}/api/ps', key, user=user)) else: request_tasks.append(asyncio.ensure_future(asyncio.sleep(0, None))) @@ -538,33 +502,31 @@ async def get_ollama_loaded_models(request: Request, user=Depends(get_admin_user url = request.app.state.config.OLLAMA_BASE_URLS[idx] api_config = request.app.state.config.OLLAMA_API_CONFIGS.get( str(idx), - request.app.state.config.OLLAMA_API_CONFIGS.get( - url, {} - ), # Legacy support + request.app.state.config.OLLAMA_API_CONFIGS.get(url, {}), # Legacy support ) - prefix_id = api_config.get("prefix_id", None) + prefix_id = api_config.get('prefix_id', None) - for model in response.get("models", []): + for model in response.get('models', []): if prefix_id: - model["model"] = f"{prefix_id}.{model['model']}" + model['model'] = f'{prefix_id}.{model["model"]}' models = { - "models": merge_ollama_models_lists( + 'models': merge_ollama_models_lists( map( - lambda response: response.get("models", []) if response else None, + lambda response: response.get('models', []) if response else None, responses, ) ) } else: - models = {"models": []} + models = {'models': []} return models -@router.get("/api/version") -@router.get("/api/version/{url_idx}") +@router.get('/api/version') +@router.get('/api/version/{url_idx}') async def get_ollama_versions(request: Request, url_idx: Optional[int] = None): if request.app.state.config.ENABLE_OLLAMA_API: if url_idx is None: @@ -574,18 +536,16 @@ async def get_ollama_versions(request: Request, url_idx: Optional[int] = None): for idx, url in enumerate(request.app.state.config.OLLAMA_BASE_URLS): api_config = request.app.state.config.OLLAMA_API_CONFIGS.get( str(idx), - request.app.state.config.OLLAMA_API_CONFIGS.get( - url, {} - ), # Legacy support + request.app.state.config.OLLAMA_API_CONFIGS.get(url, {}), # Legacy support ) - enable = api_config.get("enable", True) - key = api_config.get("key", None) + enable = api_config.get('enable', True) + key = api_config.get('key', None) if enable: request_tasks.append( send_get_request( - f"{url}/api/version", + f'{url}/api/version', key, ) ) @@ -596,12 +556,10 @@ async def get_ollama_versions(request: Request, url_idx: Optional[int] = None): if len(responses) > 0: lowest_version = min( responses, - key=lambda x: tuple( - map(int, re.sub(r"^v|-.*", "", x["version"]).split(".")) - ), + key=lambda x: tuple(map(int, re.sub(r'^v|-.*', '', x['version']).split('.'))), ) - return {"version": lowest_version["version"]} + return {'version': lowest_version['version']} else: raise HTTPException( status_code=500, @@ -612,7 +570,7 @@ async def get_ollama_versions(request: Request, url_idx: Optional[int] = None): r = None try: - r = requests.request(method="GET", url=f"{url}/api/version") + r = requests.request(method='GET', url=f'{url}/api/version') r.raise_for_status() return r.json() @@ -623,49 +581,45 @@ async def get_ollama_versions(request: Request, url_idx: Optional[int] = None): if r is not None: try: res = r.json() - if "error" in res: - detail = f"Ollama: {res['error']}" + if 'error' in res: + detail = f'Ollama: {res["error"]}' except Exception: - detail = f"Ollama: {e}" + detail = f'Ollama: {e}' raise HTTPException( status_code=r.status_code if r else 500, - detail=detail if detail else "Open WebUI: Server Connection Error", + detail=detail if detail else 'Open WebUI: Server Connection Error', ) else: - return {"version": False} + return {'version': False} class ModelNameForm(BaseModel): model: Optional[str] = None model_config = ConfigDict( - extra="allow", + extra='allow', ) -@router.post("/api/unload") +@router.post('/api/unload') async def unload_model( request: Request, form_data: ModelNameForm, user=Depends(get_admin_user), ): form_data = form_data.model_dump(exclude_none=True) - model_name = form_data.get("model", form_data.get("name")) + model_name = form_data.get('model', form_data.get('name')) if not model_name: - raise HTTPException( - status_code=400, detail="Missing name of the model to unload." - ) + raise HTTPException(status_code=400, detail='Missing name of the model to unload.') # Refresh/load models if needed, get mapping from name to URLs await get_all_models(request, user=user) models = request.app.state.OLLAMA_MODELS if model_name not in models: - raise HTTPException( - status_code=400, detail=ERROR_MESSAGES.MODEL_NOT_FOUND(model_name) - ) - url_indices = models[model_name]["urls"] + raise HTTPException(status_code=400, detail=ERROR_MESSAGES.MODEL_NOT_FOUND(model_name)) + url_indices = models[model_name]['urls'] # Send unload to ALL url_indices results = [] @@ -677,36 +631,36 @@ async def unload_model( ) key = get_api_key(idx, url, request.app.state.config.OLLAMA_API_CONFIGS) - prefix_id = api_config.get("prefix_id", None) - if prefix_id and model_name.startswith(f"{prefix_id}."): - model_name = model_name[len(f"{prefix_id}.") :] + prefix_id = api_config.get('prefix_id', None) + if prefix_id and model_name.startswith(f'{prefix_id}.'): + model_name = model_name[len(f'{prefix_id}.') :] - payload = {"model": model_name, "keep_alive": 0, "prompt": ""} + payload = {'model': model_name, 'keep_alive': 0, 'prompt': ''} try: res = await send_post_request( - url=f"{url}/api/generate", + url=f'{url}/api/generate', payload=json.dumps(payload), stream=False, key=key, user=user, ) - results.append({"url_idx": idx, "success": True, "response": res}) + results.append({'url_idx': idx, 'success': True, 'response': res}) except Exception as e: - log.exception(f"Failed to unload model on node {idx}: {e}") - errors.append({"url_idx": idx, "success": False, "error": str(e)}) + log.exception(f'Failed to unload model on node {idx}: {e}') + errors.append({'url_idx': idx, 'success': False, 'error': str(e)}) if len(errors) > 0: raise HTTPException( status_code=500, - detail=f"Failed to unload model on {len(errors)} nodes: {errors}", + detail=f'Failed to unload model on {len(errors)} nodes: {errors}', ) - return {"status": True} + return {'status': True} -@router.post("/api/pull") -@router.post("/api/pull/{url_idx}") +@router.post('/api/pull') +@router.post('/api/pull/{url_idx}') async def pull_model( request: Request, form_data: ModelNameForm, @@ -714,19 +668,19 @@ async def pull_model( user=Depends(get_admin_user), ): if not request.app.state.config.ENABLE_OLLAMA_API: - raise HTTPException(status_code=503, detail="Ollama API is disabled") + raise HTTPException(status_code=503, detail='Ollama API is disabled') form_data = form_data.model_dump(exclude_none=True) - form_data["model"] = form_data.get("model", form_data.get("name")) + form_data['model'] = form_data.get('model', form_data.get('name')) url = request.app.state.config.OLLAMA_BASE_URLS[url_idx] - log.info(f"url: {url}") + log.info(f'url: {url}') # Admin should be able to pull models from any source - payload = {**form_data, "insecure": True} + payload = {**form_data, 'insecure': True} return await send_post_request( - url=f"{url}/api/pull", + url=f'{url}/api/pull', payload=json.dumps(payload), key=get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS), user=user, @@ -739,8 +693,8 @@ class PushModelForm(BaseModel): stream: Optional[bool] = None -@router.delete("/api/push") -@router.delete("/api/push/{url_idx}") +@router.delete('/api/push') +@router.delete('/api/push/{url_idx}') async def push_model( request: Request, form_data: PushModelForm, @@ -748,14 +702,14 @@ async def push_model( user=Depends(get_admin_user), ): if not request.app.state.config.ENABLE_OLLAMA_API: - raise HTTPException(status_code=503, detail="Ollama API is disabled") + raise HTTPException(status_code=503, detail='Ollama API is disabled') if url_idx is None: await get_all_models(request, user=user) models = request.app.state.OLLAMA_MODELS if form_data.model in models: - url_idx = models[form_data.model]["urls"][0] + url_idx = models[form_data.model]['urls'][0] else: raise HTTPException( status_code=400, @@ -763,10 +717,10 @@ async def push_model( ) url = request.app.state.config.OLLAMA_BASE_URLS[url_idx] - log.debug(f"url: {url}") + log.debug(f'url: {url}') return await send_post_request( - url=f"{url}/api/push", + url=f'{url}/api/push', payload=form_data.model_dump_json(exclude_none=True).encode(), key=get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS), user=user, @@ -778,11 +732,11 @@ class CreateModelForm(BaseModel): stream: Optional[bool] = None path: Optional[str] = None - model_config = ConfigDict(extra="allow") + model_config = ConfigDict(extra='allow') -@router.post("/api/create") -@router.post("/api/create/{url_idx}") +@router.post('/api/create') +@router.post('/api/create/{url_idx}') async def create_model( request: Request, form_data: CreateModelForm, @@ -790,13 +744,13 @@ async def create_model( user=Depends(get_admin_user), ): if not request.app.state.config.ENABLE_OLLAMA_API: - raise HTTPException(status_code=503, detail="Ollama API is disabled") + raise HTTPException(status_code=503, detail='Ollama API is disabled') - log.debug(f"form_data: {form_data}") + log.debug(f'form_data: {form_data}') url = request.app.state.config.OLLAMA_BASE_URLS[url_idx] return await send_post_request( - url=f"{url}/api/create", + url=f'{url}/api/create', payload=form_data.model_dump_json(exclude_none=True).encode(), key=get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS), user=user, @@ -808,8 +762,8 @@ class CopyModelForm(BaseModel): destination: str -@router.post("/api/copy") -@router.post("/api/copy/{url_idx}") +@router.post('/api/copy') +@router.post('/api/copy/{url_idx}') async def copy_model( request: Request, form_data: CopyModelForm, @@ -817,14 +771,14 @@ async def copy_model( user=Depends(get_admin_user), ): if not request.app.state.config.ENABLE_OLLAMA_API: - raise HTTPException(status_code=503, detail="Ollama API is disabled") + raise HTTPException(status_code=503, detail='Ollama API is disabled') if url_idx is None: await get_all_models(request, user=user) models = request.app.state.OLLAMA_MODELS if form_data.source in models: - url_idx = models[form_data.source]["urls"][0] + url_idx = models[form_data.source]['urls'][0] else: raise HTTPException( status_code=400, @@ -836,22 +790,22 @@ async def copy_model( try: headers = { - "Content-Type": "application/json", - **({"Authorization": f"Bearer {key}"} if key else {}), + 'Content-Type': 'application/json', + **({'Authorization': f'Bearer {key}'} if key else {}), } if ENABLE_FORWARD_USER_INFO_HEADERS and user: headers = include_user_info_headers(headers, user) r = requests.request( - method="POST", - url=f"{url}/api/copy", + method='POST', + url=f'{url}/api/copy', headers=headers, data=form_data.model_dump_json(exclude_none=True).encode(), ) r.raise_for_status() - log.debug(f"r.text: {r.text}") + log.debug(f'r.text: {r.text}') return True except Exception as e: log.exception(e) @@ -860,19 +814,19 @@ async def copy_model( if r is not None: try: res = r.json() - if "error" in res: - detail = f"Ollama: {res['error']}" + if 'error' in res: + detail = f'Ollama: {res["error"]}' except Exception: - detail = f"Ollama: {e}" + detail = f'Ollama: {e}' raise HTTPException( status_code=r.status_code if r else 500, - detail=detail if detail else "Open WebUI: Server Connection Error", + detail=detail if detail else 'Open WebUI: Server Connection Error', ) -@router.delete("/api/delete") -@router.delete("/api/delete/{url_idx}") +@router.delete('/api/delete') +@router.delete('/api/delete/{url_idx}') async def delete_model( request: Request, form_data: ModelNameForm, @@ -880,19 +834,19 @@ async def delete_model( user=Depends(get_admin_user), ): if not request.app.state.config.ENABLE_OLLAMA_API: - raise HTTPException(status_code=503, detail="Ollama API is disabled") + raise HTTPException(status_code=503, detail='Ollama API is disabled') form_data = form_data.model_dump(exclude_none=True) - form_data["model"] = form_data.get("model", form_data.get("name")) + form_data['model'] = form_data.get('model', form_data.get('name')) - model = form_data.get("model") + model = form_data.get('model') if url_idx is None: await get_all_models(request, user=user) models = request.app.state.OLLAMA_MODELS if model in models: - url_idx = models[model]["urls"][0] + url_idx = models[model]['urls'][0] else: raise HTTPException( status_code=400, @@ -905,22 +859,22 @@ async def delete_model( r = None try: headers = { - "Content-Type": "application/json", - **({"Authorization": f"Bearer {key}"} if key else {}), + 'Content-Type': 'application/json', + **({'Authorization': f'Bearer {key}'} if key else {}), } if ENABLE_FORWARD_USER_INFO_HEADERS and user: headers = include_user_info_headers(headers, user) r = requests.request( - method="DELETE", - url=f"{url}/api/delete", + method='DELETE', + url=f'{url}/api/delete', headers=headers, json=form_data, ) r.raise_for_status() - log.debug(f"r.text: {r.text}") + log.debug(f'r.text: {r.text}') return True except Exception as e: log.exception(e) @@ -929,31 +883,29 @@ async def delete_model( if r is not None: try: res = r.json() - if "error" in res: - detail = f"Ollama: {res['error']}" + if 'error' in res: + detail = f'Ollama: {res["error"]}' except Exception: - detail = f"Ollama: {e}" + detail = f'Ollama: {e}' raise HTTPException( status_code=r.status_code if r else 500, - detail=detail if detail else "Open WebUI: Server Connection Error", + detail=detail if detail else 'Open WebUI: Server Connection Error', ) -@router.post("/api/show") -async def show_model_info( - request: Request, form_data: ModelNameForm, user=Depends(get_verified_user) -): +@router.post('/api/show') +async def show_model_info(request: Request, form_data: ModelNameForm, user=Depends(get_verified_user)): if not request.app.state.config.ENABLE_OLLAMA_API: - raise HTTPException(status_code=503, detail="Ollama API is disabled") + raise HTTPException(status_code=503, detail='Ollama API is disabled') form_data = form_data.model_dump(exclude_none=True) - form_data["model"] = form_data.get("model", form_data.get("name")) + form_data['model'] = form_data.get('model', form_data.get('name')) await get_all_models(request, user=user) models = request.app.state.OLLAMA_MODELS - model = form_data.get("model") + model = form_data.get('model') if model not in models: raise HTTPException( @@ -961,23 +913,21 @@ async def show_model_info( detail=ERROR_MESSAGES.MODEL_NOT_FOUND(model), ) - url_idx = random.choice(models[model]["urls"]) + url_idx = random.choice(models[model]['urls']) url = request.app.state.config.OLLAMA_BASE_URLS[url_idx] key = get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS) try: headers = { - "Content-Type": "application/json", - **({"Authorization": f"Bearer {key}"} if key else {}), + 'Content-Type': 'application/json', + **({'Authorization': f'Bearer {key}'} if key else {}), } if ENABLE_FORWARD_USER_INFO_HEADERS and user: headers = include_user_info_headers(headers, user) - r = requests.request( - method="POST", url=f"{url}/api/show", headers=headers, json=form_data - ) + r = requests.request(method='POST', url=f'{url}/api/show', headers=headers, json=form_data) r.raise_for_status() return r.json() @@ -988,14 +938,14 @@ async def show_model_info( if r is not None: try: res = r.json() - if "error" in res: - detail = f"Ollama: {res['error']}" + if 'error' in res: + detail = f'Ollama: {res["error"]}' except Exception: - detail = f"Ollama: {e}" + detail = f'Ollama: {e}' raise HTTPException( status_code=r.status_code if r else 500, - detail=detail if detail else "Open WebUI: Server Connection Error", + detail=detail if detail else 'Open WebUI: Server Connection Error', ) @@ -1007,12 +957,12 @@ class GenerateEmbedForm(BaseModel): keep_alive: Optional[Union[int, str]] = None model_config = ConfigDict( - extra="allow", + extra='allow', ) -@router.post("/api/embed") -@router.post("/api/embed/{url_idx}") +@router.post('/api/embed') +@router.post('/api/embed/{url_idx}') async def embed( request: Request, form_data: GenerateEmbedForm, @@ -1020,9 +970,9 @@ async def embed( user=Depends(get_verified_user), ): if not request.app.state.config.ENABLE_OLLAMA_API: - raise HTTPException(status_code=503, detail="Ollama API is disabled") + raise HTTPException(status_code=503, detail='Ollama API is disabled') - log.info(f"generate_ollama_batch_embeddings {form_data}") + log.info(f'generate_ollama_batch_embeddings {form_data}') if url_idx is None: model = form_data.model @@ -1034,7 +984,7 @@ async def embed( models = request.app.state.OLLAMA_MODELS if model in models: - url_idx = random.choice(models[model]["urls"]) + url_idx = random.choice(models[model]['urls']) else: raise HTTPException( status_code=400, @@ -1048,22 +998,22 @@ async def embed( ) key = get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS) - prefix_id = api_config.get("prefix_id", None) + prefix_id = api_config.get('prefix_id', None) if prefix_id: - form_data.model = form_data.model.replace(f"{prefix_id}.", "") + form_data.model = form_data.model.replace(f'{prefix_id}.', '') try: headers = { - "Content-Type": "application/json", - **({"Authorization": f"Bearer {key}"} if key else {}), + 'Content-Type': 'application/json', + **({'Authorization': f'Bearer {key}'} if key else {}), } if ENABLE_FORWARD_USER_INFO_HEADERS and user: headers = include_user_info_headers(headers, user) r = requests.request( - method="POST", - url=f"{url}/api/embed", + method='POST', + url=f'{url}/api/embed', headers=headers, data=form_data.model_dump_json(exclude_none=True).encode(), ) @@ -1078,14 +1028,14 @@ async def embed( if r is not None: try: res = r.json() - if "error" in res: - detail = f"Ollama: {res['error']}" + if 'error' in res: + detail = f'Ollama: {res["error"]}' except Exception: - detail = f"Ollama: {e}" + detail = f'Ollama: {e}' raise HTTPException( status_code=r.status_code if r else 500, - detail=detail if detail else "Open WebUI: Server Connection Error", + detail=detail if detail else 'Open WebUI: Server Connection Error', ) @@ -1096,8 +1046,8 @@ class GenerateEmbeddingsForm(BaseModel): keep_alive: Optional[Union[int, str]] = None -@router.post("/api/embeddings") -@router.post("/api/embeddings/{url_idx}") +@router.post('/api/embeddings') +@router.post('/api/embeddings/{url_idx}') async def embeddings( request: Request, form_data: GenerateEmbeddingsForm, @@ -1105,9 +1055,9 @@ async def embeddings( user=Depends(get_verified_user), ): if not request.app.state.config.ENABLE_OLLAMA_API: - raise HTTPException(status_code=503, detail="Ollama API is disabled") + raise HTTPException(status_code=503, detail='Ollama API is disabled') - log.info(f"generate_ollama_embeddings {form_data}") + log.info(f'generate_ollama_embeddings {form_data}') if url_idx is None: model = form_data.model @@ -1119,7 +1069,7 @@ async def embeddings( models = request.app.state.OLLAMA_MODELS if model in models: - url_idx = random.choice(models[model]["urls"]) + url_idx = random.choice(models[model]['urls']) else: raise HTTPException( status_code=400, @@ -1133,22 +1083,22 @@ async def embeddings( ) key = get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS) - prefix_id = api_config.get("prefix_id", None) + prefix_id = api_config.get('prefix_id', None) if prefix_id: - form_data.model = form_data.model.replace(f"{prefix_id}.", "") + form_data.model = form_data.model.replace(f'{prefix_id}.', '') try: headers = { - "Content-Type": "application/json", - **({"Authorization": f"Bearer {key}"} if key else {}), + 'Content-Type': 'application/json', + **({'Authorization': f'Bearer {key}'} if key else {}), } if ENABLE_FORWARD_USER_INFO_HEADERS and user: headers = include_user_info_headers(headers, user) r = requests.request( - method="POST", - url=f"{url}/api/embeddings", + method='POST', + url=f'{url}/api/embeddings', headers=headers, data=form_data.model_dump_json(exclude_none=True).encode(), ) @@ -1163,14 +1113,14 @@ async def embeddings( if r is not None: try: res = r.json() - if "error" in res: - detail = f"Ollama: {res['error']}" + if 'error' in res: + detail = f'Ollama: {res["error"]}' except Exception: - detail = f"Ollama: {e}" + detail = f'Ollama: {e}' raise HTTPException( status_code=r.status_code if r else 500, - detail=detail if detail else "Open WebUI: Server Connection Error", + detail=detail if detail else 'Open WebUI: Server Connection Error', ) @@ -1189,8 +1139,8 @@ class GenerateCompletionForm(BaseModel): keep_alive: Optional[Union[int, str]] = None -@router.post("/api/generate") -@router.post("/api/generate/{url_idx}") +@router.post('/api/generate') +@router.post('/api/generate/{url_idx}') async def generate_completion( request: Request, form_data: GenerateCompletionForm, @@ -1198,7 +1148,7 @@ async def generate_completion( user=Depends(get_verified_user), ): if not request.app.state.config.ENABLE_OLLAMA_API: - raise HTTPException(status_code=503, detail="Ollama API is disabled") + raise HTTPException(status_code=503, detail='Ollama API is disabled') if url_idx is None: await get_all_models(request, user=user) @@ -1206,7 +1156,7 @@ async def generate_completion( model = form_data.model if model in models: - url_idx = random.choice(models[model]["urls"]) + url_idx = random.choice(models[model]['urls']) else: raise HTTPException( status_code=400, @@ -1219,12 +1169,12 @@ async def generate_completion( request.app.state.config.OLLAMA_API_CONFIGS.get(url, {}), # Legacy support ) - prefix_id = api_config.get("prefix_id", None) + prefix_id = api_config.get('prefix_id', None) if prefix_id: - form_data.model = form_data.model.replace(f"{prefix_id}.", "") + form_data.model = form_data.model.replace(f'{prefix_id}.', '') return await send_post_request( - url=f"{url}/api/generate", + url=f'{url}/api/generate', payload=form_data.model_dump_json(exclude_none=True).encode(), key=get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS), user=user, @@ -1237,16 +1187,12 @@ class ChatMessage(BaseModel): tool_calls: Optional[list[dict]] = None images: Optional[list[str]] = None - @validator("content", pre=True) + @validator('content', pre=True) @classmethod def check_at_least_one_field(cls, field_value, values, **kwargs): # Raise an error if both 'content' and 'tool_calls' are None - if field_value is None and ( - "tool_calls" not in values or values["tool_calls"] is None - ): - raise ValueError( - "At least one of 'content' or 'tool_calls' must be provided" - ) + if field_value is None and ('tool_calls' not in values or values['tool_calls'] is None): + raise ValueError("At least one of 'content' or 'tool_calls' must be provided") return field_value @@ -1261,7 +1207,7 @@ class GenerateChatCompletionForm(BaseModel): keep_alive: Optional[Union[int, str]] = None tools: Optional[list[dict]] = None model_config = ConfigDict( - extra="allow", + extra='allow', ) @@ -1273,13 +1219,13 @@ async def get_ollama_url(request: Request, model: str, url_idx: Optional[int] = status_code=400, detail=ERROR_MESSAGES.MODEL_NOT_FOUND(model), ) - url_idx = random.choice(models[model].get("urls", [])) + url_idx = random.choice(models[model].get('urls', [])) url = request.app.state.config.OLLAMA_BASE_URLS[url_idx] return url, url_idx -@router.post("/api/chat") -@router.post("/api/chat/{url_idx}") +@router.post('/api/chat') +@router.post('/api/chat/{url_idx}') async def generate_chat_completion( request: Request, form_data: dict, @@ -1288,7 +1234,7 @@ async def generate_chat_completion( bypass_system_prompt: bool = False, ): if not request.app.state.config.ENABLE_OLLAMA_API: - raise HTTPException(status_code=503, detail="Ollama API is disabled") + raise HTTPException(status_code=503, detail='Ollama API is disabled') # NOTE: We intentionally do NOT use Depends(get_session) here. # Database operations (get_model_by_id, AccessGrants.has_access) manage their own short-lived sessions. @@ -1298,11 +1244,11 @@ async def generate_chat_completion( # bypass_filter is read from request.state to prevent external clients from # setting it via query parameter (CVE fix). Only internal server-side callers # (e.g. utils/chat.py) should set request.state.bypass_filter = True. - bypass_filter = getattr(request.state, "bypass_filter", False) + bypass_filter = getattr(request.state, 'bypass_filter', False) if BYPASS_MODEL_ACCESS_CONTROL: bypass_filter = True - metadata = form_data.pop("metadata", None) + metadata = form_data.pop('metadata', None) try: form_data = GenerateChatCompletionForm(**form_data) except Exception as e: @@ -1315,72 +1261,68 @@ async def generate_chat_completion( if isinstance(form_data, BaseModel): payload = {**form_data.model_dump(exclude_none=True)} - if "metadata" in payload: - del payload["metadata"] + if 'metadata' in payload: + del payload['metadata'] - model_id = payload["model"] + model_id = payload['model'] model_info = Models.get_model_by_id(model_id) if model_info: if model_info.base_model_id: base_model_id = ( - request.base_model_id - if hasattr(request, "base_model_id") - else model_info.base_model_id + request.base_model_id if hasattr(request, 'base_model_id') else model_info.base_model_id ) # Use request's base_model_id if available - payload["model"] = base_model_id + payload['model'] = base_model_id params = model_info.params.model_dump() if params: - system = params.pop("system", None) + system = params.pop('system', None) payload = apply_model_params_to_body_ollama(params, payload) if not bypass_system_prompt: payload = apply_system_prompt_to_body(system, payload, metadata, user) # Check if user has access to the model - if not bypass_filter and user.role == "user": - user_group_ids = { - group.id for group in Groups.get_groups_by_member_id(user.id) - } + if not bypass_filter and user.role == 'user': + user_group_ids = {group.id for group in Groups.get_groups_by_member_id(user.id)} if not ( user.id == model_info.user_id or AccessGrants.has_access( user_id=user.id, - resource_type="model", + resource_type='model', resource_id=model_info.id, - permission="read", + permission='read', user_group_ids=user_group_ids, ) ): raise HTTPException( status_code=403, - detail="Model not found", + detail='Model not found', ) elif not bypass_filter: - if user.role != "admin": + if user.role != 'admin': raise HTTPException( status_code=403, - detail="Model not found", + detail='Model not found', ) - url, url_idx = await get_ollama_url(request, payload["model"], url_idx) + url, url_idx = await get_ollama_url(request, payload['model'], url_idx) api_config = request.app.state.config.OLLAMA_API_CONFIGS.get( str(url_idx), request.app.state.config.OLLAMA_API_CONFIGS.get(url, {}), # Legacy support ) - prefix_id = api_config.get("prefix_id", None) + prefix_id = api_config.get('prefix_id', None) if prefix_id: - payload["model"] = payload["model"].replace(f"{prefix_id}.", "") + payload['model'] = payload['model'].replace(f'{prefix_id}.', '') return await send_post_request( - url=f"{url}/api/chat", + url=f'{url}/api/chat', payload=json.dumps(payload), stream=form_data.stream, key=get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS), - content_type="application/x-ndjson", + content_type='application/x-ndjson', user=user, metadata=metadata, ) @@ -1389,32 +1331,32 @@ async def generate_chat_completion( # TODO: we should update this part once Ollama supports other types class OpenAIChatMessageContent(BaseModel): type: str - model_config = ConfigDict(extra="allow") + model_config = ConfigDict(extra='allow') class OpenAIChatMessage(BaseModel): role: str content: Union[Optional[str], list[OpenAIChatMessageContent]] - model_config = ConfigDict(extra="allow") + model_config = ConfigDict(extra='allow') class OpenAIChatCompletionForm(BaseModel): model: str messages: list[OpenAIChatMessage] - model_config = ConfigDict(extra="allow") + model_config = ConfigDict(extra='allow') class OpenAICompletionForm(BaseModel): model: str prompt: str - model_config = ConfigDict(extra="allow") + model_config = ConfigDict(extra='allow') -@router.post("/v1/completions") -@router.post("/v1/completions/{url_idx}") +@router.post('/v1/completions') +@router.post('/v1/completions/{url_idx}') async def generate_openai_completion( request: Request, form_data: dict, @@ -1425,7 +1367,7 @@ async def generate_openai_completion( # Database operations (get_model_by_id, AccessGrants.has_access) manage their own short-lived sessions. # This prevents holding a connection during the entire LLM call (30-60+ seconds), # which would exhaust the connection pool under concurrent load. - metadata = form_data.pop("metadata", None) + metadata = form_data.pop('metadata', None) try: form_data = OpenAICompletionForm(**form_data) @@ -1436,69 +1378,67 @@ async def generate_openai_completion( detail=str(e), ) - payload = {**form_data.model_dump(exclude_none=True, exclude=["metadata"])} - if "metadata" in payload: - del payload["metadata"] + payload = {**form_data.model_dump(exclude_none=True, exclude=['metadata'])} + if 'metadata' in payload: + del payload['metadata'] model_id = form_data.model model_info = Models.get_model_by_id(model_id) if model_info: if model_info.base_model_id: - payload["model"] = model_info.base_model_id + payload['model'] = model_info.base_model_id params = model_info.params.model_dump() if params: payload = apply_model_params_to_body_openai(params, payload) # Check if user has access to the model - if user.role == "user": - user_group_ids = { - group.id for group in Groups.get_groups_by_member_id(user.id) - } + if user.role == 'user': + user_group_ids = {group.id for group in Groups.get_groups_by_member_id(user.id)} if not ( user.id == model_info.user_id or AccessGrants.has_access( user_id=user.id, - resource_type="model", + resource_type='model', resource_id=model_info.id, - permission="read", + permission='read', user_group_ids=user_group_ids, ) ): raise HTTPException( status_code=403, - detail="Model not found", + detail='Model not found', ) else: - if user.role != "admin": + if user.role != 'admin': raise HTTPException( status_code=403, - detail="Model not found", + detail='Model not found', ) - url, url_idx = await get_ollama_url(request, payload["model"], url_idx) + url, url_idx = await get_ollama_url(request, payload['model'], url_idx) api_config = request.app.state.config.OLLAMA_API_CONFIGS.get( str(url_idx), request.app.state.config.OLLAMA_API_CONFIGS.get(url, {}), # Legacy support ) - prefix_id = api_config.get("prefix_id", None) + prefix_id = api_config.get('prefix_id', None) if prefix_id: - payload["model"] = payload["model"].replace(f"{prefix_id}.", "") + payload['model'] = payload['model'].replace(f'{prefix_id}.', '') return await send_post_request( - url=f"{url}/v1/completions", + url=f'{url}/v1/completions', payload=json.dumps(payload), - stream=payload.get("stream", False), + stream=payload.get('stream', False), key=get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS), user=user, metadata=metadata, ) -@router.post("/v1/chat/completions") -@router.post("/v1/chat/completions/{url_idx}") +@router.post('/v1/chat/completions') +@router.post('/v1/chat/completions/{url_idx}') async def generate_openai_chat_completion( request: Request, form_data: dict, @@ -1509,7 +1449,7 @@ async def generate_openai_chat_completion( # Database operations (get_model_by_id, AccessGrants.has_access) manage their own short-lived sessions. # This prevents holding a connection during the entire LLM call (30-60+ seconds), # which would exhaust the connection pool under concurrent load. - metadata = form_data.pop("metadata", None) + metadata = form_data.pop('metadata', None) try: completion_form = OpenAIChatCompletionForm(**form_data) @@ -1520,160 +1460,149 @@ async def generate_openai_chat_completion( detail=str(e), ) - payload = {**completion_form.model_dump(exclude_none=True, exclude=["metadata"])} - if "metadata" in payload: - del payload["metadata"] + payload = {**completion_form.model_dump(exclude_none=True, exclude=['metadata'])} + if 'metadata' in payload: + del payload['metadata'] model_id = completion_form.model model_info = Models.get_model_by_id(model_id) if model_info: if model_info.base_model_id: - payload["model"] = model_info.base_model_id + payload['model'] = model_info.base_model_id params = model_info.params.model_dump() if params: - system = params.pop("system", None) + system = params.pop('system', None) payload = apply_model_params_to_body_openai(params, payload) payload = apply_system_prompt_to_body(system, payload, metadata, user) # Check if user has access to the model - if user.role == "user": - user_group_ids = { - group.id for group in Groups.get_groups_by_member_id(user.id) - } + if user.role == 'user': + user_group_ids = {group.id for group in Groups.get_groups_by_member_id(user.id)} if not ( user.id == model_info.user_id or AccessGrants.has_access( user_id=user.id, - resource_type="model", + resource_type='model', resource_id=model_info.id, - permission="read", + permission='read', user_group_ids=user_group_ids, ) ): raise HTTPException( status_code=403, - detail="Model not found", + detail='Model not found', ) else: - if user.role != "admin": + if user.role != 'admin': raise HTTPException( status_code=403, - detail="Model not found", + detail='Model not found', ) - url, url_idx = await get_ollama_url(request, payload["model"], url_idx) + url, url_idx = await get_ollama_url(request, payload['model'], url_idx) api_config = request.app.state.config.OLLAMA_API_CONFIGS.get( str(url_idx), request.app.state.config.OLLAMA_API_CONFIGS.get(url, {}), # Legacy support ) - prefix_id = api_config.get("prefix_id", None) + prefix_id = api_config.get('prefix_id', None) if prefix_id: - payload["model"] = payload["model"].replace(f"{prefix_id}.", "") + payload['model'] = payload['model'].replace(f'{prefix_id}.', '') return await send_post_request( - url=f"{url}/v1/chat/completions", + url=f'{url}/v1/chat/completions', payload=json.dumps(payload), - stream=payload.get("stream", False), + stream=payload.get('stream', False), key=get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS), user=user, metadata=metadata, ) -@router.get("/v1/models") -@router.get("/v1/models/{url_idx}") +@router.get('/v1/models') +@router.get('/v1/models/{url_idx}') async def get_openai_models( request: Request, url_idx: Optional[int] = None, user=Depends(get_verified_user), db: Session = Depends(get_session), ): - models = [] if url_idx is None: model_list = await get_all_models(request, user=user) models = [ { - "id": model["model"], - "object": "model", - "created": int(time.time()), - "owned_by": "openai", + 'id': model['model'], + 'object': 'model', + 'created': int(time.time()), + 'owned_by': 'openai', } - for model in model_list["models"] + for model in model_list['models'] ] else: url = request.app.state.config.OLLAMA_BASE_URLS[url_idx] try: - r = requests.request(method="GET", url=f"{url}/api/tags") + r = requests.request(method='GET', url=f'{url}/api/tags') r.raise_for_status() model_list = r.json() models = [ { - "id": model["model"], - "object": "model", - "created": int(time.time()), - "owned_by": "openai", + 'id': model['model'], + 'object': 'model', + 'created': int(time.time()), + 'owned_by': 'openai', } - for model in models["models"] + for model in models['models'] ] except Exception as e: log.exception(e) - error_detail = "Open WebUI: Server Connection Error" + error_detail = 'Open WebUI: Server Connection Error' if r is not None: try: res = r.json() - if "error" in res: - error_detail = f"Ollama: {res['error']}" + if 'error' in res: + error_detail = f'Ollama: {res["error"]}' except Exception: - error_detail = f"Ollama: {e}" + error_detail = f'Ollama: {e}' raise HTTPException( status_code=r.status_code if r else 500, detail=error_detail, ) - if user.role == "user" and not BYPASS_MODEL_ACCESS_CONTROL: + if user.role == 'user' and not BYPASS_MODEL_ACCESS_CONTROL: # Filter models based on user access control - model_ids = [model["id"] for model in models] - model_infos = { - model_info.id: model_info - for model_info in Models.get_models_by_ids(model_ids, db=db) - } - user_group_ids = { - group.id for group in Groups.get_groups_by_member_id(user.id, db=db) - } + model_ids = [model['id'] for model in models] + model_infos = {model_info.id: model_info for model_info in Models.get_models_by_ids(model_ids, db=db)} + user_group_ids = {group.id for group in Groups.get_groups_by_member_id(user.id, db=db)} # Batch-fetch accessible resource IDs in a single query instead of N has_access calls accessible_model_ids = AccessGrants.get_accessible_resource_ids( user_id=user.id, - resource_type="model", + resource_type='model', resource_ids=list(model_infos.keys()), - permission="read", + permission='read', user_group_ids=user_group_ids, db=db, ) filtered_models = [] for model in models: - model_info = model_infos.get(model["id"]) + model_info = model_infos.get(model['id']) if model_info: - if ( - user.id == model_info.user_id - or model_info.id in accessible_model_ids - ): + if user.id == model_info.user_id or model_info.id in accessible_model_ids: filtered_models.append(model) models = filtered_models return { - "data": models, - "object": "list", + 'data': models, + 'object': 'list', } @@ -1691,7 +1620,7 @@ def parse_huggingface_url(hf_url): parsed_url = urlparse(hf_url) # Get the path and split it into components - path_components = parsed_url.path.split("/") + path_components = parsed_url.path.split('/') # Extract the desired output model_file = path_components[-1] @@ -1701,9 +1630,7 @@ def parse_huggingface_url(hf_url): return None -async def download_file_stream( - ollama_url, file_url, file_path, file_name, chunk_size=1024 * 1024 -): +async def download_file_stream(ollama_url, file_url, file_path, file_name, chunk_size=1024 * 1024): done = False if os.path.exists(file_path): @@ -1711,17 +1638,15 @@ async def download_file_stream( else: current_size = 0 - headers = {"Range": f"bytes={current_size}-"} if current_size > 0 else {} + headers = {'Range': f'bytes={current_size}-'} if current_size > 0 else {} timeout = aiohttp.ClientTimeout(total=600) # Set the timeout async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session: - async with session.get( - file_url, headers=headers, ssl=AIOHTTP_CLIENT_SESSION_SSL - ) as response: - total_size = int(response.headers.get("content-length", 0)) + current_size + async with session.get(file_url, headers=headers, ssl=AIOHTTP_CLIENT_SESSION_SSL) as response: + total_size = int(response.headers.get('content-length', 0)) + current_size - with open(file_path, "ab+") as file: + with open(file_path, 'ab+') as file: async for data in response.content.iter_chunked(chunk_size): current_size += len(data) file.write(data) @@ -1735,42 +1660,40 @@ async def download_file_stream( file.close() hashed = calculate_sha256(file_path, chunk_size) - with open(file_path, "rb") as file: + with open(file_path, 'rb') as file: chunk_size = 1024 * 1024 * 2 - url = f"{ollama_url}/api/blobs/sha256:{hashed}" + url = f'{ollama_url}/api/blobs/sha256:{hashed}' with requests.Session() as session: response = session.post(url, data=file, timeout=30) if response.ok: res = { - "done": done, - "blob": f"sha256:{hashed}", - "name": file_name, + 'done': done, + 'blob': f'sha256:{hashed}', + 'name': file_name, } os.remove(file_path) - yield f"data: {json.dumps(res)}\n\n" + yield f'data: {json.dumps(res)}\n\n' else: - raise RuntimeError( - "Ollama: Could not create blob, Please try again." - ) + raise RuntimeError('Ollama: Could not create blob, Please try again.') # url = "https://huggingface.co/TheBloke/stablelm-zephyr-3b-GGUF/resolve/main/stablelm-zephyr-3b.Q2_K.gguf" -@router.post("/models/download") -@router.post("/models/download/{url_idx}") +@router.post('/models/download') +@router.post('/models/download/{url_idx}') async def download_model( request: Request, form_data: UrlForm, url_idx: Optional[int] = None, user=Depends(get_admin_user), ): - allowed_hosts = ["https://huggingface.co/", "https://github.com/"] + allowed_hosts = ['https://huggingface.co/', 'https://github.com/'] if not any(form_data.url.startswith(host) for host in allowed_hosts): raise HTTPException( status_code=400, - detail="Invalid file_url. Only URLs from allowed hosts are permitted.", + detail='Invalid file_url. Only URLs from allowed hosts are permitted.', ) if url_idx is None: @@ -1780,7 +1703,7 @@ async def download_model( file_name = parse_huggingface_url(form_data.url) if file_name: - file_path = f"{UPLOAD_DIR}/{file_name}" + file_path = f'{UPLOAD_DIR}/{file_name}' return StreamingResponse( download_file_stream(url, form_data.url, file_path, file_name), @@ -1790,8 +1713,8 @@ async def download_model( # TODO: Progress bar does not reflect size & duration of upload. -@router.post("/models/upload") -@router.post("/models/upload/{url_idx}") +@router.post('/models/upload') +@router.post('/models/upload/{url_idx}') async def upload_model( request: Request, file: UploadFile = File(...), @@ -1808,7 +1731,7 @@ async def upload_model( # --- P1: save file locally --- chunk_size = 1024 * 1024 * 2 # 2 MB chunks - with open(file_path, "wb") as out_f: + with open(file_path, 'wb') as out_f: while True: chunk = file.file.read(chunk_size) # log.info(f"Chunk: {str(chunk)}") # DEBUG @@ -1819,72 +1742,70 @@ async def upload_model( async def file_process_stream(): nonlocal ollama_url total_size = os.path.getsize(file_path) - log.info(f"Total Model Size: {str(total_size)}") # DEBUG + log.info(f'Total Model Size: {str(total_size)}') # DEBUG # --- P2: SSE progress + calculate sha256 hash --- file_hash = calculate_sha256(file_path, chunk_size) - log.info(f"Model Hash: {str(file_hash)}") # DEBUG + log.info(f'Model Hash: {str(file_hash)}') # DEBUG try: - with open(file_path, "rb") as f: + with open(file_path, 'rb') as f: bytes_read = 0 while chunk := f.read(chunk_size): bytes_read += len(chunk) progress = round(bytes_read / total_size * 100, 2) data_msg = { - "progress": progress, - "total": total_size, - "completed": bytes_read, + 'progress': progress, + 'total': total_size, + 'completed': bytes_read, } - yield f"data: {json.dumps(data_msg)}\n\n" + yield f'data: {json.dumps(data_msg)}\n\n' # --- P3: Upload to ollama /api/blobs --- - with open(file_path, "rb") as f: - url = f"{ollama_url}/api/blobs/sha256:{file_hash}" + with open(file_path, 'rb') as f: + url = f'{ollama_url}/api/blobs/sha256:{file_hash}' response = requests.post(url, data=f) if response.ok: - log.info(f"Uploaded to /api/blobs") # DEBUG + log.info(f'Uploaded to /api/blobs') # DEBUG # Remove local file os.remove(file_path) # Create model in ollama model_name, ext = os.path.splitext(filename) - log.info(f"Created Model: {model_name}") # DEBUG + log.info(f'Created Model: {model_name}') # DEBUG create_payload = { - "model": model_name, + 'model': model_name, # Reference the file by its original name => the uploaded blob's digest - "files": {filename: f"sha256:{file_hash}"}, + 'files': {filename: f'sha256:{file_hash}'}, } - log.info(f"Model Payload: {create_payload}") # DEBUG + log.info(f'Model Payload: {create_payload}') # DEBUG # Call ollama /api/create # https://github.com/ollama/ollama/blob/main/docs/api.md#create-a-model create_resp = requests.post( - url=f"{ollama_url}/api/create", - headers={"Content-Type": "application/json"}, + url=f'{ollama_url}/api/create', + headers={'Content-Type': 'application/json'}, data=json.dumps(create_payload), ) if create_resp.ok: - log.info(f"API SUCCESS!") # DEBUG + log.info(f'API SUCCESS!') # DEBUG done_msg = { - "done": True, - "blob": f"sha256:{file_hash}", - "name": filename, - "model_created": model_name, + 'done': True, + 'blob': f'sha256:{file_hash}', + 'name': filename, + 'model_created': model_name, } - yield f"data: {json.dumps(done_msg)}\n\n" + yield f'data: {json.dumps(done_msg)}\n\n' else: - raise Exception( - f"Failed to create model in Ollama. {create_resp.text}" - ) + raise Exception(f'Failed to create model in Ollama. {create_resp.text}') else: - raise Exception("Ollama: Could not create blob, Please try again.") + raise Exception('Ollama: Could not create blob, Please try again.') except Exception as e: - res = {"error": str(e)} - yield f"data: {json.dumps(res)}\n\n" + res = {'error': str(e)} + yield f'data: {json.dumps(res)}\n\n' - return StreamingResponse(file_process_stream(), media_type="text/event-stream") + return StreamingResponse(file_process_stream(), media_type='text/event-stream') diff --git a/backend/open_webui/routers/openai.py b/backend/open_webui/routers/openai.py index 6ad49a112b..89fbf4852c 100644 --- a/backend/open_webui/routers/openai.py +++ b/backend/open_webui/routers/openai.py @@ -74,7 +74,7 @@ async def send_get_request(url, key=None, user: UserModel = None): try: async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session: headers = { - **({"Authorization": f"Bearer {key}"} if key else {}), + **({'Authorization': f'Bearer {key}'} if key else {}), } if ENABLE_FORWARD_USER_INFO_HEADERS and user: @@ -88,33 +88,33 @@ async def send_get_request(url, key=None, user: UserModel = None): return await response.json() except Exception as e: # Handle connection error here - log.error(f"Connection error: {e}") + log.error(f'Connection error: {e}') return None async def get_models_request(url, key=None, user: UserModel = None): if is_anthropic_url(url): return await get_anthropic_models(url, key, user=user) - return await send_get_request(f"{url}/models", key, user=user) + return await send_get_request(f'{url}/models', key, user=user) def openai_reasoning_model_handler(payload): """ Handle reasoning model specific parameters """ - if "max_tokens" in payload: + if 'max_tokens' in payload: # Convert "max_tokens" to "max_completion_tokens" for all reasoning models - payload["max_completion_tokens"] = payload["max_tokens"] - del payload["max_tokens"] + payload['max_completion_tokens'] = payload['max_tokens'] + del payload['max_tokens'] # Handle system role conversion based on model type - if payload["messages"][0]["role"] == "system": - model_lower = payload["model"].lower() + if payload['messages'][0]['role'] == 'system': + model_lower = payload['model'].lower() # Legacy models use "user" role instead of "system" - if model_lower.startswith("o1-mini") or model_lower.startswith("o1-preview"): - payload["messages"][0]["role"] = "user" + if model_lower.startswith('o1-mini') or model_lower.startswith('o1-preview'): + payload['messages'][0]['role'] = 'user' else: - payload["messages"][0]["role"] = "developer" + payload['messages'][0]['role'] = 'developer' return payload @@ -129,57 +129,57 @@ async def get_headers_and_cookies( ): cookies = {} headers = { - "Content-Type": "application/json", + 'Content-Type': 'application/json', **( { - "HTTP-Referer": "https://openwebui.com/", - "X-Title": "Open WebUI", + 'HTTP-Referer': 'https://openwebui.com/', + 'X-Title': 'Open WebUI', } - if "openrouter.ai" in url + if 'openrouter.ai' in url else {} ), } if ENABLE_FORWARD_USER_INFO_HEADERS and user: headers = include_user_info_headers(headers, user) - if metadata and metadata.get("chat_id"): - headers[FORWARD_SESSION_INFO_HEADER_CHAT_ID] = metadata.get("chat_id") + if metadata and metadata.get('chat_id'): + headers[FORWARD_SESSION_INFO_HEADER_CHAT_ID] = metadata.get('chat_id') token = None - auth_type = config.get("auth_type") + auth_type = config.get('auth_type') - if auth_type == "bearer" or auth_type is None: + if auth_type == 'bearer' or auth_type is None: # Default to bearer if not specified - token = f"{key}" - elif auth_type == "none": + token = f'{key}' + elif auth_type == 'none': token = None - elif auth_type == "session": + elif auth_type == 'session': cookies = request.cookies token = request.state.token.credentials - elif auth_type == "system_oauth": + elif auth_type == 'system_oauth': cookies = request.cookies oauth_token = None try: - if request.cookies.get("oauth_session_id", None): + if request.cookies.get('oauth_session_id', None): oauth_token = await request.app.state.oauth_manager.get_oauth_token( user.id, - request.cookies.get("oauth_session_id", None), + request.cookies.get('oauth_session_id', None), ) except Exception as e: - log.error(f"Error getting OAuth token: {e}") + log.error(f'Error getting OAuth token: {e}') if oauth_token: - token = f"{oauth_token.get('access_token', '')}" + token = f'{oauth_token.get("access_token", "")}' - elif auth_type in ("azure_ad", "microsoft_entra_id"): + elif auth_type in ('azure_ad', 'microsoft_entra_id'): token = get_microsoft_entra_id_access_token() if token: - headers["Authorization"] = f"Bearer {token}" + headers['Authorization'] = f'Bearer {token}' - if config.get("headers") and isinstance(config.get("headers"), dict): - headers = {**headers, **config.get("headers")} + if config.get('headers') and isinstance(config.get('headers'), dict): + headers = {**headers, **config.get('headers')} return headers, cookies @@ -191,11 +191,11 @@ def get_microsoft_entra_id_access_token(): """ try: token_provider = get_bearer_token_provider( - DefaultAzureCredential(), "https://cognitiveservices.azure.com/.default" + DefaultAzureCredential(), 'https://cognitiveservices.azure.com/.default' ) return token_provider() except Exception as e: - log.error(f"Error getting Microsoft Entra ID access token: {e}") + log.error(f'Error getting Microsoft Entra ID access token: {e}') return None @@ -208,13 +208,13 @@ def get_microsoft_entra_id_access_token(): router = APIRouter() -@router.get("/config") +@router.get('/config') async def get_config(request: Request, user=Depends(get_admin_user)): return { - "ENABLE_OPENAI_API": request.app.state.config.ENABLE_OPENAI_API, - "OPENAI_API_BASE_URLS": request.app.state.config.OPENAI_API_BASE_URLS, - "OPENAI_API_KEYS": request.app.state.config.OPENAI_API_KEYS, - "OPENAI_API_CONFIGS": request.app.state.config.OPENAI_API_CONFIGS, + 'ENABLE_OPENAI_API': request.app.state.config.ENABLE_OPENAI_API, + 'OPENAI_API_BASE_URLS': request.app.state.config.OPENAI_API_BASE_URLS, + 'OPENAI_API_KEYS': request.app.state.config.OPENAI_API_KEYS, + 'OPENAI_API_CONFIGS': request.app.state.config.OPENAI_API_CONFIGS, } @@ -225,30 +225,21 @@ class OpenAIConfigForm(BaseModel): OPENAI_API_CONFIGS: dict -@router.post("/config/update") -async def update_config( - request: Request, form_data: OpenAIConfigForm, user=Depends(get_admin_user) -): +@router.post('/config/update') +async def update_config(request: Request, form_data: OpenAIConfigForm, user=Depends(get_admin_user)): request.app.state.config.ENABLE_OPENAI_API = form_data.ENABLE_OPENAI_API request.app.state.config.OPENAI_API_BASE_URLS = form_data.OPENAI_API_BASE_URLS request.app.state.config.OPENAI_API_KEYS = form_data.OPENAI_API_KEYS # Check if API KEYS length is same than API URLS length - if len(request.app.state.config.OPENAI_API_KEYS) != len( - request.app.state.config.OPENAI_API_BASE_URLS - ): - if len(request.app.state.config.OPENAI_API_KEYS) > len( - request.app.state.config.OPENAI_API_BASE_URLS - ): - request.app.state.config.OPENAI_API_KEYS = ( - request.app.state.config.OPENAI_API_KEYS[ - : len(request.app.state.config.OPENAI_API_BASE_URLS) - ] - ) + if len(request.app.state.config.OPENAI_API_KEYS) != len(request.app.state.config.OPENAI_API_BASE_URLS): + if len(request.app.state.config.OPENAI_API_KEYS) > len(request.app.state.config.OPENAI_API_BASE_URLS): + request.app.state.config.OPENAI_API_KEYS = request.app.state.config.OPENAI_API_KEYS[ + : len(request.app.state.config.OPENAI_API_BASE_URLS) + ] else: - request.app.state.config.OPENAI_API_KEYS += [""] * ( - len(request.app.state.config.OPENAI_API_BASE_URLS) - - len(request.app.state.config.OPENAI_API_KEYS) + request.app.state.config.OPENAI_API_KEYS += [''] * ( + len(request.app.state.config.OPENAI_API_BASE_URLS) - len(request.app.state.config.OPENAI_API_KEYS) ) request.app.state.config.OPENAI_API_CONFIGS = form_data.OPENAI_API_CONFIGS @@ -256,34 +247,30 @@ async def update_config( # Remove the API configs that are not in the API URLS keys = list(map(str, range(len(request.app.state.config.OPENAI_API_BASE_URLS)))) request.app.state.config.OPENAI_API_CONFIGS = { - key: value - for key, value in request.app.state.config.OPENAI_API_CONFIGS.items() - if key in keys + key: value for key, value in request.app.state.config.OPENAI_API_CONFIGS.items() if key in keys } return { - "ENABLE_OPENAI_API": request.app.state.config.ENABLE_OPENAI_API, - "OPENAI_API_BASE_URLS": request.app.state.config.OPENAI_API_BASE_URLS, - "OPENAI_API_KEYS": request.app.state.config.OPENAI_API_KEYS, - "OPENAI_API_CONFIGS": request.app.state.config.OPENAI_API_CONFIGS, + 'ENABLE_OPENAI_API': request.app.state.config.ENABLE_OPENAI_API, + 'OPENAI_API_BASE_URLS': request.app.state.config.OPENAI_API_BASE_URLS, + 'OPENAI_API_KEYS': request.app.state.config.OPENAI_API_KEYS, + 'OPENAI_API_CONFIGS': request.app.state.config.OPENAI_API_CONFIGS, } -@router.post("/audio/speech") +@router.post('/audio/speech') async def speech(request: Request, user=Depends(get_verified_user)): idx = None try: - idx = request.app.state.config.OPENAI_API_BASE_URLS.index( - "https://api.openai.com/v1" - ) + idx = request.app.state.config.OPENAI_API_BASE_URLS.index('https://api.openai.com/v1') body = await request.body() name = hashlib.sha256(body).hexdigest() - SPEECH_CACHE_DIR = CACHE_DIR / "audio" / "speech" + SPEECH_CACHE_DIR = CACHE_DIR / 'audio' / 'speech' SPEECH_CACHE_DIR.mkdir(parents=True, exist_ok=True) - file_path = SPEECH_CACHE_DIR.joinpath(f"{name}.mp3") - file_body_path = SPEECH_CACHE_DIR.joinpath(f"{name}.json") + file_path = SPEECH_CACHE_DIR.joinpath(f'{name}.mp3') + file_body_path = SPEECH_CACHE_DIR.joinpath(f'{name}.json') # Check if the file already exists in the cache if file_path.is_file(): @@ -296,14 +283,12 @@ async def speech(request: Request, user=Depends(get_verified_user)): request.app.state.config.OPENAI_API_CONFIGS.get(url, {}), # Legacy support ) - headers, cookies = await get_headers_and_cookies( - request, url, key, api_config, user=user - ) + headers, cookies = await get_headers_and_cookies(request, url, key, api_config, user=user) r = None try: r = requests.post( - url=f"{url}/audio/speech", + url=f'{url}/audio/speech', data=body, headers=headers, cookies=cookies, @@ -313,12 +298,12 @@ async def speech(request: Request, user=Depends(get_verified_user)): r.raise_for_status() # Save the streaming content to a file - with open(file_path, "wb") as f: + with open(file_path, 'wb') as f: for chunk in r.iter_content(chunk_size=8192): f.write(chunk) - with open(file_body_path, "w") as f: - json.dump(json.loads(body.decode("utf-8")), f) + with open(file_body_path, 'w') as f: + json.dump(json.loads(body.decode('utf-8')), f) # Return the saved file return FileResponse(file_path) @@ -330,14 +315,14 @@ async def speech(request: Request, user=Depends(get_verified_user)): if r is not None: try: res = r.json() - if "error" in res: - detail = f"External: {res['error']}" + if 'error' in res: + detail = f'External: {res["error"]}' except Exception: - detail = f"External: {e}" + detail = f'External: {e}' raise HTTPException( status_code=r.status_code if r else 500, - detail=detail if detail else "Open WebUI: Server Connection Error", + detail=detail if detail else 'Open WebUI: Server Connection Error', ) except ValueError: @@ -366,7 +351,7 @@ async def get_all_models_responses(request: Request, user: UserModel) -> list: request.app.state.config.OPENAI_API_KEYS = api_keys # if there are more urls than keys, add empty keys else: - api_keys += [""] * (num_urls - num_keys) + api_keys += [''] * (num_urls - num_keys) request.app.state.config.OPENAI_API_KEYS = api_keys request_tasks = [] @@ -379,32 +364,28 @@ async def get_all_models_responses(request: Request, user: UserModel) -> list: api_configs.get(url, {}), # Legacy support ) - enable = api_config.get("enable", True) - model_ids = api_config.get("model_ids", []) + enable = api_config.get('enable', True) + model_ids = api_config.get('model_ids', []) if enable: if len(model_ids) == 0: - request_tasks.append( - get_models_request(url, api_keys[idx], user=user) - ) + request_tasks.append(get_models_request(url, api_keys[idx], user=user)) else: model_list = { - "object": "list", - "data": [ + 'object': 'list', + 'data': [ { - "id": model_id, - "name": model_id, - "owned_by": "openai", - "openai": {"id": model_id}, - "urlIdx": idx, + 'id': model_id, + 'name': model_id, + 'owned_by': 'openai', + 'openai': {'id': model_id}, + 'urlIdx': idx, } for model_id in model_ids ], } - request_tasks.append( - asyncio.ensure_future(asyncio.sleep(0, model_list)) - ) + request_tasks.append(asyncio.ensure_future(asyncio.sleep(0, model_list))) else: request_tasks.append(asyncio.ensure_future(asyncio.sleep(0, None))) @@ -418,61 +399,52 @@ async def get_all_models_responses(request: Request, user: UserModel) -> list: api_configs.get(url, {}), # Legacy support ) - connection_type = api_config.get("connection_type", "external") - prefix_id = api_config.get("prefix_id", None) - tags = api_config.get("tags", []) + connection_type = api_config.get('connection_type', 'external') + prefix_id = api_config.get('prefix_id', None) + tags = api_config.get('tags', []) - model_list = ( - response if isinstance(response, list) else response.get("data", []) - ) + model_list = response if isinstance(response, list) else response.get('data', []) if not isinstance(model_list, list): # Catch non-list responses model_list = [] for model in model_list: # Remove name key if its value is None #16689 - if "name" in model and model["name"] is None: - del model["name"] + if 'name' in model and model['name'] is None: + del model['name'] if prefix_id: - model["id"] = ( - f"{prefix_id}.{model.get('id', model.get('name', ''))}" - ) + model['id'] = f'{prefix_id}.{model.get("id", model.get("name", ""))}' if tags: - model["tags"] = tags + model['tags'] = tags if connection_type: - model["connection_type"] = connection_type + model['connection_type'] = connection_type - log.debug(f"get_all_models:responses() {responses}") + log.debug(f'get_all_models:responses() {responses}') return responses async def get_filtered_models(models, user, db=None): # Filter models based on user access control - model_ids = [model["id"] for model in models.get("data", [])] - model_infos = { - model_info.id: model_info - for model_info in Models.get_models_by_ids(model_ids, db=db) - } - user_group_ids = { - group.id for group in Groups.get_groups_by_member_id(user.id, db=db) - } + model_ids = [model['id'] for model in models.get('data', [])] + model_infos = {model_info.id: model_info for model_info in Models.get_models_by_ids(model_ids, db=db)} + user_group_ids = {group.id for group in Groups.get_groups_by_member_id(user.id, db=db)} # Batch-fetch accessible resource IDs in a single query instead of N has_access calls accessible_model_ids = AccessGrants.get_accessible_resource_ids( user_id=user.id, - resource_type="model", + resource_type='model', resource_ids=list(model_infos.keys()), - permission="read", + permission='read', user_group_ids=user_group_ids, db=db, ) filtered_models = [] - for model in models.get("data", []): - model_info = model_infos.get(model["id"]) + for model in models.get('data', []): + model_info = model_infos.get(model['id']) if model_info: if user.id == model_info.user_id or model_info.id in accessible_model_ids: filtered_models.append(model) @@ -481,13 +453,13 @@ async def get_filtered_models(models, user, db=None): @cached( ttl=MODELS_CACHE_TTL, - key=lambda _, user: f"openai_all_models_{user.id}" if user else "openai_all_models", + key=lambda _, user: f'openai_all_models_{user.id}' if user else 'openai_all_models', ) async def get_all_models(request: Request, user: UserModel) -> dict[str, list]: - log.info("get_all_models()") + log.info('get_all_models()') if not request.app.state.config.ENABLE_OPENAI_API: - return {"data": []} + return {'data': []} # Cache config value locally to avoid repeated Redis lookups inside # the nested loop in get_merged_models (one GET per model otherwise). @@ -496,8 +468,8 @@ async def get_all_models(request: Request, user: UserModel) -> dict[str, list]: responses = await get_all_models_responses(request, user=user) def extract_data(response): - if response and "data" in response: - return response["data"] + if response and 'data' in response: + return response['data'] if isinstance(response, list): return response return None @@ -506,63 +478,59 @@ async def get_all_models(request: Request, user: UserModel) -> dict[str, list]: if any( name in model_id for name in [ - "babbage", - "dall-e", - "davinci", - "embedding", - "tts", - "whisper", + 'babbage', + 'dall-e', + 'davinci', + 'embedding', + 'tts', + 'whisper', ] ): return False return True def get_merged_models(model_lists): - log.debug(f"merge_models_lists {model_lists}") + log.debug(f'merge_models_lists {model_lists}') models = {} for idx, model_list in enumerate(model_lists): - if model_list is not None and "error" not in model_list: + if model_list is not None and 'error' not in model_list: for model in model_list: - model_id = model.get("id") or model.get("name") + model_id = model.get('id') or model.get('name') base_url = api_base_urls[idx] hostname = urlparse(base_url).hostname if base_url else None - if hostname == "api.openai.com" and not is_supported_openai_models( - model_id - ): + if hostname == 'api.openai.com' and not is_supported_openai_models(model_id): # Skip unwanted OpenAI models continue if model_id and model_id not in models: models[model_id] = { **model, - "name": model.get("name", model_id), - "owned_by": "openai", - "openai": model, - "connection_type": model.get("connection_type", "external"), - "urlIdx": idx, + 'name': model.get('name', model_id), + 'owned_by': 'openai', + 'openai': model, + 'connection_type': model.get('connection_type', 'external'), + 'urlIdx': idx, } return models models = get_merged_models(map(extract_data, responses)) - log.debug(f"models: {models}") + log.debug(f'models: {models}') request.app.state.OPENAI_MODELS = models - return {"data": list(models.values())} + return {'data': list(models.values())} -@router.get("/models") -@router.get("/models/{url_idx}") -async def get_models( - request: Request, url_idx: Optional[int] = None, user=Depends(get_verified_user) -): +@router.get('/models') +@router.get('/models/{url_idx}') +async def get_models(request: Request, url_idx: Optional[int] = None, user=Depends(get_verified_user)): if not request.app.state.config.ENABLE_OPENAI_API: - raise HTTPException(status_code=503, detail="OpenAI API is disabled") + raise HTTPException(status_code=503, detail='OpenAI API is disabled') models = { - "data": [], + 'data': [], } if url_idx is None: @@ -582,51 +550,49 @@ async def get_models( timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST), ) as session: try: - headers, cookies = await get_headers_and_cookies( - request, url, key, api_config, user=user - ) + headers, cookies = await get_headers_and_cookies(request, url, key, api_config, user=user) - if api_config.get("azure", False): + if api_config.get('azure', False): models = { - "data": api_config.get("model_ids", []) or [], - "object": "list", + 'data': api_config.get('model_ids', []) or [], + 'object': 'list', } elif is_anthropic_url(url): models = await get_anthropic_models(url, key, user=user) if models is None: - raise Exception("Failed to connect to Anthropic API") + raise Exception('Failed to connect to Anthropic API') else: async with session.get( - f"{url}/models", + f'{url}/models', headers=headers, cookies=cookies, ssl=AIOHTTP_CLIENT_SESSION_SSL, ) as r: if r.status != 200: - error_detail = f"HTTP Error: {r.status}" + error_detail = f'HTTP Error: {r.status}' try: res = await r.json() - if "error" in res: - error_detail = f"External Error: {res['error']}" + if 'error' in res: + error_detail = f'External Error: {res["error"]}' except Exception: pass raise Exception(error_detail) response_data = await r.json() - if "api.openai.com" in url: - response_data["data"] = [ + if 'api.openai.com' in url: + response_data['data'] = [ model - for model in response_data.get("data", []) + for model in response_data.get('data', []) if not any( - name in model["id"] + name in model['id'] for name in [ - "babbage", - "dall-e", - "davinci", - "embedding", - "tts", - "whisper", + 'babbage', + 'dall-e', + 'davinci', + 'embedding', + 'tts', + 'whisper', ] ) ] @@ -634,17 +600,15 @@ async def get_models( models = response_data except aiohttp.ClientError as e: # ClientError covers all aiohttp requests issues - log.exception(f"Client error: {str(e)}") - raise HTTPException( - status_code=500, detail="Open WebUI: Server Connection Error" - ) + log.exception(f'Client error: {str(e)}') + raise HTTPException(status_code=500, detail='Open WebUI: Server Connection Error') except Exception as e: - log.exception(f"Unexpected error: {e}") - error_detail = f"Unexpected error: {str(e)}" + log.exception(f'Unexpected error: {e}') + error_detail = f'Unexpected error: {str(e)}' raise HTTPException(status_code=500, detail=error_detail) - if user.role == "user" and not BYPASS_MODEL_ACCESS_CONTROL: - models["data"] = await get_filtered_models(models, user) + if user.role == 'user' and not BYPASS_MODEL_ACCESS_CONTROL: + models['data'] = await get_filtered_models(models, user) return models @@ -656,7 +620,7 @@ class ConnectionVerificationForm(BaseModel): config: Optional[dict] = None -@router.post("/verify") +@router.post('/verify') async def verify_connection( request: Request, form_data: ConnectionVerificationForm, @@ -672,19 +636,17 @@ async def verify_connection( timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST), ) as session: try: - headers, cookies = await get_headers_and_cookies( - request, url, key, api_config, user=user - ) + headers, cookies = await get_headers_and_cookies(request, url, key, api_config, user=user) - if api_config.get("azure", False): + if api_config.get('azure', False): # Only set api-key header if not using Azure Entra ID authentication - auth_type = api_config.get("auth_type", "bearer") - if auth_type not in ("azure_ad", "microsoft_entra_id"): - headers["api-key"] = key + auth_type = api_config.get('auth_type', 'bearer') + if auth_type not in ('azure_ad', 'microsoft_entra_id'): + headers['api-key'] = key - api_version = api_config.get("api_version", "") or "2023-03-15-preview" + api_version = api_config.get('api_version', '') or '2023-03-15-preview' async with session.get( - url=f"{url}/openai/models?api-version={api_version}", + url=f'{url}/openai/models?api-version={api_version}', headers=headers, cookies=cookies, ssl=AIOHTTP_CLIENT_SESSION_SSL, @@ -696,27 +658,21 @@ async def verify_connection( if r.status != 200: if isinstance(response_data, (dict, list)): - return JSONResponse( - status_code=r.status, content=response_data - ) + return JSONResponse(status_code=r.status, content=response_data) else: - return PlainTextResponse( - status_code=r.status, content=response_data - ) + return PlainTextResponse(status_code=r.status, content=response_data) return response_data elif is_anthropic_url(url): result = await get_anthropic_models(url, key) if result is None: - raise HTTPException( - status_code=500, detail="Failed to connect to Anthropic API" - ) - if "error" in result: - raise HTTPException(status_code=500, detail=result["error"]) + raise HTTPException(status_code=500, detail='Failed to connect to Anthropic API') + if 'error' in result: + raise HTTPException(status_code=500, detail=result['error']) return result else: async with session.get( - f"{url}/models", + f'{url}/models', headers=headers, cookies=cookies, ssl=AIOHTTP_CLIENT_SESSION_SSL, @@ -728,77 +684,67 @@ async def verify_connection( if r.status != 200: if isinstance(response_data, (dict, list)): - return JSONResponse( - status_code=r.status, content=response_data - ) + return JSONResponse(status_code=r.status, content=response_data) else: - return PlainTextResponse( - status_code=r.status, content=response_data - ) + return PlainTextResponse(status_code=r.status, content=response_data) return response_data except aiohttp.ClientError as e: # ClientError covers all aiohttp requests issues - log.exception(f"Client error: {str(e)}") - raise HTTPException( - status_code=500, detail="Open WebUI: Server Connection Error" - ) + log.exception(f'Client error: {str(e)}') + raise HTTPException(status_code=500, detail='Open WebUI: Server Connection Error') except Exception as e: - log.exception(f"Unexpected error: {e}") - raise HTTPException( - status_code=500, detail="Open WebUI: Server Connection Error" - ) + log.exception(f'Unexpected error: {e}') + raise HTTPException(status_code=500, detail='Open WebUI: Server Connection Error') def get_azure_allowed_params(api_version: str) -> set[str]: allowed_params = { - "messages", - "temperature", - "role", - "content", - "contentPart", - "contentPartImage", - "enhancements", - "dataSources", - "n", - "stream", - "stop", - "max_tokens", - "presence_penalty", - "frequency_penalty", - "logit_bias", - "user", - "function_call", - "functions", - "tools", - "tool_choice", - "top_p", - "log_probs", - "top_logprobs", - "response_format", - "seed", - "max_completion_tokens", - "reasoning_effort", + 'messages', + 'temperature', + 'role', + 'content', + 'contentPart', + 'contentPartImage', + 'enhancements', + 'dataSources', + 'n', + 'stream', + 'stop', + 'max_tokens', + 'presence_penalty', + 'frequency_penalty', + 'logit_bias', + 'user', + 'function_call', + 'functions', + 'tools', + 'tool_choice', + 'top_p', + 'log_probs', + 'top_logprobs', + 'response_format', + 'seed', + 'max_completion_tokens', + 'reasoning_effort', } try: - if api_version >= "2024-09-01-preview": - allowed_params.add("stream_options") + if api_version >= '2024-09-01-preview': + allowed_params.add('stream_options') except ValueError: - log.debug( - f"Invalid API version {api_version} for Azure OpenAI. Defaulting to allowed parameters." - ) + log.debug(f'Invalid API version {api_version} for Azure OpenAI. Defaulting to allowed parameters.') return allowed_params def is_openai_reasoning_model(model: str) -> bool: - return model.lower().startswith(("o1", "o3", "o4", "gpt-5")) + return model.lower().startswith(('o1', 'o3', 'o4', 'gpt-5')) def convert_to_azure_payload(url, payload: dict, api_version: str): - model = payload.get("model", "") + model = payload.get('model', '') # Filter allowed parameters based on Azure OpenAI API allowed_params = get_azure_allowed_params(api_version) @@ -806,21 +752,21 @@ def convert_to_azure_payload(url, payload: dict, api_version: str): # Special handling for o-series models if is_openai_reasoning_model(model): # Convert max_tokens to max_completion_tokens for o-series models - if "max_tokens" in payload: - payload["max_completion_tokens"] = payload["max_tokens"] - del payload["max_tokens"] + if 'max_tokens' in payload: + payload['max_completion_tokens'] = payload['max_tokens'] + del payload['max_tokens'] # Remove temperature if not 1 for o-series models - if "temperature" in payload and payload["temperature"] != 1: + if 'temperature' in payload and payload['temperature'] != 1: log.debug( - f"Removing temperature parameter for o-series model {model} as only default value (1) is supported" + f'Removing temperature parameter for o-series model {model} as only default value (1) is supported' ) - del payload["temperature"] + del payload['temperature'] # Filter out unsupported parameters payload = {k: v for k, v in payload.items() if k in allowed_params} - url = f"{url}/openai/deployments/{model}" + url = f'{url}/openai/deployments/{model}' return url, payload @@ -831,95 +777,87 @@ def convert_to_responses_payload(payload: dict) -> dict: Chat Completions: { messages: [{role, content}], ... } Responses API: { input: [{type: "message", role, content: [...]}], instructions: "system" } """ - messages = payload.pop("messages", []) + messages = payload.pop('messages', []) - system_content = "" + system_content = '' input_items = [] for msg in messages: - role = msg.get("role", "user") - content = msg.get("content", "") + role = msg.get('role', 'user') + content = msg.get('content', '') # Check for stored output items (from previous Responses API turn) - stored_output = msg.get("output") + stored_output = msg.get('output') if stored_output and isinstance(stored_output, list): input_items.extend(stored_output) continue - if role == "system": + if role == 'system': if isinstance(content, str): system_content = content elif isinstance(content, list): - system_content = "\n".join( - p.get("text", "") for p in content if p.get("type") == "text" - ) + system_content = '\n'.join(p.get('text', '') for p in content if p.get('type') == 'text') continue # Convert content format - text_type = "output_text" if role == "assistant" else "input_text" + text_type = 'output_text' if role == 'assistant' else 'input_text' if isinstance(content, str): - content_parts = [{"type": text_type, "text": content}] + content_parts = [{'type': text_type, 'text': content}] elif isinstance(content, list): content_parts = [] for part in content: - if part.get("type") == "text": - content_parts.append( - {"type": text_type, "text": part.get("text", "")} - ) - elif part.get("type") == "image_url": - url_data = part.get("image_url", {}) - url = ( - url_data.get("url", "") - if isinstance(url_data, dict) - else url_data - ) - content_parts.append({"type": "input_image", "image_url": url}) + if part.get('type') == 'text': + content_parts.append({'type': text_type, 'text': part.get('text', '')}) + elif part.get('type') == 'image_url': + url_data = part.get('image_url', {}) + url = url_data.get('url', '') if isinstance(url_data, dict) else url_data + content_parts.append({'type': 'input_image', 'image_url': url}) else: - content_parts = [{"type": text_type, "text": str(content)}] + content_parts = [{'type': text_type, 'text': str(content)}] - input_items.append({"type": "message", "role": role, "content": content_parts}) + input_items.append({'type': 'message', 'role': role, 'content': content_parts}) - responses_payload = {**payload, "input": input_items} + responses_payload = {**payload, 'input': input_items} if system_content: - responses_payload["instructions"] = system_content + responses_payload['instructions'] = system_content - if "max_tokens" in responses_payload: - responses_payload["max_output_tokens"] = responses_payload.pop("max_tokens") + if 'max_tokens' in responses_payload: + responses_payload['max_output_tokens'] = responses_payload.pop('max_tokens') # Remove Chat Completions-only parameters not supported by the Responses API for unsupported_key in ( - "stream_options", - "logit_bias", - "frequency_penalty", - "presence_penalty", - "stop", + 'stream_options', + 'logit_bias', + 'frequency_penalty', + 'presence_penalty', + 'stop', ): responses_payload.pop(unsupported_key, None) # Convert Chat Completions tools format to Responses API format # Chat Completions: {"type": "function", "function": {"name": ..., "description": ..., "parameters": ...}} # Responses API: {"type": "function", "name": ..., "description": ..., "parameters": ...} - if "tools" in responses_payload and isinstance(responses_payload["tools"], list): + if 'tools' in responses_payload and isinstance(responses_payload['tools'], list): converted_tools = [] - for tool in responses_payload["tools"]: - if isinstance(tool, dict) and "function" in tool: - func = tool["function"] - converted_tool = {"type": tool.get("type", "function")} + for tool in responses_payload['tools']: + if isinstance(tool, dict) and 'function' in tool: + func = tool['function'] + converted_tool = {'type': tool.get('type', 'function')} if isinstance(func, dict): - converted_tool["name"] = func.get("name", "") - if "description" in func: - converted_tool["description"] = func["description"] - if "parameters" in func: - converted_tool["parameters"] = func["parameters"] - if "strict" in func: - converted_tool["strict"] = func["strict"] + converted_tool['name'] = func.get('name', '') + if 'description' in func: + converted_tool['description'] = func['description'] + if 'parameters' in func: + converted_tool['parameters'] = func['parameters'] + if 'strict' in func: + converted_tool['strict'] = func['strict'] converted_tools.append(converted_tool) else: # Already in correct format or unknown format, pass through converted_tools.append(tool) - responses_payload["tools"] = converted_tools + responses_payload['tools'] = converted_tools return responses_payload @@ -929,11 +867,11 @@ def convert_responses_result(response: dict) -> dict: Convert non-streaming Responses API result. Just add done flag - pass through raw response, frontend handles output. """ - response["done"] = True + response['done'] = True return response -@router.post("/chat/completions") +@router.post('/chat/completions') async def generate_chat_completion( request: Request, form_data: dict, @@ -948,62 +886,58 @@ async def generate_chat_completion( # bypass_filter is read from request.state to prevent external clients from # setting it via query parameter (CVE fix). Only internal server-side callers # (e.g. utils/chat.py) should set request.state.bypass_filter = True. - bypass_filter = getattr(request.state, "bypass_filter", False) + bypass_filter = getattr(request.state, 'bypass_filter', False) if BYPASS_MODEL_ACCESS_CONTROL: bypass_filter = True idx = 0 payload = {**form_data} - metadata = payload.pop("metadata", None) + metadata = payload.pop('metadata', None) - model_id = form_data.get("model") + model_id = form_data.get('model') model_info = Models.get_model_by_id(model_id) # Check model info and override the payload if model_info: if model_info.base_model_id: base_model_id = ( - request.base_model_id - if hasattr(request, "base_model_id") - else model_info.base_model_id + request.base_model_id if hasattr(request, 'base_model_id') else model_info.base_model_id ) # Use request's base_model_id if available - payload["model"] = base_model_id + payload['model'] = base_model_id model_id = base_model_id params = model_info.params.model_dump() if params: - system = params.pop("system", None) + system = params.pop('system', None) payload = apply_model_params_to_body_openai(params, payload) if not bypass_system_prompt: payload = apply_system_prompt_to_body(system, payload, metadata, user) # Check if user has access to the model - if not bypass_filter and user.role == "user": - user_group_ids = { - group.id for group in Groups.get_groups_by_member_id(user.id) - } + if not bypass_filter and user.role == 'user': + user_group_ids = {group.id for group in Groups.get_groups_by_member_id(user.id)} if not ( user.id == model_info.user_id or AccessGrants.has_access( user_id=user.id, - resource_type="model", + resource_type='model', resource_id=model_info.id, - permission="read", + permission='read', user_group_ids=user_group_ids, ) ): raise HTTPException( status_code=403, - detail="Model not found", + detail='Model not found', ) elif not bypass_filter: - if user.role != "admin": + if user.role != 'admin': raise HTTPException( status_code=403, - detail="Model not found", + detail='Model not found', ) # Check if model is already in app state cache to avoid expensive get_all_models() call @@ -1014,11 +948,11 @@ async def generate_chat_completion( model = models.get(model_id) if model: - idx = model["urlIdx"] + idx = model['urlIdx'] else: raise HTTPException( status_code=404, - detail="Model not found", + detail='Model not found', ) # Get the API config for the model @@ -1029,69 +963,67 @@ async def generate_chat_completion( ), # Legacy support ) - prefix_id = api_config.get("prefix_id", None) + prefix_id = api_config.get('prefix_id', None) if prefix_id: - payload["model"] = payload["model"].replace(f"{prefix_id}.", "") + payload['model'] = payload['model'].replace(f'{prefix_id}.', '') # Add user info to the payload if the model is a pipeline - if "pipeline" in model and model.get("pipeline"): - payload["user"] = { - "name": user.name, - "id": user.id, - "email": user.email, - "role": user.role, + if 'pipeline' in model and model.get('pipeline'): + payload['user'] = { + 'name': user.name, + 'id': user.id, + 'email': user.email, + 'role': user.role, } url = request.app.state.config.OPENAI_API_BASE_URLS[idx] key = request.app.state.config.OPENAI_API_KEYS[idx] # Check if model is a reasoning model that needs special handling - if is_openai_reasoning_model(payload["model"]): + if is_openai_reasoning_model(payload['model']): payload = openai_reasoning_model_handler(payload) - elif "api.openai.com" not in url: + elif 'api.openai.com' not in url: # Remove "max_completion_tokens" from the payload for backward compatibility - if "max_completion_tokens" in payload: - payload["max_tokens"] = payload["max_completion_tokens"] - del payload["max_completion_tokens"] + if 'max_completion_tokens' in payload: + payload['max_tokens'] = payload['max_completion_tokens'] + del payload['max_completion_tokens'] - if "max_tokens" in payload and "max_completion_tokens" in payload: - del payload["max_tokens"] + if 'max_tokens' in payload and 'max_completion_tokens' in payload: + del payload['max_tokens'] # Convert the modified body back to JSON - if "logit_bias" in payload and payload["logit_bias"]: - logit_bias = convert_logit_bias_input_to_json(payload["logit_bias"]) + if 'logit_bias' in payload and payload['logit_bias']: + logit_bias = convert_logit_bias_input_to_json(payload['logit_bias']) if logit_bias: - payload["logit_bias"] = json.loads(logit_bias) + payload['logit_bias'] = json.loads(logit_bias) - headers, cookies = await get_headers_and_cookies( - request, url, key, api_config, metadata, user=user - ) + headers, cookies = await get_headers_and_cookies(request, url, key, api_config, metadata, user=user) - is_responses = api_config.get("api_type") == "responses" + is_responses = api_config.get('api_type') == 'responses' - if api_config.get("azure", False): - api_version = api_config.get("api_version", "2023-03-15-preview") + if api_config.get('azure', False): + api_version = api_config.get('api_version', '2023-03-15-preview') request_url, payload = convert_to_azure_payload(url, payload, api_version) # Only set api-key header if not using Azure Entra ID authentication - auth_type = api_config.get("auth_type", "bearer") - if auth_type not in ("azure_ad", "microsoft_entra_id"): - headers["api-key"] = key + auth_type = api_config.get('auth_type', 'bearer') + if auth_type not in ('azure_ad', 'microsoft_entra_id'): + headers['api-key'] = key - headers["api-version"] = api_version + headers['api-version'] = api_version if is_responses: payload = convert_to_responses_payload(payload) - request_url = f"{request_url}/responses?api-version={api_version}" + request_url = f'{request_url}/responses?api-version={api_version}' else: - request_url = f"{request_url}/chat/completions?api-version={api_version}" + request_url = f'{request_url}/chat/completions?api-version={api_version}' else: if is_responses: payload = convert_to_responses_payload(payload) - request_url = f"{url}/responses" + request_url = f'{url}/responses' else: - request_url = f"{url}/chat/completions" + request_url = f'{url}/chat/completions' payload = json.dumps(payload) @@ -1101,12 +1033,10 @@ async def generate_chat_completion( response = None try: - session = aiohttp.ClientSession( - trust_env=True, timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT) - ) + session = aiohttp.ClientSession(trust_env=True, timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT)) r = await session.request( - method="POST", + method='POST', url=request_url, data=payload, headers=headers, @@ -1115,7 +1045,7 @@ async def generate_chat_completion( ) # Check if response is SSE - if "text/event-stream" in r.headers.get("Content-Type", ""): + if 'text/event-stream' in r.headers.get('Content-Type', ''): streaming = True return StreamingResponse( stream_wrapper(r, session, stream_chunks_handler), @@ -1145,7 +1075,7 @@ async def generate_chat_completion( raise HTTPException( status_code=r.status if r else 500, - detail="Open WebUI: Server Connection Error", + detail='Open WebUI: Server Connection Error', ) finally: if not streaming: @@ -1168,14 +1098,14 @@ async def embeddings(request: Request, form_data: dict, user): # Prepare payload/body body = json.dumps(form_data) # Find correct backend url/key based on model - model_id = form_data.get("model") + model_id = form_data.get('model') # Check if model is already in app state cache to avoid expensive get_all_models() call models = request.app.state.OPENAI_MODELS if not models or model_id not in models: await get_all_models(request, user=user) models = request.app.state.OPENAI_MODELS if model_id in models: - idx = models[model_id]["urlIdx"] + idx = models[model_id]['urlIdx'] url = request.app.state.config.OPENAI_API_BASE_URLS[idx] key = request.app.state.config.OPENAI_API_KEYS[idx] @@ -1188,23 +1118,21 @@ async def embeddings(request: Request, form_data: dict, user): session = None streaming = False - headers, cookies = await get_headers_and_cookies( - request, url, key, api_config, user=user - ) + headers, cookies = await get_headers_and_cookies(request, url, key, api_config, user=user) try: session = aiohttp.ClientSession( trust_env=True, timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT), ) r = await session.request( - method="POST", - url=f"{url}/embeddings", + method='POST', + url=f'{url}/embeddings', data=body, headers=headers, cookies=cookies, ) - if "text/event-stream" in r.headers.get("Content-Type", ""): + if 'text/event-stream' in r.headers.get('Content-Type', ''): streaming = True return StreamingResponse( stream_wrapper(r, session), @@ -1221,16 +1149,14 @@ async def embeddings(request: Request, form_data: dict, user): if isinstance(response_data, (dict, list)): return JSONResponse(status_code=r.status, content=response_data) else: - return PlainTextResponse( - status_code=r.status, content=response_data - ) + return PlainTextResponse(status_code=r.status, content=response_data) return response_data except Exception as e: log.exception(e) raise HTTPException( status_code=r.status if r else 500, - detail="Open WebUI: Server Connection Error", + detail='Open WebUI: Server Connection Error', ) finally: if not streaming: @@ -1238,7 +1164,7 @@ async def embeddings(request: Request, form_data: dict, user): class ResponsesForm(BaseModel): - model_config = ConfigDict(extra="allow") + model_config = ConfigDict(extra='allow') model: str input: Optional[list | str] = None @@ -1257,7 +1183,7 @@ class ResponsesForm(BaseModel): previous_response_id: Optional[str] = None -@router.post("/responses") +@router.post('/responses') async def responses( request: Request, form_data: ResponsesForm, @@ -1278,7 +1204,7 @@ async def responses( await get_all_models(request, user=user) models = request.app.state.OPENAI_MODELS if model_id in models: - idx = models[model_id]["urlIdx"] + idx = models[model_id]['urlIdx'] url = request.app.state.config.OPENAI_API_BASE_URLS[idx] key = request.app.state.config.OPENAI_API_KEYS[idx] @@ -1292,32 +1218,28 @@ async def responses( streaming = False try: - headers, cookies = await get_headers_and_cookies( - request, url, key, api_config, user=user - ) + headers, cookies = await get_headers_and_cookies(request, url, key, api_config, user=user) - if api_config.get("azure", False): - api_version = api_config.get("api_version", "2023-03-15-preview") + if api_config.get('azure', False): + api_version = api_config.get('api_version', '2023-03-15-preview') - auth_type = api_config.get("auth_type", "bearer") - if auth_type not in ("azure_ad", "microsoft_entra_id"): - headers["api-key"] = key + auth_type = api_config.get('auth_type', 'bearer') + if auth_type not in ('azure_ad', 'microsoft_entra_id'): + headers['api-key'] = key - headers["api-version"] = api_version + headers['api-version'] = api_version - model = payload.get("model", "") - request_url = ( - f"{url}/openai/deployments/{model}/responses?api-version={api_version}" - ) + model = payload.get('model', '') + request_url = f'{url}/openai/deployments/{model}/responses?api-version={api_version}' else: - request_url = f"{url}/responses" + request_url = f'{url}/responses' session = aiohttp.ClientSession( trust_env=True, timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT), ) r = await session.request( - method="POST", + method='POST', url=request_url, data=body, headers=headers, @@ -1326,7 +1248,7 @@ async def responses( ) # Check if response is SSE - if "text/event-stream" in r.headers.get("Content-Type", ""): + if 'text/event-stream' in r.headers.get('Content-Type', ''): streaming = True return StreamingResponse( stream_wrapper(r, session), @@ -1343,9 +1265,7 @@ async def responses( if isinstance(response_data, (dict, list)): return JSONResponse(status_code=r.status, content=response_data) else: - return PlainTextResponse( - status_code=r.status, content=response_data - ) + return PlainTextResponse(status_code=r.status, content=response_data) return response_data @@ -1353,14 +1273,14 @@ async def responses( log.exception(e) raise HTTPException( status_code=r.status if r else 500, - detail="Open WebUI: Server Connection Error", + detail='Open WebUI: Server Connection Error', ) finally: if not streaming: await cleanup_response(r, session) -@router.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"]) +@router.api_route('/{path:path}', methods=['GET', 'POST', 'PUT', 'DELETE']) async def proxy(path: str, request: Request, user=Depends(get_verified_user)): """ Deprecated: proxy all requests to OpenAI API @@ -1377,14 +1297,14 @@ async def proxy(path: str, request: Request, user=Depends(get_verified_user)): payload = None idx = 0 - model_id = payload.get("model") if isinstance(payload, dict) else None + model_id = payload.get('model') if isinstance(payload, dict) else None if model_id: models = request.app.state.OPENAI_MODELS if not models or model_id not in models: await get_all_models(request, user=user) models = request.app.state.OPENAI_MODELS if model_id in models: - idx = models[model_id]["urlIdx"] + idx = models[model_id]['urlIdx'] url = request.app.state.config.OPENAI_API_BASE_URLS[idx] key = request.app.state.config.OPENAI_API_KEYS[idx] @@ -1400,27 +1320,25 @@ async def proxy(path: str, request: Request, user=Depends(get_verified_user)): streaming = False try: - headers, cookies = await get_headers_and_cookies( - request, url, key, api_config, user=user - ) + headers, cookies = await get_headers_and_cookies(request, url, key, api_config, user=user) - if api_config.get("azure", False): - api_version = api_config.get("api_version", "2023-03-15-preview") + if api_config.get('azure', False): + api_version = api_config.get('api_version', '2023-03-15-preview') # Only set api-key header if not using Azure Entra ID authentication - auth_type = api_config.get("auth_type", "bearer") - if auth_type not in ("azure_ad", "microsoft_entra_id"): - headers["api-key"] = key + auth_type = api_config.get('auth_type', 'bearer') + if auth_type not in ('azure_ad', 'microsoft_entra_id'): + headers['api-key'] = key - headers["api-version"] = api_version + headers['api-version'] = api_version payload = json.loads(body) url, payload = convert_to_azure_payload(url, payload, api_version) body = json.dumps(payload).encode() - request_url = f"{url}/{path}?api-version={api_version}" + request_url = f'{url}/{path}?api-version={api_version}' else: - request_url = f"{url}/{path}" + request_url = f'{url}/{path}' session = aiohttp.ClientSession( trust_env=True, @@ -1436,7 +1354,7 @@ async def proxy(path: str, request: Request, user=Depends(get_verified_user)): ) # Check if response is SSE - if "text/event-stream" in r.headers.get("Content-Type", ""): + if 'text/event-stream' in r.headers.get('Content-Type', ''): streaming = True return StreamingResponse( stream_wrapper(r, session), @@ -1453,9 +1371,7 @@ async def proxy(path: str, request: Request, user=Depends(get_verified_user)): if isinstance(response_data, (dict, list)): return JSONResponse(status_code=r.status, content=response_data) else: - return PlainTextResponse( - status_code=r.status, content=response_data - ) + return PlainTextResponse(status_code=r.status, content=response_data) return response_data @@ -1463,7 +1379,7 @@ async def proxy(path: str, request: Request, user=Depends(get_verified_user)): log.exception(e) raise HTTPException( status_code=r.status if r else 500, - detail="Open WebUI: Server Connection Error", + detail='Open WebUI: Server Connection Error', ) finally: if not streaming: diff --git a/backend/open_webui/routers/pipelines.py b/backend/open_webui/routers/pipelines.py index 24cf6682c7..4f1022476b 100644 --- a/backend/open_webui/routers/pipelines.py +++ b/backend/open_webui/routers/pipelines.py @@ -40,33 +40,30 @@ def get_sorted_filters(model_id, models): filters = [ model for model in models.values() - if "pipeline" in model - and "type" in model["pipeline"] - and model["pipeline"]["type"] == "filter" + if 'pipeline' in model + and 'type' in model['pipeline'] + and model['pipeline']['type'] == 'filter' and ( - model["pipeline"]["pipelines"] == ["*"] - or any( - model_id == target_model_id - for target_model_id in model["pipeline"]["pipelines"] - ) + model['pipeline']['pipelines'] == ['*'] + or any(model_id == target_model_id for target_model_id in model['pipeline']['pipelines']) ) ] - sorted_filters = sorted(filters, key=lambda x: x["pipeline"]["priority"]) + sorted_filters = sorted(filters, key=lambda x: x['pipeline']['priority']) return sorted_filters async def process_pipeline_inlet_filter(request, payload, user, models): - user = {"id": user.id, "email": user.email, "name": user.name, "role": user.role} - model_id = payload["model"] + user = {'id': user.id, 'email': user.email, 'name': user.name, 'role': user.role} + model_id = payload['model'] sorted_filters = get_sorted_filters(model_id, models) model = models[model_id] - if "pipeline" in model: + if 'pipeline' in model: sorted_filters.append(model) async with aiohttp.ClientSession(trust_env=True) as session: for filter in sorted_filters: - urlIdx = filter.get("urlIdx") + urlIdx = filter.get('urlIdx') try: urlIdx = int(urlIdx) @@ -79,15 +76,15 @@ async def process_pipeline_inlet_filter(request, payload, user, models): if not key: continue - headers = {"Authorization": f"Bearer {key}"} + headers = {'Authorization': f'Bearer {key}'} request_data = { - "user": user, - "body": payload, + 'user': user, + 'body': payload, } try: async with session.post( - f"{url}/{filter['id']}/filter/inlet", + f'{url}/{filter["id"]}/filter/inlet', headers=headers, json=request_data, ssl=AIOHTTP_CLIENT_SESSION_SSL, @@ -95,31 +92,27 @@ async def process_pipeline_inlet_filter(request, payload, user, models): response.raise_for_status() payload = await response.json() except aiohttp.ClientResponseError as e: - res = ( - await response.json() - if response.content_type == "application/json" - else {} - ) - if "detail" in res: - raise Exception(response.status, res["detail"]) + res = await response.json() if response.content_type == 'application/json' else {} + if 'detail' in res: + raise Exception(response.status, res['detail']) except Exception as e: - log.exception(f"Connection error: {e}") + log.exception(f'Connection error: {e}') return payload async def process_pipeline_outlet_filter(request, payload, user, models): - user = {"id": user.id, "email": user.email, "name": user.name, "role": user.role} - model_id = payload["model"] + user = {'id': user.id, 'email': user.email, 'name': user.name, 'role': user.role} + model_id = payload['model'] sorted_filters = get_sorted_filters(model_id, models) model = models[model_id] - if "pipeline" in model: + if 'pipeline' in model: sorted_filters = [model] + sorted_filters async with aiohttp.ClientSession(trust_env=True) as session: for filter in sorted_filters: - urlIdx = filter.get("urlIdx") + urlIdx = filter.get('urlIdx') try: urlIdx = int(urlIdx) @@ -132,15 +125,15 @@ async def process_pipeline_outlet_filter(request, payload, user, models): if not key: continue - headers = {"Authorization": f"Bearer {key}"} + headers = {'Authorization': f'Bearer {key}'} request_data = { - "user": user, - "body": payload, + 'user': user, + 'body': payload, } try: async with session.post( - f"{url}/{filter['id']}/filter/outlet", + f'{url}/{filter["id"]}/filter/outlet', headers=headers, json=request_data, ssl=AIOHTTP_CLIENT_SESSION_SSL, @@ -149,17 +142,13 @@ async def process_pipeline_outlet_filter(request, payload, user, models): payload = await response.json() except aiohttp.ClientResponseError as e: try: - res = ( - await response.json() - if "application/json" in response.content_type - else {} - ) - if "detail" in res: + res = await response.json() if 'application/json' in response.content_type else {} + if 'detail' in res: raise Exception(response.status, res) except Exception: pass except Exception as e: - log.exception(f"Connection error: {e}") + log.exception(f'Connection error: {e}') return payload @@ -173,72 +162,68 @@ async def process_pipeline_outlet_filter(request, payload, user, models): router = APIRouter() -@router.get("/list") +@router.get('/list') async def get_pipelines_list(request: Request, user=Depends(get_admin_user)): responses = await get_all_models_responses(request, user) - log.debug(f"get_pipelines_list: get_openai_models_responses returned {responses}") + log.debug(f'get_pipelines_list: get_openai_models_responses returned {responses}') - urlIdxs = [ - idx - for idx, response in enumerate(responses) - if response is not None and "pipelines" in response - ] + urlIdxs = [idx for idx, response in enumerate(responses) if response is not None and 'pipelines' in response] return { - "data": [ + 'data': [ { - "url": request.app.state.config.OPENAI_API_BASE_URLS[urlIdx], - "idx": urlIdx, + 'url': request.app.state.config.OPENAI_API_BASE_URLS[urlIdx], + 'idx': urlIdx, } for urlIdx in urlIdxs ] } -@router.post("/upload") +@router.post('/upload') async def upload_pipeline( request: Request, urlIdx: int = Form(...), file: UploadFile = File(...), user=Depends(get_admin_user), ): - log.info(f"upload_pipeline: urlIdx={urlIdx}, filename={file.filename}") + log.info(f'upload_pipeline: urlIdx={urlIdx}, filename={file.filename}') filename = os.path.basename(file.filename) # Check if the uploaded file is a python file - if not (filename and filename.endswith(".py")): + if not (filename and filename.endswith('.py')): raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, - detail="Only Python (.py) files are allowed.", + detail='Only Python (.py) files are allowed.', ) - upload_folder = f"{CACHE_DIR}/pipelines" + upload_folder = f'{CACHE_DIR}/pipelines' os.makedirs(upload_folder, exist_ok=True) file_path = os.path.join(upload_folder, filename) response = None try: # Save the uploaded file - with open(file_path, "wb") as buffer: + with open(file_path, 'wb') as buffer: shutil.copyfileobj(file.file, buffer) url = request.app.state.config.OPENAI_API_BASE_URLS[urlIdx] key = request.app.state.config.OPENAI_API_KEYS[urlIdx] - headers = {"Authorization": f"Bearer {key}"} + headers = {'Authorization': f'Bearer {key}'} async with aiohttp.ClientSession(trust_env=True) as session: - with open(file_path, "rb") as f: + with open(file_path, 'rb') as f: form_data = aiohttp.FormData() form_data.add_field( - "file", + 'file', f, filename=filename, - content_type="application/octet-stream", + content_type='application/octet-stream', ) async with session.post( - f"{url}/pipelines/upload", + f'{url}/pipelines/upload', headers=headers, data=form_data, ssl=AIOHTTP_CLIENT_SESSION_SSL, @@ -249,7 +234,7 @@ async def upload_pipeline( return {**data} except Exception as e: # Handle connection error here - log.exception(f"Connection error: {e}") + log.exception(f'Connection error: {e}') detail = None status_code = status.HTTP_404_NOT_FOUND @@ -257,14 +242,14 @@ async def upload_pipeline( status_code = response.status try: res = await response.json() - if "detail" in res: - detail = res["detail"] + if 'detail' in res: + detail = res['detail'] except Exception: pass raise HTTPException( status_code=status_code, - detail=detail if detail else "Pipeline not found", + detail=detail if detail else 'Pipeline not found', ) finally: # Ensure the file is deleted after the upload is completed or on failure @@ -277,10 +262,8 @@ class AddPipelineForm(BaseModel): urlIdx: int -@router.post("/add") -async def add_pipeline( - request: Request, form_data: AddPipelineForm, user=Depends(get_admin_user) -): +@router.post('/add') +async def add_pipeline(request: Request, form_data: AddPipelineForm, user=Depends(get_admin_user)): response = None try: urlIdx = form_data.urlIdx @@ -290,9 +273,9 @@ async def add_pipeline( async with aiohttp.ClientSession(trust_env=True) as session: async with session.post( - f"{url}/pipelines/add", - headers={"Authorization": f"Bearer {key}"}, - json={"url": form_data.url}, + f'{url}/pipelines/add', + headers={'Authorization': f'Bearer {key}'}, + json={'url': form_data.url}, ssl=AIOHTTP_CLIENT_SESSION_SSL, ) as response: response.raise_for_status() @@ -301,22 +284,20 @@ async def add_pipeline( return {**data} except Exception as e: # Handle connection error here - log.exception(f"Connection error: {e}") + log.exception(f'Connection error: {e}') detail = None if response is not None: try: res = await response.json() - if "detail" in res: - detail = res["detail"] + if 'detail' in res: + detail = res['detail'] except Exception: pass raise HTTPException( - status_code=( - response.status if response is not None else status.HTTP_404_NOT_FOUND - ), - detail=detail if detail else "Pipeline not found", + status_code=(response.status if response is not None else status.HTTP_404_NOT_FOUND), + detail=detail if detail else 'Pipeline not found', ) @@ -325,10 +306,8 @@ class DeletePipelineForm(BaseModel): urlIdx: int -@router.delete("/delete") -async def delete_pipeline( - request: Request, form_data: DeletePipelineForm, user=Depends(get_admin_user) -): +@router.delete('/delete') +async def delete_pipeline(request: Request, form_data: DeletePipelineForm, user=Depends(get_admin_user)): response = None try: urlIdx = form_data.urlIdx @@ -338,9 +317,9 @@ async def delete_pipeline( async with aiohttp.ClientSession(trust_env=True) as session: async with session.delete( - f"{url}/pipelines/delete", - headers={"Authorization": f"Bearer {key}"}, - json={"id": form_data.id}, + f'{url}/pipelines/delete', + headers={'Authorization': f'Bearer {key}'}, + json={'id': form_data.id}, ssl=AIOHTTP_CLIENT_SESSION_SSL, ) as response: response.raise_for_status() @@ -349,29 +328,25 @@ async def delete_pipeline( return {**data} except Exception as e: # Handle connection error here - log.exception(f"Connection error: {e}") + log.exception(f'Connection error: {e}') detail = None if response is not None: try: res = await response.json() - if "detail" in res: - detail = res["detail"] + if 'detail' in res: + detail = res['detail'] except Exception: pass raise HTTPException( - status_code=( - response.status if response is not None else status.HTTP_404_NOT_FOUND - ), - detail=detail if detail else "Pipeline not found", + status_code=(response.status if response is not None else status.HTTP_404_NOT_FOUND), + detail=detail if detail else 'Pipeline not found', ) -@router.get("/") -async def get_pipelines( - request: Request, urlIdx: Optional[int] = None, user=Depends(get_admin_user) -): +@router.get('/') +async def get_pipelines(request: Request, urlIdx: Optional[int] = None, user=Depends(get_admin_user)): response = None try: url = request.app.state.config.OPENAI_API_BASE_URLS[urlIdx] @@ -379,8 +354,8 @@ async def get_pipelines( async with aiohttp.ClientSession(trust_env=True) as session: async with session.get( - f"{url}/pipelines", - headers={"Authorization": f"Bearer {key}"}, + f'{url}/pipelines', + headers={'Authorization': f'Bearer {key}'}, ssl=AIOHTTP_CLIENT_SESSION_SSL, ) as response: response.raise_for_status() @@ -389,26 +364,24 @@ async def get_pipelines( return {**data} except Exception as e: # Handle connection error here - log.exception(f"Connection error: {e}") + log.exception(f'Connection error: {e}') detail = None if response is not None: try: res = await response.json() - if "detail" in res: - detail = res["detail"] + if 'detail' in res: + detail = res['detail'] except Exception: pass raise HTTPException( - status_code=( - response.status if response is not None else status.HTTP_404_NOT_FOUND - ), - detail=detail if detail else "Pipeline not found", + status_code=(response.status if response is not None else status.HTTP_404_NOT_FOUND), + detail=detail if detail else 'Pipeline not found', ) -@router.get("/{pipeline_id}/valves") +@router.get('/{pipeline_id}/valves') async def get_pipeline_valves( request: Request, urlIdx: Optional[int], @@ -422,8 +395,8 @@ async def get_pipeline_valves( async with aiohttp.ClientSession(trust_env=True) as session: async with session.get( - f"{url}/{pipeline_id}/valves", - headers={"Authorization": f"Bearer {key}"}, + f'{url}/{pipeline_id}/valves', + headers={'Authorization': f'Bearer {key}'}, ssl=AIOHTTP_CLIENT_SESSION_SSL, ) as response: response.raise_for_status() @@ -432,26 +405,24 @@ async def get_pipeline_valves( return {**data} except Exception as e: # Handle connection error here - log.exception(f"Connection error: {e}") + log.exception(f'Connection error: {e}') detail = None if response is not None: try: res = await response.json() - if "detail" in res: - detail = res["detail"] + if 'detail' in res: + detail = res['detail'] except Exception: pass raise HTTPException( - status_code=( - response.status if response is not None else status.HTTP_404_NOT_FOUND - ), - detail=detail if detail else "Pipeline not found", + status_code=(response.status if response is not None else status.HTTP_404_NOT_FOUND), + detail=detail if detail else 'Pipeline not found', ) -@router.get("/{pipeline_id}/valves/spec") +@router.get('/{pipeline_id}/valves/spec') async def get_pipeline_valves_spec( request: Request, urlIdx: Optional[int], @@ -465,8 +436,8 @@ async def get_pipeline_valves_spec( async with aiohttp.ClientSession(trust_env=True) as session: async with session.get( - f"{url}/{pipeline_id}/valves/spec", - headers={"Authorization": f"Bearer {key}"}, + f'{url}/{pipeline_id}/valves/spec', + headers={'Authorization': f'Bearer {key}'}, ssl=AIOHTTP_CLIENT_SESSION_SSL, ) as response: response.raise_for_status() @@ -475,26 +446,24 @@ async def get_pipeline_valves_spec( return {**data} except Exception as e: # Handle connection error here - log.exception(f"Connection error: {e}") + log.exception(f'Connection error: {e}') detail = None if response is not None: try: res = await response.json() - if "detail" in res: - detail = res["detail"] + if 'detail' in res: + detail = res['detail'] except Exception: pass raise HTTPException( - status_code=( - response.status if response is not None else status.HTTP_404_NOT_FOUND - ), - detail=detail if detail else "Pipeline not found", + status_code=(response.status if response is not None else status.HTTP_404_NOT_FOUND), + detail=detail if detail else 'Pipeline not found', ) -@router.post("/{pipeline_id}/valves/update") +@router.post('/{pipeline_id}/valves/update') async def update_pipeline_valves( request: Request, urlIdx: Optional[int], @@ -509,8 +478,8 @@ async def update_pipeline_valves( async with aiohttp.ClientSession(trust_env=True) as session: async with session.post( - f"{url}/{pipeline_id}/valves/update", - headers={"Authorization": f"Bearer {key}"}, + f'{url}/{pipeline_id}/valves/update', + headers={'Authorization': f'Bearer {key}'}, json={**form_data}, ssl=AIOHTTP_CLIENT_SESSION_SSL, ) as response: @@ -520,21 +489,19 @@ async def update_pipeline_valves( return {**data} except Exception as e: # Handle connection error here - log.exception(f"Connection error: {e}") + log.exception(f'Connection error: {e}') detail = None if response is not None: try: res = await response.json() - if "detail" in res: - detail = res["detail"] + if 'detail' in res: + detail = res['detail'] except Exception: pass raise HTTPException( - status_code=( - response.status if response is not None else status.HTTP_404_NOT_FOUND - ), - detail=detail if detail else "Pipeline not found", + status_code=(response.status if response is not None else status.HTTP_404_NOT_FOUND), + detail=detail if detail else 'Pipeline not found', ) diff --git a/backend/open_webui/routers/prompts.py b/backend/open_webui/routers/prompts.py index 1f3342dad7..df07c778c1 100644 --- a/backend/open_webui/routers/prompts.py +++ b/backend/open_webui/routers/prompts.py @@ -45,26 +45,22 @@ PAGE_ITEM_COUNT = 30 ############################ -@router.get("/", response_model=list[PromptModel]) -async def get_prompts( - user=Depends(get_verified_user), db: Session = Depends(get_session) -): - if user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL: +@router.get('/', response_model=list[PromptModel]) +async def get_prompts(user=Depends(get_verified_user), db: Session = Depends(get_session)): + if user.role == 'admin' and BYPASS_ADMIN_ACCESS_CONTROL: prompts = Prompts.get_prompts(db=db) else: - prompts = Prompts.get_prompts_by_user_id(user.id, "read", db=db) + prompts = Prompts.get_prompts_by_user_id(user.id, 'read', db=db) return prompts -@router.get("/tags", response_model=list[str]) -async def get_prompt_tags( - user=Depends(get_verified_user), db: Session = Depends(get_session) -): - if user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL: +@router.get('/tags', response_model=list[str]) +async def get_prompt_tags(user=Depends(get_verified_user), db: Session = Depends(get_session)): + if user.role == 'admin' and BYPASS_ADMIN_ACCESS_CONTROL: return Prompts.get_tags(db=db) else: - prompts = Prompts.get_prompts_by_user_id(user.id, "read", db=db) + prompts = Prompts.get_prompts_by_user_id(user.id, 'read', db=db) tags = set() for prompt in prompts: if prompt.tags: @@ -72,7 +68,7 @@ async def get_prompt_tags( return sorted(list(tags)) -@router.get("/list", response_model=PromptAccessListResponse) +@router.get('/list', response_model=PromptAccessListResponse) async def get_prompt_list( query: Optional[str] = None, view_option: Optional[str] = None, @@ -90,37 +86,35 @@ async def get_prompt_list( filter = {} if query: - filter["query"] = query + filter['query'] = query if view_option: - filter["view_option"] = view_option + filter['view_option'] = view_option if tag: - filter["tag"] = tag + filter['tag'] = tag if order_by: - filter["order_by"] = order_by + filter['order_by'] = order_by if direction: - filter["direction"] = direction + filter['direction'] = direction # Pre-fetch user group IDs once - used for both filter and write_access check groups = Groups.get_groups_by_member_id(user.id, db=db) user_group_ids = {group.id for group in groups} - if not (user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL): + if not (user.role == 'admin' and BYPASS_ADMIN_ACCESS_CONTROL): if groups: - filter["group_ids"] = [group.id for group in groups] + filter['group_ids'] = [group.id for group in groups] - filter["user_id"] = user.id + filter['user_id'] = user.id - result = Prompts.search_prompts( - user.id, filter=filter, skip=skip, limit=limit, db=db - ) + result = Prompts.search_prompts(user.id, filter=filter, skip=skip, limit=limit, db=db) # Batch-fetch writable prompt IDs in a single query instead of N has_access calls prompt_ids = [prompt.id for prompt in result.items] writable_prompt_ids = AccessGrants.get_accessible_resource_ids( user_id=user.id, - resource_type="prompt", + resource_type='prompt', resource_ids=prompt_ids, - permission="write", + permission='write', user_group_ids=user_group_ids, db=db, ) @@ -130,7 +124,7 @@ async def get_prompt_list( PromptAccessResponse( **prompt.model_dump(), write_access=( - (user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL) + (user.role == 'admin' and BYPASS_ADMIN_ACCESS_CONTROL) or user.id == prompt.user_id or prompt.id in writable_prompt_ids ), @@ -146,23 +140,23 @@ async def get_prompt_list( ############################ -@router.post("/create", response_model=Optional[PromptModel]) +@router.post('/create', response_model=Optional[PromptModel]) async def create_new_prompt( request: Request, form_data: PromptForm, user=Depends(get_verified_user), db: Session = Depends(get_session), ): - if user.role != "admin" and not ( + if user.role != 'admin' and not ( has_permission( user.id, - "workspace.prompts", + 'workspace.prompts', request.app.state.config.USER_PERMISSIONS, db=db, ) or has_permission( user.id, - "workspace.prompts_import", + 'workspace.prompts_import', request.app.state.config.USER_PERMISSIONS, db=db, ) @@ -193,34 +187,32 @@ async def create_new_prompt( ############################ -@router.get("/command/{command}", response_model=Optional[PromptAccessResponse]) -async def get_prompt_by_command( - command: str, user=Depends(get_verified_user), db: Session = Depends(get_session) -): +@router.get('/command/{command}', response_model=Optional[PromptAccessResponse]) +async def get_prompt_by_command(command: str, user=Depends(get_verified_user), db: Session = Depends(get_session)): prompt = Prompts.get_prompt_by_command(command, db=db) if prompt: if ( - user.role == "admin" + user.role == 'admin' or prompt.user_id == user.id or AccessGrants.has_access( user_id=user.id, - resource_type="prompt", + resource_type='prompt', resource_id=prompt.id, - permission="read", + permission='read', db=db, ) ): return PromptAccessResponse( **prompt.model_dump(), write_access=( - (user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL) + (user.role == 'admin' and BYPASS_ADMIN_ACCESS_CONTROL) or user.id == prompt.user_id or AccessGrants.has_access( user_id=user.id, - resource_type="prompt", + resource_type='prompt', resource_id=prompt.id, - permission="write", + permission='write', db=db, ) ), @@ -237,34 +229,32 @@ async def get_prompt_by_command( ############################ -@router.get("/id/{prompt_id}", response_model=Optional[PromptAccessResponse]) -async def get_prompt_by_id( - prompt_id: str, user=Depends(get_verified_user), db: Session = Depends(get_session) -): +@router.get('/id/{prompt_id}', response_model=Optional[PromptAccessResponse]) +async def get_prompt_by_id(prompt_id: str, user=Depends(get_verified_user), db: Session = Depends(get_session)): prompt = Prompts.get_prompt_by_id(prompt_id, db=db) if prompt: if ( - user.role == "admin" + user.role == 'admin' or prompt.user_id == user.id or AccessGrants.has_access( user_id=user.id, - resource_type="prompt", + resource_type='prompt', resource_id=prompt.id, - permission="read", + permission='read', db=db, ) ): return PromptAccessResponse( **prompt.model_dump(), write_access=( - (user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL) + (user.role == 'admin' and BYPASS_ADMIN_ACCESS_CONTROL) or user.id == prompt.user_id or AccessGrants.has_access( user_id=user.id, - resource_type="prompt", + resource_type='prompt', resource_id=prompt.id, - permission="write", + permission='write', db=db, ) ), @@ -281,7 +271,7 @@ async def get_prompt_by_id( ############################ -@router.post("/id/{prompt_id}/update", response_model=Optional[PromptModel]) +@router.post('/id/{prompt_id}/update', response_model=Optional[PromptModel]) async def update_prompt_by_id( prompt_id: str, form_data: PromptForm, @@ -301,12 +291,12 @@ async def update_prompt_by_id( prompt.user_id != user.id and not AccessGrants.has_access( user_id=user.id, - resource_type="prompt", + resource_type='prompt', resource_id=prompt.id, - permission="write", + permission='write', db=db, ) - and user.role != "admin" + and user.role != 'admin' ): raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, @@ -338,7 +328,7 @@ async def update_prompt_by_id( ############################ -@router.post("/id/{prompt_id}/update/meta", response_model=Optional[PromptModel]) +@router.post('/id/{prompt_id}/update/meta', response_model=Optional[PromptModel]) async def update_prompt_metadata( prompt_id: str, form_data: PromptMetadataForm, @@ -358,12 +348,12 @@ async def update_prompt_metadata( prompt.user_id != user.id and not AccessGrants.has_access( user_id=user.id, - resource_type="prompt", + resource_type='prompt', resource_id=prompt.id, - permission="write", + permission='write', db=db, ) - and user.role != "admin" + and user.role != 'admin' ): raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, @@ -379,9 +369,7 @@ async def update_prompt_metadata( detail=f"Command '/{form_data.command}' is already in use", ) - updated_prompt = Prompts.update_prompt_metadata( - prompt.id, form_data.name, form_data.command, form_data.tags, db=db - ) + updated_prompt = Prompts.update_prompt_metadata(prompt.id, form_data.name, form_data.command, form_data.tags, db=db) if updated_prompt: return updated_prompt else: @@ -391,7 +379,7 @@ async def update_prompt_metadata( ) -@router.post("/id/{prompt_id}/update/version", response_model=Optional[PromptModel]) +@router.post('/id/{prompt_id}/update/version', response_model=Optional[PromptModel]) async def set_prompt_version( prompt_id: str, form_data: PromptVersionUpdateForm, @@ -409,21 +397,19 @@ async def set_prompt_version( prompt.user_id != user.id and not AccessGrants.has_access( user_id=user.id, - resource_type="prompt", + resource_type='prompt', resource_id=prompt.id, - permission="write", + permission='write', db=db, ) - and user.role != "admin" + and user.role != 'admin' ): raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.ACCESS_PROHIBITED, ) - updated_prompt = Prompts.update_prompt_version( - prompt.id, form_data.version_id, db=db - ) + updated_prompt = Prompts.update_prompt_version(prompt.id, form_data.version_id, db=db) if updated_prompt: return updated_prompt else: @@ -442,7 +428,7 @@ class PromptAccessGrantsForm(BaseModel): access_grants: list[dict] -@router.post("/id/{prompt_id}/access/update", response_model=Optional[PromptModel]) +@router.post('/id/{prompt_id}/access/update', response_model=Optional[PromptModel]) async def update_prompt_access_by_id( request: Request, prompt_id: str, @@ -461,12 +447,12 @@ async def update_prompt_access_by_id( prompt.user_id != user.id and not AccessGrants.has_access( user_id=user.id, - resource_type="prompt", + resource_type='prompt', resource_id=prompt.id, - permission="write", + permission='write', db=db, ) - and user.role != "admin" + and user.role != 'admin' ): raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, @@ -478,10 +464,10 @@ async def update_prompt_access_by_id( user.id, user.role, form_data.access_grants, - "sharing.public_prompts", + 'sharing.public_prompts', ) - AccessGrants.set_access_grants("prompt", prompt_id, form_data.access_grants, db=db) + AccessGrants.set_access_grants('prompt', prompt_id, form_data.access_grants, db=db) return Prompts.get_prompt_by_id(prompt_id, db=db) @@ -491,10 +477,8 @@ async def update_prompt_access_by_id( ############################ -@router.post("/id/{prompt_id}/toggle", response_model=Optional[PromptModel]) -async def toggle_prompt_active( - prompt_id: str, user=Depends(get_verified_user), db: Session = Depends(get_session) -): +@router.post('/id/{prompt_id}/toggle', response_model=Optional[PromptModel]) +async def toggle_prompt_active(prompt_id: str, user=Depends(get_verified_user), db: Session = Depends(get_session)): prompt = Prompts.get_prompt_by_id(prompt_id, db=db) if not prompt: @@ -507,12 +491,12 @@ async def toggle_prompt_active( prompt.user_id != user.id and not AccessGrants.has_access( user_id=user.id, - resource_type="prompt", + resource_type='prompt', resource_id=prompt.id, - permission="write", + permission='write', db=db, ) - and user.role != "admin" + and user.role != 'admin' ): raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, @@ -533,10 +517,8 @@ async def toggle_prompt_active( ############################ -@router.delete("/id/{prompt_id}/delete", response_model=bool) -async def delete_prompt_by_id( - prompt_id: str, user=Depends(get_verified_user), db: Session = Depends(get_session) -): +@router.delete('/id/{prompt_id}/delete', response_model=bool) +async def delete_prompt_by_id(prompt_id: str, user=Depends(get_verified_user), db: Session = Depends(get_session)): prompt = Prompts.get_prompt_by_id(prompt_id, db=db) if not prompt: @@ -549,12 +531,12 @@ async def delete_prompt_by_id( prompt.user_id != user.id and not AccessGrants.has_access( user_id=user.id, - resource_type="prompt", + resource_type='prompt', resource_id=prompt.id, - permission="write", + permission='write', db=db, ) - and user.role != "admin" + and user.role != 'admin' ): raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, @@ -570,7 +552,7 @@ async def delete_prompt_by_id( ############################ -@router.get("/id/{prompt_id}/history", response_model=list[PromptHistoryResponse]) +@router.get('/id/{prompt_id}/history', response_model=list[PromptHistoryResponse]) async def get_prompt_history( prompt_id: str, page: int = 0, @@ -590,13 +572,13 @@ async def get_prompt_history( # Check read access if not ( - user.role == "admin" + user.role == 'admin' or prompt.user_id == user.id or AccessGrants.has_access( user_id=user.id, - resource_type="prompt", + resource_type='prompt', resource_id=prompt.id, - permission="read", + permission='read', db=db, ) ): @@ -605,13 +587,11 @@ async def get_prompt_history( detail=ERROR_MESSAGES.ACCESS_PROHIBITED, ) - history = PromptHistories.get_history_by_prompt_id( - prompt.id, limit=PAGE_SIZE, offset=page * PAGE_SIZE, db=db - ) + history = PromptHistories.get_history_by_prompt_id(prompt.id, limit=PAGE_SIZE, offset=page * PAGE_SIZE, db=db) return history -@router.get("/id/{prompt_id}/history/{history_id}", response_model=PromptHistoryModel) +@router.get('/id/{prompt_id}/history/{history_id}', response_model=PromptHistoryModel) async def get_prompt_history_entry( prompt_id: str, history_id: str, @@ -629,13 +609,13 @@ async def get_prompt_history_entry( # Check read access if not ( - user.role == "admin" + user.role == 'admin' or prompt.user_id == user.id or AccessGrants.has_access( user_id=user.id, - resource_type="prompt", + resource_type='prompt', resource_id=prompt.id, - permission="read", + permission='read', db=db, ) ): @@ -654,7 +634,7 @@ async def get_prompt_history_entry( return history_entry -@router.delete("/id/{prompt_id}/history/{history_id}", response_model=bool) +@router.delete('/id/{prompt_id}/history/{history_id}', response_model=bool) async def delete_prompt_history_entry( prompt_id: str, history_id: str, @@ -672,13 +652,13 @@ async def delete_prompt_history_entry( # Check write access if not ( - user.role == "admin" + user.role == 'admin' or prompt.user_id == user.id or AccessGrants.has_access( user_id=user.id, - resource_type="prompt", + resource_type='prompt', resource_id=prompt.id, - permission="write", + permission='write', db=db, ) ): @@ -691,7 +671,7 @@ async def delete_prompt_history_entry( if prompt.version_id == history_id: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, - detail="Cannot delete the active production version", + detail='Cannot delete the active production version', ) success = PromptHistories.delete_history_entry(history_id, db=db) @@ -704,7 +684,7 @@ async def delete_prompt_history_entry( return success -@router.get("/id/{prompt_id}/history/diff") +@router.get('/id/{prompt_id}/history/diff') async def get_prompt_diff( prompt_id: str, from_id: str, @@ -723,13 +703,13 @@ async def get_prompt_diff( # Check read access if not ( - user.role == "admin" + user.role == 'admin' or prompt.user_id == user.id or AccessGrants.has_access( user_id=user.id, - resource_type="prompt", + resource_type='prompt', resource_id=prompt.id, - permission="read", + permission='read', db=db, ) ): @@ -742,7 +722,7 @@ async def get_prompt_diff( if not diff: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, - detail="One or both history entries not found", + detail='One or both history entries not found', ) return diff diff --git a/backend/open_webui/routers/retrieval.py b/backend/open_webui/routers/retrieval.py index 3e509d2f88..30b69ee041 100644 --- a/backend/open_webui/routers/retrieval.py +++ b/backend/open_webui/routers/retrieval.py @@ -137,7 +137,7 @@ def get_ef( auto_update: bool = RAG_EMBEDDING_MODEL_AUTO_UPDATE, ): ef = None - if embedding_model and engine == "": + if embedding_model and engine == '': from sentence_transformers import SentenceTransformer try: @@ -149,39 +149,37 @@ def get_ef( model_kwargs=SENTENCE_TRANSFORMERS_MODEL_KWARGS, ) except Exception as e: - log.debug(f"Error loading SentenceTransformer: {e}") + log.debug(f'Error loading SentenceTransformer: {e}') return ef def get_rf( - engine: str = "", + engine: str = '', reranking_model: Optional[str] = None, - external_reranker_url: str = "", - external_reranker_api_key: str = "", - external_reranker_timeout: str = "", + external_reranker_url: str = '', + external_reranker_api_key: str = '', + external_reranker_timeout: str = '', auto_update: bool = RAG_RERANKING_MODEL_AUTO_UPDATE, ): rf = None # Convert timeout string to int or None (system default) - timeout_value = ( - int(external_reranker_timeout) if external_reranker_timeout else None - ) + timeout_value = int(external_reranker_timeout) if external_reranker_timeout else None if reranking_model: - if any(model in reranking_model for model in ["jinaai/jina-colbert-v2"]): + if any(model in reranking_model for model in ['jinaai/jina-colbert-v2']): try: from open_webui.retrieval.models.colbert import ColBERT rf = ColBERT( get_model_path(reranking_model, auto_update), - env="docker" if DOCKER else None, + env='docker' if DOCKER else None, ) except Exception as e: - log.error(f"ColBERT: {e}") + log.error(f'ColBERT: {e}') raise Exception(ERROR_MESSAGES.DEFAULT(e)) else: - if engine == "external": + if engine == 'external': try: from open_webui.retrieval.models.external import ExternalReranker @@ -192,7 +190,7 @@ def get_rf( timeout=timeout_value, ) except Exception as e: - log.error(f"ExternalReranking: {e}") + log.error(f'ExternalReranking: {e}') raise Exception(ERROR_MESSAGES.DEFAULT(e)) else: import sentence_transformers @@ -212,28 +210,24 @@ def get_rf( ), ) except Exception as e: - log.error(f"CrossEncoder: {e}") - raise Exception(ERROR_MESSAGES.DEFAULT("CrossEncoder error")) + log.error(f'CrossEncoder: {e}') + raise Exception(ERROR_MESSAGES.DEFAULT('CrossEncoder error')) # Safely adjust pad_token_id if missing as some models do not have this in config try: - model_cfg = getattr(rf, "model", None) - if model_cfg and hasattr(model_cfg, "config"): + model_cfg = getattr(rf, 'model', None) + if model_cfg and hasattr(model_cfg, 'config'): cfg = model_cfg.config - if getattr(cfg, "pad_token_id", None) is None: + if getattr(cfg, 'pad_token_id', None) is None: # Fallback to eos_token_id when available - eos = getattr(cfg, "eos_token_id", None) + eos = getattr(cfg, 'eos_token_id', None) if eos is not None: cfg.pad_token_id = eos - log.debug( - f"Missing pad_token_id detected; set to eos_token_id={eos}" - ) + log.debug(f'Missing pad_token_id detected; set to eos_token_id={eos}') else: - log.warning( - "Neither pad_token_id nor eos_token_id present in model config" - ) + log.warning('Neither pad_token_id nor eos_token_id present in model config') except Exception as e2: - log.warning(f"Failed to adjust pad_token_id on CrossEncoder: {e2}") + log.warning(f'Failed to adjust pad_token_id on CrossEncoder: {e2}') return rf @@ -260,43 +254,43 @@ class SearchForm(BaseModel): queries: List[str] -@router.get("/") +@router.get('/') async def get_status(request: Request): return { - "status": True, - "CHUNK_SIZE": request.app.state.config.CHUNK_SIZE, - "CHUNK_OVERLAP": request.app.state.config.CHUNK_OVERLAP, - "RAG_TEMPLATE": request.app.state.config.RAG_TEMPLATE, - "RAG_EMBEDDING_ENGINE": request.app.state.config.RAG_EMBEDDING_ENGINE, - "RAG_EMBEDDING_MODEL": request.app.state.config.RAG_EMBEDDING_MODEL, - "RAG_RERANKING_MODEL": request.app.state.config.RAG_RERANKING_MODEL, - "RAG_EMBEDDING_BATCH_SIZE": request.app.state.config.RAG_EMBEDDING_BATCH_SIZE, - "ENABLE_ASYNC_EMBEDDING": request.app.state.config.ENABLE_ASYNC_EMBEDDING, - "RAG_EMBEDDING_CONCURRENT_REQUESTS": request.app.state.config.RAG_EMBEDDING_CONCURRENT_REQUESTS, + 'status': True, + 'CHUNK_SIZE': request.app.state.config.CHUNK_SIZE, + 'CHUNK_OVERLAP': request.app.state.config.CHUNK_OVERLAP, + 'RAG_TEMPLATE': request.app.state.config.RAG_TEMPLATE, + 'RAG_EMBEDDING_ENGINE': request.app.state.config.RAG_EMBEDDING_ENGINE, + 'RAG_EMBEDDING_MODEL': request.app.state.config.RAG_EMBEDDING_MODEL, + 'RAG_RERANKING_MODEL': request.app.state.config.RAG_RERANKING_MODEL, + 'RAG_EMBEDDING_BATCH_SIZE': request.app.state.config.RAG_EMBEDDING_BATCH_SIZE, + 'ENABLE_ASYNC_EMBEDDING': request.app.state.config.ENABLE_ASYNC_EMBEDDING, + 'RAG_EMBEDDING_CONCURRENT_REQUESTS': request.app.state.config.RAG_EMBEDDING_CONCURRENT_REQUESTS, } -@router.get("/embedding") +@router.get('/embedding') async def get_embedding_config(request: Request, user=Depends(get_admin_user)): return { - "status": True, - "RAG_EMBEDDING_ENGINE": request.app.state.config.RAG_EMBEDDING_ENGINE, - "RAG_EMBEDDING_MODEL": request.app.state.config.RAG_EMBEDDING_MODEL, - "RAG_EMBEDDING_BATCH_SIZE": request.app.state.config.RAG_EMBEDDING_BATCH_SIZE, - "ENABLE_ASYNC_EMBEDDING": request.app.state.config.ENABLE_ASYNC_EMBEDDING, - "RAG_EMBEDDING_CONCURRENT_REQUESTS": request.app.state.config.RAG_EMBEDDING_CONCURRENT_REQUESTS, - "openai_config": { - "url": request.app.state.config.RAG_OPENAI_API_BASE_URL, - "key": request.app.state.config.RAG_OPENAI_API_KEY, + 'status': True, + 'RAG_EMBEDDING_ENGINE': request.app.state.config.RAG_EMBEDDING_ENGINE, + 'RAG_EMBEDDING_MODEL': request.app.state.config.RAG_EMBEDDING_MODEL, + 'RAG_EMBEDDING_BATCH_SIZE': request.app.state.config.RAG_EMBEDDING_BATCH_SIZE, + 'ENABLE_ASYNC_EMBEDDING': request.app.state.config.ENABLE_ASYNC_EMBEDDING, + 'RAG_EMBEDDING_CONCURRENT_REQUESTS': request.app.state.config.RAG_EMBEDDING_CONCURRENT_REQUESTS, + 'openai_config': { + 'url': request.app.state.config.RAG_OPENAI_API_BASE_URL, + 'key': request.app.state.config.RAG_OPENAI_API_KEY, }, - "ollama_config": { - "url": request.app.state.config.RAG_OLLAMA_BASE_URL, - "key": request.app.state.config.RAG_OLLAMA_API_KEY, + 'ollama_config': { + 'url': request.app.state.config.RAG_OLLAMA_BASE_URL, + 'key': request.app.state.config.RAG_OLLAMA_API_KEY, }, - "azure_openai_config": { - "url": request.app.state.config.RAG_AZURE_OPENAI_BASE_URL, - "key": request.app.state.config.RAG_AZURE_OPENAI_API_KEY, - "version": request.app.state.config.RAG_AZURE_OPENAI_API_VERSION, + 'azure_openai_config': { + 'url': request.app.state.config.RAG_AZURE_OPENAI_BASE_URL, + 'key': request.app.state.config.RAG_AZURE_OPENAI_API_KEY, + 'version': request.app.state.config.RAG_AZURE_OPENAI_API_VERSION, }, } @@ -329,72 +323,50 @@ class EmbeddingModelUpdateForm(BaseModel): def unload_embedding_model(request: Request): - if request.app.state.config.RAG_EMBEDDING_ENGINE == "": + if request.app.state.config.RAG_EMBEDDING_ENGINE == '': # unloads current internal embedding model and clears VRAM cache request.app.state.ef = None request.app.state.EMBEDDING_FUNCTION = None import gc gc.collect() - if DEVICE_TYPE == "cuda": + if DEVICE_TYPE == 'cuda': import torch if torch.cuda.is_available(): torch.cuda.empty_cache() -@router.post("/embedding/update") -async def update_embedding_config( - request: Request, form_data: EmbeddingModelUpdateForm, user=Depends(get_admin_user) -): +@router.post('/embedding/update') +async def update_embedding_config(request: Request, form_data: EmbeddingModelUpdateForm, user=Depends(get_admin_user)): log.info( - f"Updating embedding model: {request.app.state.config.RAG_EMBEDDING_MODEL} to {form_data.RAG_EMBEDDING_MODEL}" + f'Updating embedding model: {request.app.state.config.RAG_EMBEDDING_MODEL} to {form_data.RAG_EMBEDDING_MODEL}' ) unload_embedding_model(request) try: request.app.state.config.RAG_EMBEDDING_ENGINE = form_data.RAG_EMBEDDING_ENGINE request.app.state.config.RAG_EMBEDDING_MODEL = form_data.RAG_EMBEDDING_MODEL - request.app.state.config.RAG_EMBEDDING_BATCH_SIZE = ( - form_data.RAG_EMBEDDING_BATCH_SIZE - ) - request.app.state.config.ENABLE_ASYNC_EMBEDDING = ( - form_data.ENABLE_ASYNC_EMBEDDING - ) - request.app.state.config.RAG_EMBEDDING_CONCURRENT_REQUESTS = ( - form_data.RAG_EMBEDDING_CONCURRENT_REQUESTS - ) + request.app.state.config.RAG_EMBEDDING_BATCH_SIZE = form_data.RAG_EMBEDDING_BATCH_SIZE + request.app.state.config.ENABLE_ASYNC_EMBEDDING = form_data.ENABLE_ASYNC_EMBEDDING + request.app.state.config.RAG_EMBEDDING_CONCURRENT_REQUESTS = form_data.RAG_EMBEDDING_CONCURRENT_REQUESTS if request.app.state.config.RAG_EMBEDDING_ENGINE in [ - "ollama", - "openai", - "azure_openai", + 'ollama', + 'openai', + 'azure_openai', ]: if form_data.openai_config is not None: - request.app.state.config.RAG_OPENAI_API_BASE_URL = ( - form_data.openai_config.url - ) - request.app.state.config.RAG_OPENAI_API_KEY = ( - form_data.openai_config.key - ) + request.app.state.config.RAG_OPENAI_API_BASE_URL = form_data.openai_config.url + request.app.state.config.RAG_OPENAI_API_KEY = form_data.openai_config.key if form_data.ollama_config is not None: - request.app.state.config.RAG_OLLAMA_BASE_URL = ( - form_data.ollama_config.url - ) - request.app.state.config.RAG_OLLAMA_API_KEY = ( - form_data.ollama_config.key - ) + request.app.state.config.RAG_OLLAMA_BASE_URL = form_data.ollama_config.url + request.app.state.config.RAG_OLLAMA_API_KEY = form_data.ollama_config.key if form_data.azure_openai_config is not None: - request.app.state.config.RAG_AZURE_OPENAI_BASE_URL = ( - form_data.azure_openai_config.url - ) - request.app.state.config.RAG_AZURE_OPENAI_API_KEY = ( - form_data.azure_openai_config.key - ) - request.app.state.config.RAG_AZURE_OPENAI_API_VERSION = ( - form_data.azure_openai_config.version - ) + request.app.state.config.RAG_AZURE_OPENAI_BASE_URL = form_data.azure_openai_config.url + request.app.state.config.RAG_AZURE_OPENAI_API_KEY = form_data.azure_openai_config.key + request.app.state.config.RAG_AZURE_OPENAI_API_VERSION = form_data.azure_openai_config.version request.app.state.ef = get_ef( request.app.state.config.RAG_EMBEDDING_ENGINE, @@ -407,26 +379,26 @@ async def update_embedding_config( request.app.state.ef, ( request.app.state.config.RAG_OPENAI_API_BASE_URL - if request.app.state.config.RAG_EMBEDDING_ENGINE == "openai" + if request.app.state.config.RAG_EMBEDDING_ENGINE == 'openai' else ( request.app.state.config.RAG_OLLAMA_BASE_URL - if request.app.state.config.RAG_EMBEDDING_ENGINE == "ollama" + if request.app.state.config.RAG_EMBEDDING_ENGINE == 'ollama' else request.app.state.config.RAG_AZURE_OPENAI_BASE_URL ) ), ( request.app.state.config.RAG_OPENAI_API_KEY - if request.app.state.config.RAG_EMBEDDING_ENGINE == "openai" + if request.app.state.config.RAG_EMBEDDING_ENGINE == 'openai' else ( request.app.state.config.RAG_OLLAMA_API_KEY - if request.app.state.config.RAG_EMBEDDING_ENGINE == "ollama" + if request.app.state.config.RAG_EMBEDDING_ENGINE == 'ollama' else request.app.state.config.RAG_AZURE_OPENAI_API_KEY ) ), request.app.state.config.RAG_EMBEDDING_BATCH_SIZE, azure_api_version=( request.app.state.config.RAG_AZURE_OPENAI_API_VERSION - if request.app.state.config.RAG_EMBEDDING_ENGINE == "azure_openai" + if request.app.state.config.RAG_EMBEDDING_ENGINE == 'azure_openai' else None ), enable_async=request.app.state.config.ENABLE_ASYNC_EMBEDDING, @@ -434,167 +406,167 @@ async def update_embedding_config( ) return { - "status": True, - "RAG_EMBEDDING_ENGINE": request.app.state.config.RAG_EMBEDDING_ENGINE, - "RAG_EMBEDDING_MODEL": request.app.state.config.RAG_EMBEDDING_MODEL, - "RAG_EMBEDDING_BATCH_SIZE": request.app.state.config.RAG_EMBEDDING_BATCH_SIZE, - "ENABLE_ASYNC_EMBEDDING": request.app.state.config.ENABLE_ASYNC_EMBEDDING, - "RAG_EMBEDDING_CONCURRENT_REQUESTS": request.app.state.config.RAG_EMBEDDING_CONCURRENT_REQUESTS, - "openai_config": { - "url": request.app.state.config.RAG_OPENAI_API_BASE_URL, - "key": request.app.state.config.RAG_OPENAI_API_KEY, + 'status': True, + 'RAG_EMBEDDING_ENGINE': request.app.state.config.RAG_EMBEDDING_ENGINE, + 'RAG_EMBEDDING_MODEL': request.app.state.config.RAG_EMBEDDING_MODEL, + 'RAG_EMBEDDING_BATCH_SIZE': request.app.state.config.RAG_EMBEDDING_BATCH_SIZE, + 'ENABLE_ASYNC_EMBEDDING': request.app.state.config.ENABLE_ASYNC_EMBEDDING, + 'RAG_EMBEDDING_CONCURRENT_REQUESTS': request.app.state.config.RAG_EMBEDDING_CONCURRENT_REQUESTS, + 'openai_config': { + 'url': request.app.state.config.RAG_OPENAI_API_BASE_URL, + 'key': request.app.state.config.RAG_OPENAI_API_KEY, }, - "ollama_config": { - "url": request.app.state.config.RAG_OLLAMA_BASE_URL, - "key": request.app.state.config.RAG_OLLAMA_API_KEY, + 'ollama_config': { + 'url': request.app.state.config.RAG_OLLAMA_BASE_URL, + 'key': request.app.state.config.RAG_OLLAMA_API_KEY, }, - "azure_openai_config": { - "url": request.app.state.config.RAG_AZURE_OPENAI_BASE_URL, - "key": request.app.state.config.RAG_AZURE_OPENAI_API_KEY, - "version": request.app.state.config.RAG_AZURE_OPENAI_API_VERSION, + 'azure_openai_config': { + 'url': request.app.state.config.RAG_AZURE_OPENAI_BASE_URL, + 'key': request.app.state.config.RAG_AZURE_OPENAI_API_KEY, + 'version': request.app.state.config.RAG_AZURE_OPENAI_API_VERSION, }, } except Exception as e: - log.exception(f"Problem updating embedding model: {e}") + log.exception(f'Problem updating embedding model: {e}') raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=ERROR_MESSAGES.DEFAULT(e), ) -@router.get("/config") +@router.get('/config') async def get_rag_config(request: Request, user=Depends(get_admin_user)): return { - "status": True, + 'status': True, # RAG settings - "RAG_TEMPLATE": request.app.state.config.RAG_TEMPLATE, - "TOP_K": request.app.state.config.TOP_K, - "BYPASS_EMBEDDING_AND_RETRIEVAL": request.app.state.config.BYPASS_EMBEDDING_AND_RETRIEVAL, - "RAG_FULL_CONTEXT": request.app.state.config.RAG_FULL_CONTEXT, + 'RAG_TEMPLATE': request.app.state.config.RAG_TEMPLATE, + 'TOP_K': request.app.state.config.TOP_K, + 'BYPASS_EMBEDDING_AND_RETRIEVAL': request.app.state.config.BYPASS_EMBEDDING_AND_RETRIEVAL, + 'RAG_FULL_CONTEXT': request.app.state.config.RAG_FULL_CONTEXT, # Hybrid search settings - "ENABLE_RAG_HYBRID_SEARCH": request.app.state.config.ENABLE_RAG_HYBRID_SEARCH, - "ENABLE_RAG_HYBRID_SEARCH_ENRICHED_TEXTS": request.app.state.config.ENABLE_RAG_HYBRID_SEARCH_ENRICHED_TEXTS, - "TOP_K_RERANKER": request.app.state.config.TOP_K_RERANKER, - "RELEVANCE_THRESHOLD": request.app.state.config.RELEVANCE_THRESHOLD, - "HYBRID_BM25_WEIGHT": request.app.state.config.HYBRID_BM25_WEIGHT, + 'ENABLE_RAG_HYBRID_SEARCH': request.app.state.config.ENABLE_RAG_HYBRID_SEARCH, + 'ENABLE_RAG_HYBRID_SEARCH_ENRICHED_TEXTS': request.app.state.config.ENABLE_RAG_HYBRID_SEARCH_ENRICHED_TEXTS, + 'TOP_K_RERANKER': request.app.state.config.TOP_K_RERANKER, + 'RELEVANCE_THRESHOLD': request.app.state.config.RELEVANCE_THRESHOLD, + 'HYBRID_BM25_WEIGHT': request.app.state.config.HYBRID_BM25_WEIGHT, # Content extraction settings - "CONTENT_EXTRACTION_ENGINE": request.app.state.config.CONTENT_EXTRACTION_ENGINE, - "PDF_EXTRACT_IMAGES": request.app.state.config.PDF_EXTRACT_IMAGES, - "PDF_LOADER_MODE": request.app.state.config.PDF_LOADER_MODE, - "DATALAB_MARKER_API_KEY": request.app.state.config.DATALAB_MARKER_API_KEY, - "DATALAB_MARKER_API_BASE_URL": request.app.state.config.DATALAB_MARKER_API_BASE_URL, - "DATALAB_MARKER_ADDITIONAL_CONFIG": request.app.state.config.DATALAB_MARKER_ADDITIONAL_CONFIG, - "DATALAB_MARKER_SKIP_CACHE": request.app.state.config.DATALAB_MARKER_SKIP_CACHE, - "DATALAB_MARKER_FORCE_OCR": request.app.state.config.DATALAB_MARKER_FORCE_OCR, - "DATALAB_MARKER_PAGINATE": request.app.state.config.DATALAB_MARKER_PAGINATE, - "DATALAB_MARKER_STRIP_EXISTING_OCR": request.app.state.config.DATALAB_MARKER_STRIP_EXISTING_OCR, - "DATALAB_MARKER_DISABLE_IMAGE_EXTRACTION": request.app.state.config.DATALAB_MARKER_DISABLE_IMAGE_EXTRACTION, - "DATALAB_MARKER_FORMAT_LINES": request.app.state.config.DATALAB_MARKER_FORMAT_LINES, - "DATALAB_MARKER_USE_LLM": request.app.state.config.DATALAB_MARKER_USE_LLM, - "DATALAB_MARKER_OUTPUT_FORMAT": request.app.state.config.DATALAB_MARKER_OUTPUT_FORMAT, - "EXTERNAL_DOCUMENT_LOADER_URL": request.app.state.config.EXTERNAL_DOCUMENT_LOADER_URL, - "EXTERNAL_DOCUMENT_LOADER_API_KEY": request.app.state.config.EXTERNAL_DOCUMENT_LOADER_API_KEY, - "TIKA_SERVER_URL": request.app.state.config.TIKA_SERVER_URL, - "DOCLING_SERVER_URL": request.app.state.config.DOCLING_SERVER_URL, - "DOCLING_API_KEY": request.app.state.config.DOCLING_API_KEY, - "DOCLING_PARAMS": request.app.state.config.DOCLING_PARAMS, - "DOCUMENT_INTELLIGENCE_ENDPOINT": request.app.state.config.DOCUMENT_INTELLIGENCE_ENDPOINT, - "DOCUMENT_INTELLIGENCE_KEY": request.app.state.config.DOCUMENT_INTELLIGENCE_KEY, - "DOCUMENT_INTELLIGENCE_MODEL": request.app.state.config.DOCUMENT_INTELLIGENCE_MODEL, - "MISTRAL_OCR_API_BASE_URL": request.app.state.config.MISTRAL_OCR_API_BASE_URL, - "MISTRAL_OCR_API_KEY": request.app.state.config.MISTRAL_OCR_API_KEY, + 'CONTENT_EXTRACTION_ENGINE': request.app.state.config.CONTENT_EXTRACTION_ENGINE, + 'PDF_EXTRACT_IMAGES': request.app.state.config.PDF_EXTRACT_IMAGES, + 'PDF_LOADER_MODE': request.app.state.config.PDF_LOADER_MODE, + 'DATALAB_MARKER_API_KEY': request.app.state.config.DATALAB_MARKER_API_KEY, + 'DATALAB_MARKER_API_BASE_URL': request.app.state.config.DATALAB_MARKER_API_BASE_URL, + 'DATALAB_MARKER_ADDITIONAL_CONFIG': request.app.state.config.DATALAB_MARKER_ADDITIONAL_CONFIG, + 'DATALAB_MARKER_SKIP_CACHE': request.app.state.config.DATALAB_MARKER_SKIP_CACHE, + 'DATALAB_MARKER_FORCE_OCR': request.app.state.config.DATALAB_MARKER_FORCE_OCR, + 'DATALAB_MARKER_PAGINATE': request.app.state.config.DATALAB_MARKER_PAGINATE, + 'DATALAB_MARKER_STRIP_EXISTING_OCR': request.app.state.config.DATALAB_MARKER_STRIP_EXISTING_OCR, + 'DATALAB_MARKER_DISABLE_IMAGE_EXTRACTION': request.app.state.config.DATALAB_MARKER_DISABLE_IMAGE_EXTRACTION, + 'DATALAB_MARKER_FORMAT_LINES': request.app.state.config.DATALAB_MARKER_FORMAT_LINES, + 'DATALAB_MARKER_USE_LLM': request.app.state.config.DATALAB_MARKER_USE_LLM, + 'DATALAB_MARKER_OUTPUT_FORMAT': request.app.state.config.DATALAB_MARKER_OUTPUT_FORMAT, + 'EXTERNAL_DOCUMENT_LOADER_URL': request.app.state.config.EXTERNAL_DOCUMENT_LOADER_URL, + 'EXTERNAL_DOCUMENT_LOADER_API_KEY': request.app.state.config.EXTERNAL_DOCUMENT_LOADER_API_KEY, + 'TIKA_SERVER_URL': request.app.state.config.TIKA_SERVER_URL, + 'DOCLING_SERVER_URL': request.app.state.config.DOCLING_SERVER_URL, + 'DOCLING_API_KEY': request.app.state.config.DOCLING_API_KEY, + 'DOCLING_PARAMS': request.app.state.config.DOCLING_PARAMS, + 'DOCUMENT_INTELLIGENCE_ENDPOINT': request.app.state.config.DOCUMENT_INTELLIGENCE_ENDPOINT, + 'DOCUMENT_INTELLIGENCE_KEY': request.app.state.config.DOCUMENT_INTELLIGENCE_KEY, + 'DOCUMENT_INTELLIGENCE_MODEL': request.app.state.config.DOCUMENT_INTELLIGENCE_MODEL, + 'MISTRAL_OCR_API_BASE_URL': request.app.state.config.MISTRAL_OCR_API_BASE_URL, + 'MISTRAL_OCR_API_KEY': request.app.state.config.MISTRAL_OCR_API_KEY, # MinerU settings - "MINERU_API_MODE": request.app.state.config.MINERU_API_MODE, - "MINERU_API_URL": request.app.state.config.MINERU_API_URL, - "MINERU_API_KEY": request.app.state.config.MINERU_API_KEY, - "MINERU_API_TIMEOUT": request.app.state.config.MINERU_API_TIMEOUT, - "MINERU_PARAMS": request.app.state.config.MINERU_PARAMS, + 'MINERU_API_MODE': request.app.state.config.MINERU_API_MODE, + 'MINERU_API_URL': request.app.state.config.MINERU_API_URL, + 'MINERU_API_KEY': request.app.state.config.MINERU_API_KEY, + 'MINERU_API_TIMEOUT': request.app.state.config.MINERU_API_TIMEOUT, + 'MINERU_PARAMS': request.app.state.config.MINERU_PARAMS, # Reranking settings - "RAG_RERANKING_MODEL": request.app.state.config.RAG_RERANKING_MODEL, - "RAG_RERANKING_ENGINE": request.app.state.config.RAG_RERANKING_ENGINE, - "RAG_EXTERNAL_RERANKER_URL": request.app.state.config.RAG_EXTERNAL_RERANKER_URL, - "RAG_EXTERNAL_RERANKER_API_KEY": request.app.state.config.RAG_EXTERNAL_RERANKER_API_KEY, - "RAG_EXTERNAL_RERANKER_TIMEOUT": request.app.state.config.RAG_EXTERNAL_RERANKER_TIMEOUT, + 'RAG_RERANKING_MODEL': request.app.state.config.RAG_RERANKING_MODEL, + 'RAG_RERANKING_ENGINE': request.app.state.config.RAG_RERANKING_ENGINE, + 'RAG_EXTERNAL_RERANKER_URL': request.app.state.config.RAG_EXTERNAL_RERANKER_URL, + 'RAG_EXTERNAL_RERANKER_API_KEY': request.app.state.config.RAG_EXTERNAL_RERANKER_API_KEY, + 'RAG_EXTERNAL_RERANKER_TIMEOUT': request.app.state.config.RAG_EXTERNAL_RERANKER_TIMEOUT, # Chunking settings - "TEXT_SPLITTER": request.app.state.config.TEXT_SPLITTER, - "ENABLE_MARKDOWN_HEADER_TEXT_SPLITTER": request.app.state.config.ENABLE_MARKDOWN_HEADER_TEXT_SPLITTER, - "CHUNK_SIZE": request.app.state.config.CHUNK_SIZE, - "CHUNK_MIN_SIZE_TARGET": request.app.state.config.CHUNK_MIN_SIZE_TARGET, - "CHUNK_OVERLAP": request.app.state.config.CHUNK_OVERLAP, + 'TEXT_SPLITTER': request.app.state.config.TEXT_SPLITTER, + 'ENABLE_MARKDOWN_HEADER_TEXT_SPLITTER': request.app.state.config.ENABLE_MARKDOWN_HEADER_TEXT_SPLITTER, + 'CHUNK_SIZE': request.app.state.config.CHUNK_SIZE, + 'CHUNK_MIN_SIZE_TARGET': request.app.state.config.CHUNK_MIN_SIZE_TARGET, + 'CHUNK_OVERLAP': request.app.state.config.CHUNK_OVERLAP, # File upload settings - "FILE_MAX_SIZE": request.app.state.config.FILE_MAX_SIZE, - "FILE_MAX_COUNT": request.app.state.config.FILE_MAX_COUNT, - "FILE_IMAGE_COMPRESSION_WIDTH": request.app.state.config.FILE_IMAGE_COMPRESSION_WIDTH, - "FILE_IMAGE_COMPRESSION_HEIGHT": request.app.state.config.FILE_IMAGE_COMPRESSION_HEIGHT, - "ALLOWED_FILE_EXTENSIONS": request.app.state.config.ALLOWED_FILE_EXTENSIONS, + 'FILE_MAX_SIZE': request.app.state.config.FILE_MAX_SIZE, + 'FILE_MAX_COUNT': request.app.state.config.FILE_MAX_COUNT, + 'FILE_IMAGE_COMPRESSION_WIDTH': request.app.state.config.FILE_IMAGE_COMPRESSION_WIDTH, + 'FILE_IMAGE_COMPRESSION_HEIGHT': request.app.state.config.FILE_IMAGE_COMPRESSION_HEIGHT, + 'ALLOWED_FILE_EXTENSIONS': request.app.state.config.ALLOWED_FILE_EXTENSIONS, # Integration settings - "ENABLE_GOOGLE_DRIVE_INTEGRATION": request.app.state.config.ENABLE_GOOGLE_DRIVE_INTEGRATION, - "ENABLE_ONEDRIVE_INTEGRATION": request.app.state.config.ENABLE_ONEDRIVE_INTEGRATION, + 'ENABLE_GOOGLE_DRIVE_INTEGRATION': request.app.state.config.ENABLE_GOOGLE_DRIVE_INTEGRATION, + 'ENABLE_ONEDRIVE_INTEGRATION': request.app.state.config.ENABLE_ONEDRIVE_INTEGRATION, # Web search settings - "web": { - "ENABLE_WEB_SEARCH": request.app.state.config.ENABLE_WEB_SEARCH, - "WEB_SEARCH_ENGINE": request.app.state.config.WEB_SEARCH_ENGINE, - "WEB_SEARCH_TRUST_ENV": request.app.state.config.WEB_SEARCH_TRUST_ENV, - "WEB_SEARCH_RESULT_COUNT": request.app.state.config.WEB_SEARCH_RESULT_COUNT, - "WEB_SEARCH_CONCURRENT_REQUESTS": request.app.state.config.WEB_SEARCH_CONCURRENT_REQUESTS, - "WEB_FETCH_MAX_CONTENT_LENGTH": request.app.state.config.WEB_FETCH_MAX_CONTENT_LENGTH, - "WEB_LOADER_CONCURRENT_REQUESTS": request.app.state.config.WEB_LOADER_CONCURRENT_REQUESTS, - "WEB_SEARCH_DOMAIN_FILTER_LIST": request.app.state.config.WEB_SEARCH_DOMAIN_FILTER_LIST, - "BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL": request.app.state.config.BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL, - "BYPASS_WEB_SEARCH_WEB_LOADER": request.app.state.config.BYPASS_WEB_SEARCH_WEB_LOADER, - "OLLAMA_CLOUD_WEB_SEARCH_API_KEY": request.app.state.config.OLLAMA_CLOUD_WEB_SEARCH_API_KEY, - "SEARXNG_QUERY_URL": request.app.state.config.SEARXNG_QUERY_URL, - "SEARXNG_LANGUAGE": request.app.state.config.SEARXNG_LANGUAGE, - "YACY_QUERY_URL": request.app.state.config.YACY_QUERY_URL, - "YACY_USERNAME": request.app.state.config.YACY_USERNAME, - "YACY_PASSWORD": request.app.state.config.YACY_PASSWORD, - "GOOGLE_PSE_API_KEY": request.app.state.config.GOOGLE_PSE_API_KEY, - "GOOGLE_PSE_ENGINE_ID": request.app.state.config.GOOGLE_PSE_ENGINE_ID, - "BRAVE_SEARCH_API_KEY": request.app.state.config.BRAVE_SEARCH_API_KEY, - "KAGI_SEARCH_API_KEY": request.app.state.config.KAGI_SEARCH_API_KEY, - "MOJEEK_SEARCH_API_KEY": request.app.state.config.MOJEEK_SEARCH_API_KEY, - "BOCHA_SEARCH_API_KEY": request.app.state.config.BOCHA_SEARCH_API_KEY, - "SERPSTACK_API_KEY": request.app.state.config.SERPSTACK_API_KEY, - "SERPSTACK_HTTPS": request.app.state.config.SERPSTACK_HTTPS, - "SERPER_API_KEY": request.app.state.config.SERPER_API_KEY, - "SERPLY_API_KEY": request.app.state.config.SERPLY_API_KEY, - "DDGS_BACKEND": request.app.state.config.DDGS_BACKEND, - "TAVILY_API_KEY": request.app.state.config.TAVILY_API_KEY, - "SEARCHAPI_API_KEY": request.app.state.config.SEARCHAPI_API_KEY, - "SEARCHAPI_ENGINE": request.app.state.config.SEARCHAPI_ENGINE, - "SERPAPI_API_KEY": request.app.state.config.SERPAPI_API_KEY, - "SERPAPI_ENGINE": request.app.state.config.SERPAPI_ENGINE, - "JINA_API_KEY": request.app.state.config.JINA_API_KEY, - "JINA_API_BASE_URL": request.app.state.config.JINA_API_BASE_URL, - "BING_SEARCH_V7_ENDPOINT": request.app.state.config.BING_SEARCH_V7_ENDPOINT, - "BING_SEARCH_V7_SUBSCRIPTION_KEY": request.app.state.config.BING_SEARCH_V7_SUBSCRIPTION_KEY, - "EXA_API_KEY": request.app.state.config.EXA_API_KEY, - "PERPLEXITY_API_KEY": request.app.state.config.PERPLEXITY_API_KEY, - "PERPLEXITY_MODEL": request.app.state.config.PERPLEXITY_MODEL, - "PERPLEXITY_SEARCH_CONTEXT_USAGE": request.app.state.config.PERPLEXITY_SEARCH_CONTEXT_USAGE, - "PERPLEXITY_SEARCH_API_URL": request.app.state.config.PERPLEXITY_SEARCH_API_URL, - "SOUGOU_API_SID": request.app.state.config.SOUGOU_API_SID, - "SOUGOU_API_SK": request.app.state.config.SOUGOU_API_SK, - "WEB_LOADER_ENGINE": request.app.state.config.WEB_LOADER_ENGINE, - "WEB_LOADER_TIMEOUT": request.app.state.config.WEB_LOADER_TIMEOUT, - "ENABLE_WEB_LOADER_SSL_VERIFICATION": request.app.state.config.ENABLE_WEB_LOADER_SSL_VERIFICATION, - "PLAYWRIGHT_WS_URL": request.app.state.config.PLAYWRIGHT_WS_URL, - "PLAYWRIGHT_TIMEOUT": request.app.state.config.PLAYWRIGHT_TIMEOUT, - "FIRECRAWL_API_KEY": request.app.state.config.FIRECRAWL_API_KEY, - "FIRECRAWL_API_BASE_URL": request.app.state.config.FIRECRAWL_API_BASE_URL, - "FIRECRAWL_TIMEOUT": request.app.state.config.FIRECRAWL_TIMEOUT, - "TAVILY_EXTRACT_DEPTH": request.app.state.config.TAVILY_EXTRACT_DEPTH, - "EXTERNAL_WEB_SEARCH_URL": request.app.state.config.EXTERNAL_WEB_SEARCH_URL, - "EXTERNAL_WEB_SEARCH_API_KEY": request.app.state.config.EXTERNAL_WEB_SEARCH_API_KEY, - "EXTERNAL_WEB_LOADER_URL": request.app.state.config.EXTERNAL_WEB_LOADER_URL, - "EXTERNAL_WEB_LOADER_API_KEY": request.app.state.config.EXTERNAL_WEB_LOADER_API_KEY, - "YOUTUBE_LOADER_LANGUAGE": request.app.state.config.YOUTUBE_LOADER_LANGUAGE, - "YOUTUBE_LOADER_PROXY_URL": request.app.state.config.YOUTUBE_LOADER_PROXY_URL, - "YOUTUBE_LOADER_TRANSLATION": request.app.state.YOUTUBE_LOADER_TRANSLATION, - "YANDEX_WEB_SEARCH_URL": request.app.state.config.YANDEX_WEB_SEARCH_URL, - "YANDEX_WEB_SEARCH_API_KEY": request.app.state.config.YANDEX_WEB_SEARCH_API_KEY, - "YANDEX_WEB_SEARCH_CONFIG": request.app.state.config.YANDEX_WEB_SEARCH_CONFIG, - "YOUCOM_API_KEY": request.app.state.config.YOUCOM_API_KEY, + 'web': { + 'ENABLE_WEB_SEARCH': request.app.state.config.ENABLE_WEB_SEARCH, + 'WEB_SEARCH_ENGINE': request.app.state.config.WEB_SEARCH_ENGINE, + 'WEB_SEARCH_TRUST_ENV': request.app.state.config.WEB_SEARCH_TRUST_ENV, + 'WEB_SEARCH_RESULT_COUNT': request.app.state.config.WEB_SEARCH_RESULT_COUNT, + 'WEB_SEARCH_CONCURRENT_REQUESTS': request.app.state.config.WEB_SEARCH_CONCURRENT_REQUESTS, + 'WEB_FETCH_MAX_CONTENT_LENGTH': request.app.state.config.WEB_FETCH_MAX_CONTENT_LENGTH, + 'WEB_LOADER_CONCURRENT_REQUESTS': request.app.state.config.WEB_LOADER_CONCURRENT_REQUESTS, + 'WEB_SEARCH_DOMAIN_FILTER_LIST': request.app.state.config.WEB_SEARCH_DOMAIN_FILTER_LIST, + 'BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL': request.app.state.config.BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL, + 'BYPASS_WEB_SEARCH_WEB_LOADER': request.app.state.config.BYPASS_WEB_SEARCH_WEB_LOADER, + 'OLLAMA_CLOUD_WEB_SEARCH_API_KEY': request.app.state.config.OLLAMA_CLOUD_WEB_SEARCH_API_KEY, + 'SEARXNG_QUERY_URL': request.app.state.config.SEARXNG_QUERY_URL, + 'SEARXNG_LANGUAGE': request.app.state.config.SEARXNG_LANGUAGE, + 'YACY_QUERY_URL': request.app.state.config.YACY_QUERY_URL, + 'YACY_USERNAME': request.app.state.config.YACY_USERNAME, + 'YACY_PASSWORD': request.app.state.config.YACY_PASSWORD, + 'GOOGLE_PSE_API_KEY': request.app.state.config.GOOGLE_PSE_API_KEY, + 'GOOGLE_PSE_ENGINE_ID': request.app.state.config.GOOGLE_PSE_ENGINE_ID, + 'BRAVE_SEARCH_API_KEY': request.app.state.config.BRAVE_SEARCH_API_KEY, + 'KAGI_SEARCH_API_KEY': request.app.state.config.KAGI_SEARCH_API_KEY, + 'MOJEEK_SEARCH_API_KEY': request.app.state.config.MOJEEK_SEARCH_API_KEY, + 'BOCHA_SEARCH_API_KEY': request.app.state.config.BOCHA_SEARCH_API_KEY, + 'SERPSTACK_API_KEY': request.app.state.config.SERPSTACK_API_KEY, + 'SERPSTACK_HTTPS': request.app.state.config.SERPSTACK_HTTPS, + 'SERPER_API_KEY': request.app.state.config.SERPER_API_KEY, + 'SERPLY_API_KEY': request.app.state.config.SERPLY_API_KEY, + 'DDGS_BACKEND': request.app.state.config.DDGS_BACKEND, + 'TAVILY_API_KEY': request.app.state.config.TAVILY_API_KEY, + 'SEARCHAPI_API_KEY': request.app.state.config.SEARCHAPI_API_KEY, + 'SEARCHAPI_ENGINE': request.app.state.config.SEARCHAPI_ENGINE, + 'SERPAPI_API_KEY': request.app.state.config.SERPAPI_API_KEY, + 'SERPAPI_ENGINE': request.app.state.config.SERPAPI_ENGINE, + 'JINA_API_KEY': request.app.state.config.JINA_API_KEY, + 'JINA_API_BASE_URL': request.app.state.config.JINA_API_BASE_URL, + 'BING_SEARCH_V7_ENDPOINT': request.app.state.config.BING_SEARCH_V7_ENDPOINT, + 'BING_SEARCH_V7_SUBSCRIPTION_KEY': request.app.state.config.BING_SEARCH_V7_SUBSCRIPTION_KEY, + 'EXA_API_KEY': request.app.state.config.EXA_API_KEY, + 'PERPLEXITY_API_KEY': request.app.state.config.PERPLEXITY_API_KEY, + 'PERPLEXITY_MODEL': request.app.state.config.PERPLEXITY_MODEL, + 'PERPLEXITY_SEARCH_CONTEXT_USAGE': request.app.state.config.PERPLEXITY_SEARCH_CONTEXT_USAGE, + 'PERPLEXITY_SEARCH_API_URL': request.app.state.config.PERPLEXITY_SEARCH_API_URL, + 'SOUGOU_API_SID': request.app.state.config.SOUGOU_API_SID, + 'SOUGOU_API_SK': request.app.state.config.SOUGOU_API_SK, + 'WEB_LOADER_ENGINE': request.app.state.config.WEB_LOADER_ENGINE, + 'WEB_LOADER_TIMEOUT': request.app.state.config.WEB_LOADER_TIMEOUT, + 'ENABLE_WEB_LOADER_SSL_VERIFICATION': request.app.state.config.ENABLE_WEB_LOADER_SSL_VERIFICATION, + 'PLAYWRIGHT_WS_URL': request.app.state.config.PLAYWRIGHT_WS_URL, + 'PLAYWRIGHT_TIMEOUT': request.app.state.config.PLAYWRIGHT_TIMEOUT, + 'FIRECRAWL_API_KEY': request.app.state.config.FIRECRAWL_API_KEY, + 'FIRECRAWL_API_BASE_URL': request.app.state.config.FIRECRAWL_API_BASE_URL, + 'FIRECRAWL_TIMEOUT': request.app.state.config.FIRECRAWL_TIMEOUT, + 'TAVILY_EXTRACT_DEPTH': request.app.state.config.TAVILY_EXTRACT_DEPTH, + 'EXTERNAL_WEB_SEARCH_URL': request.app.state.config.EXTERNAL_WEB_SEARCH_URL, + 'EXTERNAL_WEB_SEARCH_API_KEY': request.app.state.config.EXTERNAL_WEB_SEARCH_API_KEY, + 'EXTERNAL_WEB_LOADER_URL': request.app.state.config.EXTERNAL_WEB_LOADER_URL, + 'EXTERNAL_WEB_LOADER_API_KEY': request.app.state.config.EXTERNAL_WEB_LOADER_API_KEY, + 'YOUTUBE_LOADER_LANGUAGE': request.app.state.config.YOUTUBE_LOADER_LANGUAGE, + 'YOUTUBE_LOADER_PROXY_URL': request.app.state.config.YOUTUBE_LOADER_PROXY_URL, + 'YOUTUBE_LOADER_TRANSLATION': request.app.state.YOUTUBE_LOADER_TRANSLATION, + 'YANDEX_WEB_SEARCH_URL': request.app.state.config.YANDEX_WEB_SEARCH_URL, + 'YANDEX_WEB_SEARCH_API_KEY': request.app.state.config.YANDEX_WEB_SEARCH_API_KEY, + 'YANDEX_WEB_SEARCH_CONFIG': request.app.state.config.YANDEX_WEB_SEARCH_CONFIG, + 'YOUCOM_API_KEY': request.app.state.config.YOUCOM_API_KEY, }, } @@ -745,21 +717,13 @@ class ConfigForm(BaseModel): web: Optional[WebConfig] = None -@router.post("/config/update") -async def update_rag_config( - request: Request, form_data: ConfigForm, user=Depends(get_admin_user) -): +@router.post('/config/update') +async def update_rag_config(request: Request, form_data: ConfigForm, user=Depends(get_admin_user)): # RAG settings request.app.state.config.RAG_TEMPLATE = ( - form_data.RAG_TEMPLATE - if form_data.RAG_TEMPLATE is not None - else request.app.state.config.RAG_TEMPLATE - ) - request.app.state.config.TOP_K = ( - form_data.TOP_K - if form_data.TOP_K is not None - else request.app.state.config.TOP_K + form_data.RAG_TEMPLATE if form_data.RAG_TEMPLATE is not None else request.app.state.config.RAG_TEMPLATE ) + request.app.state.config.TOP_K = form_data.TOP_K if form_data.TOP_K is not None else request.app.state.config.TOP_K request.app.state.config.BYPASS_EMBEDDING_AND_RETRIEVAL = ( form_data.BYPASS_EMBEDDING_AND_RETRIEVAL if form_data.BYPASS_EMBEDDING_AND_RETRIEVAL is not None @@ -784,9 +748,7 @@ async def update_rag_config( ) request.app.state.config.TOP_K_RERANKER = ( - form_data.TOP_K_RERANKER - if form_data.TOP_K_RERANKER is not None - else request.app.state.config.TOP_K_RERANKER + form_data.TOP_K_RERANKER if form_data.TOP_K_RERANKER is not None else request.app.state.config.TOP_K_RERANKER ) request.app.state.config.RELEVANCE_THRESHOLD = ( form_data.RELEVANCE_THRESHOLD @@ -811,9 +773,7 @@ async def update_rag_config( else request.app.state.config.PDF_EXTRACT_IMAGES ) request.app.state.config.PDF_LOADER_MODE = ( - form_data.PDF_LOADER_MODE - if form_data.PDF_LOADER_MODE is not None - else request.app.state.config.PDF_LOADER_MODE + form_data.PDF_LOADER_MODE if form_data.PDF_LOADER_MODE is not None else request.app.state.config.PDF_LOADER_MODE ) request.app.state.config.DATALAB_MARKER_API_KEY = ( form_data.DATALAB_MARKER_API_KEY @@ -881,9 +841,7 @@ async def update_rag_config( else request.app.state.config.EXTERNAL_DOCUMENT_LOADER_API_KEY ) request.app.state.config.TIKA_SERVER_URL = ( - form_data.TIKA_SERVER_URL - if form_data.TIKA_SERVER_URL is not None - else request.app.state.config.TIKA_SERVER_URL + form_data.TIKA_SERVER_URL if form_data.TIKA_SERVER_URL is not None else request.app.state.config.TIKA_SERVER_URL ) request.app.state.config.DOCLING_SERVER_URL = ( form_data.DOCLING_SERVER_URL @@ -891,14 +849,10 @@ async def update_rag_config( else request.app.state.config.DOCLING_SERVER_URL ) request.app.state.config.DOCLING_API_KEY = ( - form_data.DOCLING_API_KEY - if form_data.DOCLING_API_KEY is not None - else request.app.state.config.DOCLING_API_KEY + form_data.DOCLING_API_KEY if form_data.DOCLING_API_KEY is not None else request.app.state.config.DOCLING_API_KEY ) request.app.state.config.DOCLING_PARAMS = ( - form_data.DOCLING_PARAMS - if form_data.DOCLING_PARAMS is not None - else request.app.state.config.DOCLING_PARAMS + form_data.DOCLING_PARAMS if form_data.DOCLING_PARAMS is not None else request.app.state.config.DOCLING_PARAMS ) request.app.state.config.DOCUMENT_INTELLIGENCE_ENDPOINT = ( form_data.DOCUMENT_INTELLIGENCE_ENDPOINT @@ -929,19 +883,13 @@ async def update_rag_config( # MinerU settings request.app.state.config.MINERU_API_MODE = ( - form_data.MINERU_API_MODE - if form_data.MINERU_API_MODE is not None - else request.app.state.config.MINERU_API_MODE + form_data.MINERU_API_MODE if form_data.MINERU_API_MODE is not None else request.app.state.config.MINERU_API_MODE ) request.app.state.config.MINERU_API_URL = ( - form_data.MINERU_API_URL - if form_data.MINERU_API_URL is not None - else request.app.state.config.MINERU_API_URL + form_data.MINERU_API_URL if form_data.MINERU_API_URL is not None else request.app.state.config.MINERU_API_URL ) request.app.state.config.MINERU_API_KEY = ( - form_data.MINERU_API_KEY - if form_data.MINERU_API_KEY is not None - else request.app.state.config.MINERU_API_KEY + form_data.MINERU_API_KEY if form_data.MINERU_API_KEY is not None else request.app.state.config.MINERU_API_KEY ) request.app.state.config.MINERU_API_TIMEOUT = ( form_data.MINERU_API_TIMEOUT @@ -949,20 +897,18 @@ async def update_rag_config( else request.app.state.config.MINERU_API_TIMEOUT ) request.app.state.config.MINERU_PARAMS = ( - form_data.MINERU_PARAMS - if form_data.MINERU_PARAMS is not None - else request.app.state.config.MINERU_PARAMS + form_data.MINERU_PARAMS if form_data.MINERU_PARAMS is not None else request.app.state.config.MINERU_PARAMS ) # Reranking settings - if request.app.state.config.RAG_RERANKING_ENGINE == "": + if request.app.state.config.RAG_RERANKING_ENGINE == '': # Unloading the internal reranker and clear VRAM memory request.app.state.rf = None request.app.state.RERANKING_FUNCTION = None import gc gc.collect() - if DEVICE_TYPE == "cuda": + if DEVICE_TYPE == 'cuda': import torch if torch.cuda.is_available(): @@ -992,7 +938,7 @@ async def update_rag_config( ) log.info( - f"Updating reranking model: {request.app.state.config.RAG_RERANKING_MODEL} to {form_data.RAG_RERANKING_MODEL}" + f'Updating reranking model: {request.app.state.config.RAG_RERANKING_MODEL} to {form_data.RAG_RERANKING_MODEL}' ) try: request.app.state.config.RAG_RERANKING_MODEL = ( @@ -1020,10 +966,10 @@ async def update_rag_config( request.app.state.rf, ) except Exception as e: - log.error(f"Error loading reranking model: {e}") + log.error(f'Error loading reranking model: {e}') request.app.state.config.ENABLE_RAG_HYBRID_SEARCH = False except Exception as e: - log.exception(f"Problem updating reranking model: {e}") + log.exception(f'Problem updating reranking model: {e}') raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=ERROR_MESSAGES.DEFAULT(e), @@ -1031,9 +977,7 @@ async def update_rag_config( # Chunking settings request.app.state.config.TEXT_SPLITTER = ( - form_data.TEXT_SPLITTER - if form_data.TEXT_SPLITTER is not None - else request.app.state.config.TEXT_SPLITTER + form_data.TEXT_SPLITTER if form_data.TEXT_SPLITTER is not None else request.app.state.config.TEXT_SPLITTER ) request.app.state.config.ENABLE_MARKDOWN_HEADER_TEXT_SPLITTER = ( form_data.ENABLE_MARKDOWN_HEADER_TEXT_SPLITTER @@ -1041,9 +985,7 @@ async def update_rag_config( else request.app.state.config.ENABLE_MARKDOWN_HEADER_TEXT_SPLITTER ) request.app.state.config.CHUNK_SIZE = ( - form_data.CHUNK_SIZE - if form_data.CHUNK_SIZE is not None - else request.app.state.config.CHUNK_SIZE + form_data.CHUNK_SIZE if form_data.CHUNK_SIZE is not None else request.app.state.config.CHUNK_SIZE ) request.app.state.config.CHUNK_MIN_SIZE_TARGET = ( form_data.CHUNK_MIN_SIZE_TARGET @@ -1051,33 +993,23 @@ async def update_rag_config( else request.app.state.config.CHUNK_MIN_SIZE_TARGET ) request.app.state.config.CHUNK_OVERLAP = ( - form_data.CHUNK_OVERLAP - if form_data.CHUNK_OVERLAP is not None - else request.app.state.config.CHUNK_OVERLAP + form_data.CHUNK_OVERLAP if form_data.CHUNK_OVERLAP is not None else request.app.state.config.CHUNK_OVERLAP ) # File upload settings # Empty string means "clear to None" (unlimited/no compression), # None means "don't change", int means "set to this value" if form_data.FILE_MAX_SIZE is not None: - request.app.state.config.FILE_MAX_SIZE = ( - None if form_data.FILE_MAX_SIZE == "" else form_data.FILE_MAX_SIZE - ) + request.app.state.config.FILE_MAX_SIZE = None if form_data.FILE_MAX_SIZE == '' else form_data.FILE_MAX_SIZE if form_data.FILE_MAX_COUNT is not None: - request.app.state.config.FILE_MAX_COUNT = ( - None if form_data.FILE_MAX_COUNT == "" else form_data.FILE_MAX_COUNT - ) + request.app.state.config.FILE_MAX_COUNT = None if form_data.FILE_MAX_COUNT == '' else form_data.FILE_MAX_COUNT if form_data.FILE_IMAGE_COMPRESSION_WIDTH is not None: request.app.state.config.FILE_IMAGE_COMPRESSION_WIDTH = ( - None - if form_data.FILE_IMAGE_COMPRESSION_WIDTH == "" - else form_data.FILE_IMAGE_COMPRESSION_WIDTH + None if form_data.FILE_IMAGE_COMPRESSION_WIDTH == '' else form_data.FILE_IMAGE_COMPRESSION_WIDTH ) if form_data.FILE_IMAGE_COMPRESSION_HEIGHT is not None: request.app.state.config.FILE_IMAGE_COMPRESSION_HEIGHT = ( - None - if form_data.FILE_IMAGE_COMPRESSION_HEIGHT == "" - else form_data.FILE_IMAGE_COMPRESSION_HEIGHT + None if form_data.FILE_IMAGE_COMPRESSION_HEIGHT == '' else form_data.FILE_IMAGE_COMPRESSION_HEIGHT ) request.app.state.config.ALLOWED_FILE_EXTENSIONS = ( @@ -1102,52 +1034,28 @@ async def update_rag_config( # Web search settings request.app.state.config.ENABLE_WEB_SEARCH = form_data.web.ENABLE_WEB_SEARCH request.app.state.config.WEB_SEARCH_ENGINE = form_data.web.WEB_SEARCH_ENGINE - request.app.state.config.WEB_SEARCH_TRUST_ENV = ( - form_data.web.WEB_SEARCH_TRUST_ENV - ) - request.app.state.config.WEB_SEARCH_RESULT_COUNT = ( - form_data.web.WEB_SEARCH_RESULT_COUNT - ) - request.app.state.config.WEB_SEARCH_CONCURRENT_REQUESTS = ( - form_data.web.WEB_SEARCH_CONCURRENT_REQUESTS - ) - request.app.state.config.WEB_FETCH_MAX_CONTENT_LENGTH = ( - form_data.web.WEB_FETCH_MAX_CONTENT_LENGTH - ) - request.app.state.config.WEB_LOADER_CONCURRENT_REQUESTS = ( - form_data.web.WEB_LOADER_CONCURRENT_REQUESTS - ) - request.app.state.config.WEB_SEARCH_DOMAIN_FILTER_LIST = ( - form_data.web.WEB_SEARCH_DOMAIN_FILTER_LIST - ) + request.app.state.config.WEB_SEARCH_TRUST_ENV = form_data.web.WEB_SEARCH_TRUST_ENV + request.app.state.config.WEB_SEARCH_RESULT_COUNT = form_data.web.WEB_SEARCH_RESULT_COUNT + request.app.state.config.WEB_SEARCH_CONCURRENT_REQUESTS = form_data.web.WEB_SEARCH_CONCURRENT_REQUESTS + request.app.state.config.WEB_FETCH_MAX_CONTENT_LENGTH = form_data.web.WEB_FETCH_MAX_CONTENT_LENGTH + request.app.state.config.WEB_LOADER_CONCURRENT_REQUESTS = form_data.web.WEB_LOADER_CONCURRENT_REQUESTS + request.app.state.config.WEB_SEARCH_DOMAIN_FILTER_LIST = form_data.web.WEB_SEARCH_DOMAIN_FILTER_LIST request.app.state.config.BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL = ( form_data.web.BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL ) - request.app.state.config.BYPASS_WEB_SEARCH_WEB_LOADER = ( - form_data.web.BYPASS_WEB_SEARCH_WEB_LOADER - ) - request.app.state.config.OLLAMA_CLOUD_WEB_SEARCH_API_KEY = ( - form_data.web.OLLAMA_CLOUD_WEB_SEARCH_API_KEY - ) + request.app.state.config.BYPASS_WEB_SEARCH_WEB_LOADER = form_data.web.BYPASS_WEB_SEARCH_WEB_LOADER + request.app.state.config.OLLAMA_CLOUD_WEB_SEARCH_API_KEY = form_data.web.OLLAMA_CLOUD_WEB_SEARCH_API_KEY request.app.state.config.SEARXNG_QUERY_URL = form_data.web.SEARXNG_QUERY_URL request.app.state.config.SEARXNG_LANGUAGE = form_data.web.SEARXNG_LANGUAGE request.app.state.config.YACY_QUERY_URL = form_data.web.YACY_QUERY_URL request.app.state.config.YACY_USERNAME = form_data.web.YACY_USERNAME request.app.state.config.YACY_PASSWORD = form_data.web.YACY_PASSWORD request.app.state.config.GOOGLE_PSE_API_KEY = form_data.web.GOOGLE_PSE_API_KEY - request.app.state.config.GOOGLE_PSE_ENGINE_ID = ( - form_data.web.GOOGLE_PSE_ENGINE_ID - ) - request.app.state.config.BRAVE_SEARCH_API_KEY = ( - form_data.web.BRAVE_SEARCH_API_KEY - ) + request.app.state.config.GOOGLE_PSE_ENGINE_ID = form_data.web.GOOGLE_PSE_ENGINE_ID + request.app.state.config.BRAVE_SEARCH_API_KEY = form_data.web.BRAVE_SEARCH_API_KEY request.app.state.config.KAGI_SEARCH_API_KEY = form_data.web.KAGI_SEARCH_API_KEY - request.app.state.config.MOJEEK_SEARCH_API_KEY = ( - form_data.web.MOJEEK_SEARCH_API_KEY - ) - request.app.state.config.BOCHA_SEARCH_API_KEY = ( - form_data.web.BOCHA_SEARCH_API_KEY - ) + request.app.state.config.MOJEEK_SEARCH_API_KEY = form_data.web.MOJEEK_SEARCH_API_KEY + request.app.state.config.BOCHA_SEARCH_API_KEY = form_data.web.BOCHA_SEARCH_API_KEY request.app.state.config.SERPSTACK_API_KEY = form_data.web.SERPSTACK_API_KEY request.app.state.config.SERPSTACK_HTTPS = form_data.web.SERPSTACK_HTTPS request.app.state.config.SERPER_API_KEY = form_data.web.SERPER_API_KEY @@ -1160,21 +1068,13 @@ async def update_rag_config( request.app.state.config.SERPAPI_ENGINE = form_data.web.SERPAPI_ENGINE request.app.state.config.JINA_API_KEY = form_data.web.JINA_API_KEY request.app.state.config.JINA_API_BASE_URL = form_data.web.JINA_API_BASE_URL - request.app.state.config.BING_SEARCH_V7_ENDPOINT = ( - form_data.web.BING_SEARCH_V7_ENDPOINT - ) - request.app.state.config.BING_SEARCH_V7_SUBSCRIPTION_KEY = ( - form_data.web.BING_SEARCH_V7_SUBSCRIPTION_KEY - ) + request.app.state.config.BING_SEARCH_V7_ENDPOINT = form_data.web.BING_SEARCH_V7_ENDPOINT + request.app.state.config.BING_SEARCH_V7_SUBSCRIPTION_KEY = form_data.web.BING_SEARCH_V7_SUBSCRIPTION_KEY request.app.state.config.EXA_API_KEY = form_data.web.EXA_API_KEY request.app.state.config.PERPLEXITY_API_KEY = form_data.web.PERPLEXITY_API_KEY request.app.state.config.PERPLEXITY_MODEL = form_data.web.PERPLEXITY_MODEL - request.app.state.config.PERPLEXITY_SEARCH_CONTEXT_USAGE = ( - form_data.web.PERPLEXITY_SEARCH_CONTEXT_USAGE - ) - request.app.state.config.PERPLEXITY_SEARCH_API_URL = ( - form_data.web.PERPLEXITY_SEARCH_API_URL - ) + request.app.state.config.PERPLEXITY_SEARCH_CONTEXT_USAGE = form_data.web.PERPLEXITY_SEARCH_CONTEXT_USAGE + request.app.state.config.PERPLEXITY_SEARCH_API_URL = form_data.web.PERPLEXITY_SEARCH_API_URL request.app.state.config.SOUGOU_API_SID = form_data.web.SOUGOU_API_SID request.app.state.config.SOUGOU_API_SK = form_data.web.SOUGOU_API_SK @@ -1182,179 +1082,153 @@ async def update_rag_config( request.app.state.config.WEB_LOADER_ENGINE = form_data.web.WEB_LOADER_ENGINE request.app.state.config.WEB_LOADER_TIMEOUT = form_data.web.WEB_LOADER_TIMEOUT - request.app.state.config.ENABLE_WEB_LOADER_SSL_VERIFICATION = ( - form_data.web.ENABLE_WEB_LOADER_SSL_VERIFICATION - ) + request.app.state.config.ENABLE_WEB_LOADER_SSL_VERIFICATION = form_data.web.ENABLE_WEB_LOADER_SSL_VERIFICATION request.app.state.config.PLAYWRIGHT_WS_URL = form_data.web.PLAYWRIGHT_WS_URL request.app.state.config.PLAYWRIGHT_TIMEOUT = form_data.web.PLAYWRIGHT_TIMEOUT request.app.state.config.FIRECRAWL_API_KEY = form_data.web.FIRECRAWL_API_KEY - request.app.state.config.FIRECRAWL_API_BASE_URL = ( - form_data.web.FIRECRAWL_API_BASE_URL - ) + request.app.state.config.FIRECRAWL_API_BASE_URL = form_data.web.FIRECRAWL_API_BASE_URL request.app.state.config.FIRECRAWL_TIMEOUT = form_data.web.FIRECRAWL_TIMEOUT - request.app.state.config.EXTERNAL_WEB_SEARCH_URL = ( - form_data.web.EXTERNAL_WEB_SEARCH_URL - ) - request.app.state.config.EXTERNAL_WEB_SEARCH_API_KEY = ( - form_data.web.EXTERNAL_WEB_SEARCH_API_KEY - ) - request.app.state.config.EXTERNAL_WEB_LOADER_URL = ( - form_data.web.EXTERNAL_WEB_LOADER_URL - ) - request.app.state.config.EXTERNAL_WEB_LOADER_API_KEY = ( - form_data.web.EXTERNAL_WEB_LOADER_API_KEY - ) - request.app.state.config.TAVILY_EXTRACT_DEPTH = ( - form_data.web.TAVILY_EXTRACT_DEPTH - ) - request.app.state.config.YOUTUBE_LOADER_LANGUAGE = ( - form_data.web.YOUTUBE_LOADER_LANGUAGE - ) - request.app.state.config.YOUTUBE_LOADER_PROXY_URL = ( - form_data.web.YOUTUBE_LOADER_PROXY_URL - ) - request.app.state.YOUTUBE_LOADER_TRANSLATION = ( - form_data.web.YOUTUBE_LOADER_TRANSLATION - ) - request.app.state.config.YANDEX_WEB_SEARCH_URL = ( - form_data.web.YANDEX_WEB_SEARCH_URL - ) - request.app.state.config.YANDEX_WEB_SEARCH_API_KEY = ( - form_data.web.YANDEX_WEB_SEARCH_API_KEY - ) - request.app.state.config.YANDEX_WEB_SEARCH_CONFIG = ( - form_data.web.YANDEX_WEB_SEARCH_CONFIG - ) + request.app.state.config.EXTERNAL_WEB_SEARCH_URL = form_data.web.EXTERNAL_WEB_SEARCH_URL + request.app.state.config.EXTERNAL_WEB_SEARCH_API_KEY = form_data.web.EXTERNAL_WEB_SEARCH_API_KEY + request.app.state.config.EXTERNAL_WEB_LOADER_URL = form_data.web.EXTERNAL_WEB_LOADER_URL + request.app.state.config.EXTERNAL_WEB_LOADER_API_KEY = form_data.web.EXTERNAL_WEB_LOADER_API_KEY + request.app.state.config.TAVILY_EXTRACT_DEPTH = form_data.web.TAVILY_EXTRACT_DEPTH + request.app.state.config.YOUTUBE_LOADER_LANGUAGE = form_data.web.YOUTUBE_LOADER_LANGUAGE + request.app.state.config.YOUTUBE_LOADER_PROXY_URL = form_data.web.YOUTUBE_LOADER_PROXY_URL + request.app.state.YOUTUBE_LOADER_TRANSLATION = form_data.web.YOUTUBE_LOADER_TRANSLATION + request.app.state.config.YANDEX_WEB_SEARCH_URL = form_data.web.YANDEX_WEB_SEARCH_URL + request.app.state.config.YANDEX_WEB_SEARCH_API_KEY = form_data.web.YANDEX_WEB_SEARCH_API_KEY + request.app.state.config.YANDEX_WEB_SEARCH_CONFIG = form_data.web.YANDEX_WEB_SEARCH_CONFIG request.app.state.config.YOUCOM_API_KEY = form_data.web.YOUCOM_API_KEY return { - "status": True, + 'status': True, # RAG settings - "RAG_TEMPLATE": request.app.state.config.RAG_TEMPLATE, - "TOP_K": request.app.state.config.TOP_K, - "BYPASS_EMBEDDING_AND_RETRIEVAL": request.app.state.config.BYPASS_EMBEDDING_AND_RETRIEVAL, - "RAG_FULL_CONTEXT": request.app.state.config.RAG_FULL_CONTEXT, + 'RAG_TEMPLATE': request.app.state.config.RAG_TEMPLATE, + 'TOP_K': request.app.state.config.TOP_K, + 'BYPASS_EMBEDDING_AND_RETRIEVAL': request.app.state.config.BYPASS_EMBEDDING_AND_RETRIEVAL, + 'RAG_FULL_CONTEXT': request.app.state.config.RAG_FULL_CONTEXT, # Hybrid search settings - "ENABLE_RAG_HYBRID_SEARCH": request.app.state.config.ENABLE_RAG_HYBRID_SEARCH, - "TOP_K_RERANKER": request.app.state.config.TOP_K_RERANKER, - "RELEVANCE_THRESHOLD": request.app.state.config.RELEVANCE_THRESHOLD, - "HYBRID_BM25_WEIGHT": request.app.state.config.HYBRID_BM25_WEIGHT, + 'ENABLE_RAG_HYBRID_SEARCH': request.app.state.config.ENABLE_RAG_HYBRID_SEARCH, + 'TOP_K_RERANKER': request.app.state.config.TOP_K_RERANKER, + 'RELEVANCE_THRESHOLD': request.app.state.config.RELEVANCE_THRESHOLD, + 'HYBRID_BM25_WEIGHT': request.app.state.config.HYBRID_BM25_WEIGHT, # Content extraction settings - "CONTENT_EXTRACTION_ENGINE": request.app.state.config.CONTENT_EXTRACTION_ENGINE, - "PDF_EXTRACT_IMAGES": request.app.state.config.PDF_EXTRACT_IMAGES, - "PDF_LOADER_MODE": request.app.state.config.PDF_LOADER_MODE, - "DATALAB_MARKER_API_KEY": request.app.state.config.DATALAB_MARKER_API_KEY, - "DATALAB_MARKER_API_BASE_URL": request.app.state.config.DATALAB_MARKER_API_BASE_URL, - "DATALAB_MARKER_ADDITIONAL_CONFIG": request.app.state.config.DATALAB_MARKER_ADDITIONAL_CONFIG, - "DATALAB_MARKER_SKIP_CACHE": request.app.state.config.DATALAB_MARKER_SKIP_CACHE, - "DATALAB_MARKER_FORCE_OCR": request.app.state.config.DATALAB_MARKER_FORCE_OCR, - "DATALAB_MARKER_PAGINATE": request.app.state.config.DATALAB_MARKER_PAGINATE, - "DATALAB_MARKER_STRIP_EXISTING_OCR": request.app.state.config.DATALAB_MARKER_STRIP_EXISTING_OCR, - "DATALAB_MARKER_DISABLE_IMAGE_EXTRACTION": request.app.state.config.DATALAB_MARKER_DISABLE_IMAGE_EXTRACTION, - "DATALAB_MARKER_USE_LLM": request.app.state.config.DATALAB_MARKER_USE_LLM, - "DATALAB_MARKER_OUTPUT_FORMAT": request.app.state.config.DATALAB_MARKER_OUTPUT_FORMAT, - "EXTERNAL_DOCUMENT_LOADER_URL": request.app.state.config.EXTERNAL_DOCUMENT_LOADER_URL, - "EXTERNAL_DOCUMENT_LOADER_API_KEY": request.app.state.config.EXTERNAL_DOCUMENT_LOADER_API_KEY, - "TIKA_SERVER_URL": request.app.state.config.TIKA_SERVER_URL, - "DOCLING_SERVER_URL": request.app.state.config.DOCLING_SERVER_URL, - "DOCLING_API_KEY": request.app.state.config.DOCLING_API_KEY, - "DOCLING_PARAMS": request.app.state.config.DOCLING_PARAMS, - "DOCUMENT_INTELLIGENCE_ENDPOINT": request.app.state.config.DOCUMENT_INTELLIGENCE_ENDPOINT, - "DOCUMENT_INTELLIGENCE_KEY": request.app.state.config.DOCUMENT_INTELLIGENCE_KEY, - "DOCUMENT_INTELLIGENCE_MODEL": request.app.state.config.DOCUMENT_INTELLIGENCE_MODEL, - "MISTRAL_OCR_API_BASE_URL": request.app.state.config.MISTRAL_OCR_API_BASE_URL, - "MISTRAL_OCR_API_KEY": request.app.state.config.MISTRAL_OCR_API_KEY, + 'CONTENT_EXTRACTION_ENGINE': request.app.state.config.CONTENT_EXTRACTION_ENGINE, + 'PDF_EXTRACT_IMAGES': request.app.state.config.PDF_EXTRACT_IMAGES, + 'PDF_LOADER_MODE': request.app.state.config.PDF_LOADER_MODE, + 'DATALAB_MARKER_API_KEY': request.app.state.config.DATALAB_MARKER_API_KEY, + 'DATALAB_MARKER_API_BASE_URL': request.app.state.config.DATALAB_MARKER_API_BASE_URL, + 'DATALAB_MARKER_ADDITIONAL_CONFIG': request.app.state.config.DATALAB_MARKER_ADDITIONAL_CONFIG, + 'DATALAB_MARKER_SKIP_CACHE': request.app.state.config.DATALAB_MARKER_SKIP_CACHE, + 'DATALAB_MARKER_FORCE_OCR': request.app.state.config.DATALAB_MARKER_FORCE_OCR, + 'DATALAB_MARKER_PAGINATE': request.app.state.config.DATALAB_MARKER_PAGINATE, + 'DATALAB_MARKER_STRIP_EXISTING_OCR': request.app.state.config.DATALAB_MARKER_STRIP_EXISTING_OCR, + 'DATALAB_MARKER_DISABLE_IMAGE_EXTRACTION': request.app.state.config.DATALAB_MARKER_DISABLE_IMAGE_EXTRACTION, + 'DATALAB_MARKER_USE_LLM': request.app.state.config.DATALAB_MARKER_USE_LLM, + 'DATALAB_MARKER_OUTPUT_FORMAT': request.app.state.config.DATALAB_MARKER_OUTPUT_FORMAT, + 'EXTERNAL_DOCUMENT_LOADER_URL': request.app.state.config.EXTERNAL_DOCUMENT_LOADER_URL, + 'EXTERNAL_DOCUMENT_LOADER_API_KEY': request.app.state.config.EXTERNAL_DOCUMENT_LOADER_API_KEY, + 'TIKA_SERVER_URL': request.app.state.config.TIKA_SERVER_URL, + 'DOCLING_SERVER_URL': request.app.state.config.DOCLING_SERVER_URL, + 'DOCLING_API_KEY': request.app.state.config.DOCLING_API_KEY, + 'DOCLING_PARAMS': request.app.state.config.DOCLING_PARAMS, + 'DOCUMENT_INTELLIGENCE_ENDPOINT': request.app.state.config.DOCUMENT_INTELLIGENCE_ENDPOINT, + 'DOCUMENT_INTELLIGENCE_KEY': request.app.state.config.DOCUMENT_INTELLIGENCE_KEY, + 'DOCUMENT_INTELLIGENCE_MODEL': request.app.state.config.DOCUMENT_INTELLIGENCE_MODEL, + 'MISTRAL_OCR_API_BASE_URL': request.app.state.config.MISTRAL_OCR_API_BASE_URL, + 'MISTRAL_OCR_API_KEY': request.app.state.config.MISTRAL_OCR_API_KEY, # MinerU settings - "MINERU_API_MODE": request.app.state.config.MINERU_API_MODE, - "MINERU_API_URL": request.app.state.config.MINERU_API_URL, - "MINERU_API_KEY": request.app.state.config.MINERU_API_KEY, - "MINERU_API_TIMEOUT": request.app.state.config.MINERU_API_TIMEOUT, - "MINERU_PARAMS": request.app.state.config.MINERU_PARAMS, + 'MINERU_API_MODE': request.app.state.config.MINERU_API_MODE, + 'MINERU_API_URL': request.app.state.config.MINERU_API_URL, + 'MINERU_API_KEY': request.app.state.config.MINERU_API_KEY, + 'MINERU_API_TIMEOUT': request.app.state.config.MINERU_API_TIMEOUT, + 'MINERU_PARAMS': request.app.state.config.MINERU_PARAMS, # Reranking settings - "RAG_RERANKING_MODEL": request.app.state.config.RAG_RERANKING_MODEL, - "RAG_RERANKING_ENGINE": request.app.state.config.RAG_RERANKING_ENGINE, - "RAG_EXTERNAL_RERANKER_URL": request.app.state.config.RAG_EXTERNAL_RERANKER_URL, - "RAG_EXTERNAL_RERANKER_API_KEY": request.app.state.config.RAG_EXTERNAL_RERANKER_API_KEY, - "RAG_EXTERNAL_RERANKER_TIMEOUT": request.app.state.config.RAG_EXTERNAL_RERANKER_TIMEOUT, + 'RAG_RERANKING_MODEL': request.app.state.config.RAG_RERANKING_MODEL, + 'RAG_RERANKING_ENGINE': request.app.state.config.RAG_RERANKING_ENGINE, + 'RAG_EXTERNAL_RERANKER_URL': request.app.state.config.RAG_EXTERNAL_RERANKER_URL, + 'RAG_EXTERNAL_RERANKER_API_KEY': request.app.state.config.RAG_EXTERNAL_RERANKER_API_KEY, + 'RAG_EXTERNAL_RERANKER_TIMEOUT': request.app.state.config.RAG_EXTERNAL_RERANKER_TIMEOUT, # Chunking settings - "TEXT_SPLITTER": request.app.state.config.TEXT_SPLITTER, - "CHUNK_SIZE": request.app.state.config.CHUNK_SIZE, - "CHUNK_MIN_SIZE_TARGET": request.app.state.config.CHUNK_MIN_SIZE_TARGET, - "ENABLE_MARKDOWN_HEADER_TEXT_SPLITTER": request.app.state.config.ENABLE_MARKDOWN_HEADER_TEXT_SPLITTER, - "CHUNK_OVERLAP": request.app.state.config.CHUNK_OVERLAP, + 'TEXT_SPLITTER': request.app.state.config.TEXT_SPLITTER, + 'CHUNK_SIZE': request.app.state.config.CHUNK_SIZE, + 'CHUNK_MIN_SIZE_TARGET': request.app.state.config.CHUNK_MIN_SIZE_TARGET, + 'ENABLE_MARKDOWN_HEADER_TEXT_SPLITTER': request.app.state.config.ENABLE_MARKDOWN_HEADER_TEXT_SPLITTER, + 'CHUNK_OVERLAP': request.app.state.config.CHUNK_OVERLAP, # File upload settings - "FILE_MAX_SIZE": request.app.state.config.FILE_MAX_SIZE, - "FILE_MAX_COUNT": request.app.state.config.FILE_MAX_COUNT, - "FILE_IMAGE_COMPRESSION_WIDTH": request.app.state.config.FILE_IMAGE_COMPRESSION_WIDTH, - "FILE_IMAGE_COMPRESSION_HEIGHT": request.app.state.config.FILE_IMAGE_COMPRESSION_HEIGHT, - "ALLOWED_FILE_EXTENSIONS": request.app.state.config.ALLOWED_FILE_EXTENSIONS, + 'FILE_MAX_SIZE': request.app.state.config.FILE_MAX_SIZE, + 'FILE_MAX_COUNT': request.app.state.config.FILE_MAX_COUNT, + 'FILE_IMAGE_COMPRESSION_WIDTH': request.app.state.config.FILE_IMAGE_COMPRESSION_WIDTH, + 'FILE_IMAGE_COMPRESSION_HEIGHT': request.app.state.config.FILE_IMAGE_COMPRESSION_HEIGHT, + 'ALLOWED_FILE_EXTENSIONS': request.app.state.config.ALLOWED_FILE_EXTENSIONS, # Integration settings - "ENABLE_GOOGLE_DRIVE_INTEGRATION": request.app.state.config.ENABLE_GOOGLE_DRIVE_INTEGRATION, - "ENABLE_ONEDRIVE_INTEGRATION": request.app.state.config.ENABLE_ONEDRIVE_INTEGRATION, + 'ENABLE_GOOGLE_DRIVE_INTEGRATION': request.app.state.config.ENABLE_GOOGLE_DRIVE_INTEGRATION, + 'ENABLE_ONEDRIVE_INTEGRATION': request.app.state.config.ENABLE_ONEDRIVE_INTEGRATION, # Web search settings - "web": { - "ENABLE_WEB_SEARCH": request.app.state.config.ENABLE_WEB_SEARCH, - "WEB_SEARCH_ENGINE": request.app.state.config.WEB_SEARCH_ENGINE, - "WEB_SEARCH_TRUST_ENV": request.app.state.config.WEB_SEARCH_TRUST_ENV, - "WEB_SEARCH_RESULT_COUNT": request.app.state.config.WEB_SEARCH_RESULT_COUNT, - "WEB_SEARCH_CONCURRENT_REQUESTS": request.app.state.config.WEB_SEARCH_CONCURRENT_REQUESTS, - "FETCH_URL_MAX_CONTENT_LENGTH": request.app.state.config.FETCH_URL_MAX_CONTENT_LENGTH, - "WEB_LOADER_CONCURRENT_REQUESTS": request.app.state.config.WEB_LOADER_CONCURRENT_REQUESTS, - "WEB_SEARCH_DOMAIN_FILTER_LIST": request.app.state.config.WEB_SEARCH_DOMAIN_FILTER_LIST, - "BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL": request.app.state.config.BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL, - "BYPASS_WEB_SEARCH_WEB_LOADER": request.app.state.config.BYPASS_WEB_SEARCH_WEB_LOADER, - "OLLAMA_CLOUD_WEB_SEARCH_API_KEY": request.app.state.config.OLLAMA_CLOUD_WEB_SEARCH_API_KEY, - "SEARXNG_QUERY_URL": request.app.state.config.SEARXNG_QUERY_URL, - "SEARXNG_LANGUAGE": request.app.state.config.SEARXNG_LANGUAGE, - "YACY_QUERY_URL": request.app.state.config.YACY_QUERY_URL, - "YACY_USERNAME": request.app.state.config.YACY_USERNAME, - "YACY_PASSWORD": request.app.state.config.YACY_PASSWORD, - "GOOGLE_PSE_API_KEY": request.app.state.config.GOOGLE_PSE_API_KEY, - "GOOGLE_PSE_ENGINE_ID": request.app.state.config.GOOGLE_PSE_ENGINE_ID, - "BRAVE_SEARCH_API_KEY": request.app.state.config.BRAVE_SEARCH_API_KEY, - "KAGI_SEARCH_API_KEY": request.app.state.config.KAGI_SEARCH_API_KEY, - "MOJEEK_SEARCH_API_KEY": request.app.state.config.MOJEEK_SEARCH_API_KEY, - "BOCHA_SEARCH_API_KEY": request.app.state.config.BOCHA_SEARCH_API_KEY, - "SERPSTACK_API_KEY": request.app.state.config.SERPSTACK_API_KEY, - "SERPSTACK_HTTPS": request.app.state.config.SERPSTACK_HTTPS, - "SERPER_API_KEY": request.app.state.config.SERPER_API_KEY, - "SERPLY_API_KEY": request.app.state.config.SERPLY_API_KEY, - "TAVILY_API_KEY": request.app.state.config.TAVILY_API_KEY, - "SEARCHAPI_API_KEY": request.app.state.config.SEARCHAPI_API_KEY, - "SEARCHAPI_ENGINE": request.app.state.config.SEARCHAPI_ENGINE, - "SERPAPI_API_KEY": request.app.state.config.SERPAPI_API_KEY, - "SERPAPI_ENGINE": request.app.state.config.SERPAPI_ENGINE, - "JINA_API_KEY": request.app.state.config.JINA_API_KEY, - "JINA_API_BASE_URL": request.app.state.config.JINA_API_BASE_URL, - "BING_SEARCH_V7_ENDPOINT": request.app.state.config.BING_SEARCH_V7_ENDPOINT, - "BING_SEARCH_V7_SUBSCRIPTION_KEY": request.app.state.config.BING_SEARCH_V7_SUBSCRIPTION_KEY, - "EXA_API_KEY": request.app.state.config.EXA_API_KEY, - "PERPLEXITY_API_KEY": request.app.state.config.PERPLEXITY_API_KEY, - "PERPLEXITY_MODEL": request.app.state.config.PERPLEXITY_MODEL, - "PERPLEXITY_SEARCH_CONTEXT_USAGE": request.app.state.config.PERPLEXITY_SEARCH_CONTEXT_USAGE, - "PERPLEXITY_SEARCH_API_URL": request.app.state.config.PERPLEXITY_SEARCH_API_URL, - "SOUGOU_API_SID": request.app.state.config.SOUGOU_API_SID, - "SOUGOU_API_SK": request.app.state.config.SOUGOU_API_SK, - "WEB_LOADER_ENGINE": request.app.state.config.WEB_LOADER_ENGINE, - "WEB_LOADER_TIMEOUT": request.app.state.config.WEB_LOADER_TIMEOUT, - "ENABLE_WEB_LOADER_SSL_VERIFICATION": request.app.state.config.ENABLE_WEB_LOADER_SSL_VERIFICATION, - "PLAYWRIGHT_WS_URL": request.app.state.config.PLAYWRIGHT_WS_URL, - "PLAYWRIGHT_TIMEOUT": request.app.state.config.PLAYWRIGHT_TIMEOUT, - "FIRECRAWL_API_KEY": request.app.state.config.FIRECRAWL_API_KEY, - "FIRECRAWL_API_BASE_URL": request.app.state.config.FIRECRAWL_API_BASE_URL, - "FIRECRAWL_TIMEOUT": request.app.state.config.FIRECRAWL_TIMEOUT, - "TAVILY_EXTRACT_DEPTH": request.app.state.config.TAVILY_EXTRACT_DEPTH, - "EXTERNAL_WEB_SEARCH_URL": request.app.state.config.EXTERNAL_WEB_SEARCH_URL, - "EXTERNAL_WEB_SEARCH_API_KEY": request.app.state.config.EXTERNAL_WEB_SEARCH_API_KEY, - "EXTERNAL_WEB_LOADER_URL": request.app.state.config.EXTERNAL_WEB_LOADER_URL, - "EXTERNAL_WEB_LOADER_API_KEY": request.app.state.config.EXTERNAL_WEB_LOADER_API_KEY, - "YOUTUBE_LOADER_LANGUAGE": request.app.state.config.YOUTUBE_LOADER_LANGUAGE, - "YOUTUBE_LOADER_PROXY_URL": request.app.state.config.YOUTUBE_LOADER_PROXY_URL, - "YOUTUBE_LOADER_TRANSLATION": request.app.state.YOUTUBE_LOADER_TRANSLATION, - "YANDEX_WEB_SEARCH_URL": request.app.state.config.YANDEX_WEB_SEARCH_URL, - "YANDEX_WEB_SEARCH_API_KEY": request.app.state.config.YANDEX_WEB_SEARCH_API_KEY, - "YANDEX_WEB_SEARCH_CONFIG": request.app.state.config.YANDEX_WEB_SEARCH_CONFIG, - "YOUCOM_API_KEY": request.app.state.config.YOUCOM_API_KEY, + 'web': { + 'ENABLE_WEB_SEARCH': request.app.state.config.ENABLE_WEB_SEARCH, + 'WEB_SEARCH_ENGINE': request.app.state.config.WEB_SEARCH_ENGINE, + 'WEB_SEARCH_TRUST_ENV': request.app.state.config.WEB_SEARCH_TRUST_ENV, + 'WEB_SEARCH_RESULT_COUNT': request.app.state.config.WEB_SEARCH_RESULT_COUNT, + 'WEB_SEARCH_CONCURRENT_REQUESTS': request.app.state.config.WEB_SEARCH_CONCURRENT_REQUESTS, + 'FETCH_URL_MAX_CONTENT_LENGTH': request.app.state.config.FETCH_URL_MAX_CONTENT_LENGTH, + 'WEB_LOADER_CONCURRENT_REQUESTS': request.app.state.config.WEB_LOADER_CONCURRENT_REQUESTS, + 'WEB_SEARCH_DOMAIN_FILTER_LIST': request.app.state.config.WEB_SEARCH_DOMAIN_FILTER_LIST, + 'BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL': request.app.state.config.BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL, + 'BYPASS_WEB_SEARCH_WEB_LOADER': request.app.state.config.BYPASS_WEB_SEARCH_WEB_LOADER, + 'OLLAMA_CLOUD_WEB_SEARCH_API_KEY': request.app.state.config.OLLAMA_CLOUD_WEB_SEARCH_API_KEY, + 'SEARXNG_QUERY_URL': request.app.state.config.SEARXNG_QUERY_URL, + 'SEARXNG_LANGUAGE': request.app.state.config.SEARXNG_LANGUAGE, + 'YACY_QUERY_URL': request.app.state.config.YACY_QUERY_URL, + 'YACY_USERNAME': request.app.state.config.YACY_USERNAME, + 'YACY_PASSWORD': request.app.state.config.YACY_PASSWORD, + 'GOOGLE_PSE_API_KEY': request.app.state.config.GOOGLE_PSE_API_KEY, + 'GOOGLE_PSE_ENGINE_ID': request.app.state.config.GOOGLE_PSE_ENGINE_ID, + 'BRAVE_SEARCH_API_KEY': request.app.state.config.BRAVE_SEARCH_API_KEY, + 'KAGI_SEARCH_API_KEY': request.app.state.config.KAGI_SEARCH_API_KEY, + 'MOJEEK_SEARCH_API_KEY': request.app.state.config.MOJEEK_SEARCH_API_KEY, + 'BOCHA_SEARCH_API_KEY': request.app.state.config.BOCHA_SEARCH_API_KEY, + 'SERPSTACK_API_KEY': request.app.state.config.SERPSTACK_API_KEY, + 'SERPSTACK_HTTPS': request.app.state.config.SERPSTACK_HTTPS, + 'SERPER_API_KEY': request.app.state.config.SERPER_API_KEY, + 'SERPLY_API_KEY': request.app.state.config.SERPLY_API_KEY, + 'TAVILY_API_KEY': request.app.state.config.TAVILY_API_KEY, + 'SEARCHAPI_API_KEY': request.app.state.config.SEARCHAPI_API_KEY, + 'SEARCHAPI_ENGINE': request.app.state.config.SEARCHAPI_ENGINE, + 'SERPAPI_API_KEY': request.app.state.config.SERPAPI_API_KEY, + 'SERPAPI_ENGINE': request.app.state.config.SERPAPI_ENGINE, + 'JINA_API_KEY': request.app.state.config.JINA_API_KEY, + 'JINA_API_BASE_URL': request.app.state.config.JINA_API_BASE_URL, + 'BING_SEARCH_V7_ENDPOINT': request.app.state.config.BING_SEARCH_V7_ENDPOINT, + 'BING_SEARCH_V7_SUBSCRIPTION_KEY': request.app.state.config.BING_SEARCH_V7_SUBSCRIPTION_KEY, + 'EXA_API_KEY': request.app.state.config.EXA_API_KEY, + 'PERPLEXITY_API_KEY': request.app.state.config.PERPLEXITY_API_KEY, + 'PERPLEXITY_MODEL': request.app.state.config.PERPLEXITY_MODEL, + 'PERPLEXITY_SEARCH_CONTEXT_USAGE': request.app.state.config.PERPLEXITY_SEARCH_CONTEXT_USAGE, + 'PERPLEXITY_SEARCH_API_URL': request.app.state.config.PERPLEXITY_SEARCH_API_URL, + 'SOUGOU_API_SID': request.app.state.config.SOUGOU_API_SID, + 'SOUGOU_API_SK': request.app.state.config.SOUGOU_API_SK, + 'WEB_LOADER_ENGINE': request.app.state.config.WEB_LOADER_ENGINE, + 'WEB_LOADER_TIMEOUT': request.app.state.config.WEB_LOADER_TIMEOUT, + 'ENABLE_WEB_LOADER_SSL_VERIFICATION': request.app.state.config.ENABLE_WEB_LOADER_SSL_VERIFICATION, + 'PLAYWRIGHT_WS_URL': request.app.state.config.PLAYWRIGHT_WS_URL, + 'PLAYWRIGHT_TIMEOUT': request.app.state.config.PLAYWRIGHT_TIMEOUT, + 'FIRECRAWL_API_KEY': request.app.state.config.FIRECRAWL_API_KEY, + 'FIRECRAWL_API_BASE_URL': request.app.state.config.FIRECRAWL_API_BASE_URL, + 'FIRECRAWL_TIMEOUT': request.app.state.config.FIRECRAWL_TIMEOUT, + 'TAVILY_EXTRACT_DEPTH': request.app.state.config.TAVILY_EXTRACT_DEPTH, + 'EXTERNAL_WEB_SEARCH_URL': request.app.state.config.EXTERNAL_WEB_SEARCH_URL, + 'EXTERNAL_WEB_SEARCH_API_KEY': request.app.state.config.EXTERNAL_WEB_SEARCH_API_KEY, + 'EXTERNAL_WEB_LOADER_URL': request.app.state.config.EXTERNAL_WEB_LOADER_URL, + 'EXTERNAL_WEB_LOADER_API_KEY': request.app.state.config.EXTERNAL_WEB_LOADER_API_KEY, + 'YOUTUBE_LOADER_LANGUAGE': request.app.state.config.YOUTUBE_LOADER_LANGUAGE, + 'YOUTUBE_LOADER_PROXY_URL': request.app.state.config.YOUTUBE_LOADER_PROXY_URL, + 'YOUTUBE_LOADER_TRANSLATION': request.app.state.YOUTUBE_LOADER_TRANSLATION, + 'YANDEX_WEB_SEARCH_URL': request.app.state.config.YANDEX_WEB_SEARCH_URL, + 'YANDEX_WEB_SEARCH_API_KEY': request.app.state.config.YANDEX_WEB_SEARCH_API_KEY, + 'YANDEX_WEB_SEARCH_CONFIG': request.app.state.config.YANDEX_WEB_SEARCH_CONFIG, + 'YOUCOM_API_KEY': request.app.state.config.YOUCOM_API_KEY, }, } @@ -1367,11 +1241,11 @@ async def update_rag_config( def can_merge_chunks(a: Document, b: Document) -> bool: - if a.metadata.get("source") != b.metadata.get("source"): + if a.metadata.get('source') != b.metadata.get('source'): return False - a_file_id = a.metadata.get("file_id") - b_file_id = b.metadata.get("file_id") + a_file_id = a.metadata.get('file_id') + b_file_id = b.metadata.get('file_id') if a_file_id is not None and b_file_id is not None: return a_file_id == b_file_id @@ -1397,16 +1271,14 @@ def merge_docs_to_target_size( return chunks measure_chunk_size = len - if request.app.state.config.TEXT_SPLITTER == "token": - encoding = tiktoken.get_encoding( - str(request.app.state.config.TIKTOKEN_ENCODING_NAME) - ) + if request.app.state.config.TEXT_SPLITTER == 'token': + encoding = tiktoken.get_encoding(str(request.app.state.config.TIKTOKEN_ENCODING_NAME)) measure_chunk_size = lambda text: len(encoding.encode(text)) processed_chunks: list[Document] = [] current_chunk: Document | None = None - current_content: str = "" + current_content: str = '' for next_chunk in chunks: if current_chunk is None: @@ -1414,7 +1286,7 @@ def merge_docs_to_target_size( current_content = next_chunk.page_content continue # First chunk initialization - proposed_content = f"{current_content}\n\n{next_chunk.page_content}" + proposed_content = f'{current_content}\n\n{next_chunk.page_content}' can_merge = ( can_merge_chunks(current_chunk, next_chunk) @@ -1460,26 +1332,24 @@ def save_docs_to_vector_db( # Trying to select relevant metadata identifying the document. for doc in docs: - metadata = getattr(doc, "metadata", {}) - doc_name = metadata.get("name", "") + metadata = getattr(doc, 'metadata', {}) + doc_name = metadata.get('name', '') if not doc_name: - doc_name = metadata.get("title", "") + doc_name = metadata.get('title', '') if not doc_name: - doc_name = metadata.get("source", "") + doc_name = metadata.get('source', '') if doc_name: docs_info.add(doc_name) - return ", ".join(docs_info) + return ', '.join(docs_info) - log.debug( - f"save_docs_to_vector_db: document {_get_docs_info(docs)} {collection_name}" - ) + log.debug(f'save_docs_to_vector_db: document {_get_docs_info(docs)} {collection_name}') # Check if entries with the same hash (metadata.hash) already exist - if metadata and "hash" in metadata: + if metadata and 'hash' in metadata: result = VECTOR_DB_CLIENT.query( collection_name=collection_name, - filter={"hash": metadata["hash"]}, + filter={'hash': metadata['hash']}, ) if result is not None and result.ids and len(result.ids) > 0: @@ -1490,24 +1360,24 @@ def save_docs_to_vector_db( # If different file_id, this is a duplicate - block it existing_file_id = None if result.metadatas and result.metadatas[0]: - existing_file_id = result.metadatas[0][0].get("file_id") + existing_file_id = result.metadatas[0][0].get('file_id') - if existing_file_id != metadata.get("file_id"): - log.info(f"Document with hash {metadata['hash']} already exists") + if existing_file_id != metadata.get('file_id'): + log.info(f'Document with hash {metadata["hash"]} already exists') raise ValueError(ERROR_MESSAGES.DUPLICATE_CONTENT) if split: if request.app.state.config.ENABLE_MARKDOWN_HEADER_TEXT_SPLITTER: - log.info("Using markdown header text splitter") + log.info('Using markdown header text splitter') # Define headers to split on - covering most common markdown header levels markdown_splitter = MarkdownHeaderTextSplitter( headers_to_split_on=[ - ("#", "Header 1"), - ("##", "Header 2"), - ("###", "Header 3"), - ("####", "Header 4"), - ("#####", "Header 5"), - ("######", "Header 6"), + ('#', 'Header 1'), + ('##', 'Header 2'), + ('###', 'Header 3'), + ('####', 'Header 4'), + ('#####', 'Header 5'), + ('######', 'Header 6'), ], strip_headers=False, # Keep headers in content for context ) @@ -1520,9 +1390,7 @@ def save_docs_to_vector_db( page_content=split_chunk.page_content, metadata={**doc.metadata}, ) - for split_chunk in markdown_splitter.split_text( - doc.page_content - ) + for split_chunk in markdown_splitter.split_text(doc.page_content) ] ) @@ -1530,17 +1398,15 @@ def save_docs_to_vector_db( if request.app.state.config.CHUNK_MIN_SIZE_TARGET > 0: docs = merge_docs_to_target_size(request, docs) - if request.app.state.config.TEXT_SPLITTER in ["", "character"]: + if request.app.state.config.TEXT_SPLITTER in ['', 'character']: text_splitter = RecursiveCharacterTextSplitter( chunk_size=request.app.state.config.CHUNK_SIZE, chunk_overlap=request.app.state.config.CHUNK_OVERLAP, add_start_index=True, ) docs = text_splitter.split_documents(docs) - elif request.app.state.config.TEXT_SPLITTER == "token": - log.info( - f"Using token text splitter: {request.app.state.config.TIKTOKEN_ENCODING_NAME}" - ) + elif request.app.state.config.TEXT_SPLITTER == 'token': + log.info(f'Using token text splitter: {request.app.state.config.TIKTOKEN_ENCODING_NAME}') tiktoken.get_encoding(str(request.app.state.config.TIKTOKEN_ENCODING_NAME)) text_splitter = TokenTextSplitter( @@ -1551,7 +1417,7 @@ def save_docs_to_vector_db( ) docs = text_splitter.split_documents(docs) else: - raise ValueError(ERROR_MESSAGES.DEFAULT("Invalid text splitter")) + raise ValueError(ERROR_MESSAGES.DEFAULT('Invalid text splitter')) if len(docs) == 0: raise ValueError(ERROR_MESSAGES.EMPTY_CONTENT) @@ -1561,9 +1427,9 @@ def save_docs_to_vector_db( { **doc.metadata, **(metadata if metadata else {}), - "embedding_config": { - "engine": request.app.state.config.RAG_EMBEDDING_ENGINE, - "model": request.app.state.config.RAG_EMBEDDING_MODEL, + 'embedding_config': { + 'engine': request.app.state.config.RAG_EMBEDDING_ENGINE, + 'model': request.app.state.config.RAG_EMBEDDING_MODEL, }, } for doc in docs @@ -1571,44 +1437,42 @@ def save_docs_to_vector_db( try: if VECTOR_DB_CLIENT.has_collection(collection_name=collection_name): - log.info(f"collection {collection_name} already exists") + log.info(f'collection {collection_name} already exists') if overwrite: VECTOR_DB_CLIENT.delete_collection(collection_name=collection_name) - log.info(f"deleting existing collection {collection_name}") + log.info(f'deleting existing collection {collection_name}') elif add is False: - log.info( - f"collection {collection_name} already exists, overwrite is False and add is False" - ) + log.info(f'collection {collection_name} already exists, overwrite is False and add is False') return True - log.info(f"generating embeddings for {collection_name}") + log.info(f'generating embeddings for {collection_name}') embedding_function = get_embedding_function( request.app.state.config.RAG_EMBEDDING_ENGINE, request.app.state.config.RAG_EMBEDDING_MODEL, request.app.state.ef, ( request.app.state.config.RAG_OPENAI_API_BASE_URL - if request.app.state.config.RAG_EMBEDDING_ENGINE == "openai" + if request.app.state.config.RAG_EMBEDDING_ENGINE == 'openai' else ( request.app.state.config.RAG_OLLAMA_BASE_URL - if request.app.state.config.RAG_EMBEDDING_ENGINE == "ollama" + if request.app.state.config.RAG_EMBEDDING_ENGINE == 'ollama' else request.app.state.config.RAG_AZURE_OPENAI_BASE_URL ) ), ( request.app.state.config.RAG_OPENAI_API_KEY - if request.app.state.config.RAG_EMBEDDING_ENGINE == "openai" + if request.app.state.config.RAG_EMBEDDING_ENGINE == 'openai' else ( request.app.state.config.RAG_OLLAMA_API_KEY - if request.app.state.config.RAG_EMBEDDING_ENGINE == "ollama" + if request.app.state.config.RAG_EMBEDDING_ENGINE == 'ollama' else request.app.state.config.RAG_AZURE_OPENAI_API_KEY ) ), request.app.state.config.RAG_EMBEDDING_BATCH_SIZE, azure_api_version=( request.app.state.config.RAG_AZURE_OPENAI_API_VERSION - if request.app.state.config.RAG_EMBEDDING_ENGINE == "azure_openai" + if request.app.state.config.RAG_EMBEDDING_ENGINE == 'azure_openai' else None ), enable_async=request.app.state.config.ENABLE_ASYNC_EMBEDDING, @@ -1621,32 +1485,32 @@ def save_docs_to_vector_db( future = asyncio.run_coroutine_threadsafe( embedding_function( - list(map(lambda x: x.replace("\n", " "), texts)), + list(map(lambda x: x.replace('\n', ' '), texts)), prefix=RAG_EMBEDDING_CONTENT_PREFIX, user=user, ), request.app.state.main_loop, ) embeddings = future.result(timeout=embedding_timeout) - log.info(f"embeddings generated {len(embeddings)} for {len(texts)} items") + log.info(f'embeddings generated {len(embeddings)} for {len(texts)} items') items = [ { - "id": str(uuid.uuid4()), - "text": text, - "vector": embeddings[idx], - "metadata": metadatas[idx], + 'id': str(uuid.uuid4()), + 'text': text, + 'vector': embeddings[idx], + 'metadata': metadatas[idx], } for idx, text in enumerate(texts) ] - log.info(f"adding to collection {collection_name}") + log.info(f'adding to collection {collection_name}') VECTOR_DB_CLIENT.insert( collection_name=collection_name, items=items, ) - log.info(f"added {len(items)} items to collection {collection_name}") + log.info(f'added {len(items)} items to collection {collection_name}') return True except Exception as e: log.exception(e) @@ -1659,7 +1523,7 @@ class ProcessFileForm(BaseModel): collection_name: Optional[str] = None -@router.post("/process/file") +@router.post('/process/file') def process_file( request: Request, form_data: ProcessFileForm, @@ -1672,18 +1536,17 @@ def process_file( Note: granular session management is used to prevent connection pool exhaustion. The session is committed before external API calls, and updates use a fresh session. """ - if user.role == "admin": + if user.role == 'admin': file = Files.get_file_by_id(form_data.file_id, db=db) else: file = Files.get_file_by_id_and_user_id(form_data.file_id, user.id, db=db) if file: try: - collection_name = form_data.collection_name if collection_name is None: - collection_name = f"file-{file.id}" + collection_name = f'file-{file.id}' if form_data.content: # Update the content in the file @@ -1691,22 +1554,20 @@ def process_file( try: # /files/{file_id}/data/content/update - VECTOR_DB_CLIENT.delete_collection( - collection_name=f"file-{file.id}" - ) + VECTOR_DB_CLIENT.delete_collection(collection_name=f'file-{file.id}') except Exception: # Audio file upload pipeline pass docs = [ Document( - page_content=form_data.content.replace("
", "\n"), + page_content=form_data.content.replace('
', '\n'), metadata={ **file.meta, - "name": file.filename, - "created_by": file.user_id, - "file_id": file.id, - "source": file.filename, + 'name': file.filename, + 'created_by': file.user_id, + 'file_id': file.id, + 'source': file.filename, }, ) ] @@ -1716,9 +1577,7 @@ def process_file( # Check if the file has already been processed and save the content # Usage: /knowledge/{id}/file/add, /knowledge/{id}/file/update - result = VECTOR_DB_CLIENT.query( - collection_name=f"file-{file.id}", filter={"file_id": file.id} - ) + result = VECTOR_DB_CLIENT.query(collection_name=f'file-{file.id}', filter={'file_id': file.id}) if result is not None and len(result.ids[0]) > 0: docs = [ @@ -1731,18 +1590,18 @@ def process_file( else: docs = [ Document( - page_content=file.data.get("content", ""), + page_content=file.data.get('content', ''), metadata={ **file.meta, - "name": file.filename, - "created_by": file.user_id, - "file_id": file.id, - "source": file.filename, + 'name': file.filename, + 'created_by': file.user_id, + 'file_id': file.id, + 'source': file.filename, }, ) ] - text_content = file.data.get("content", "") + text_content = file.data.get('content', '') else: # Process the file and save the content # Usage: /files/ @@ -1782,19 +1641,17 @@ def process_file( MINERU_API_TIMEOUT=request.app.state.config.MINERU_API_TIMEOUT, MINERU_PARAMS=request.app.state.config.MINERU_PARAMS, ) - docs = loader.load( - file.filename, file.meta.get("content_type"), file_path - ) + docs = loader.load(file.filename, file.meta.get('content_type'), file_path) docs = [ Document( page_content=doc.page_content, metadata={ **filter_metadata(doc.metadata), - "name": file.filename, - "created_by": file.user_id, - "file_id": file.id, - "source": file.filename, + 'name': file.filename, + 'created_by': file.user_id, + 'file_id': file.id, + 'source': file.filename, }, ) for doc in docs @@ -1802,34 +1659,34 @@ def process_file( else: docs = [ Document( - page_content=file.data.get("content", ""), + page_content=file.data.get('content', ''), metadata={ **file.meta, - "name": file.filename, - "created_by": file.user_id, - "file_id": file.id, - "source": file.filename, + 'name': file.filename, + 'created_by': file.user_id, + 'file_id': file.id, + 'source': file.filename, }, ) ] - text_content = " ".join([doc.page_content for doc in docs]) + text_content = ' '.join([doc.page_content for doc in docs]) - log.debug(f"text_content: {text_content}") + log.debug(f'text_content: {text_content}') Files.update_file_data_by_id( file.id, - {"content": text_content}, + {'content': text_content}, db=db, ) hash = calculate_sha256_string(text_content) if request.app.state.config.BYPASS_EMBEDDING_AND_RETRIEVAL: - Files.update_file_data_by_id(file.id, {"status": "completed"}, db=db) + Files.update_file_data_by_id(file.id, {'status': 'completed'}, db=db) Files.update_file_hash_by_id(file.id, hash, db=db) return { - "status": True, - "collection_name": None, - "filename": file.filename, - "content": text_content, + 'status': True, + 'collection_name': None, + 'filename': file.filename, + 'content': text_content, } else: try: @@ -1844,14 +1701,14 @@ def process_file( docs=docs, collection_name=collection_name, metadata={ - "file_id": file.id, - "name": file.filename, - "hash": hash, + 'file_id': file.id, + 'name': file.filename, + 'hash': hash, }, add=(True if form_data.collection_name else False), user=user, ) - log.info(f"added {len(docs)} items to collection {collection_name}") + log.info(f'added {len(docs)} items to collection {collection_name}') if result: # Fresh session for the final update. @@ -1859,26 +1716,26 @@ def process_file( Files.update_file_metadata_by_id( file.id, { - "collection_name": collection_name, + 'collection_name': collection_name, }, db=session, ) Files.update_file_data_by_id( file.id, - {"status": "completed"}, + {'status': 'completed'}, db=session, ) Files.update_file_hash_by_id(file.id, hash, db=session) return { - "status": True, - "collection_name": collection_name, - "filename": file.filename, - "content": text_content, + 'status': True, + 'collection_name': collection_name, + 'filename': file.filename, + 'content': text_content, } else: - raise Exception("Error saving document to vector database") + raise Exception('Error saving document to vector database') except Exception as e: raise e @@ -1888,13 +1745,13 @@ def process_file( with get_db() as session: Files.update_file_data_by_id( file.id, - {"status": "failed"}, + {'status': 'failed'}, db=session, ) # Clear the hash so the file can be re-uploaded after fixing the issue Files.update_file_hash_by_id(file.id, None, db=session) - if "No pandoc was found" in str(e): + if 'No pandoc was found' in str(e): raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.PANDOC_NOT_INSTALLED, @@ -1906,9 +1763,7 @@ def process_file( ) else: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND - ) + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND) class ProcessTextForm(BaseModel): @@ -1917,7 +1772,7 @@ class ProcessTextForm(BaseModel): collection_name: Optional[str] = None -@router.post("/process/text") +@router.post('/process/text') async def process_text( request: Request, form_data: ProcessTextForm, @@ -1930,20 +1785,18 @@ async def process_text( docs = [ Document( page_content=form_data.content, - metadata={"name": form_data.name, "created_by": user.id}, + metadata={'name': form_data.name, 'created_by': user.id}, ) ] text_content = form_data.content - log.debug(f"text_content: {text_content}") + log.debug(f'text_content: {text_content}') - result = await run_in_threadpool( - save_docs_to_vector_db, request, docs, collection_name, user=user - ) + result = await run_in_threadpool(save_docs_to_vector_db, request, docs, collection_name, user=user) if result: return { - "status": True, - "collection_name": collection_name, - "content": text_content, + 'status': True, + 'collection_name': collection_name, + 'content': text_content, } else: raise HTTPException( @@ -1952,22 +1805,18 @@ async def process_text( ) -@router.post("/process/youtube") -@router.post("/process/web") +@router.post('/process/youtube') +@router.post('/process/web') async def process_web( request: Request, form_data: ProcessUrlForm, - process: bool = Query(True, description="Whether to process and save the content"), - overwrite: bool = Query( - True, description="Whether to overwrite existing collection" - ), + process: bool = Query(True, description='Whether to process and save the content'), + overwrite: bool = Query(True, description='Whether to overwrite existing collection'), user=Depends(get_verified_user), ): try: - content, docs = await run_in_threadpool( - get_content_from_url, request, form_data.url - ) - log.debug(f"text_content: {content}") + content, docs = await run_in_threadpool(get_content_from_url, request, form_data.url) + log.debug(f'text_content: {content}') if process: collection_name = form_data.collection_name @@ -1988,23 +1837,23 @@ async def process_web( collection_name = None return { - "status": True, - "collection_name": collection_name, - "filename": form_data.url, - "file": { - "data": { - "content": content, + 'status': True, + 'collection_name': collection_name, + 'filename': form_data.url, + 'file': { + 'data': { + 'content': content, }, - "meta": { - "name": form_data.url, - "source": form_data.url, + 'meta': { + 'name': form_data.url, + 'source': form_data.url, }, }, } else: return { - "status": True, - "content": content, + 'status': True, + 'content': content, } except Exception as e: log.exception(e) @@ -2014,9 +1863,7 @@ async def process_web( ) -def search_web( - request: Request, engine: str, query: str, user=None -) -> list[SearchResult]: +def search_web(request: Request, engine: str, query: str, user=None) -> list[SearchResult]: """Search the web using a search engine and return the results as a list of SearchResult objects. Will look for a search engine API key in environment variables in the following order: - SEARXNG_QUERY_URL @@ -2040,15 +1887,15 @@ def search_web( """ # TODO: add playwright to search the web - if engine == "ollama_cloud": + if engine == 'ollama_cloud': return search_ollama_cloud( - "https://ollama.com", + 'https://ollama.com', request.app.state.config.OLLAMA_CLOUD_WEB_SEARCH_API_KEY, query, request.app.state.config.WEB_SEARCH_RESULT_COUNT, request.app.state.config.WEB_SEARCH_DOMAIN_FILTER_LIST, ) - elif engine == "perplexity_search": + elif engine == 'perplexity_search': if request.app.state.config.PERPLEXITY_API_KEY: return search_perplexity_search( request.app.state.config.PERPLEXITY_API_KEY, @@ -2059,10 +1906,10 @@ def search_web( user, ) else: - raise Exception("No PERPLEXITY_API_KEY found in environment variables") - elif engine == "searxng": + raise Exception('No PERPLEXITY_API_KEY found in environment variables') + elif engine == 'searxng': if request.app.state.config.SEARXNG_QUERY_URL: - searxng_kwargs = {"language": request.app.state.config.SEARXNG_LANGUAGE} + searxng_kwargs = {'language': request.app.state.config.SEARXNG_LANGUAGE} return search_searxng( request.app.state.config.SEARXNG_QUERY_URL, query, @@ -2071,8 +1918,8 @@ def search_web( **searxng_kwargs, ) else: - raise Exception("No SEARXNG_QUERY_URL found in environment variables") - elif engine == "yacy": + raise Exception('No SEARXNG_QUERY_URL found in environment variables') + elif engine == 'yacy': if request.app.state.config.YACY_QUERY_URL: return search_yacy( request.app.state.config.YACY_QUERY_URL, @@ -2083,12 +1930,9 @@ def search_web( request.app.state.config.WEB_SEARCH_DOMAIN_FILTER_LIST, ) else: - raise Exception("No YACY_QUERY_URL found in environment variables") - elif engine == "google_pse": - if ( - request.app.state.config.GOOGLE_PSE_API_KEY - and request.app.state.config.GOOGLE_PSE_ENGINE_ID - ): + raise Exception('No YACY_QUERY_URL found in environment variables') + elif engine == 'google_pse': + if request.app.state.config.GOOGLE_PSE_API_KEY and request.app.state.config.GOOGLE_PSE_ENGINE_ID: return search_google_pse( request.app.state.config.GOOGLE_PSE_API_KEY, request.app.state.config.GOOGLE_PSE_ENGINE_ID, @@ -2098,10 +1942,8 @@ def search_web( referer=request.app.state.config.WEBUI_URL, ) else: - raise Exception( - "No GOOGLE_PSE_API_KEY or GOOGLE_PSE_ENGINE_ID found in environment variables" - ) - elif engine == "brave": + raise Exception('No GOOGLE_PSE_API_KEY or GOOGLE_PSE_ENGINE_ID found in environment variables') + elif engine == 'brave': if request.app.state.config.BRAVE_SEARCH_API_KEY: return search_brave( request.app.state.config.BRAVE_SEARCH_API_KEY, @@ -2110,8 +1952,8 @@ def search_web( request.app.state.config.WEB_SEARCH_DOMAIN_FILTER_LIST, ) else: - raise Exception("No BRAVE_SEARCH_API_KEY found in environment variables") - elif engine == "kagi": + raise Exception('No BRAVE_SEARCH_API_KEY found in environment variables') + elif engine == 'kagi': if request.app.state.config.KAGI_SEARCH_API_KEY: return search_kagi( request.app.state.config.KAGI_SEARCH_API_KEY, @@ -2120,8 +1962,8 @@ def search_web( request.app.state.config.WEB_SEARCH_DOMAIN_FILTER_LIST, ) else: - raise Exception("No KAGI_SEARCH_API_KEY found in environment variables") - elif engine == "mojeek": + raise Exception('No KAGI_SEARCH_API_KEY found in environment variables') + elif engine == 'mojeek': if request.app.state.config.MOJEEK_SEARCH_API_KEY: return search_mojeek( request.app.state.config.MOJEEK_SEARCH_API_KEY, @@ -2130,8 +1972,8 @@ def search_web( request.app.state.config.WEB_SEARCH_DOMAIN_FILTER_LIST, ) else: - raise Exception("No MOJEEK_SEARCH_API_KEY found in environment variables") - elif engine == "bocha": + raise Exception('No MOJEEK_SEARCH_API_KEY found in environment variables') + elif engine == 'bocha': if request.app.state.config.BOCHA_SEARCH_API_KEY: return search_bocha( request.app.state.config.BOCHA_SEARCH_API_KEY, @@ -2140,8 +1982,8 @@ def search_web( request.app.state.config.WEB_SEARCH_DOMAIN_FILTER_LIST, ) else: - raise Exception("No BOCHA_SEARCH_API_KEY found in environment variables") - elif engine == "serpstack": + raise Exception('No BOCHA_SEARCH_API_KEY found in environment variables') + elif engine == 'serpstack': if request.app.state.config.SERPSTACK_API_KEY: return search_serpstack( request.app.state.config.SERPSTACK_API_KEY, @@ -2151,8 +1993,8 @@ def search_web( https_enabled=request.app.state.config.SERPSTACK_HTTPS, ) else: - raise Exception("No SERPSTACK_API_KEY found in environment variables") - elif engine == "serper": + raise Exception('No SERPSTACK_API_KEY found in environment variables') + elif engine == 'serper': if request.app.state.config.SERPER_API_KEY: return search_serper( request.app.state.config.SERPER_API_KEY, @@ -2161,8 +2003,8 @@ def search_web( request.app.state.config.WEB_SEARCH_DOMAIN_FILTER_LIST, ) else: - raise Exception("No SERPER_API_KEY found in environment variables") - elif engine == "serply": + raise Exception('No SERPER_API_KEY found in environment variables') + elif engine == 'serply': if request.app.state.config.SERPLY_API_KEY: return search_serply( request.app.state.config.SERPLY_API_KEY, @@ -2171,8 +2013,8 @@ def search_web( filter_list=request.app.state.config.WEB_SEARCH_DOMAIN_FILTER_LIST, ) else: - raise Exception("No SERPLY_API_KEY found in environment variables") - elif engine == "duckduckgo": + raise Exception('No SERPLY_API_KEY found in environment variables') + elif engine == 'duckduckgo': return search_duckduckgo( query, request.app.state.config.WEB_SEARCH_RESULT_COUNT, @@ -2180,7 +2022,7 @@ def search_web( concurrent_requests=request.app.state.config.WEB_SEARCH_CONCURRENT_REQUESTS, backend=request.app.state.config.DDGS_BACKEND, ) - elif engine == "tavily": + elif engine == 'tavily': if request.app.state.config.TAVILY_API_KEY: return search_tavily( request.app.state.config.TAVILY_API_KEY, @@ -2189,8 +2031,8 @@ def search_web( request.app.state.config.WEB_SEARCH_DOMAIN_FILTER_LIST, ) else: - raise Exception("No TAVILY_API_KEY found in environment variables") - elif engine == "exa": + raise Exception('No TAVILY_API_KEY found in environment variables') + elif engine == 'exa': if request.app.state.config.EXA_API_KEY: return search_exa( request.app.state.config.EXA_API_KEY, @@ -2199,8 +2041,8 @@ def search_web( request.app.state.config.WEB_SEARCH_DOMAIN_FILTER_LIST, ) else: - raise Exception("No EXA_API_KEY found in environment variables") - elif engine == "searchapi": + raise Exception('No EXA_API_KEY found in environment variables') + elif engine == 'searchapi': if request.app.state.config.SEARCHAPI_API_KEY: return search_searchapi( request.app.state.config.SEARCHAPI_API_KEY, @@ -2210,8 +2052,8 @@ def search_web( request.app.state.config.WEB_SEARCH_DOMAIN_FILTER_LIST, ) else: - raise Exception("No SEARCHAPI_API_KEY found in environment variables") - elif engine == "serpapi": + raise Exception('No SEARCHAPI_API_KEY found in environment variables') + elif engine == 'serpapi': if request.app.state.config.SERPAPI_API_KEY: return search_serpapi( request.app.state.config.SERPAPI_API_KEY, @@ -2221,15 +2063,15 @@ def search_web( request.app.state.config.WEB_SEARCH_DOMAIN_FILTER_LIST, ) else: - raise Exception("No SERPAPI_API_KEY found in environment variables") - elif engine == "jina": + raise Exception('No SERPAPI_API_KEY found in environment variables') + elif engine == 'jina': return search_jina( request.app.state.config.JINA_API_KEY, query, request.app.state.config.WEB_SEARCH_RESULT_COUNT, request.app.state.config.JINA_API_BASE_URL, ) - elif engine == "bing": + elif engine == 'bing': return search_bing( request.app.state.config.BING_SEARCH_V7_SUBSCRIPTION_KEY, request.app.state.config.BING_SEARCH_V7_ENDPOINT, @@ -2238,7 +2080,7 @@ def search_web( request.app.state.config.WEB_SEARCH_RESULT_COUNT, request.app.state.config.WEB_SEARCH_DOMAIN_FILTER_LIST, ) - elif engine == "azure": + elif engine == 'azure': if ( request.app.state.config.AZURE_AI_SEARCH_API_KEY and request.app.state.config.AZURE_AI_SEARCH_ENDPOINT @@ -2254,16 +2096,16 @@ def search_web( ) else: raise Exception( - "AZURE_AI_SEARCH_API_KEY, AZURE_AI_SEARCH_ENDPOINT, and AZURE_AI_SEARCH_INDEX_NAME are required for Azure AI Search" + 'AZURE_AI_SEARCH_API_KEY, AZURE_AI_SEARCH_ENDPOINT, and AZURE_AI_SEARCH_INDEX_NAME are required for Azure AI Search' ) - elif engine == "exa": + elif engine == 'exa': return search_exa( request.app.state.config.EXA_API_KEY, query, request.app.state.config.WEB_SEARCH_RESULT_COUNT, request.app.state.config.WEB_SEARCH_DOMAIN_FILTER_LIST, ) - elif engine == "perplexity": + elif engine == 'perplexity': return search_perplexity( request.app.state.config.PERPLEXITY_API_KEY, query, @@ -2272,11 +2114,8 @@ def search_web( model=request.app.state.config.PERPLEXITY_MODEL, search_context_usage=request.app.state.config.PERPLEXITY_SEARCH_CONTEXT_USAGE, ) - elif engine == "sougou": - if ( - request.app.state.config.SOUGOU_API_SID - and request.app.state.config.SOUGOU_API_SK - ): + elif engine == 'sougou': + if request.app.state.config.SOUGOU_API_SID and request.app.state.config.SOUGOU_API_SK: return search_sougou( request.app.state.config.SOUGOU_API_SID, request.app.state.config.SOUGOU_API_SK, @@ -2285,10 +2124,8 @@ def search_web( request.app.state.config.WEB_SEARCH_DOMAIN_FILTER_LIST, ) else: - raise Exception( - "No SOUGOU_API_SID or SOUGOU_API_SK found in environment variables" - ) - elif engine == "firecrawl": + raise Exception('No SOUGOU_API_SID or SOUGOU_API_SK found in environment variables') + elif engine == 'firecrawl': return search_firecrawl( request.app.state.config.FIRECRAWL_API_BASE_URL, request.app.state.config.FIRECRAWL_API_KEY, @@ -2296,7 +2133,7 @@ def search_web( request.app.state.config.WEB_SEARCH_RESULT_COUNT, request.app.state.config.WEB_SEARCH_DOMAIN_FILTER_LIST, ) - elif engine == "external": + elif engine == 'external': return search_external( request, request.app.state.config.EXTERNAL_WEB_SEARCH_URL, @@ -2306,7 +2143,7 @@ def search_web( request.app.state.config.WEB_SEARCH_DOMAIN_FILTER_LIST, user=user, ) - elif engine == "yandex": + elif engine == 'yandex': return search_yandex( request, request.app.state.config.YANDEX_WEB_SEARCH_URL, @@ -2317,7 +2154,7 @@ def search_web( request.app.state.config.WEB_SEARCH_DOMAIN_FILTER_LIST, user=user, ) - elif engine == "youcom": + elif engine == 'youcom': return search_youcom( request.app.state.config.YOUCOM_API_KEY, query, @@ -2325,21 +2162,19 @@ def search_web( request.app.state.config.WEB_SEARCH_DOMAIN_FILTER_LIST, ) else: - raise Exception("No search engine API key found in environment variables") + raise Exception('No search engine API key found in environment variables') -@router.post("/process/web/search") -async def process_web_search( - request: Request, form_data: SearchForm, user=Depends(get_verified_user) -): +@router.post('/process/web/search') +async def process_web_search(request: Request, form_data: SearchForm, user=Depends(get_verified_user)): if not request.app.state.config.ENABLE_WEB_SEARCH: raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.ACCESS_PROHIBITED, ) - if user.role != "admin" and not has_permission( - user.id, "features.web_search", request.app.state.config.USER_PERMISSIONS + if user.role != 'admin' and not has_permission( + user.id, 'features.web_search', request.app.state.config.USER_PERMISSIONS ): raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, @@ -2350,9 +2185,7 @@ async def process_web_search( result_items = [] try: - logging.debug( - f"trying to web search with {request.app.state.config.WEB_SEARCH_ENGINE, form_data.queries}" - ) + logging.debug(f'trying to web search with {request.app.state.config.WEB_SEARCH_ENGINE, form_data.queries}') # Use semaphore to limit concurrent requests based on WEB_SEARCH_CONCURRENT_REQUESTS # 0 or None = unlimited (previous behavior), positive number = limited concurrency @@ -2373,9 +2206,7 @@ async def process_web_search( user, ) - search_tasks = [ - search_query_with_semaphore(query) for query in form_data.queries - ] + search_tasks = [search_query_with_semaphore(query) for query in form_data.queries] else: # Unlimited parallel execution (previous behavior) search_tasks = [ @@ -2399,7 +2230,7 @@ async def process_web_search( urls.append(item.link) urls = list(dict.fromkeys(urls)) - log.debug(f"urls: {urls}") + log.debug(f'urls: {urls}') except Exception as e: log.exception(e) @@ -2412,27 +2243,25 @@ async def process_web_search( if len(urls) == 0: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, - detail=ERROR_MESSAGES.DEFAULT("No results found from web search"), + detail=ERROR_MESSAGES.DEFAULT('No results found from web search'), ) try: if request.app.state.config.BYPASS_WEB_SEARCH_WEB_LOADER: - search_results = [ - item for result in search_results for item in result if result - ] + search_results = [item for result in search_results for item in result if result] docs = [ Document( page_content=result.snippet, metadata={ - "source": result.link, - "title": result.title, - "snippet": result.snippet, - "link": result.link, + 'source': result.link, + 'title': result.title, + 'snippet': result.snippet, + 'link': result.link, }, ) for result in search_results - if hasattr(result, "snippet") and result.snippet is not None + if hasattr(result, 'snippet') and result.snippet is not None ] else: loader = get_web_loader( @@ -2444,7 +2273,7 @@ async def process_web_search( docs = await loader.aload() urls = [ - doc.metadata.get("source") for doc in docs if doc.metadata.get("source") + doc.metadata.get('source') for doc in docs if doc.metadata.get('source') ] # only keep the urls returned by the loader result_items = [ dict(item) for item in result_items if item.link in urls @@ -2452,26 +2281,22 @@ async def process_web_search( if request.app.state.config.BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL: return { - "status": True, - "collection_name": None, - "filenames": urls, - "items": result_items, - "docs": [ + 'status': True, + 'collection_name': None, + 'filenames': urls, + 'items': result_items, + 'docs': [ { - "content": doc.page_content, - "metadata": doc.metadata, + 'content': doc.page_content, + 'metadata': doc.metadata, } for doc in docs ], - "loaded_count": len(docs), + 'loaded_count': len(docs), } else: # Create a single collection for all documents - collection_name = ( - f"web-search-{calculate_sha256_string('-'.join(form_data.queries))}"[ - :63 - ] - ) + collection_name = f'web-search-{calculate_sha256_string("-".join(form_data.queries))}'[:63] try: await run_in_threadpool( @@ -2483,14 +2308,14 @@ async def process_web_search( user=user, ) except Exception as e: - log.debug(f"error saving docs: {e}") + log.debug(f'error saving docs: {e}') return { - "status": True, - "collection_names": [collection_name], - "items": result_items, - "filenames": urls, - "loaded_count": len(docs), + 'status': True, + 'collection_names': [collection_name], + 'items': result_items, + 'filenames': urls, + 'loaded_count': len(docs), } except Exception as e: log.exception(e) @@ -2506,20 +2331,20 @@ def _validate_collection_access(collection_names: list[str], user) -> None: Enforces ownership on user-memory-* and file-* collections. Admins bypass this check. """ - if user.role == "admin": + if user.role == 'admin': return for name in collection_names: - if name.startswith("user-memory-") and name != f"user-memory-{user.id}": + if name.startswith('user-memory-') and name != f'user-memory-{user.id}': raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.ACCESS_PROHIBITED, ) - elif name.startswith("file-"): - file_id = name[len("file-") :] + elif name.startswith('file-'): + file_id = name[len('file-') :] if not has_access_to_file( file_id=file_id, - access_type="read", + access_type='read', user=user, ): raise HTTPException( @@ -2537,7 +2362,7 @@ class QueryDocForm(BaseModel): hybrid: Optional[bool] = None -@router.post("/query/doc") +@router.post('/query/doc') async def query_doc_handler( request: Request, form_data: QueryDocForm, @@ -2546,9 +2371,7 @@ async def query_doc_handler( _validate_collection_access([form_data.collection_name], user) try: - if request.app.state.config.ENABLE_RAG_HYBRID_SEARCH and ( - form_data.hybrid is None or form_data.hybrid - ): + if request.app.state.config.ENABLE_RAG_HYBRID_SEARCH and (form_data.hybrid is None or form_data.hybrid): collection_results = {} collection_results[form_data.collection_name] = VECTOR_DB_CLIENT.get( collection_name=form_data.collection_name @@ -2562,21 +2385,12 @@ async def query_doc_handler( ), k=form_data.k if form_data.k else request.app.state.config.TOP_K, reranking_function=( - ( - lambda query, documents: request.app.state.RERANKING_FUNCTION( - query, documents, user=user - ) - ) + (lambda query, documents: request.app.state.RERANKING_FUNCTION(query, documents, user=user)) if request.app.state.RERANKING_FUNCTION else None ), - k_reranker=form_data.k_reranker - or request.app.state.config.TOP_K_RERANKER, - r=( - form_data.r - if form_data.r - else request.app.state.config.RELEVANCE_THRESHOLD - ), + k_reranker=form_data.k_reranker or request.app.state.config.TOP_K_RERANKER, + r=(form_data.r if form_data.r else request.app.state.config.RELEVANCE_THRESHOLD), hybrid_bm25_weight=( form_data.hybrid_bm25_weight if form_data.hybrid_bm25_weight @@ -2613,7 +2427,7 @@ class QueryCollectionsForm(BaseModel): enable_enriched_texts: Optional[bool] = None -@router.post("/query/collection") +@router.post('/query/collection') async def query_collection_handler( request: Request, form_data: QueryCollectionsForm, @@ -2622,9 +2436,7 @@ async def query_collection_handler( _validate_collection_access(form_data.collection_names, user) try: - if request.app.state.config.ENABLE_RAG_HYBRID_SEARCH and ( - form_data.hybrid is None or form_data.hybrid - ): + if request.app.state.config.ENABLE_RAG_HYBRID_SEARCH and (form_data.hybrid is None or form_data.hybrid): return await query_collection_with_hybrid_search( collection_names=form_data.collection_names, queries=[form_data.query], @@ -2633,21 +2445,12 @@ async def query_collection_handler( ), k=form_data.k if form_data.k else request.app.state.config.TOP_K, reranking_function=( - ( - lambda query, documents: request.app.state.RERANKING_FUNCTION( - query, documents, user=user - ) - ) + (lambda query, documents: request.app.state.RERANKING_FUNCTION(query, documents, user=user)) if request.app.state.RERANKING_FUNCTION else None ), - k_reranker=form_data.k_reranker - or request.app.state.config.TOP_K_RERANKER, - r=( - form_data.r - if form_data.r - else request.app.state.config.RELEVANCE_THRESHOLD - ), + k_reranker=form_data.k_reranker or request.app.state.config.TOP_K_RERANKER, + r=(form_data.r if form_data.r else request.app.state.config.RELEVANCE_THRESHOLD), hybrid_bm25_weight=( form_data.hybrid_bm25_weight if form_data.hybrid_bm25_weight @@ -2689,7 +2492,7 @@ class DeleteForm(BaseModel): file_id: str -@router.post("/delete") +@router.post('/delete') def delete_entries_from_collection( form_data: DeleteForm, user=Depends(get_admin_user), @@ -2707,25 +2510,25 @@ def delete_entries_from_collection( VECTOR_DB_CLIENT.delete( collection_name=form_data.collection_name, - metadata={"hash": hash}, + metadata={'hash': hash}, ) - return {"status": True} + return {'status': True} else: - return {"status": False} + return {'status': False} except Exception as e: log.exception(e) - return {"status": False} + return {'status': False} -@router.post("/reset/db") +@router.post('/reset/db') def reset_vector_db(user=Depends(get_admin_user), db: Session = Depends(get_session)): VECTOR_DB_CLIENT.reset() Knowledges.delete_all_knowledge(db=db) -@router.post("/reset/uploads") +@router.post('/reset/uploads') def reset_upload_dir(user=Depends(get_admin_user)) -> bool: - folder = f"{UPLOAD_DIR}" + folder = f'{UPLOAD_DIR}' try: # Check if the directory exists if os.path.exists(folder): @@ -2738,23 +2541,19 @@ def reset_upload_dir(user=Depends(get_admin_user)) -> bool: elif os.path.isdir(file_path): shutil.rmtree(file_path) # Remove the directory except Exception as e: - log.exception(f"Failed to delete {file_path}. Reason: {e}") + log.exception(f'Failed to delete {file_path}. Reason: {e}') else: - log.warning(f"The directory {folder} does not exist") + log.warning(f'The directory {folder} does not exist') except Exception as e: - log.exception(f"Failed to process the directory {folder}. Reason: {e}") + log.exception(f'Failed to process the directory {folder}. Reason: {e}') return True -if ENV == "dev": +if ENV == 'dev': - @router.get("/ef/{text}") - async def get_embeddings(request: Request, text: Optional[str] = "Hello World!"): - return { - "result": await request.app.state.EMBEDDING_FUNCTION( - text, prefix=RAG_EMBEDDING_QUERY_PREFIX - ) - } + @router.get('/ef/{text}') + async def get_embeddings(request: Request, text: Optional[str] = 'Hello World!'): + return {'result': await request.app.state.EMBEDDING_FUNCTION(text, prefix=RAG_EMBEDDING_QUERY_PREFIX)} class BatchProcessFilesForm(BaseModel): @@ -2773,7 +2572,7 @@ class BatchProcessFilesResponse(BaseModel): errors: List[BatchProcessFilesResult] -@router.post("/process/files/batch") +@router.post('/process/files/batch') async def process_files_batch( request: Request, form_data: BatchProcessFilesForm, @@ -2805,31 +2604,31 @@ async def process_files_batch( file_errors.append( BatchProcessFilesResult( file_id=file.id, - status="failed", - error="File not found", + status='failed', + error='File not found', ) ) continue - if db_file.user_id != user.id and user.role != "admin": + if db_file.user_id != user.id and user.role != 'admin': file_errors.append( BatchProcessFilesResult( file_id=file.id, - status="failed", - error="Permission denied: not file owner", + status='failed', + error='Permission denied: not file owner', ) ) continue - text_content = file.data.get("content", "") + text_content = file.data.get('content', '') docs: List[Document] = [ Document( - page_content=text_content.replace("
", "\n"), + page_content=text_content.replace('
', '\n'), metadata={ **file.meta, - "name": file.filename, - "created_by": file.user_id, - "file_id": file.id, - "source": file.filename, + 'name': file.filename, + 'created_by': file.user_id, + 'file_id': file.id, + 'source': file.filename, }, ) ] @@ -2839,18 +2638,14 @@ async def process_files_batch( file_updates.append( FileUpdateForm( hash=calculate_sha256_string(text_content), - data={"content": text_content}, + data={'content': text_content}, ) ) - file_results.append( - BatchProcessFilesResult(file_id=file.id, status="prepared") - ) + file_results.append(BatchProcessFilesResult(file_id=file.id, status='prepared')) except Exception as e: - log.error(f"process_files_batch: Error processing file {file.id}: {str(e)}") - file_errors.append( - BatchProcessFilesResult(file_id=file.id, status="failed", error=str(e)) - ) + log.error(f'process_files_batch: Error processing file {file.id}: {str(e)}') + file_errors.append(BatchProcessFilesResult(file_id=file.id, status='failed', error=str(e))) # Save all documents in one batch if all_docs: @@ -2867,18 +2662,12 @@ async def process_files_batch( # Update all files with collection name for file_update, file_result in zip(file_updates, file_results): Files.update_file_by_id(id=file_result.file_id, form_data=file_update) - file_result.status = "completed" + file_result.status = 'completed' except Exception as e: - log.error( - f"process_files_batch: Error saving documents to vector DB: {str(e)}" - ) + log.error(f'process_files_batch: Error saving documents to vector DB: {str(e)}') for file_result in file_results: - file_result.status = "failed" - file_errors.append( - BatchProcessFilesResult( - file_id=file_result.file_id, status="failed", error=str(e) - ) - ) + file_result.status = 'failed' + file_errors.append(BatchProcessFilesResult(file_id=file_result.file_id, status='failed', error=str(e))) return BatchProcessFilesResponse(results=file_results, errors=file_errors) diff --git a/backend/open_webui/routers/scim.py b/backend/open_webui/routers/scim.py index 4d56a7e97f..ed721ce335 100644 --- a/backend/open_webui/routers/scim.py +++ b/backend/open_webui/routers/scim.py @@ -37,32 +37,32 @@ log = logging.getLogger(__name__) router = APIRouter() # SCIM 2.0 Schema URIs -SCIM_USER_SCHEMA = "urn:ietf:params:scim:schemas:core:2.0:User" -SCIM_GROUP_SCHEMA = "urn:ietf:params:scim:schemas:core:2.0:Group" -SCIM_LIST_RESPONSE_SCHEMA = "urn:ietf:params:scim:api:messages:2.0:ListResponse" -SCIM_ERROR_SCHEMA = "urn:ietf:params:scim:api:messages:2.0:Error" +SCIM_USER_SCHEMA = 'urn:ietf:params:scim:schemas:core:2.0:User' +SCIM_GROUP_SCHEMA = 'urn:ietf:params:scim:schemas:core:2.0:Group' +SCIM_LIST_RESPONSE_SCHEMA = 'urn:ietf:params:scim:api:messages:2.0:ListResponse' +SCIM_ERROR_SCHEMA = 'urn:ietf:params:scim:api:messages:2.0:Error' # SCIM Resource Types -SCIM_RESOURCE_TYPE_USER = "User" -SCIM_RESOURCE_TYPE_GROUP = "Group" +SCIM_RESOURCE_TYPE_USER = 'User' +SCIM_RESOURCE_TYPE_GROUP = 'Group' def scim_error(status_code: int, detail: str, scim_type: Optional[str] = None): """Create a SCIM-compliant error response""" error_body = { - "schemas": [SCIM_ERROR_SCHEMA], - "status": str(status_code), - "detail": detail, + 'schemas': [SCIM_ERROR_SCHEMA], + 'status': str(status_code), + 'detail': detail, } if scim_type: - error_body["scimType"] = scim_type + error_body['scimType'] = scim_type elif status_code == 404: - error_body["scimType"] = "invalidValue" + error_body['scimType'] = 'invalidValue' elif status_code == 409: - error_body["scimType"] = "uniqueness" + error_body['scimType'] = 'uniqueness' elif status_code == 400: - error_body["scimType"] = "invalidSyntax" + error_body['scimType'] = 'invalidSyntax' return JSONResponse(status_code=status_code, content=error_body) @@ -101,7 +101,7 @@ class SCIMEmail(BaseModel): """SCIM Email""" value: str - type: Optional[str] = "work" + type: Optional[str] = 'work' primary: bool = True display: Optional[str] = None @@ -110,7 +110,7 @@ class SCIMPhoto(BaseModel): """SCIM Photo""" value: str - type: Optional[str] = "photo" + type: Optional[str] = 'photo' primary: bool = True display: Optional[str] = None @@ -119,8 +119,8 @@ class SCIMGroupMember(BaseModel): """SCIM Group Member""" value: str # User ID - ref: Optional[str] = Field(None, alias="$ref") - type: Optional[str] = "User" + ref: Optional[str] = Field(None, alias='$ref') + type: Optional[str] = 'User' display: Optional[str] = None @@ -227,13 +227,11 @@ class SCIMPatchOperation(BaseModel): class SCIMPatchRequest(BaseModel): """SCIM Patch Request""" - schemas: List[str] = ["urn:ietf:params:scim:api:messages:2.0:PatchOp"] + schemas: List[str] = ['urn:ietf:params:scim:api:messages:2.0:PatchOp'] Operations: List[SCIMPatchOperation] -def get_scim_auth( - request: Request, authorization: Optional[str] = Header(None) -) -> bool: +def get_scim_auth(request: Request, authorization: Optional[str] = Header(None)) -> bool: """ Verify SCIM authentication Checks for SCIM-specific bearer token configured in the system @@ -241,8 +239,8 @@ def get_scim_auth( if not authorization: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, - detail="Authorization header required", - headers={"WWW-Authenticate": "Bearer"}, + detail='Authorization header required', + headers={'WWW-Authenticate': 'Bearer'}, ) try: @@ -250,42 +248,40 @@ def get_scim_auth( if len(parts) != 2: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, - detail="Invalid authorization format. Expected: Bearer ", + detail='Invalid authorization format. Expected: Bearer ', ) scheme, token = parts - if scheme.lower() != "bearer": + if scheme.lower() != 'bearer': raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, - detail="Invalid authentication scheme", + detail='Invalid authentication scheme', ) # Check if SCIM is enabled - enable_scim = getattr(request.app.state, "ENABLE_SCIM", False) - log.info( - f"SCIM auth check - raw ENABLE_SCIM: {enable_scim}, type: {type(enable_scim)}" - ) + enable_scim = getattr(request.app.state, 'ENABLE_SCIM', False) + log.info(f'SCIM auth check - raw ENABLE_SCIM: {enable_scim}, type: {type(enable_scim)}') # Handle both PersistentConfig and direct value - if hasattr(enable_scim, "value"): + if hasattr(enable_scim, 'value'): enable_scim = enable_scim.value if not enable_scim: raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, - detail="SCIM is not enabled", + detail='SCIM is not enabled', ) # Verify the SCIM token - scim_token = getattr(request.app.state, "SCIM_TOKEN", None) + scim_token = getattr(request.app.state, 'SCIM_TOKEN', None) # Handle both PersistentConfig and direct value - if hasattr(scim_token, "value"): + if hasattr(scim_token, 'value'): scim_token = scim_token.value - log.debug(f"SCIM token configured: {bool(scim_token)}") + log.debug(f'SCIM token configured: {bool(scim_token)}') if not scim_token or token != scim_token: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, - detail="Invalid SCIM token", + detail='Invalid SCIM token', ) return True @@ -293,13 +289,13 @@ def get_scim_auth( # Re-raise HTTP exceptions as-is raise except Exception as e: - log.error(f"SCIM authentication error: {e}") + log.error(f'SCIM authentication error: {e}') import traceback - log.error(f"Traceback: {traceback.format_exc()}") + log.error(f'Traceback: {traceback.format_exc()}') raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, - detail="Authentication failed", + detail='Authentication failed', ) @@ -311,8 +307,8 @@ def get_external_id(user: UserModel) -> Optional[str]: if not user.scim: return None for provider_data in user.scim.values(): - if isinstance(provider_data, dict) and "external_id" in provider_data: - return provider_data["external_id"] + if isinstance(provider_data, dict) and 'external_id' in provider_data: + return provider_data['external_id'] return None @@ -324,7 +320,7 @@ def get_scim_provider() -> str: if not SCIM_AUTH_PROVIDER: raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail="SCIM_AUTH_PROVIDER environment variable is required when SCIM is enabled", + detail='SCIM_AUTH_PROVIDER environment variable is required when SCIM is enabled', ) return SCIM_AUTH_PROVIDER @@ -343,18 +339,18 @@ def find_user_by_external_id(external_id: str, db=None) -> Optional[UserModel]: def user_to_scim(user: UserModel, request: Request, db=None) -> SCIMUser: """Convert internal User model to SCIM User""" # Parse display name into name components - name_parts = user.name.split(" ", 1) if user.name else ["", ""] - given_name = name_parts[0] if name_parts else "" - family_name = name_parts[1] if len(name_parts) > 1 else "" + name_parts = user.name.split(' ', 1) if user.name else ['', ''] + given_name = name_parts[0] if name_parts else '' + family_name = name_parts[1] if len(name_parts) > 1 else '' # Get user's groups user_groups = Groups.get_groups_by_member_id(user.id, db=db) groups = [ { - "value": group.id, - "display": group.name, - "$ref": f"{request.base_url}api/v1/scim/v2/Groups/{group.id}", - "type": "direct", + 'value': group.id, + 'display': group.name, + '$ref': f'{request.base_url}api/v1/scim/v2/Groups/{group.id}', + 'type': 'direct', } for group in user_groups ] @@ -370,22 +366,14 @@ def user_to_scim(user: UserModel, request: Request, db=None) -> SCIMUser: ), displayName=user.name, emails=[SCIMEmail(value=user.email)], - active=user.role != "pending", - photos=( - [SCIMPhoto(value=user.profile_image_url)] - if user.profile_image_url - else None - ), + active=user.role != 'pending', + photos=([SCIMPhoto(value=user.profile_image_url)] if user.profile_image_url else None), groups=groups if groups else None, meta=SCIMMeta( resourceType=SCIM_RESOURCE_TYPE_USER, - created=datetime.fromtimestamp( - user.created_at, tz=timezone.utc - ).isoformat(), - lastModified=datetime.fromtimestamp( - user.updated_at, tz=timezone.utc - ).isoformat(), - location=f"{request.base_url}api/v1/scim/v2/Users/{user.id}", + created=datetime.fromtimestamp(user.created_at, tz=timezone.utc).isoformat(), + lastModified=datetime.fromtimestamp(user.updated_at, tz=timezone.utc).isoformat(), + location=f'{request.base_url}api/v1/scim/v2/Users/{user.id}', ), ) @@ -399,7 +387,7 @@ def group_to_scim(group: GroupModel, request: Request, db=None) -> SCIMGroup: members = [ SCIMGroupMember( value=user.id, - ref=f"{request.base_url}api/v1/scim/v2/Users/{user.id}", + ref=f'{request.base_url}api/v1/scim/v2/Users/{user.id}', display=user.name, ) for user in users @@ -411,108 +399,104 @@ def group_to_scim(group: GroupModel, request: Request, db=None) -> SCIMGroup: members=members, meta=SCIMMeta( resourceType=SCIM_RESOURCE_TYPE_GROUP, - created=datetime.fromtimestamp( - group.created_at, tz=timezone.utc - ).isoformat(), - lastModified=datetime.fromtimestamp( - group.updated_at, tz=timezone.utc - ).isoformat(), - location=f"{request.base_url}api/v1/scim/v2/Groups/{group.id}", + created=datetime.fromtimestamp(group.created_at, tz=timezone.utc).isoformat(), + lastModified=datetime.fromtimestamp(group.updated_at, tz=timezone.utc).isoformat(), + location=f'{request.base_url}api/v1/scim/v2/Groups/{group.id}', ), ) # SCIM Service Provider Config -@router.get("/ServiceProviderConfig") +@router.get('/ServiceProviderConfig') async def get_service_provider_config(): """Get SCIM Service Provider Configuration""" return { - "schemas": ["urn:ietf:params:scim:schemas:core:2.0:ServiceProviderConfig"], - "patch": {"supported": True}, - "bulk": {"supported": False, "maxOperations": 1000, "maxPayloadSize": 1048576}, - "filter": {"supported": True, "maxResults": 200}, - "changePassword": {"supported": False}, - "sort": {"supported": False}, - "etag": {"supported": False}, - "authenticationSchemes": [ + 'schemas': ['urn:ietf:params:scim:schemas:core:2.0:ServiceProviderConfig'], + 'patch': {'supported': True}, + 'bulk': {'supported': False, 'maxOperations': 1000, 'maxPayloadSize': 1048576}, + 'filter': {'supported': True, 'maxResults': 200}, + 'changePassword': {'supported': False}, + 'sort': {'supported': False}, + 'etag': {'supported': False}, + 'authenticationSchemes': [ { - "type": "oauthbearertoken", - "name": "OAuth Bearer Token", - "description": "Authentication using OAuth 2.0 Bearer Token", + 'type': 'oauthbearertoken', + 'name': 'OAuth Bearer Token', + 'description': 'Authentication using OAuth 2.0 Bearer Token', } ], } # SCIM Resource Types -@router.get("/ResourceTypes") +@router.get('/ResourceTypes') async def get_resource_types(request: Request): """Get SCIM Resource Types""" return [ { - "schemas": ["urn:ietf:params:scim:schemas:core:2.0:ResourceType"], - "id": "User", - "name": "User", - "endpoint": "/Users", - "schema": SCIM_USER_SCHEMA, - "meta": { - "location": f"{request.base_url}api/v1/scim/v2/ResourceTypes/User", - "resourceType": "ResourceType", + 'schemas': ['urn:ietf:params:scim:schemas:core:2.0:ResourceType'], + 'id': 'User', + 'name': 'User', + 'endpoint': '/Users', + 'schema': SCIM_USER_SCHEMA, + 'meta': { + 'location': f'{request.base_url}api/v1/scim/v2/ResourceTypes/User', + 'resourceType': 'ResourceType', }, }, { - "schemas": ["urn:ietf:params:scim:schemas:core:2.0:ResourceType"], - "id": "Group", - "name": "Group", - "endpoint": "/Groups", - "schema": SCIM_GROUP_SCHEMA, - "meta": { - "location": f"{request.base_url}api/v1/scim/v2/ResourceTypes/Group", - "resourceType": "ResourceType", + 'schemas': ['urn:ietf:params:scim:schemas:core:2.0:ResourceType'], + 'id': 'Group', + 'name': 'Group', + 'endpoint': '/Groups', + 'schema': SCIM_GROUP_SCHEMA, + 'meta': { + 'location': f'{request.base_url}api/v1/scim/v2/ResourceTypes/Group', + 'resourceType': 'ResourceType', }, }, ] # SCIM Schemas -@router.get("/Schemas") +@router.get('/Schemas') async def get_schemas(): """Get SCIM Schemas""" return [ { - "schemas": ["urn:ietf:params:scim:schemas:core:2.0:Schema"], - "id": SCIM_USER_SCHEMA, - "name": "User", - "description": "User Account", - "attributes": [ + 'schemas': ['urn:ietf:params:scim:schemas:core:2.0:Schema'], + 'id': SCIM_USER_SCHEMA, + 'name': 'User', + 'description': 'User Account', + 'attributes': [ { - "name": "userName", - "type": "string", - "required": True, - "uniqueness": "server", + 'name': 'userName', + 'type': 'string', + 'required': True, + 'uniqueness': 'server', }, - {"name": "displayName", "type": "string", "required": True}, + {'name': 'displayName', 'type': 'string', 'required': True}, { - "name": "emails", - "type": "complex", - "multiValued": True, - "required": True, + 'name': 'emails', + 'type': 'complex', + 'multiValued': True, + 'required': True, }, - {"name": "active", "type": "boolean", "required": False}, + {'name': 'active', 'type': 'boolean', 'required': False}, ], }, { - "schemas": ["urn:ietf:params:scim:schemas:core:2.0:Schema"], - "id": SCIM_GROUP_SCHEMA, - "name": "Group", - "description": "Group", - "attributes": [ - {"name": "displayName", "type": "string", "required": True}, + 'schemas': ['urn:ietf:params:scim:schemas:core:2.0:Schema'], + 'id': SCIM_GROUP_SCHEMA, + 'name': 'Group', + 'description': 'Group', + 'attributes': [ + {'name': 'displayName', 'type': 'string', 'required': True}, { - "name": "members", - "type": "complex", - "multiValued": True, - "required": False, + 'name': 'members', + 'type': 'complex', + 'multiValued': True, + 'required': False, }, ], }, @@ -520,7 +504,7 @@ async def get_schemas(): # Users endpoints -@router.get("/Users", response_model=SCIMListResponse) +@router.get('/Users', response_model=SCIMListResponse) async def get_users( request: Request, startIndex: int = Query(1), @@ -540,24 +524,24 @@ async def get_users( # Get users from database if filter: # Simple filter parsing - supports userName eq, externalId eq - if "userName eq" in filter: + if 'userName eq' in filter: email = filter.split('"')[1] user = Users.get_user_by_email(email, db=db) users_list = [user] if user else [] total = 1 if user else 0 - elif "externalId eq" in filter: + elif 'externalId eq' in filter: external_id = filter.split('"')[1] user = find_user_by_external_id(external_id, db=db) users_list = [user] if user else [] total = 1 if user else 0 else: response = Users.get_users(skip=skip, limit=limit, db=db) - users_list = response["users"] - total = response["total"] + users_list = response['users'] + total = response['total'] else: response = Users.get_users(skip=skip, limit=limit, db=db) - users_list = response["users"] - total = response["total"] + users_list = response['users'] + total = response['total'] # Convert to SCIM format scim_users = [user_to_scim(user, request, db=db) for user in users_list] @@ -570,7 +554,7 @@ async def get_users( ) -@router.get("/Users/{user_id}", response_model=SCIMUser) +@router.get('/Users/{user_id}', response_model=SCIMUser) async def get_user( user_id: str, request: Request, @@ -580,14 +564,12 @@ async def get_user( """Get SCIM User by ID""" user = Users.get_user_by_id(user_id, db=db) if not user: - return scim_error( - status_code=status.HTTP_404_NOT_FOUND, detail=f"User {user_id} not found" - ) + return scim_error(status_code=status.HTTP_404_NOT_FOUND, detail=f'User {user_id} not found') return user_to_scim(user, request, db=db) -@router.post("/Users", response_model=SCIMUser, status_code=status.HTTP_201_CREATED) +@router.post('/Users', response_model=SCIMUser, status_code=status.HTTP_201_CREATED) async def create_user( request: Request, user_data: SCIMUserCreateRequest, @@ -601,7 +583,7 @@ async def create_user( if existing_user: raise HTTPException( status_code=status.HTTP_409_CONFLICT, - detail=f"User with externalId {user_data.externalId} already exists", + detail=f'User with externalId {user_data.externalId} already exists', ) # Determine primary email (lowercased per RFC 5321) @@ -617,7 +599,7 @@ async def create_user( if existing_user: raise HTTPException( status_code=status.HTTP_409_CONFLICT, - detail=f"User with email {email} already exists", + detail=f'User with email {email} already exists', ) # Create user @@ -629,10 +611,10 @@ async def create_user( if user_data.name.formatted: name = user_data.name.formatted elif user_data.name.givenName or user_data.name.familyName: - name = f"{user_data.name.givenName or ''} {user_data.name.familyName or ''}".strip() + name = f'{user_data.name.givenName or ""} {user_data.name.familyName or ""}'.strip() # Get profile image if provided - profile_image = "/user.png" + profile_image = '/user.png' if user_data.photos and len(user_data.photos) > 0: profile_image = user_data.photos[0].value @@ -641,14 +623,14 @@ async def create_user( name=name, email=email, profile_image_url=profile_image, - role="user" if user_data.active else "pending", + role='user' if user_data.active else 'pending', db=db, ) if not new_user: raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail="Failed to create user", + detail='Failed to create user', ) # Store externalId in the scim field @@ -660,7 +642,7 @@ async def create_user( return user_to_scim(new_user, request, db=db) -@router.put("/Users/{user_id}", response_model=SCIMUser) +@router.put('/Users/{user_id}', response_model=SCIMUser) async def update_user( user_id: str, request: Request, @@ -673,39 +655,37 @@ async def update_user( if not user: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, - detail=f"User {user_id} not found", + detail=f'User {user_id} not found', ) # Build update dict update_data = {} if user_data.userName: - update_data["email"] = user_data.userName + update_data['email'] = user_data.userName if user_data.displayName: - update_data["name"] = user_data.displayName + update_data['name'] = user_data.displayName elif user_data.name: if user_data.name.formatted: - update_data["name"] = user_data.name.formatted + update_data['name'] = user_data.name.formatted elif user_data.name.givenName or user_data.name.familyName: - update_data["name"] = ( - f"{user_data.name.givenName or ''} {user_data.name.familyName or ''}".strip() - ) + update_data['name'] = f'{user_data.name.givenName or ""} {user_data.name.familyName or ""}'.strip() if user_data.emails and len(user_data.emails) > 0: - update_data["email"] = user_data.emails[0].value + update_data['email'] = user_data.emails[0].value if user_data.active is not None: - update_data["role"] = "user" if user_data.active else "pending" + update_data['role'] = 'user' if user_data.active else 'pending' if user_data.photos and len(user_data.photos) > 0: - update_data["profile_image_url"] = user_data.photos[0].value + update_data['profile_image_url'] = user_data.photos[0].value updated_user = Users.update_user_by_id(user_id, update_data, db=db) if not updated_user: raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail="Failed to update user", + detail='Failed to update user', ) # Update externalId in the scim field @@ -717,7 +697,7 @@ async def update_user( return user_to_scim(updated_user, request, db=db) -@router.patch("/Users/{user_id}", response_model=SCIMUser) +@router.patch('/Users/{user_id}', response_model=SCIMUser) async def patch_user( user_id: str, request: Request, @@ -730,7 +710,7 @@ async def patch_user( if not user: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, - detail=f"User {user_id} not found", + detail=f'User {user_id} not found', ) update_data = {} @@ -740,18 +720,18 @@ async def patch_user( path = operation.path value = operation.value - if op == "replace": - if path == "active": - update_data["role"] = "user" if value else "pending" - elif path == "userName": - update_data["email"] = value - elif path == "displayName": - update_data["name"] = value - elif path == "emails[primary eq true].value": - update_data["email"] = value - elif path == "name.formatted": - update_data["name"] = value - elif path == "externalId": + if op == 'replace': + if path == 'active': + update_data['role'] = 'user' if value else 'pending' + elif path == 'userName': + update_data['email'] = value + elif path == 'displayName': + update_data['name'] = value + elif path == 'emails[primary eq true].value': + update_data['email'] = value + elif path == 'name.formatted': + update_data['name'] = value + elif path == 'externalId': provider = get_scim_provider() Users.update_user_scim_by_id(user_id, provider, value, db=db) @@ -761,7 +741,7 @@ async def patch_user( if not updated_user: raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail="Failed to update user", + detail='Failed to update user', ) else: updated_user = user @@ -769,7 +749,7 @@ async def patch_user( return user_to_scim(updated_user, request, db=db) -@router.delete("/Users/{user_id}", status_code=status.HTTP_204_NO_CONTENT) +@router.delete('/Users/{user_id}', status_code=status.HTTP_204_NO_CONTENT) async def delete_user( user_id: str, request: Request, @@ -781,21 +761,21 @@ async def delete_user( if not user: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, - detail=f"User {user_id} not found", + detail=f'User {user_id} not found', ) success = Users.delete_user_by_id(user_id, db=db) if not success: raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail="Failed to delete user", + detail='Failed to delete user', ) return None # Groups endpoints -@router.get("/Groups", response_model=SCIMListResponse) +@router.get('/Groups', response_model=SCIMListResponse) async def get_groups( request: Request, startIndex: int = Query(1), @@ -830,7 +810,7 @@ async def get_groups( ) -@router.get("/Groups/{group_id}", response_model=SCIMGroup) +@router.get('/Groups/{group_id}', response_model=SCIMGroup) async def get_group( group_id: str, request: Request, @@ -842,13 +822,13 @@ async def get_group( if not group: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, - detail=f"Group {group_id} not found", + detail=f'Group {group_id} not found', ) return group_to_scim(group, request, db=db) -@router.post("/Groups", response_model=SCIMGroup, status_code=status.HTTP_201_CREATED) +@router.post('/Groups', response_model=SCIMGroup, status_code=status.HTTP_201_CREATED) async def create_group( request: Request, group_data: SCIMGroupCreateRequest, @@ -867,7 +847,7 @@ async def create_group( form = GroupForm( name=group_data.displayName, - description="", + description='', ) # Need to get the creating user's ID - we'll use the first admin @@ -875,14 +855,14 @@ async def create_group( if not admin_user: raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail="No admin user found", + detail='No admin user found', ) new_group = Groups.insert_new_group(admin_user.id, form, db=db) if not new_group: raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail="Failed to create group", + detail='Failed to create group', ) # Add members if provided @@ -902,7 +882,7 @@ async def create_group( return group_to_scim(new_group, request, db=db) -@router.put("/Groups/{group_id}", response_model=SCIMGroup) +@router.put('/Groups/{group_id}', response_model=SCIMGroup) async def update_group( group_id: str, request: Request, @@ -915,7 +895,7 @@ async def update_group( if not group: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, - detail=f"Group {group_id} not found", + detail=f'Group {group_id} not found', ) # Build update form @@ -936,13 +916,13 @@ async def update_group( if not updated_group: raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail="Failed to update group", + detail='Failed to update group', ) return group_to_scim(updated_group, request, db=db) -@router.patch("/Groups/{group_id}", response_model=SCIMGroup) +@router.patch('/Groups/{group_id}', response_model=SCIMGroup) async def patch_group( group_id: str, request: Request, @@ -955,7 +935,7 @@ async def patch_group( if not group: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, - detail=f"Group {group_id} not found", + detail=f'Group {group_id} not found', ) from open_webui.models.groups import GroupUpdateForm @@ -970,26 +950,22 @@ async def patch_group( path = operation.path value = operation.value - if op == "replace": - if path == "displayName": + if op == 'replace': + if path == 'displayName': update_form.name = value - elif path == "members": + elif path == 'members': # Replace all members - Groups.set_group_user_ids_by_id( - group_id, [member["value"] for member in value], db=db - ) + Groups.set_group_user_ids_by_id(group_id, [member['value'] for member in value], db=db) - elif op == "add": - if path == "members": + elif op == 'add': + if path == 'members': # Add members if isinstance(value, list): for member in value: - if isinstance(member, dict) and "value" in member: - Groups.add_users_to_group( - group_id, [member["value"]], db=db - ) - elif op == "remove": - if path and path.startswith("members[value eq"): + if isinstance(member, dict) and 'value' in member: + Groups.add_users_to_group(group_id, [member['value']], db=db) + elif op == 'remove': + if path and path.startswith('members[value eq'): # Remove specific member member_id = path.split('"')[1] Groups.remove_users_from_group(group_id, [member_id], db=db) @@ -999,13 +975,13 @@ async def patch_group( if not updated_group: raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail="Failed to update group", + detail='Failed to update group', ) return group_to_scim(updated_group, request, db=db) -@router.delete("/Groups/{group_id}", status_code=status.HTTP_204_NO_CONTENT) +@router.delete('/Groups/{group_id}', status_code=status.HTTP_204_NO_CONTENT) async def delete_group( group_id: str, request: Request, @@ -1017,14 +993,14 @@ async def delete_group( if not group: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, - detail=f"Group {group_id} not found", + detail=f'Group {group_id} not found', ) success = Groups.delete_group_by_id(group_id, db=db) if not success: raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail="Failed to delete group", + detail='Failed to delete group', ) return None diff --git a/backend/open_webui/routers/skills.py b/backend/open_webui/routers/skills.py index e18b561321..1838914e4a 100644 --- a/backend/open_webui/routers/skills.py +++ b/backend/open_webui/routers/skills.py @@ -36,18 +36,16 @@ router = APIRouter() ############################ -@router.get("/", response_model=list[SkillUserResponse]) +@router.get('/', response_model=list[SkillUserResponse]) async def get_skills( request: Request, user=Depends(get_verified_user), db: Session = Depends(get_session), ): - if user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL: + if user.role == 'admin' and BYPASS_ADMIN_ACCESS_CONTROL: skills = Skills.get_skills(db=db) else: - user_group_ids = { - group.id for group in Groups.get_groups_by_member_id(user.id, db=db) - } + user_group_ids = {group.id for group in Groups.get_groups_by_member_id(user.id, db=db)} all_skills = Skills.get_skills(db=db) skills = [ skill @@ -55,9 +53,9 @@ async def get_skills( if skill.user_id == user.id or AccessGrants.has_access( user_id=user.id, - resource_type="skill", + resource_type='skill', resource_id=skill.id, - permission="read", + permission='read', user_group_ids=user_group_ids, db=db, ) @@ -71,7 +69,7 @@ async def get_skills( ############################ -@router.get("/list", response_model=SkillAccessListResponse) +@router.get('/list', response_model=SkillAccessListResponse) async def get_skill_list( query: Optional[str] = None, view_option: Optional[str] = None, @@ -86,16 +84,16 @@ async def get_skill_list( filter = {} if query: - filter["query"] = query + filter['query'] = query if view_option: - filter["view_option"] = view_option + filter['view_option'] = view_option - if not (user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL): + if not (user.role == 'admin' and BYPASS_ADMIN_ACCESS_CONTROL): groups = Groups.get_groups_by_member_id(user.id, db=db) if groups: - filter["group_ids"] = [group.id for group in groups] + filter['group_ids'] = [group.id for group in groups] - filter["user_id"] = user.id + filter['user_id'] = user.id result = Skills.search_skills(user.id, filter=filter, skip=skip, limit=limit, db=db) @@ -104,13 +102,13 @@ async def get_skill_list( SkillAccessResponse( **skill.model_dump(), write_access=( - (user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL) + (user.role == 'admin' and BYPASS_ADMIN_ACCESS_CONTROL) or user.id == skill.user_id or AccessGrants.has_access( user_id=user.id, - resource_type="skill", + resource_type='skill', resource_id=skill.id, - permission="write", + permission='write', db=db, ) ), @@ -126,15 +124,15 @@ async def get_skill_list( ############################ -@router.get("/export", response_model=list[SkillModel]) +@router.get('/export', response_model=list[SkillModel]) async def export_skills( request: Request, user=Depends(get_verified_user), db: Session = Depends(get_session), ): - if user.role != "admin" and not has_permission( + if user.role != 'admin' and not has_permission( user.id, - "workspace.skills", + 'workspace.skills', request.app.state.config.USER_PERMISSIONS, db=db, ): @@ -143,10 +141,10 @@ async def export_skills( detail=ERROR_MESSAGES.UNAUTHORIZED, ) - if user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL: + if user.role == 'admin' and BYPASS_ADMIN_ACCESS_CONTROL: return Skills.get_skills(db=db) else: - return Skills.get_skills_by_user_id(user.id, "read", db=db) + return Skills.get_skills_by_user_id(user.id, 'read', db=db) ############################ @@ -154,22 +152,22 @@ async def export_skills( ############################ -@router.post("/create", response_model=Optional[SkillResponse]) +@router.post('/create', response_model=Optional[SkillResponse]) async def create_new_skill( request: Request, form_data: SkillForm, user=Depends(get_verified_user), db: Session = Depends(get_session), ): - if user.role != "admin" and not has_permission( - user.id, "workspace.skills", request.app.state.config.USER_PERMISSIONS, db=db + if user.role != 'admin' and not has_permission( + user.id, 'workspace.skills', request.app.state.config.USER_PERMISSIONS, db=db ): raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.UNAUTHORIZED, ) - form_data.id = form_data.id.lower().replace(" ", "-") + form_data.id = form_data.id.lower().replace(' ', '-') existing = Skills.get_skill_by_id(form_data.id, db=db) if existing is not None: @@ -185,10 +183,10 @@ async def create_new_skill( else: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, - detail=ERROR_MESSAGES.DEFAULT("Error creating skill"), + detail=ERROR_MESSAGES.DEFAULT('Error creating skill'), ) except Exception as e: - log.exception(f"Failed to create skill: {e}") + log.exception(f'Failed to create skill: {e}') raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT(str(e)), @@ -200,34 +198,32 @@ async def create_new_skill( ############################ -@router.get("/id/{id}", response_model=Optional[SkillAccessResponse]) -async def get_skill_by_id( - id: str, user=Depends(get_verified_user), db: Session = Depends(get_session) -): +@router.get('/id/{id}', response_model=Optional[SkillAccessResponse]) +async def get_skill_by_id(id: str, user=Depends(get_verified_user), db: Session = Depends(get_session)): skill = Skills.get_skill_by_id(id, db=db) if skill: if ( - user.role == "admin" + user.role == 'admin' or skill.user_id == user.id or AccessGrants.has_access( user_id=user.id, - resource_type="skill", + resource_type='skill', resource_id=skill.id, - permission="read", + permission='read', db=db, ) ): return SkillAccessResponse( **skill.model_dump(), write_access=( - (user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL) + (user.role == 'admin' and BYPASS_ADMIN_ACCESS_CONTROL) or user.id == skill.user_id or AccessGrants.has_access( user_id=user.id, - resource_type="skill", + resource_type='skill', resource_id=skill.id, - permission="write", + permission='write', db=db, ) ), @@ -249,7 +245,7 @@ async def get_skill_by_id( ############################ -@router.post("/id/{id}/update", response_model=Optional[SkillModel]) +@router.post('/id/{id}/update', response_model=Optional[SkillModel]) async def update_skill_by_id( request: Request, id: str, @@ -268,12 +264,12 @@ async def update_skill_by_id( skill.user_id != user.id and not AccessGrants.has_access( user_id=user.id, - resource_type="skill", + resource_type='skill', resource_id=skill.id, - permission="write", + permission='write', db=db, ) - and user.role != "admin" + and user.role != 'admin' ): raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, @@ -282,7 +278,7 @@ async def update_skill_by_id( try: updated = { - **form_data.model_dump(exclude={"id"}), + **form_data.model_dump(exclude={'id'}), } skill = Skills.update_skill_by_id(id, updated, db=db) @@ -292,7 +288,7 @@ async def update_skill_by_id( else: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, - detail=ERROR_MESSAGES.DEFAULT("Error updating skill"), + detail=ERROR_MESSAGES.DEFAULT('Error updating skill'), ) except Exception as e: raise HTTPException( @@ -310,7 +306,7 @@ class SkillAccessGrantsForm(BaseModel): access_grants: list[dict] -@router.post("/id/{id}/access/update", response_model=Optional[SkillModel]) +@router.post('/id/{id}/access/update', response_model=Optional[SkillModel]) async def update_skill_access_by_id( request: Request, id: str, @@ -329,12 +325,12 @@ async def update_skill_access_by_id( skill.user_id != user.id and not AccessGrants.has_access( user_id=user.id, - resource_type="skill", + resource_type='skill', resource_id=skill.id, - permission="write", + permission='write', db=db, ) - and user.role != "admin" + and user.role != 'admin' ): raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, @@ -346,10 +342,10 @@ async def update_skill_access_by_id( user.id, user.role, form_data.access_grants, - "sharing.public_skills", + 'sharing.public_skills', ) - AccessGrants.set_access_grants("skill", id, form_data.access_grants, db=db) + AccessGrants.set_access_grants('skill', id, form_data.access_grants, db=db) return Skills.get_skill_by_id(id, db=db) @@ -359,20 +355,18 @@ async def update_skill_access_by_id( ############################ -@router.post("/id/{id}/toggle", response_model=Optional[SkillModel]) -async def toggle_skill_by_id( - id: str, user=Depends(get_verified_user), db: Session = Depends(get_session) -): +@router.post('/id/{id}/toggle', response_model=Optional[SkillModel]) +async def toggle_skill_by_id(id: str, user=Depends(get_verified_user), db: Session = Depends(get_session)): skill = Skills.get_skill_by_id(id, db=db) if skill: if ( - user.role == "admin" + user.role == 'admin' or skill.user_id == user.id or AccessGrants.has_access( user_id=user.id, - resource_type="skill", + resource_type='skill', resource_id=skill.id, - permission="write", + permission='write', db=db, ) ): @@ -383,7 +377,7 @@ async def toggle_skill_by_id( else: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, - detail=ERROR_MESSAGES.DEFAULT("Error toggling skill"), + detail=ERROR_MESSAGES.DEFAULT('Error toggling skill'), ) else: raise HTTPException( @@ -402,7 +396,7 @@ async def toggle_skill_by_id( ############################ -@router.delete("/id/{id}/delete", response_model=bool) +@router.delete('/id/{id}/delete', response_model=bool) async def delete_skill_by_id( request: Request, id: str, @@ -420,12 +414,12 @@ async def delete_skill_by_id( skill.user_id != user.id and not AccessGrants.has_access( user_id=user.id, - resource_type="skill", + resource_type='skill', resource_id=skill.id, - permission="write", + permission='write', db=db, ) - and user.role != "admin" + and user.role != 'admin' ): raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, diff --git a/backend/open_webui/routers/tasks.py b/backend/open_webui/routers/tasks.py index f6404da05c..0bb5813e6f 100644 --- a/backend/open_webui/routers/tasks.py +++ b/backend/open_webui/routers/tasks.py @@ -52,36 +52,34 @@ class ActiveChatsForm(BaseModel): chat_ids: list[str] -@router.post("/active/chats") -async def check_active_chats( - request: Request, form_data: ActiveChatsForm, user=Depends(get_verified_user) -): +@router.post('/active/chats') +async def check_active_chats(request: Request, form_data: ActiveChatsForm, user=Depends(get_verified_user)): """Check which chat IDs have active tasks.""" from open_webui.tasks import get_active_chat_ids active = await get_active_chat_ids(request.app.state.redis, form_data.chat_ids) - return {"active_chat_ids": active} + return {'active_chat_ids': active} -@router.get("/config") +@router.get('/config') async def get_task_config(request: Request, user=Depends(get_verified_user)): return { - "TASK_MODEL": request.app.state.config.TASK_MODEL, - "TASK_MODEL_EXTERNAL": request.app.state.config.TASK_MODEL_EXTERNAL, - "TITLE_GENERATION_PROMPT_TEMPLATE": request.app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE, - "IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE": request.app.state.config.IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE, - "ENABLE_AUTOCOMPLETE_GENERATION": request.app.state.config.ENABLE_AUTOCOMPLETE_GENERATION, - "AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH": request.app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH, - "TAGS_GENERATION_PROMPT_TEMPLATE": request.app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE, - "FOLLOW_UP_GENERATION_PROMPT_TEMPLATE": request.app.state.config.FOLLOW_UP_GENERATION_PROMPT_TEMPLATE, - "ENABLE_FOLLOW_UP_GENERATION": request.app.state.config.ENABLE_FOLLOW_UP_GENERATION, - "ENABLE_TAGS_GENERATION": request.app.state.config.ENABLE_TAGS_GENERATION, - "ENABLE_TITLE_GENERATION": request.app.state.config.ENABLE_TITLE_GENERATION, - "ENABLE_SEARCH_QUERY_GENERATION": request.app.state.config.ENABLE_SEARCH_QUERY_GENERATION, - "ENABLE_RETRIEVAL_QUERY_GENERATION": request.app.state.config.ENABLE_RETRIEVAL_QUERY_GENERATION, - "QUERY_GENERATION_PROMPT_TEMPLATE": request.app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE, - "TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE": request.app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE, - "VOICE_MODE_PROMPT_TEMPLATE": request.app.state.config.VOICE_MODE_PROMPT_TEMPLATE, + 'TASK_MODEL': request.app.state.config.TASK_MODEL, + 'TASK_MODEL_EXTERNAL': request.app.state.config.TASK_MODEL_EXTERNAL, + 'TITLE_GENERATION_PROMPT_TEMPLATE': request.app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE, + 'IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE': request.app.state.config.IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE, + 'ENABLE_AUTOCOMPLETE_GENERATION': request.app.state.config.ENABLE_AUTOCOMPLETE_GENERATION, + 'AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH': request.app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH, + 'TAGS_GENERATION_PROMPT_TEMPLATE': request.app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE, + 'FOLLOW_UP_GENERATION_PROMPT_TEMPLATE': request.app.state.config.FOLLOW_UP_GENERATION_PROMPT_TEMPLATE, + 'ENABLE_FOLLOW_UP_GENERATION': request.app.state.config.ENABLE_FOLLOW_UP_GENERATION, + 'ENABLE_TAGS_GENERATION': request.app.state.config.ENABLE_TAGS_GENERATION, + 'ENABLE_TITLE_GENERATION': request.app.state.config.ENABLE_TITLE_GENERATION, + 'ENABLE_SEARCH_QUERY_GENERATION': request.app.state.config.ENABLE_SEARCH_QUERY_GENERATION, + 'ENABLE_RETRIEVAL_QUERY_GENERATION': request.app.state.config.ENABLE_RETRIEVAL_QUERY_GENERATION, + 'QUERY_GENERATION_PROMPT_TEMPLATE': request.app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE, + 'TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE': request.app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE, + 'VOICE_MODE_PROMPT_TEMPLATE': request.app.state.config.VOICE_MODE_PROMPT_TEMPLATE, } @@ -104,100 +102,73 @@ class TaskConfigForm(BaseModel): VOICE_MODE_PROMPT_TEMPLATE: Optional[str] -@router.post("/config/update") -async def update_task_config( - request: Request, form_data: TaskConfigForm, user=Depends(get_admin_user) -): +@router.post('/config/update') +async def update_task_config(request: Request, form_data: TaskConfigForm, user=Depends(get_admin_user)): request.app.state.config.TASK_MODEL = form_data.TASK_MODEL request.app.state.config.TASK_MODEL_EXTERNAL = form_data.TASK_MODEL_EXTERNAL request.app.state.config.ENABLE_TITLE_GENERATION = form_data.ENABLE_TITLE_GENERATION - request.app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE = ( - form_data.TITLE_GENERATION_PROMPT_TEMPLATE - ) + request.app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE = form_data.TITLE_GENERATION_PROMPT_TEMPLATE - request.app.state.config.ENABLE_FOLLOW_UP_GENERATION = ( - form_data.ENABLE_FOLLOW_UP_GENERATION - ) - request.app.state.config.FOLLOW_UP_GENERATION_PROMPT_TEMPLATE = ( - form_data.FOLLOW_UP_GENERATION_PROMPT_TEMPLATE - ) + request.app.state.config.ENABLE_FOLLOW_UP_GENERATION = form_data.ENABLE_FOLLOW_UP_GENERATION + request.app.state.config.FOLLOW_UP_GENERATION_PROMPT_TEMPLATE = form_data.FOLLOW_UP_GENERATION_PROMPT_TEMPLATE - request.app.state.config.IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE = ( - form_data.IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE - ) + request.app.state.config.IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE = form_data.IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE - request.app.state.config.ENABLE_AUTOCOMPLETE_GENERATION = ( - form_data.ENABLE_AUTOCOMPLETE_GENERATION - ) + request.app.state.config.ENABLE_AUTOCOMPLETE_GENERATION = form_data.ENABLE_AUTOCOMPLETE_GENERATION request.app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH = ( form_data.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH ) - request.app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE = ( - form_data.TAGS_GENERATION_PROMPT_TEMPLATE - ) + request.app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE = form_data.TAGS_GENERATION_PROMPT_TEMPLATE request.app.state.config.ENABLE_TAGS_GENERATION = form_data.ENABLE_TAGS_GENERATION - request.app.state.config.ENABLE_SEARCH_QUERY_GENERATION = ( - form_data.ENABLE_SEARCH_QUERY_GENERATION - ) - request.app.state.config.ENABLE_RETRIEVAL_QUERY_GENERATION = ( - form_data.ENABLE_RETRIEVAL_QUERY_GENERATION - ) + request.app.state.config.ENABLE_SEARCH_QUERY_GENERATION = form_data.ENABLE_SEARCH_QUERY_GENERATION + request.app.state.config.ENABLE_RETRIEVAL_QUERY_GENERATION = form_data.ENABLE_RETRIEVAL_QUERY_GENERATION - request.app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE = ( - form_data.QUERY_GENERATION_PROMPT_TEMPLATE - ) - request.app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = ( - form_data.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE - ) + request.app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE = form_data.QUERY_GENERATION_PROMPT_TEMPLATE + request.app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = form_data.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE - request.app.state.config.VOICE_MODE_PROMPT_TEMPLATE = ( - form_data.VOICE_MODE_PROMPT_TEMPLATE - ) + request.app.state.config.VOICE_MODE_PROMPT_TEMPLATE = form_data.VOICE_MODE_PROMPT_TEMPLATE return { - "TASK_MODEL": request.app.state.config.TASK_MODEL, - "TASK_MODEL_EXTERNAL": request.app.state.config.TASK_MODEL_EXTERNAL, - "ENABLE_TITLE_GENERATION": request.app.state.config.ENABLE_TITLE_GENERATION, - "TITLE_GENERATION_PROMPT_TEMPLATE": request.app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE, - "IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE": request.app.state.config.IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE, - "ENABLE_AUTOCOMPLETE_GENERATION": request.app.state.config.ENABLE_AUTOCOMPLETE_GENERATION, - "AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH": request.app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH, - "TAGS_GENERATION_PROMPT_TEMPLATE": request.app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE, - "ENABLE_TAGS_GENERATION": request.app.state.config.ENABLE_TAGS_GENERATION, - "ENABLE_FOLLOW_UP_GENERATION": request.app.state.config.ENABLE_FOLLOW_UP_GENERATION, - "FOLLOW_UP_GENERATION_PROMPT_TEMPLATE": request.app.state.config.FOLLOW_UP_GENERATION_PROMPT_TEMPLATE, - "ENABLE_SEARCH_QUERY_GENERATION": request.app.state.config.ENABLE_SEARCH_QUERY_GENERATION, - "ENABLE_RETRIEVAL_QUERY_GENERATION": request.app.state.config.ENABLE_RETRIEVAL_QUERY_GENERATION, - "QUERY_GENERATION_PROMPT_TEMPLATE": request.app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE, - "TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE": request.app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE, - "VOICE_MODE_PROMPT_TEMPLATE": request.app.state.config.VOICE_MODE_PROMPT_TEMPLATE, + 'TASK_MODEL': request.app.state.config.TASK_MODEL, + 'TASK_MODEL_EXTERNAL': request.app.state.config.TASK_MODEL_EXTERNAL, + 'ENABLE_TITLE_GENERATION': request.app.state.config.ENABLE_TITLE_GENERATION, + 'TITLE_GENERATION_PROMPT_TEMPLATE': request.app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE, + 'IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE': request.app.state.config.IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE, + 'ENABLE_AUTOCOMPLETE_GENERATION': request.app.state.config.ENABLE_AUTOCOMPLETE_GENERATION, + 'AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH': request.app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH, + 'TAGS_GENERATION_PROMPT_TEMPLATE': request.app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE, + 'ENABLE_TAGS_GENERATION': request.app.state.config.ENABLE_TAGS_GENERATION, + 'ENABLE_FOLLOW_UP_GENERATION': request.app.state.config.ENABLE_FOLLOW_UP_GENERATION, + 'FOLLOW_UP_GENERATION_PROMPT_TEMPLATE': request.app.state.config.FOLLOW_UP_GENERATION_PROMPT_TEMPLATE, + 'ENABLE_SEARCH_QUERY_GENERATION': request.app.state.config.ENABLE_SEARCH_QUERY_GENERATION, + 'ENABLE_RETRIEVAL_QUERY_GENERATION': request.app.state.config.ENABLE_RETRIEVAL_QUERY_GENERATION, + 'QUERY_GENERATION_PROMPT_TEMPLATE': request.app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE, + 'TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE': request.app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE, + 'VOICE_MODE_PROMPT_TEMPLATE': request.app.state.config.VOICE_MODE_PROMPT_TEMPLATE, } -@router.post("/title/completions") -async def generate_title( - request: Request, form_data: dict, user=Depends(get_verified_user) -): - +@router.post('/title/completions') +async def generate_title(request: Request, form_data: dict, user=Depends(get_verified_user)): if not request.app.state.config.ENABLE_TITLE_GENERATION: return JSONResponse( status_code=status.HTTP_200_OK, - content={"detail": "Title generation is disabled"}, + content={'detail': 'Title generation is disabled'}, ) - if getattr(request.state, "direct", False) and hasattr(request.state, "model"): + if getattr(request.state, 'direct', False) and hasattr(request.state, 'model'): models = { - request.state.model["id"]: request.state.model, + request.state.model['id']: request.state.model, } else: models = request.app.state.MODELS - model_id = form_data["model"] + model_id = form_data['model'] if model_id not in models: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, - detail="Model not found", + detail='Model not found', ) # Check if the user has a custom task model @@ -209,37 +180,33 @@ async def generate_title( models, ) - log.debug( - f"generating chat title using model {task_model_id} for user {user.email} " - ) + log.debug(f'generating chat title using model {task_model_id} for user {user.email} ') - if request.app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE != "": + if request.app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE != '': template = request.app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE else: template = DEFAULT_TITLE_GENERATION_PROMPT_TEMPLATE - content = title_generation_template(template, form_data["messages"], user) + content = title_generation_template(template, form_data['messages'], user) - max_tokens = ( - models[task_model_id].get("info", {}).get("params", {}).get("max_tokens", 1000) - ) + max_tokens = models[task_model_id].get('info', {}).get('params', {}).get('max_tokens', 1000) payload = { - "model": task_model_id, - "messages": [{"role": "user", "content": content}], - "stream": False, + 'model': task_model_id, + 'messages': [{'role': 'user', 'content': content}], + 'stream': False, **( - {"max_tokens": max_tokens} - if models[task_model_id].get("owned_by") == "ollama" + {'max_tokens': max_tokens} + if models[task_model_id].get('owned_by') == 'ollama' else { - "max_completion_tokens": max_tokens, + 'max_completion_tokens': max_tokens, } ), - "metadata": { - **(request.state.metadata if hasattr(request.state, "metadata") else {}), - "task": str(TASKS.TITLE_GENERATION), - "task_body": form_data, - "chat_id": form_data.get("chat_id", None), + 'metadata': { + **(request.state.metadata if hasattr(request.state, 'metadata') else {}), + 'task': str(TASKS.TITLE_GENERATION), + 'task_body': form_data, + 'chat_id': form_data.get('chat_id', None), }, } @@ -252,36 +219,33 @@ async def generate_title( try: return await generate_chat_completion(request, form_data=payload, user=user) except Exception as e: - log.error("Exception occurred", exc_info=True) + log.error('Exception occurred', exc_info=True) return JSONResponse( status_code=status.HTTP_400_BAD_REQUEST, - content={"detail": "An internal error has occurred."}, + content={'detail': 'An internal error has occurred.'}, ) -@router.post("/follow_up/completions") -async def generate_follow_ups( - request: Request, form_data: dict, user=Depends(get_verified_user) -): - +@router.post('/follow_up/completions') +async def generate_follow_ups(request: Request, form_data: dict, user=Depends(get_verified_user)): if not request.app.state.config.ENABLE_FOLLOW_UP_GENERATION: return JSONResponse( status_code=status.HTTP_200_OK, - content={"detail": "Follow-up generation is disabled"}, + content={'detail': 'Follow-up generation is disabled'}, ) - if getattr(request.state, "direct", False) and hasattr(request.state, "model"): + if getattr(request.state, 'direct', False) and hasattr(request.state, 'model'): models = { - request.state.model["id"]: request.state.model, + request.state.model['id']: request.state.model, } else: models = request.app.state.MODELS - model_id = form_data["model"] + model_id = form_data['model'] if model_id not in models: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, - detail="Model not found", + detail='Model not found', ) # Check if the user has a custom task model @@ -293,26 +257,24 @@ async def generate_follow_ups( models, ) - log.debug( - f"generating chat title using model {task_model_id} for user {user.email} " - ) + log.debug(f'generating chat title using model {task_model_id} for user {user.email} ') - if request.app.state.config.FOLLOW_UP_GENERATION_PROMPT_TEMPLATE != "": + if request.app.state.config.FOLLOW_UP_GENERATION_PROMPT_TEMPLATE != '': template = request.app.state.config.FOLLOW_UP_GENERATION_PROMPT_TEMPLATE else: template = DEFAULT_FOLLOW_UP_GENERATION_PROMPT_TEMPLATE - content = follow_up_generation_template(template, form_data["messages"], user) + content = follow_up_generation_template(template, form_data['messages'], user) payload = { - "model": task_model_id, - "messages": [{"role": "user", "content": content}], - "stream": False, - "metadata": { - **(request.state.metadata if hasattr(request.state, "metadata") else {}), - "task": str(TASKS.FOLLOW_UP_GENERATION), - "task_body": form_data, - "chat_id": form_data.get("chat_id", None), + 'model': task_model_id, + 'messages': [{'role': 'user', 'content': content}], + 'stream': False, + 'metadata': { + **(request.state.metadata if hasattr(request.state, 'metadata') else {}), + 'task': str(TASKS.FOLLOW_UP_GENERATION), + 'task_body': form_data, + 'chat_id': form_data.get('chat_id', None), }, } @@ -325,36 +287,33 @@ async def generate_follow_ups( try: return await generate_chat_completion(request, form_data=payload, user=user) except Exception as e: - log.error("Exception occurred", exc_info=True) + log.error('Exception occurred', exc_info=True) return JSONResponse( status_code=status.HTTP_400_BAD_REQUEST, - content={"detail": "An internal error has occurred."}, + content={'detail': 'An internal error has occurred.'}, ) -@router.post("/tags/completions") -async def generate_chat_tags( - request: Request, form_data: dict, user=Depends(get_verified_user) -): - +@router.post('/tags/completions') +async def generate_chat_tags(request: Request, form_data: dict, user=Depends(get_verified_user)): if not request.app.state.config.ENABLE_TAGS_GENERATION: return JSONResponse( status_code=status.HTTP_200_OK, - content={"detail": "Tags generation is disabled"}, + content={'detail': 'Tags generation is disabled'}, ) - if getattr(request.state, "direct", False) and hasattr(request.state, "model"): + if getattr(request.state, 'direct', False) and hasattr(request.state, 'model'): models = { - request.state.model["id"]: request.state.model, + request.state.model['id']: request.state.model, } else: models = request.app.state.MODELS - model_id = form_data["model"] + model_id = form_data['model'] if model_id not in models: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, - detail="Model not found", + detail='Model not found', ) # Check if the user has a custom task model @@ -366,26 +325,24 @@ async def generate_chat_tags( models, ) - log.debug( - f"generating chat tags using model {task_model_id} for user {user.email} " - ) + log.debug(f'generating chat tags using model {task_model_id} for user {user.email} ') - if request.app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE != "": + if request.app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE != '': template = request.app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE else: template = DEFAULT_TAGS_GENERATION_PROMPT_TEMPLATE - content = tags_generation_template(template, form_data["messages"], user) + content = tags_generation_template(template, form_data['messages'], user) payload = { - "model": task_model_id, - "messages": [{"role": "user", "content": content}], - "stream": False, - "metadata": { - **(request.state.metadata if hasattr(request.state, "metadata") else {}), - "task": str(TASKS.TAGS_GENERATION), - "task_body": form_data, - "chat_id": form_data.get("chat_id", None), + 'model': task_model_id, + 'messages': [{'role': 'user', 'content': content}], + 'stream': False, + 'metadata': { + **(request.state.metadata if hasattr(request.state, 'metadata') else {}), + 'task': str(TASKS.TAGS_GENERATION), + 'task_body': form_data, + 'chat_id': form_data.get('chat_id', None), }, } @@ -398,29 +355,27 @@ async def generate_chat_tags( try: return await generate_chat_completion(request, form_data=payload, user=user) except Exception as e: - log.error(f"Error generating chat completion: {e}") + log.error(f'Error generating chat completion: {e}') return JSONResponse( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - content={"detail": "An internal error has occurred."}, + content={'detail': 'An internal error has occurred.'}, ) -@router.post("/image_prompt/completions") -async def generate_image_prompt( - request: Request, form_data: dict, user=Depends(get_verified_user) -): - if getattr(request.state, "direct", False) and hasattr(request.state, "model"): +@router.post('/image_prompt/completions') +async def generate_image_prompt(request: Request, form_data: dict, user=Depends(get_verified_user)): + if getattr(request.state, 'direct', False) and hasattr(request.state, 'model'): models = { - request.state.model["id"]: request.state.model, + request.state.model['id']: request.state.model, } else: models = request.app.state.MODELS - model_id = form_data["model"] + model_id = form_data['model'] if model_id not in models: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, - detail="Model not found", + detail='Model not found', ) # Check if the user has a custom task model @@ -432,26 +387,24 @@ async def generate_image_prompt( models, ) - log.debug( - f"generating image prompt using model {task_model_id} for user {user.email} " - ) + log.debug(f'generating image prompt using model {task_model_id} for user {user.email} ') - if request.app.state.config.IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE != "": + if request.app.state.config.IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE != '': template = request.app.state.config.IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE else: template = DEFAULT_IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE - content = image_prompt_generation_template(template, form_data["messages"], user) + content = image_prompt_generation_template(template, form_data['messages'], user) payload = { - "model": task_model_id, - "messages": [{"role": "user", "content": content}], - "stream": False, - "metadata": { - **(request.state.metadata if hasattr(request.state, "metadata") else {}), - "task": str(TASKS.IMAGE_PROMPT_GENERATION), - "task_body": form_data, - "chat_id": form_data.get("chat_id", None), + 'model': task_model_id, + 'messages': [{'role': 'user', 'content': content}], + 'stream': False, + 'metadata': { + **(request.state.metadata if hasattr(request.state, 'metadata') else {}), + 'task': str(TASKS.IMAGE_PROMPT_GENERATION), + 'task_body': form_data, + 'chat_id': form_data.get('chat_id', None), }, } @@ -464,48 +417,45 @@ async def generate_image_prompt( try: return await generate_chat_completion(request, form_data=payload, user=user) except Exception as e: - log.error("Exception occurred", exc_info=True) + log.error('Exception occurred', exc_info=True) return JSONResponse( status_code=status.HTTP_400_BAD_REQUEST, - content={"detail": "An internal error has occurred."}, + content={'detail': 'An internal error has occurred.'}, ) -@router.post("/queries/completions") -async def generate_queries( - request: Request, form_data: dict, user=Depends(get_verified_user) -): - - type = form_data.get("type") - if type == "web_search": +@router.post('/queries/completions') +async def generate_queries(request: Request, form_data: dict, user=Depends(get_verified_user)): + type = form_data.get('type') + if type == 'web_search': if not request.app.state.config.ENABLE_SEARCH_QUERY_GENERATION: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, - detail=f"Search query generation is disabled", + detail=f'Search query generation is disabled', ) - elif type == "retrieval": + elif type == 'retrieval': if not request.app.state.config.ENABLE_RETRIEVAL_QUERY_GENERATION: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, - detail=f"Query generation is disabled", + detail=f'Query generation is disabled', ) - if getattr(request.state, "cached_queries", None): - log.info(f"Reusing cached queries: {request.state.cached_queries}") + if getattr(request.state, 'cached_queries', None): + log.info(f'Reusing cached queries: {request.state.cached_queries}') return request.state.cached_queries - if getattr(request.state, "direct", False) and hasattr(request.state, "model"): + if getattr(request.state, 'direct', False) and hasattr(request.state, 'model'): models = { - request.state.model["id"]: request.state.model, + request.state.model['id']: request.state.model, } else: models = request.app.state.MODELS - model_id = form_data["model"] + model_id = form_data['model'] if model_id not in models: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, - detail="Model not found", + detail='Model not found', ) # Check if the user has a custom task model @@ -517,26 +467,24 @@ async def generate_queries( models, ) - log.debug( - f"generating {type} queries using model {task_model_id} for user {user.email}" - ) + log.debug(f'generating {type} queries using model {task_model_id} for user {user.email}') - if (request.app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE).strip() != "": + if (request.app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE).strip() != '': template = request.app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE else: template = DEFAULT_QUERY_GENERATION_PROMPT_TEMPLATE - content = query_generation_template(template, form_data["messages"], user) + content = query_generation_template(template, form_data['messages'], user) payload = { - "model": task_model_id, - "messages": [{"role": "user", "content": content}], - "stream": False, - "metadata": { - **(request.state.metadata if hasattr(request.state, "metadata") else {}), - "task": str(TASKS.QUERY_GENERATION), - "task_body": form_data, - "chat_id": form_data.get("chat_id", None), + 'model': task_model_id, + 'messages': [{'role': 'user', 'content': content}], + 'stream': False, + 'metadata': { + **(request.state.metadata if hasattr(request.state, 'metadata') else {}), + 'task': str(TASKS.QUERY_GENERATION), + 'task_body': form_data, + 'chat_id': form_data.get('chat_id', None), }, } @@ -551,46 +499,41 @@ async def generate_queries( except Exception as e: return JSONResponse( status_code=status.HTTP_400_BAD_REQUEST, - content={"detail": str(e)}, + content={'detail': str(e)}, ) -@router.post("/auto/completions") -async def generate_autocompletion( - request: Request, form_data: dict, user=Depends(get_verified_user) -): +@router.post('/auto/completions') +async def generate_autocompletion(request: Request, form_data: dict, user=Depends(get_verified_user)): if not request.app.state.config.ENABLE_AUTOCOMPLETE_GENERATION: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, - detail=f"Autocompletion generation is disabled", + detail=f'Autocompletion generation is disabled', ) - type = form_data.get("type") - prompt = form_data.get("prompt") - messages = form_data.get("messages") + type = form_data.get('type') + prompt = form_data.get('prompt') + messages = form_data.get('messages') if request.app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH > 0: - if ( - len(prompt) - > request.app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH - ): + if len(prompt) > request.app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, - detail=f"Input prompt exceeds maximum length of {request.app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH}", + detail=f'Input prompt exceeds maximum length of {request.app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH}', ) - if getattr(request.state, "direct", False) and hasattr(request.state, "model"): + if getattr(request.state, 'direct', False) and hasattr(request.state, 'model'): models = { - request.state.model["id"]: request.state.model, + request.state.model['id']: request.state.model, } else: models = request.app.state.MODELS - model_id = form_data["model"] + model_id = form_data['model'] if model_id not in models: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, - detail="Model not found", + detail='Model not found', ) # Check if the user has a custom task model @@ -602,11 +545,9 @@ async def generate_autocompletion( models, ) - log.debug( - f"generating autocompletion using model {task_model_id} for user {user.email}" - ) + log.debug(f'generating autocompletion using model {task_model_id} for user {user.email}') - if (request.app.state.config.AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE).strip() != "": + if (request.app.state.config.AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE).strip() != '': template = request.app.state.config.AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE else: template = DEFAULT_AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE @@ -614,14 +555,14 @@ async def generate_autocompletion( content = autocomplete_generation_template(template, prompt, messages, type, user) payload = { - "model": task_model_id, - "messages": [{"role": "user", "content": content}], - "stream": False, - "metadata": { - **(request.state.metadata if hasattr(request.state, "metadata") else {}), - "task": str(TASKS.AUTOCOMPLETE_GENERATION), - "task_body": form_data, - "chat_id": form_data.get("chat_id", None), + 'model': task_model_id, + 'messages': [{'role': 'user', 'content': content}], + 'stream': False, + 'metadata': { + **(request.state.metadata if hasattr(request.state, 'metadata') else {}), + 'task': str(TASKS.AUTOCOMPLETE_GENERATION), + 'task_body': form_data, + 'chat_id': form_data.get('chat_id', None), }, } @@ -634,30 +575,27 @@ async def generate_autocompletion( try: return await generate_chat_completion(request, form_data=payload, user=user) except Exception as e: - log.error(f"Error generating chat completion: {e}") + log.error(f'Error generating chat completion: {e}') return JSONResponse( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - content={"detail": "An internal error has occurred."}, + content={'detail': 'An internal error has occurred.'}, ) -@router.post("/emoji/completions") -async def generate_emoji( - request: Request, form_data: dict, user=Depends(get_verified_user) -): - - if getattr(request.state, "direct", False) and hasattr(request.state, "model"): +@router.post('/emoji/completions') +async def generate_emoji(request: Request, form_data: dict, user=Depends(get_verified_user)): + if getattr(request.state, 'direct', False) and hasattr(request.state, 'model'): models = { - request.state.model["id"]: request.state.model, + request.state.model['id']: request.state.model, } else: models = request.app.state.MODELS - model_id = form_data["model"] + model_id = form_data['model'] if model_id not in models: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, - detail="Model not found", + detail='Model not found', ) # Check if the user has a custom task model @@ -669,28 +607,28 @@ async def generate_emoji( models, ) - log.debug(f"generating emoji using model {task_model_id} for user {user.email} ") + log.debug(f'generating emoji using model {task_model_id} for user {user.email} ') template = DEFAULT_EMOJI_GENERATION_PROMPT_TEMPLATE - content = emoji_generation_template(template, form_data["prompt"], user) + content = emoji_generation_template(template, form_data['prompt'], user) payload = { - "model": task_model_id, - "messages": [{"role": "user", "content": content}], - "stream": False, + 'model': task_model_id, + 'messages': [{'role': 'user', 'content': content}], + 'stream': False, **( - {"max_tokens": 4} - if models[task_model_id].get("owned_by") == "ollama" + {'max_tokens': 4} + if models[task_model_id].get('owned_by') == 'ollama' else { - "max_completion_tokens": 4, + 'max_completion_tokens': 4, } ), - "metadata": { - **(request.state.metadata if hasattr(request.state, "metadata") else {}), - "task": str(TASKS.EMOJI_GENERATION), - "task_body": form_data, - "chat_id": form_data.get("chat_id", None), + 'metadata': { + **(request.state.metadata if hasattr(request.state, 'metadata') else {}), + 'task': str(TASKS.EMOJI_GENERATION), + 'task_body': form_data, + 'chat_id': form_data.get('chat_id', None), }, } @@ -705,47 +643,44 @@ async def generate_emoji( except Exception as e: return JSONResponse( status_code=status.HTTP_400_BAD_REQUEST, - content={"detail": str(e)}, + content={'detail': str(e)}, ) -@router.post("/moa/completions") -async def generate_moa_response( - request: Request, form_data: dict, user=Depends(get_verified_user) -): - - if getattr(request.state, "direct", False) and hasattr(request.state, "model"): +@router.post('/moa/completions') +async def generate_moa_response(request: Request, form_data: dict, user=Depends(get_verified_user)): + if getattr(request.state, 'direct', False) and hasattr(request.state, 'model'): models = { - request.state.model["id"]: request.state.model, + request.state.model['id']: request.state.model, } else: models = request.app.state.MODELS - model_id = form_data["model"] + model_id = form_data['model'] if model_id not in models: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, - detail="Model not found", + detail='Model not found', ) template = DEFAULT_MOA_GENERATION_PROMPT_TEMPLATE content = moa_response_generation_template( template, - form_data["prompt"], - form_data["responses"], + form_data['prompt'], + form_data['responses'], ) payload = { - "model": model_id, - "messages": [{"role": "user", "content": content}], - "stream": form_data.get("stream", False), - "metadata": { - **(request.state.metadata if hasattr(request.state, "metadata") else {}), - "chat_id": form_data.get("chat_id", None), - "task": str(TASKS.MOA_RESPONSE_GENERATION), - "task_body": form_data, + 'model': model_id, + 'messages': [{'role': 'user', 'content': content}], + 'stream': form_data.get('stream', False), + 'metadata': { + **(request.state.metadata if hasattr(request.state, 'metadata') else {}), + 'chat_id': form_data.get('chat_id', None), + 'task': str(TASKS.MOA_RESPONSE_GENERATION), + 'task_body': form_data, }, } @@ -760,5 +695,5 @@ async def generate_moa_response( except Exception as e: return JSONResponse( status_code=status.HTTP_400_BAD_REQUEST, - content={"detail": str(e)}, + content={'detail': str(e)}, ) diff --git a/backend/open_webui/routers/terminals.py b/backend/open_webui/routers/terminals.py index 322649e2d2..49c39c8bf7 100644 --- a/backend/open_webui/routers/terminals.py +++ b/backend/open_webui/routers/terminals.py @@ -23,10 +23,8 @@ log = logging.getLogger(__name__) router = APIRouter() -STREAMING_CONTENT_TYPES = ("application/octet-stream", "image/", "application/pdf") -STRIPPED_RESPONSE_HEADERS = frozenset( - ("transfer-encoding", "connection", "content-encoding", "content-length") -) +STREAMING_CONTENT_TYPES = ('application/octet-stream', 'image/', 'application/pdf') +STRIPPED_RESPONSE_HEADERS = frozenset(('transfer-encoding', 'connection', 'content-encoding', 'content-length')) def _sanitize_proxy_path(path: str) -> str | None: @@ -37,14 +35,14 @@ def _sanitize_proxy_path(path: str) -> str | None: decoded = unquote(path) normalized = posixpath.normpath(decoded) # Remove any leading slashes that would reset the base - cleaned = normalized.lstrip("/") + cleaned = normalized.lstrip('/') # Reject if normpath resolved to parent traversal or current-dir only - if cleaned.startswith("..") or cleaned == ".": + if cleaned.startswith('..') or cleaned == '.': return None return cleaned -@router.get("/") +@router.get('/') async def list_terminal_servers(request: Request, user=Depends(get_verified_user)): """Return terminal servers the authenticated user has access to.""" connections = request.app.state.config.TERMINAL_SERVER_CONNECTIONS or [] @@ -52,20 +50,19 @@ async def list_terminal_servers(request: Request, user=Depends(get_verified_user return [ { - "id": connection.get("id", ""), - "url": connection.get("url", ""), - "name": connection.get("name", ""), + 'id': connection.get('id', ''), + 'url': connection.get('url', ''), + 'name': connection.get('name', ''), } for connection in connections - if connection.get("enabled", True) - and has_connection_access(user, connection, user_group_ids) + if connection.get('enabled', True) and has_connection_access(user, connection, user_group_ids) ] -PROXY_METHODS = ["GET", "POST", "PUT", "PATCH", "DELETE", "HEAD", "OPTIONS"] +PROXY_METHODS = ['GET', 'POST', 'PUT', 'PATCH', 'DELETE', 'HEAD', 'OPTIONS'] -@router.api_route("/{server_id}/{path:path}", methods=PROXY_METHODS) +@router.api_route('/{server_id}/{path:path}', methods=PROXY_METHODS) async def proxy_terminal( server_id: str, path: str, @@ -74,56 +71,52 @@ async def proxy_terminal( ): """Proxy a request to the admin terminal server identified by *server_id*.""" connections = request.app.state.config.TERMINAL_SERVER_CONNECTIONS or [] - connection = next((c for c in connections if c.get("id") == server_id), None) + connection = next((c for c in connections if c.get('id') == server_id), None) if connection is None: - return JSONResponse( - {"error": f"Terminal server '{server_id}' not found"}, status_code=404 - ) + return JSONResponse({'error': f"Terminal server '{server_id}' not found"}, status_code=404) user_group_ids = {group.id for group in Groups.get_groups_by_member_id(user.id)} if not has_connection_access(user, connection, user_group_ids): - return JSONResponse({"error": "Access denied"}, status_code=403) + return JSONResponse({'error': 'Access denied'}, status_code=403) - base_url = (connection.get("url") or "").rstrip("/") + base_url = (connection.get('url') or '').rstrip('/') if not base_url: - return JSONResponse( - {"error": "Terminal server URL not configured"}, status_code=503 - ) + return JSONResponse({'error': 'Terminal server URL not configured'}, status_code=503) safe_path = _sanitize_proxy_path(path) if safe_path is None: - return JSONResponse({"error": "Invalid path"}, status_code=400) + return JSONResponse({'error': 'Invalid path'}, status_code=400) - target_url = f"{base_url}/{safe_path}" + target_url = f'{base_url}/{safe_path}' # Route through orchestrator policy endpoint if policy_id is set - policy_id = connection.get("policy_id") + policy_id = connection.get('policy_id') if policy_id: - target_url = f"{base_url}/p/{policy_id}/{safe_path}" + target_url = f'{base_url}/p/{policy_id}/{safe_path}' if request.query_params: - target_url += f"?{request.query_params}" + target_url += f'?{request.query_params}' - headers = {"X-User-Id": user.id} + headers = {'X-User-Id': user.id} cookies = {} - auth_type = connection.get("auth_type", "bearer") + auth_type = connection.get('auth_type', 'bearer') - if auth_type == "bearer": - headers["Authorization"] = f"Bearer {connection.get('key', '')}" - elif auth_type == "session": + if auth_type == 'bearer': + headers['Authorization'] = f'Bearer {connection.get("key", "")}' + elif auth_type == 'session': cookies = request.cookies - headers["Authorization"] = f"Bearer {request.state.token.credentials}" - elif auth_type == "system_oauth": + headers['Authorization'] = f'Bearer {request.state.token.credentials}' + elif auth_type == 'system_oauth': cookies = request.cookies - oauth_token = request.headers.get("x-oauth-access-token", "") + oauth_token = request.headers.get('x-oauth-access-token', '') if oauth_token: - headers["Authorization"] = f"Bearer {oauth_token}" + headers['Authorization'] = f'Bearer {oauth_token}' # auth_type == "none": no Authorization header - content_type = request.headers.get("content-type") + content_type = request.headers.get('content-type') if content_type: - headers["Content-Type"] = content_type + headers['Content-Type'] = content_type body = await request.body() session = aiohttp.ClientSession( @@ -140,7 +133,7 @@ async def proxy_terminal( data=body or None, ) - upstream_content_type = upstream_response.headers.get("content-type", "") + upstream_content_type = upstream_response.headers.get('content-type', '') filtered_headers = { key: value for key, value in upstream_response.headers.items() @@ -167,16 +160,12 @@ async def proxy_terminal( await upstream_response.release() await session.close() - return Response( - content=response_body, status_code=status_code, headers=filtered_headers - ) + return Response(content=response_body, status_code=status_code, headers=filtered_headers) except Exception as error: await session.close() - log.exception("Terminal proxy error: %s", error) - return JSONResponse( - {"error": f"Terminal proxy error: {error}"}, status_code=502 - ) + log.exception('Terminal proxy error: %s', error) + return JSONResponse({'error': f'Terminal proxy error: {error}'}, status_code=502) # --------------------------------------------------------------------------- @@ -201,42 +190,42 @@ async def _resolve_authenticated_connection(ws: WebSocket, server_id: str): try: raw = await asyncio.wait_for(ws.receive_text(), timeout=10.0) payload = json.loads(raw) - if payload.get("type") != "auth": - await ws.close(code=4001, reason="Expected auth message") + if payload.get('type') != 'auth': + await ws.close(code=4001, reason='Expected auth message') return None - token = payload.get("token", "") + token = payload.get('token', '') data = decode_token(token) - if data is None or "id" not in data: - await ws.close(code=4001, reason="Invalid token") + if data is None or 'id' not in data: + await ws.close(code=4001, reason='Invalid token') return None - user = Users.get_user_by_id(data["id"]) + user = Users.get_user_by_id(data['id']) if user is None: - await ws.close(code=4001, reason="User not found") + await ws.close(code=4001, reason='User not found') return None except (asyncio.TimeoutError, json.JSONDecodeError): - await ws.close(code=4001, reason="Auth timeout or invalid payload") + await ws.close(code=4001, reason='Auth timeout or invalid payload') return None except Exception: - await ws.close(code=4001, reason="Invalid token") + await ws.close(code=4001, reason='Invalid token') return None # Resolve terminal server connections = ws.app.state.config.TERMINAL_SERVER_CONNECTIONS or [] - connection = next((c for c in connections if c.get("id") == server_id), None) + connection = next((c for c in connections if c.get('id') == server_id), None) if connection is None: - await ws.close(code=4004, reason="Terminal server not found") + await ws.close(code=4004, reason='Terminal server not found') return None user_group_ids = {group.id for group in Groups.get_groups_by_member_id(user.id)} if not has_connection_access(user, connection, user_group_ids): - await ws.close(code=4003, reason="Access denied") + await ws.close(code=4003, reason='Access denied') return None return user, connection -@router.websocket("/{server_id}/api/terminals/{session_id}") +@router.websocket('/{server_id}/api/terminals/{session_id}') async def ws_terminal( ws: WebSocket, server_id: str, @@ -255,28 +244,28 @@ async def ws_terminal( return user, connection = result - base_url = (connection.get("url") or "").rstrip("/") + base_url = (connection.get('url') or '').rstrip('/') if not base_url: - await ws.close(code=4003, reason="Terminal server URL not configured") + await ws.close(code=4003, reason='Terminal server URL not configured') return # Build upstream WebSocket URL (no token in URL) - ws_base = base_url.replace("https://", "wss://").replace("http://", "ws://") + ws_base = base_url.replace('https://', 'wss://').replace('http://', 'ws://') # Route through orchestrator policy endpoint if policy_id is set - policy_id = connection.get("policy_id") + policy_id = connection.get('policy_id') upstream_params = {} # For orchestrator-backed servers, pass user_id - upstream_params["user_id"] = user.id + upstream_params['user_id'] = user.id import urllib.parse if policy_id: - upstream_url = f"{ws_base}/p/{policy_id}/api/terminals/{session_id}" + upstream_url = f'{ws_base}/p/{policy_id}/api/terminals/{session_id}' else: - upstream_url = f"{ws_base}/api/terminals/{session_id}" + upstream_url = f'{ws_base}/api/terminals/{session_id}' if upstream_params: - upstream_url += f"?{urllib.parse.urlencode(upstream_params)}" + upstream_url += f'?{urllib.parse.urlencode(upstream_params)}' session = aiohttp.ClientSession() try: @@ -285,22 +274,22 @@ async def ws_terminal( import json as _json # First-message auth to upstream terminal server - auth_type = connection.get("auth_type", "bearer") - if auth_type == "bearer": - key = connection.get("key", "") - await upstream.send_str(_json.dumps({"type": "auth", "token": key})) + auth_type = connection.get('auth_type', 'bearer') + if auth_type == 'bearer': + key = connection.get('key', '') + await upstream.send_str(_json.dumps({'type': 'auth', 'token': key})) async def _client_to_upstream(): """Forward client → upstream.""" try: while True: msg = await ws.receive() - if msg["type"] == "websocket.disconnect": + if msg['type'] == 'websocket.disconnect': break - elif "bytes" in msg and msg["bytes"]: - await upstream.send_bytes(msg["bytes"]) - elif "text" in msg and msg["text"]: - await upstream.send_str(msg["text"]) + elif 'bytes' in msg and msg['bytes']: + await upstream.send_bytes(msg['bytes']) + elif 'text' in msg and msg['text']: + await upstream.send_str(msg['text']) except Exception: pass @@ -326,7 +315,7 @@ async def ws_terminal( return_exceptions=True, ) except Exception as e: - log.exception("Terminal WebSocket proxy error: %s", e) + log.exception('Terminal WebSocket proxy error: %s', e) finally: await session.close() try: diff --git a/backend/open_webui/routers/tools.py b/backend/open_webui/routers/tools.py index 1fcc5b14bb..dcf416a606 100644 --- a/backend/open_webui/routers/tools.py +++ b/backend/open_webui/routers/tools.py @@ -59,7 +59,7 @@ def get_tool_module(request, tool_id, load_from_db=True): ############################ -@router.get("/", response_model=list[ToolUserResponse]) +@router.get('/', response_model=list[ToolUserResponse]) async def get_tools( request: Request, user=Depends(get_verified_user), @@ -69,18 +69,12 @@ async def get_tools( # Local Tools for tool in Tools.get_tools(defer_content=True, db=db): - tool_module = ( - request.app.state.TOOLS.get(tool.id) - if hasattr(request.app.state, "TOOLS") - else None - ) + tool_module = request.app.state.TOOLS.get(tool.id) if hasattr(request.app.state, 'TOOLS') else None tools.append( ToolUserResponse( **{ **tool.model_dump(), - "has_user_valves": ( - hasattr(tool_module, "UserValves") if tool_module else False - ), + 'has_user_valves': (hasattr(tool_module, 'UserValves') if tool_module else False), } ) ) @@ -88,94 +82,82 @@ async def get_tools( # OpenAPI Tool Servers server_access_grants = {} for server in await get_tool_servers(request): - server_idx = server.get("idx", 0) + server_idx = server.get('idx', 0) connections = request.app.state.config.TOOL_SERVER_CONNECTIONS if server_idx >= len(connections): log.warning( - f"Tool server index {server_idx} out of range " - f"(have {len(connections)} connections), skipping server {server.get('id')}" + f'Tool server index {server_idx} out of range ' + f'(have {len(connections)} connections), skipping server {server.get("id")}' ) continue connection = connections[server_idx] - server_config = connection.get("config", {}) + server_config = connection.get('config', {}) - server_id = f"server:{server.get('id')}" - server_access_grants[server_id] = server_config.get("access_grants", []) + server_id = f'server:{server.get("id")}' + server_access_grants[server_id] = server_config.get('access_grants', []) tools.append( ToolUserResponse( **{ - "id": server_id, - "user_id": server_id, - "name": server.get("openapi", {}) - .get("info", {}) - .get("title", "Tool Server"), - "meta": { - "description": server.get("openapi", {}) - .get("info", {}) - .get("description", ""), + 'id': server_id, + 'user_id': server_id, + 'name': server.get('openapi', {}).get('info', {}).get('title', 'Tool Server'), + 'meta': { + 'description': server.get('openapi', {}).get('info', {}).get('description', ''), }, - "updated_at": int(time.time()), - "created_at": int(time.time()), + 'updated_at': int(time.time()), + 'created_at': int(time.time()), } ) ) # MCP Tool Servers for server in request.app.state.config.TOOL_SERVER_CONNECTIONS: - if server.get("type", "openapi") == "mcp" and server.get("config", {}).get( - "enable" - ): - server_id = server.get("info", {}).get("id") - auth_type = server.get("auth_type", "none") + if server.get('type', 'openapi') == 'mcp' and server.get('config', {}).get('enable'): + server_id = server.get('info', {}).get('id') + auth_type = server.get('auth_type', 'none') session_token = None - if auth_type == "oauth_2.1": - splits = server_id.split(":") + if auth_type == 'oauth_2.1': + splits = server_id.split(':') server_id = splits[-1] if len(splits) > 1 else server_id - session_token = ( - await request.app.state.oauth_client_manager.get_oauth_token( - user.id, f"mcp:{server_id}" - ) + session_token = await request.app.state.oauth_client_manager.get_oauth_token( + user.id, f'mcp:{server_id}' ) - server_config = server.get("config", {}) + server_config = server.get('config', {}) - tool_id = f"server:mcp:{server.get('info', {}).get('id')}" - server_access_grants[tool_id] = server_config.get("access_grants", []) + tool_id = f'server:mcp:{server.get("info", {}).get("id")}' + server_access_grants[tool_id] = server_config.get('access_grants', []) tools.append( ToolUserResponse( **{ - "id": tool_id, - "user_id": tool_id, - "name": server.get("info", {}).get("name", "MCP Tool Server"), - "meta": { - "description": server.get("info", {}).get( - "description", "" - ), + 'id': tool_id, + 'user_id': tool_id, + 'name': server.get('info', {}).get('name', 'MCP Tool Server'), + 'meta': { + 'description': server.get('info', {}).get('description', ''), }, - "updated_at": int(time.time()), - "created_at": int(time.time()), + 'updated_at': int(time.time()), + 'created_at': int(time.time()), **( { - "authenticated": session_token is not None, + 'authenticated': session_token is not None, } - if auth_type == "oauth_2.1" + if auth_type == 'oauth_2.1' else {} ), } ) ) - if user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL: + if user.role == 'admin' and BYPASS_ADMIN_ACCESS_CONTROL: # Admin can see all tools return tools else: - user_group_ids = { - group.id for group in Groups.get_groups_by_member_id(user.id, db=db) - } + user_group_ids = {group.id for group in Groups.get_groups_by_member_id(user.id, db=db)} tools = [ tool for tool in tools @@ -183,17 +165,17 @@ async def get_tools( or ( has_access( user.id, - "read", + 'read', server_access_grants.get(str(tool.id), []), user_group_ids, db=db, ) - if str(tool.id).startswith("server:") + if str(tool.id).startswith('server:') else AccessGrants.has_access( user_id=user.id, - resource_type="tool", + resource_type='tool', resource_id=tool.id, - permission="read", + permission='read', user_group_ids=user_group_ids, db=db, ) @@ -207,34 +189,25 @@ async def get_tools( ############################ -@router.get("/list", response_model=list[ToolAccessResponse]) -async def get_tool_list( - user=Depends(get_verified_user), db: Session = Depends(get_session) -): - if user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL: +@router.get('/list', response_model=list[ToolAccessResponse]) +async def get_tool_list(user=Depends(get_verified_user), db: Session = Depends(get_session)): + if user.role == 'admin' and BYPASS_ADMIN_ACCESS_CONTROL: tools = Tools.get_tools(defer_content=True, db=db) else: - tools = Tools.get_tools_by_user_id(user.id, "read", defer_content=True, db=db) + tools = Tools.get_tools_by_user_id(user.id, 'read', defer_content=True, db=db) - user_group_ids = { - group.id for group in Groups.get_groups_by_member_id(user.id, db=db) - } + user_group_ids = {group.id for group in Groups.get_groups_by_member_id(user.id, db=db)} result = [] for tool in tools: has_write = ( - (user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL) + (user.role == 'admin' and BYPASS_ADMIN_ACCESS_CONTROL) or user.id == tool.user_id or any( - g.permission == "write" + g.permission == 'write' and ( - ( - g.principal_type == "user" - and (g.principal_id == user.id or g.principal_id == "*") - ) - or ( - g.principal_type == "group" and g.principal_id in user_group_ids - ) + (g.principal_type == 'user' and (g.principal_id == user.id or g.principal_id == '*')) + or (g.principal_type == 'group' and g.principal_id in user_group_ids) ) for g in tool.access_grants ) @@ -259,70 +232,59 @@ class LoadUrlForm(BaseModel): def github_url_to_raw_url(url: str) -> str: # Handle 'tree' (folder) URLs (add main.py at the end) - m1 = re.match(r"https://github\.com/([^/]+)/([^/]+)/tree/([^/]+)/(.*)", url) + m1 = re.match(r'https://github\.com/([^/]+)/([^/]+)/tree/([^/]+)/(.*)', url) if m1: org, repo, branch, path = m1.groups() - return f"https://raw.githubusercontent.com/{org}/{repo}/refs/heads/{branch}/{path.rstrip('/')}/main.py" + return f'https://raw.githubusercontent.com/{org}/{repo}/refs/heads/{branch}/{path.rstrip("/")}/main.py' # Handle 'blob' (file) URLs - m2 = re.match(r"https://github\.com/([^/]+)/([^/]+)/blob/([^/]+)/(.*)", url) + m2 = re.match(r'https://github\.com/([^/]+)/([^/]+)/blob/([^/]+)/(.*)', url) if m2: org, repo, branch, path = m2.groups() - return ( - f"https://raw.githubusercontent.com/{org}/{repo}/refs/heads/{branch}/{path}" - ) + return f'https://raw.githubusercontent.com/{org}/{repo}/refs/heads/{branch}/{path}' # No match; return as-is return url -@router.post("/load/url", response_model=Optional[dict]) -async def load_tool_from_url( - request: Request, form_data: LoadUrlForm, user=Depends(get_admin_user) -): +@router.post('/load/url', response_model=Optional[dict]) +async def load_tool_from_url(request: Request, form_data: LoadUrlForm, user=Depends(get_admin_user)): # NOTE: This is NOT a SSRF vulnerability: # This endpoint is admin-only (see get_admin_user), meant for *trusted* internal use, # and does NOT accept untrusted user input. Access is enforced by authentication. url = str(form_data.url) if not url: - raise HTTPException(status_code=400, detail="Please enter a valid URL") + raise HTTPException(status_code=400, detail='Please enter a valid URL') url = github_url_to_raw_url(url) - url_parts = url.rstrip("/").split("/") + url_parts = url.rstrip('/').split('/') file_name = url_parts[-1] tool_name = ( file_name[:-3] - if ( - file_name.endswith(".py") - and (not file_name.startswith(("main.py", "index.py", "__init__.py"))) - ) - else url_parts[-2] if len(url_parts) > 1 else "function" + if (file_name.endswith('.py') and (not file_name.startswith(('main.py', 'index.py', '__init__.py')))) + else url_parts[-2] + if len(url_parts) > 1 + else 'function' ) try: async with aiohttp.ClientSession( trust_env=True, timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT) ) as session: - async with session.get( - url, headers={"Content-Type": "application/json"} - ) as resp: + async with session.get(url, headers={'Content-Type': 'application/json'}) as resp: if resp.status != 200: - raise HTTPException( - status_code=resp.status, detail="Failed to fetch the tool" - ) + raise HTTPException(status_code=resp.status, detail='Failed to fetch the tool') data = await resp.text() if not data: - raise HTTPException( - status_code=400, detail="No data received from the URL" - ) + raise HTTPException(status_code=400, detail='No data received from the URL') return { - "name": tool_name, - "content": data, + 'name': tool_name, + 'content': data, } except Exception as e: - raise HTTPException(status_code=500, detail=f"Error importing tool: {e}") + raise HTTPException(status_code=500, detail=f'Error importing tool: {e}') ############################ @@ -330,15 +292,15 @@ async def load_tool_from_url( ############################ -@router.get("/export", response_model=list[ToolModel]) +@router.get('/export', response_model=list[ToolModel]) async def export_tools( request: Request, user=Depends(get_verified_user), db: Session = Depends(get_session), ): - if user.role != "admin" and not has_permission( + if user.role != 'admin' and not has_permission( user.id, - "workspace.tools_export", + 'workspace.tools_export', request.app.state.config.USER_PERMISSIONS, db=db, ): @@ -347,10 +309,10 @@ async def export_tools( detail=ERROR_MESSAGES.UNAUTHORIZED, ) - if user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL: + if user.role == 'admin' and BYPASS_ADMIN_ACCESS_CONTROL: return Tools.get_tools(db=db) else: - return Tools.get_tools_by_user_id(user.id, "read", db=db) + return Tools.get_tools_by_user_id(user.id, 'read', db=db) ############################ @@ -358,20 +320,18 @@ async def export_tools( ############################ -@router.post("/create", response_model=Optional[ToolResponse]) +@router.post('/create', response_model=Optional[ToolResponse]) async def create_new_tools( request: Request, form_data: ToolForm, user=Depends(get_verified_user), db: Session = Depends(get_session), ): - if user.role != "admin" and not ( - has_permission( - user.id, "workspace.tools", request.app.state.config.USER_PERMISSIONS, db=db - ) + if user.role != 'admin' and not ( + has_permission(user.id, 'workspace.tools', request.app.state.config.USER_PERMISSIONS, db=db) or has_permission( user.id, - "workspace.tools_import", + 'workspace.tools_import', request.app.state.config.USER_PERMISSIONS, db=db, ) @@ -384,7 +344,7 @@ async def create_new_tools( if not form_data.id.isidentifier(): raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, - detail="Only alphanumeric characters and underscores are allowed in the id", + detail='Only alphanumeric characters and underscores are allowed in the id', ) form_data.id = form_data.id.lower() @@ -393,9 +353,7 @@ async def create_new_tools( if tools is None: try: form_data.content = replace_imports(form_data.content) - tool_module, frontmatter = load_tool_module_by_id( - form_data.id, content=form_data.content - ) + tool_module, frontmatter = load_tool_module_by_id(form_data.id, content=form_data.content) form_data.meta.manifest = frontmatter TOOLS = request.app.state.TOOLS @@ -404,7 +362,7 @@ async def create_new_tools( specs = get_tool_specs(TOOLS[form_data.id]) tools = Tools.insert_new_tool(user.id, form_data, specs, db=db) - tool_cache_dir = CACHE_DIR / "tools" / form_data.id + tool_cache_dir = CACHE_DIR / 'tools' / form_data.id tool_cache_dir.mkdir(parents=True, exist_ok=True) if tools: @@ -412,10 +370,10 @@ async def create_new_tools( else: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, - detail=ERROR_MESSAGES.DEFAULT("Error creating tools"), + detail=ERROR_MESSAGES.DEFAULT('Error creating tools'), ) except Exception as e: - log.exception(f"Failed to load the tool by id {form_data.id}: {e}") + log.exception(f'Failed to load the tool by id {form_data.id}: {e}') raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT(str(e)), @@ -432,34 +390,32 @@ async def create_new_tools( ############################ -@router.get("/id/{id}", response_model=Optional[ToolAccessResponse]) -async def get_tools_by_id( - id: str, user=Depends(get_verified_user), db: Session = Depends(get_session) -): +@router.get('/id/{id}', response_model=Optional[ToolAccessResponse]) +async def get_tools_by_id(id: str, user=Depends(get_verified_user), db: Session = Depends(get_session)): tools = Tools.get_tool_by_id(id, db=db) if tools: if ( - user.role == "admin" + user.role == 'admin' or tools.user_id == user.id or AccessGrants.has_access( user_id=user.id, - resource_type="tool", + resource_type='tool', resource_id=tools.id, - permission="read", + permission='read', db=db, ) ): return ToolAccessResponse( **tools.model_dump(), write_access=( - (user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL) + (user.role == 'admin' and BYPASS_ADMIN_ACCESS_CONTROL) or user.id == tools.user_id or AccessGrants.has_access( user_id=user.id, - resource_type="tool", + resource_type='tool', resource_id=tools.id, - permission="write", + permission='write', db=db, ) ), @@ -481,7 +437,7 @@ async def get_tools_by_id( ############################ -@router.post("/id/{id}/update", response_model=Optional[ToolModel]) +@router.post('/id/{id}/update', response_model=Optional[ToolModel]) async def update_tools_by_id( request: Request, id: str, @@ -501,12 +457,12 @@ async def update_tools_by_id( tools.user_id != user.id and not AccessGrants.has_access( user_id=user.id, - resource_type="tool", + resource_type='tool', resource_id=tools.id, - permission="write", + permission='write', db=db, ) - and user.role != "admin" + and user.role != 'admin' ): raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, @@ -524,8 +480,8 @@ async def update_tools_by_id( specs = get_tool_specs(TOOLS[id]) updated = { - **form_data.model_dump(exclude={"id"}), - "specs": specs, + **form_data.model_dump(exclude={'id'}), + 'specs': specs, } log.debug(updated) @@ -536,7 +492,7 @@ async def update_tools_by_id( else: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, - detail=ERROR_MESSAGES.DEFAULT("Error updating tools"), + detail=ERROR_MESSAGES.DEFAULT('Error updating tools'), ) except Exception as e: @@ -555,7 +511,7 @@ class ToolAccessGrantsForm(BaseModel): access_grants: list[dict] -@router.post("/id/{id}/access/update", response_model=Optional[ToolModel]) +@router.post('/id/{id}/access/update', response_model=Optional[ToolModel]) async def update_tool_access_by_id( request: Request, id: str, @@ -574,12 +530,12 @@ async def update_tool_access_by_id( tools.user_id != user.id and not AccessGrants.has_access( user_id=user.id, - resource_type="tool", + resource_type='tool', resource_id=tools.id, - permission="write", + permission='write', db=db, ) - and user.role != "admin" + and user.role != 'admin' ): raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, @@ -591,10 +547,10 @@ async def update_tool_access_by_id( user.id, user.role, form_data.access_grants, - "sharing.public_tools", + 'sharing.public_tools', ) - AccessGrants.set_access_grants("tool", id, form_data.access_grants, db=db) + AccessGrants.set_access_grants('tool', id, form_data.access_grants, db=db) return Tools.get_tool_by_id(id, db=db) @@ -604,7 +560,7 @@ async def update_tool_access_by_id( ############################ -@router.delete("/id/{id}/delete", response_model=bool) +@router.delete('/id/{id}/delete', response_model=bool) async def delete_tools_by_id( request: Request, id: str, @@ -622,12 +578,12 @@ async def delete_tools_by_id( tools.user_id != user.id and not AccessGrants.has_access( user_id=user.id, - resource_type="tool", + resource_type='tool', resource_id=tools.id, - permission="write", + permission='write', db=db, ) - and user.role != "admin" + and user.role != 'admin' ): raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, @@ -648,10 +604,8 @@ async def delete_tools_by_id( ############################ -@router.get("/id/{id}/valves", response_model=Optional[dict]) -async def get_tools_valves_by_id( - id: str, user=Depends(get_verified_user), db: Session = Depends(get_session) -): +@router.get('/id/{id}/valves', response_model=Optional[dict]) +async def get_tools_valves_by_id(id: str, user=Depends(get_verified_user), db: Session = Depends(get_session)): tools = Tools.get_tool_by_id(id, db=db) if not tools: raise HTTPException( @@ -663,12 +617,12 @@ async def get_tools_valves_by_id( tools.user_id != user.id and not AccessGrants.has_access( user_id=user.id, - resource_type="tool", + resource_type='tool', resource_id=tools.id, - permission="write", + permission='write', db=db, ) - and user.role != "admin" + and user.role != 'admin' ): raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, @@ -690,7 +644,7 @@ async def get_tools_valves_by_id( ############################ -@router.get("/id/{id}/valves/spec", response_model=Optional[dict]) +@router.get('/id/{id}/valves/spec', response_model=Optional[dict]) async def get_tools_valves_spec_by_id( request: Request, id: str, @@ -708,12 +662,12 @@ async def get_tools_valves_spec_by_id( tools.user_id != user.id and not AccessGrants.has_access( user_id=user.id, - resource_type="tool", + resource_type='tool', resource_id=tools.id, - permission="write", + permission='write', db=db, ) - and user.role != "admin" + and user.role != 'admin' ): raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, @@ -726,7 +680,7 @@ async def get_tools_valves_spec_by_id( tools_module, _ = load_tool_module_by_id(id) request.app.state.TOOLS[id] = tools_module - if hasattr(tools_module, "Valves"): + if hasattr(tools_module, 'Valves'): Valves = tools_module.Valves schema = Valves.schema() # Resolve dynamic options for select dropdowns @@ -740,7 +694,7 @@ async def get_tools_valves_spec_by_id( ############################ -@router.post("/id/{id}/valves/update", response_model=Optional[dict]) +@router.post('/id/{id}/valves/update', response_model=Optional[dict]) async def update_tools_valves_by_id( request: Request, id: str, @@ -759,12 +713,12 @@ async def update_tools_valves_by_id( tools.user_id != user.id and not AccessGrants.has_access( user_id=user.id, - resource_type="tool", + resource_type='tool', resource_id=tools.id, - permission="write", + permission='write', db=db, ) - and user.role != "admin" + and user.role != 'admin' ): raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, @@ -777,7 +731,7 @@ async def update_tools_valves_by_id( tools_module, _ = load_tool_module_by_id(id) request.app.state.TOOLS[id] = tools_module - if not hasattr(tools_module, "Valves"): + if not hasattr(tools_module, 'Valves'): raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND, @@ -791,7 +745,7 @@ async def update_tools_valves_by_id( Tools.update_tool_valves_by_id(id, valves_dict, db=db) return valves_dict except Exception as e: - log.exception(f"Failed to update tool valves by id {id}: {e}") + log.exception(f'Failed to update tool valves by id {id}: {e}') raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT(str(e)), @@ -803,10 +757,8 @@ async def update_tools_valves_by_id( ############################ -@router.get("/id/{id}/valves/user", response_model=Optional[dict]) -async def get_tools_user_valves_by_id( - id: str, user=Depends(get_verified_user), db: Session = Depends(get_session) -): +@router.get('/id/{id}/valves/user', response_model=Optional[dict]) +async def get_tools_user_valves_by_id(id: str, user=Depends(get_verified_user), db: Session = Depends(get_session)): tools = Tools.get_tool_by_id(id, db=db) if tools: try: @@ -824,7 +776,7 @@ async def get_tools_user_valves_by_id( ) -@router.get("/id/{id}/valves/user/spec", response_model=Optional[dict]) +@router.get('/id/{id}/valves/user/spec', response_model=Optional[dict]) async def get_tools_user_valves_spec_by_id( request: Request, id: str, @@ -839,7 +791,7 @@ async def get_tools_user_valves_spec_by_id( tools_module, _ = load_tool_module_by_id(id) request.app.state.TOOLS[id] = tools_module - if hasattr(tools_module, "UserValves"): + if hasattr(tools_module, 'UserValves'): UserValves = tools_module.UserValves schema = UserValves.schema() # Resolve dynamic options for select dropdowns @@ -853,7 +805,7 @@ async def get_tools_user_valves_spec_by_id( ) -@router.post("/id/{id}/valves/user/update", response_model=Optional[dict]) +@router.post('/id/{id}/valves/user/update', response_model=Optional[dict]) async def update_tools_user_valves_by_id( request: Request, id: str, @@ -870,19 +822,17 @@ async def update_tools_user_valves_by_id( tools_module, _ = load_tool_module_by_id(id) request.app.state.TOOLS[id] = tools_module - if hasattr(tools_module, "UserValves"): + if hasattr(tools_module, 'UserValves'): UserValves = tools_module.UserValves try: form_data = {k: v for k, v in form_data.items() if v is not None} user_valves = UserValves(**form_data) user_valves_dict = user_valves.model_dump(exclude_unset=True) - Tools.update_user_valves_by_id_and_user_id( - id, user.id, user_valves_dict, db=db - ) + Tools.update_user_valves_by_id_and_user_id(id, user.id, user_valves_dict, db=db) return user_valves_dict except Exception as e: - log.exception(f"Failed to update user valves by id {id}: {e}") + log.exception(f'Failed to update user valves by id {id}: {e}') raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT(str(e)), diff --git a/backend/open_webui/routers/users.py b/backend/open_webui/routers/users.py index 143a374d9c..0bc28a2b74 100644 --- a/backend/open_webui/routers/users.py +++ b/backend/open_webui/routers/users.py @@ -54,7 +54,7 @@ router = APIRouter() PAGE_ITEM_COUNT = 30 -@router.get("/", response_model=UserGroupIdsListResponse) +@router.get('/', response_model=UserGroupIdsListResponse) async def get_users( query: Optional[str] = None, order_by: Optional[str] = None, @@ -70,38 +70,38 @@ async def get_users( filter = {} if query: - filter["query"] = query + filter['query'] = query if order_by: - filter["order_by"] = order_by + filter['order_by'] = order_by if direction: - filter["direction"] = direction + filter['direction'] = direction - filter["direction"] = direction + filter['direction'] = direction result = Users.get_users(filter=filter, skip=skip, limit=limit, db=db) - users = result["users"] - total = result["total"] + users = result['users'] + total = result['total'] # Fetch groups for all users in a single query to avoid N+1 user_ids = [user.id for user in users] user_groups = Groups.get_groups_by_member_ids(user_ids, db=db) return { - "users": [ + 'users': [ UserGroupIdsModel( **{ **user.model_dump(), - "group_ids": [group.id for group in user_groups.get(user.id, [])], + 'group_ids': [group.id for group in user_groups.get(user.id, [])], } ) for user in users ], - "total": total, + 'total': total, } -@router.get("/all", response_model=UserInfoListResponse) +@router.get('/all', response_model=UserInfoListResponse) async def get_all_users( user=Depends(get_admin_user), db: Session = Depends(get_session), @@ -109,7 +109,7 @@ async def get_all_users( return Users.get_users(db=db) -@router.get("/search", response_model=UserInfoListResponse) +@router.get('/search', response_model=UserInfoListResponse) async def search_users( query: Optional[str] = None, order_by: Optional[str] = None, @@ -125,11 +125,11 @@ async def search_users( filter = {} if query: - filter["query"] = query + filter['query'] = query if order_by: - filter["order_by"] = order_by + filter['order_by'] = order_by if direction: - filter["direction"] = direction + filter['direction'] = direction return Users.get_users(filter=filter, skip=skip, limit=limit, db=db) @@ -139,10 +139,8 @@ async def search_users( ############################ -@router.get("/groups") -async def get_user_groups( - user=Depends(get_verified_user), db: Session = Depends(get_session) -): +@router.get('/groups') +async def get_user_groups(user=Depends(get_verified_user), db: Session = Depends(get_session)): return Groups.get_groups_by_member_id(user.id, db=db) @@ -151,15 +149,13 @@ async def get_user_groups( ############################ -@router.get("/permissions") +@router.get('/permissions') async def get_user_permissisions( request: Request, user=Depends(get_verified_user), db: Session = Depends(get_session), ): - user_permissions = get_permissions( - user.id, request.app.state.config.USER_PERMISSIONS, db=db - ) + user_permissions = get_permissions(user.id, request.app.state.config.USER_PERMISSIONS, db=db) return user_permissions @@ -249,34 +245,20 @@ class UserPermissions(BaseModel): settings: SettingsPermissions -@router.get("/default/permissions", response_model=UserPermissions) +@router.get('/default/permissions', response_model=UserPermissions) async def get_default_user_permissions(request: Request, user=Depends(get_admin_user)): return { - "workspace": WorkspacePermissions( - **request.app.state.config.USER_PERMISSIONS.get("workspace", {}) - ), - "sharing": SharingPermissions( - **request.app.state.config.USER_PERMISSIONS.get("sharing", {}) - ), - "access_grants": AccessGrantsPermissions( - **request.app.state.config.USER_PERMISSIONS.get("access_grants", {}) - ), - "chat": ChatPermissions( - **request.app.state.config.USER_PERMISSIONS.get("chat", {}) - ), - "features": FeaturesPermissions( - **request.app.state.config.USER_PERMISSIONS.get("features", {}) - ), - "settings": SettingsPermissions( - **request.app.state.config.USER_PERMISSIONS.get("settings", {}) - ), + 'workspace': WorkspacePermissions(**request.app.state.config.USER_PERMISSIONS.get('workspace', {})), + 'sharing': SharingPermissions(**request.app.state.config.USER_PERMISSIONS.get('sharing', {})), + 'access_grants': AccessGrantsPermissions(**request.app.state.config.USER_PERMISSIONS.get('access_grants', {})), + 'chat': ChatPermissions(**request.app.state.config.USER_PERMISSIONS.get('chat', {})), + 'features': FeaturesPermissions(**request.app.state.config.USER_PERMISSIONS.get('features', {})), + 'settings': SettingsPermissions(**request.app.state.config.USER_PERMISSIONS.get('settings', {})), } -@router.post("/default/permissions") -async def update_default_user_permissions( - request: Request, form_data: UserPermissions, user=Depends(get_admin_user) -): +@router.post('/default/permissions') +async def update_default_user_permissions(request: Request, form_data: UserPermissions, user=Depends(get_admin_user)): request.app.state.config.USER_PERMISSIONS = form_data.model_dump() return request.app.state.config.USER_PERMISSIONS @@ -286,10 +268,8 @@ async def update_default_user_permissions( ############################ -@router.get("/user/settings", response_model=Optional[UserSettings]) -async def get_user_settings_by_session_user( - user=Depends(get_verified_user), db: Session = Depends(get_session) -): +@router.get('/user/settings', response_model=Optional[UserSettings]) +async def get_user_settings_by_session_user(user=Depends(get_verified_user), db: Session = Depends(get_session)): user = Users.get_user_by_id(user.id, db=db) if user: return user.settings @@ -305,7 +285,7 @@ async def get_user_settings_by_session_user( ############################ -@router.post("/user/settings/update", response_model=UserSettings) +@router.post('/user/settings/update', response_model=UserSettings) async def update_user_settings_by_session_user( request: Request, form_data: UserSettings, @@ -313,19 +293,19 @@ async def update_user_settings_by_session_user( db: Session = Depends(get_session), ): updated_user_settings = form_data.model_dump() - ui_settings = updated_user_settings.get("ui") + ui_settings = updated_user_settings.get('ui') if ( - user.role != "admin" + user.role != 'admin' and ui_settings is not None - and "toolServers" in ui_settings.keys() + and 'toolServers' in ui_settings.keys() and not has_permission( user.id, - "features.direct_tool_servers", + 'features.direct_tool_servers', request.app.state.config.USER_PERMISSIONS, ) ): # If the user is not an admin and does not have permission to use tool servers, remove the key - updated_user_settings["ui"].pop("toolServers", None) + updated_user_settings['ui'].pop('toolServers', None) user = Users.update_user_settings_by_id(user.id, updated_user_settings, db=db) if user: @@ -342,7 +322,7 @@ async def update_user_settings_by_session_user( ############################ -@router.get("/user/status") +@router.get('/user/status') async def get_user_status_by_session_user( request: Request, user=Depends(get_verified_user), @@ -368,7 +348,7 @@ async def get_user_status_by_session_user( ############################ -@router.post("/user/status/update") +@router.post('/user/status/update') async def update_user_status_by_session_user( request: Request, form_data: UserStatus, @@ -396,10 +376,8 @@ async def update_user_status_by_session_user( ############################ -@router.get("/user/info", response_model=Optional[dict]) -async def get_user_info_by_session_user( - user=Depends(get_verified_user), db: Session = Depends(get_session) -): +@router.get('/user/info', response_model=Optional[dict]) +async def get_user_info_by_session_user(user=Depends(get_verified_user), db: Session = Depends(get_session)): user = Users.get_user_by_id(user.id, db=db) if user: return user.info @@ -415,7 +393,7 @@ async def get_user_info_by_session_user( ############################ -@router.post("/user/info/update", response_model=Optional[dict]) +@router.post('/user/info/update', response_model=Optional[dict]) async def update_user_info_by_session_user( form_data: dict, user=Depends(get_verified_user), db: Session = Depends(get_session) ): @@ -424,9 +402,7 @@ async def update_user_info_by_session_user( if user.info is None: user.info = {} - user = Users.update_user_by_id( - user.id, {"info": {**user.info, **form_data}}, db=db - ) + user = Users.update_user_by_id(user.id, {'info': {**user.info, **form_data}}, db=db) if user: return user.info else: @@ -452,17 +428,15 @@ class UserActiveResponse(UserStatus): groups: Optional[list] = [] is_active: bool - model_config = ConfigDict(extra="allow") + model_config = ConfigDict(extra='allow') -@router.get("/{user_id}", response_model=UserActiveResponse) -async def get_user_by_id( - user_id: str, user=Depends(get_admin_user), db: Session = Depends(get_session) -): +@router.get('/{user_id}', response_model=UserActiveResponse) +async def get_user_by_id(user_id: str, user=Depends(get_admin_user), db: Session = Depends(get_session)): # Check if user_id is a shared chat # If it is, get the user_id from the chat - if user_id.startswith("shared-"): - chat_id = user_id.replace("shared-", "") + if user_id.startswith('shared-'): + chat_id = user_id.replace('shared-', '') chat = Chats.get_chat_by_id(chat_id) if chat: user_id = chat.user_id @@ -478,8 +452,8 @@ async def get_user_by_id( return UserActiveResponse( **{ **user.model_dump(), - "groups": [{"id": group.id, "name": group.name} for group in groups], - "is_active": Users.is_user_active(user_id, db=db), + 'groups': [{'id': group.id, 'name': group.name} for group in groups], + 'is_active': Users.is_user_active(user_id, db=db), } ) else: @@ -489,18 +463,16 @@ async def get_user_by_id( ) -@router.get("/{user_id}/info", response_model=UserInfoResponse) -async def get_user_info_by_id( - user_id: str, user=Depends(get_verified_user), db: Session = Depends(get_session) -): +@router.get('/{user_id}/info', response_model=UserInfoResponse) +async def get_user_info_by_id(user_id: str, user=Depends(get_verified_user), db: Session = Depends(get_session)): user = Users.get_user_by_id(user_id, db=db) if user: groups = Groups.get_groups_by_member_id(user_id, db=db) return UserInfoResponse( **{ **user.model_dump(), - "groups": [{"id": group.id, "name": group.name} for group in groups], - "is_active": Users.is_user_active(user_id, db=db), + 'groups': [{'id': group.id, 'name': group.name} for group in groups], + 'is_active': Users.is_user_active(user_id, db=db), } ) else: @@ -510,10 +482,8 @@ async def get_user_info_by_id( ) -@router.get("/{user_id}/oauth/sessions") -async def get_user_oauth_sessions_by_id( - user_id: str, user=Depends(get_admin_user), db: Session = Depends(get_session) -): +@router.get('/{user_id}/oauth/sessions') +async def get_user_oauth_sessions_by_id(user_id: str, user=Depends(get_admin_user), db: Session = Depends(get_session)): sessions = OAuthSessions.get_sessions_by_user_id(user_id, db=db) if sessions and len(sessions) > 0: return sessions @@ -529,32 +499,32 @@ async def get_user_oauth_sessions_by_id( ############################ -@router.get("/{user_id}/profile/image") +@router.get('/{user_id}/profile/image') def get_user_profile_image_by_id(user_id: str, user=Depends(get_verified_user)): user = Users.get_user_by_id(user_id) if user: if user.profile_image_url: # check if it's url or base64 - if user.profile_image_url.startswith("http"): + if user.profile_image_url.startswith('http'): return Response( status_code=status.HTTP_302_FOUND, - headers={"Location": user.profile_image_url}, + headers={'Location': user.profile_image_url}, ) - elif user.profile_image_url.startswith("data:image"): + elif user.profile_image_url.startswith('data:image'): try: - header, base64_data = user.profile_image_url.split(",", 1) + header, base64_data = user.profile_image_url.split(',', 1) image_data = base64.b64decode(base64_data) image_buffer = io.BytesIO(image_data) - media_type = header.split(";")[0].lstrip("data:") + media_type = header.split(';')[0].lstrip('data:') return StreamingResponse( image_buffer, media_type=media_type, - headers={"Content-Disposition": "inline"}, + headers={'Content-Disposition': 'inline'}, ) except Exception as e: pass - return FileResponse(f"{STATIC_DIR}/user.png") + return FileResponse(f'{STATIC_DIR}/user.png') else: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, @@ -567,12 +537,12 @@ def get_user_profile_image_by_id(user_id: str, user=Depends(get_verified_user)): ############################ -@router.get("/{user_id}/active", response_model=dict) +@router.get('/{user_id}/active', response_model=dict) async def get_user_active_status_by_id( user_id: str, user=Depends(get_verified_user), db: Session = Depends(get_session) ): return { - "active": Users.is_user_active(user_id, db=db), + 'active': Users.is_user_active(user_id, db=db), } @@ -581,7 +551,7 @@ async def get_user_active_status_by_id( ############################ -@router.post("/{user_id}/update", response_model=Optional[UserModel]) +@router.post('/{user_id}/update', response_model=Optional[UserModel]) async def update_user_by_id( user_id: str, form_data: UserUpdateForm, @@ -600,7 +570,7 @@ async def update_user_by_id( detail=ERROR_MESSAGES.ACTION_PROHIBITED, ) - if form_data.role != "admin": + if form_data.role != 'admin': # If the primary admin is trying to change their own role, prevent it raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, @@ -608,10 +578,10 @@ async def update_user_by_id( ) except Exception as e: - log.error(f"Error checking primary admin status: {e}") + log.error(f'Error checking primary admin status: {e}') raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail="Could not verify primary admin status.", + detail='Could not verify primary admin status.', ) user = Users.get_user_by_id(user_id, db=db) @@ -638,10 +608,10 @@ async def update_user_by_id( updated_user = Users.update_user_by_id( user_id, { - "role": form_data.role, - "name": form_data.name, - "email": form_data.email.lower(), - "profile_image_url": form_data.profile_image_url, + 'role': form_data.role, + 'name': form_data.name, + 'email': form_data.email.lower(), + 'profile_image_url': form_data.profile_image_url, }, db=db, ) @@ -665,10 +635,8 @@ async def update_user_by_id( ############################ -@router.delete("/{user_id}", response_model=bool) -async def delete_user_by_id( - user_id: str, user=Depends(get_admin_user), db: Session = Depends(get_session) -): +@router.delete('/{user_id}', response_model=bool) +async def delete_user_by_id(user_id: str, user=Depends(get_admin_user), db: Session = Depends(get_session)): # Prevent deletion of the primary admin user try: first_user = Users.get_first_user(db=db) @@ -678,10 +646,10 @@ async def delete_user_by_id( detail=ERROR_MESSAGES.ACTION_PROHIBITED, ) except Exception as e: - log.error(f"Error checking primary admin status: {e}") + log.error(f'Error checking primary admin status: {e}') raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail="Could not verify primary admin status.", + detail='Could not verify primary admin status.', ) if user.id != user_id: @@ -707,8 +675,6 @@ async def delete_user_by_id( ############################ -@router.get("/{user_id}/groups") -async def get_user_groups_by_id( - user_id: str, user=Depends(get_admin_user), db: Session = Depends(get_session) -): +@router.get('/{user_id}/groups') +async def get_user_groups_by_id(user_id: str, user=Depends(get_admin_user), db: Session = Depends(get_session)): return Groups.get_groups_by_member_id(user_id, db=db) diff --git a/backend/open_webui/routers/utils.py b/backend/open_webui/routers/utils.py index 49f3a5ca55..7ea4150021 100644 --- a/backend/open_webui/routers/utils.py +++ b/backend/open_webui/routers/utils.py @@ -20,7 +20,7 @@ log = logging.getLogger(__name__) router = APIRouter() -@router.get("/gravatar") +@router.get('/gravatar') async def get_gravatar(email: str, user=Depends(get_verified_user)): return get_gravatar_url(email) @@ -29,33 +29,31 @@ class CodeForm(BaseModel): code: str -@router.post("/code/format") +@router.post('/code/format') async def format_code(form_data: CodeForm, user=Depends(get_admin_user)): try: formatted_code = black.format_str(form_data.code, mode=black.Mode()) - return {"code": formatted_code} + return {'code': formatted_code} except black.NothingChanged: - return {"code": form_data.code} + return {'code': form_data.code} except Exception as e: raise HTTPException(status_code=400, detail=str(e)) -@router.post("/code/execute") -async def execute_code( - request: Request, form_data: CodeForm, user=Depends(get_verified_user) -): - if request.app.state.config.CODE_EXECUTION_ENGINE == "jupyter": +@router.post('/code/execute') +async def execute_code(request: Request, form_data: CodeForm, user=Depends(get_verified_user)): + if request.app.state.config.CODE_EXECUTION_ENGINE == 'jupyter': output = await execute_code_jupyter( request.app.state.config.CODE_EXECUTION_JUPYTER_URL, form_data.code, ( request.app.state.config.CODE_EXECUTION_JUPYTER_AUTH_TOKEN - if request.app.state.config.CODE_EXECUTION_JUPYTER_AUTH == "token" + if request.app.state.config.CODE_EXECUTION_JUPYTER_AUTH == 'token' else None ), ( request.app.state.config.CODE_EXECUTION_JUPYTER_AUTH_PASSWORD - if request.app.state.config.CODE_EXECUTION_JUPYTER_AUTH == "password" + if request.app.state.config.CODE_EXECUTION_JUPYTER_AUTH == 'password' else None ), request.app.state.config.CODE_EXECUTION_JUPYTER_TIMEOUT, @@ -65,7 +63,7 @@ async def execute_code( else: raise HTTPException( status_code=400, - detail="Code execution engine not supported", + detail='Code execution engine not supported', ) @@ -73,11 +71,9 @@ class MarkdownForm(BaseModel): md: str -@router.post("/markdown") -async def get_html_from_markdown( - form_data: MarkdownForm, user=Depends(get_verified_user) -): - return {"html": markdown.markdown(form_data.md)} +@router.post('/markdown') +async def get_html_from_markdown(form_data: MarkdownForm, user=Depends(get_verified_user)): + return {'html': markdown.markdown(form_data.md)} class ChatForm(BaseModel): @@ -85,24 +81,22 @@ class ChatForm(BaseModel): messages: list[dict] -@router.post("/pdf") -async def download_chat_as_pdf( - form_data: ChatTitleMessagesForm, user=Depends(get_verified_user) -): +@router.post('/pdf') +async def download_chat_as_pdf(form_data: ChatTitleMessagesForm, user=Depends(get_verified_user)): try: pdf_bytes = PDFGenerator(form_data).generate_chat_pdf() return Response( content=pdf_bytes, - media_type="application/pdf", - headers={"Content-Disposition": "attachment;filename=chat.pdf"}, + media_type='application/pdf', + headers={'Content-Disposition': 'attachment;filename=chat.pdf'}, ) except Exception as e: - log.exception(f"Error generating PDF: {e}") + log.exception(f'Error generating PDF: {e}') raise HTTPException(status_code=400, detail=str(e)) -@router.get("/db/download") +@router.get('/db/download') async def download_db(user=Depends(get_admin_user)): if not ENABLE_ADMIN_EXPORT: raise HTTPException( @@ -111,13 +105,13 @@ async def download_db(user=Depends(get_admin_user)): ) from open_webui.internal.db import engine - if engine.name != "sqlite": + if engine.name != 'sqlite': raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DB_NOT_SQLITE, ) return FileResponse( engine.url.database, - media_type="application/octet-stream", - filename="webui.db", + media_type='application/octet-stream', + filename='webui.db', ) diff --git a/backend/open_webui/socket/main.py b/backend/open_webui/socket/main.py index 6cf8217500..1518193da8 100644 --- a/backend/open_webui/socket/main.py +++ b/backend/open_webui/socket/main.py @@ -58,24 +58,20 @@ log = logging.getLogger(__name__) REDIS = None # Configure CORS for Socket.IO -SOCKETIO_CORS_ORIGINS = "*" if CORS_ALLOW_ORIGIN == ["*"] else CORS_ALLOW_ORIGIN +SOCKETIO_CORS_ORIGINS = '*' if CORS_ALLOW_ORIGIN == ['*'] else CORS_ALLOW_ORIGIN -if WEBSOCKET_MANAGER == "redis": +if WEBSOCKET_MANAGER == 'redis': if WEBSOCKET_SENTINEL_HOSTS: mgr = socketio.AsyncRedisManager( - get_sentinel_url_from_env( - WEBSOCKET_REDIS_URL, WEBSOCKET_SENTINEL_HOSTS, WEBSOCKET_SENTINEL_PORT - ), + get_sentinel_url_from_env(WEBSOCKET_REDIS_URL, WEBSOCKET_SENTINEL_HOSTS, WEBSOCKET_SENTINEL_PORT), redis_options=WEBSOCKET_REDIS_OPTIONS, ) else: - mgr = socketio.AsyncRedisManager( - WEBSOCKET_REDIS_URL, redis_options=WEBSOCKET_REDIS_OPTIONS - ) + mgr = socketio.AsyncRedisManager(WEBSOCKET_REDIS_URL, redis_options=WEBSOCKET_REDIS_OPTIONS) sio = socketio.AsyncServer( cors_allowed_origins=SOCKETIO_CORS_ORIGINS, - async_mode="asgi", - transports=(["websocket"] if ENABLE_WEBSOCKET_SUPPORT else ["polling"]), + async_mode='asgi', + transports=(['websocket'] if ENABLE_WEBSOCKET_SUPPORT else ['polling']), allow_upgrades=ENABLE_WEBSOCKET_SUPPORT, always_connect=True, client_manager=mgr, @@ -87,8 +83,8 @@ if WEBSOCKET_MANAGER == "redis": else: sio = socketio.AsyncServer( cors_allowed_origins=SOCKETIO_CORS_ORIGINS, - async_mode="asgi", - transports=(["websocket"] if ENABLE_WEBSOCKET_SUPPORT else ["polling"]), + async_mode='asgi', + transports=(['websocket'] if ENABLE_WEBSOCKET_SUPPORT else ['polling']), allow_upgrades=ENABLE_WEBSOCKET_SUPPORT, always_connect=True, logger=WEBSOCKET_SERVER_LOGGING, @@ -104,36 +100,32 @@ SESSION_POOL_TIMEOUT = 120 # seconds without heartbeat before session is reaped # Dictionary to maintain the user pool -if WEBSOCKET_MANAGER == "redis": - log.debug("Using Redis to manage websockets.") +if WEBSOCKET_MANAGER == 'redis': + log.debug('Using Redis to manage websockets.') REDIS = get_redis_connection( redis_url=WEBSOCKET_REDIS_URL, - redis_sentinels=get_sentinels_from_env( - WEBSOCKET_SENTINEL_HOSTS, WEBSOCKET_SENTINEL_PORT - ), + redis_sentinels=get_sentinels_from_env(WEBSOCKET_SENTINEL_HOSTS, WEBSOCKET_SENTINEL_PORT), redis_cluster=WEBSOCKET_REDIS_CLUSTER, async_mode=True, ) - redis_sentinels = get_sentinels_from_env( - WEBSOCKET_SENTINEL_HOSTS, WEBSOCKET_SENTINEL_PORT - ) + redis_sentinels = get_sentinels_from_env(WEBSOCKET_SENTINEL_HOSTS, WEBSOCKET_SENTINEL_PORT) MODELS = RedisDict( - f"{REDIS_KEY_PREFIX}:models", + f'{REDIS_KEY_PREFIX}:models', redis_url=WEBSOCKET_REDIS_URL, redis_sentinels=redis_sentinels, redis_cluster=WEBSOCKET_REDIS_CLUSTER, ) SESSION_POOL = RedisDict( - f"{REDIS_KEY_PREFIX}:session_pool", + f'{REDIS_KEY_PREFIX}:session_pool', redis_url=WEBSOCKET_REDIS_URL, redis_sentinels=redis_sentinels, redis_cluster=WEBSOCKET_REDIS_CLUSTER, ) USAGE_POOL = RedisDict( - f"{REDIS_KEY_PREFIX}:usage_pool", + f'{REDIS_KEY_PREFIX}:usage_pool', redis_url=WEBSOCKET_REDIS_URL, redis_sentinels=redis_sentinels, redis_cluster=WEBSOCKET_REDIS_CLUSTER, @@ -141,7 +133,7 @@ if WEBSOCKET_MANAGER == "redis": clean_up_lock = RedisLock( redis_url=WEBSOCKET_REDIS_URL, - lock_name=f"{REDIS_KEY_PREFIX}:usage_cleanup_lock", + lock_name=f'{REDIS_KEY_PREFIX}:usage_cleanup_lock', timeout_secs=WEBSOCKET_REDIS_LOCK_TIMEOUT, redis_sentinels=redis_sentinels, redis_cluster=WEBSOCKET_REDIS_CLUSTER, @@ -152,7 +144,7 @@ if WEBSOCKET_MANAGER == "redis": session_cleanup_lock = RedisLock( redis_url=WEBSOCKET_REDIS_URL, - lock_name=f"{REDIS_KEY_PREFIX}:session_cleanup_lock", + lock_name=f'{REDIS_KEY_PREFIX}:session_cleanup_lock', timeout_secs=WEBSOCKET_REDIS_LOCK_TIMEOUT, redis_sentinels=redis_sentinels, redis_cluster=WEBSOCKET_REDIS_CLUSTER, @@ -172,29 +164,27 @@ else: YDOC_MANAGER = YdocManager( redis=REDIS, - redis_key_prefix=f"{REDIS_KEY_PREFIX}:ydoc:documents", + redis_key_prefix=f'{REDIS_KEY_PREFIX}:ydoc:documents', ) async def periodic_session_pool_cleanup(): """Reap orphaned SESSION_POOL entries that missed heartbeats (e.g. crashed instance).""" if not session_aquire_func(): - log.debug("Session cleanup lock held by another node. Skipping.") + log.debug('Session cleanup lock held by another node. Skipping.') return try: while True: if not session_renew_func(): - log.error("Unable to renew session cleanup lock. Exiting.") + log.error('Unable to renew session cleanup lock. Exiting.') return now = int(time.time()) for sid in list(SESSION_POOL.keys()): entry = SESSION_POOL.get(sid) - if entry and now - entry.get("last_seen_at", 0) > SESSION_POOL_TIMEOUT: - log.warning( - f"Reaping orphaned session {sid} (user {entry.get('id')})" - ) + if entry and now - entry.get('last_seen_at', 0) > SESSION_POOL_TIMEOUT: + log.warning(f'Reaping orphaned session {sid} (user {entry.get("id")})') del SESSION_POOL[sid] await asyncio.sleep(SESSION_POOL_TIMEOUT) finally: @@ -203,46 +193,38 @@ async def periodic_session_pool_cleanup(): async def periodic_usage_pool_cleanup(): max_retries = 2 - retry_delay = random.uniform( - WEBSOCKET_REDIS_LOCK_TIMEOUT / 2, WEBSOCKET_REDIS_LOCK_TIMEOUT - ) + retry_delay = random.uniform(WEBSOCKET_REDIS_LOCK_TIMEOUT / 2, WEBSOCKET_REDIS_LOCK_TIMEOUT) for attempt in range(max_retries + 1): if aquire_func(): break else: if attempt < max_retries: - log.debug( - f"Cleanup lock already exists. Retry {attempt + 1} after {retry_delay}s..." - ) + log.debug(f'Cleanup lock already exists. Retry {attempt + 1} after {retry_delay}s...') await asyncio.sleep(retry_delay) else: - log.warning( - "Failed to acquire cleanup lock after retries. Skipping cleanup." - ) + log.warning('Failed to acquire cleanup lock after retries. Skipping cleanup.') return - log.debug("Running periodic_cleanup") + log.debug('Running periodic_cleanup') try: while True: if not renew_func(): - log.error(f"Unable to renew cleanup lock. Exiting usage pool cleanup.") - raise Exception("Unable to renew usage pool cleanup lock.") + log.error(f'Unable to renew cleanup lock. Exiting usage pool cleanup.') + raise Exception('Unable to renew usage pool cleanup lock.') now = int(time.time()) send_usage = False for model_id, connections in list(USAGE_POOL.items()): # Creating a list of sids to remove if they have timed out expired_sids = [ - sid - for sid, details in connections.items() - if now - details["updated_at"] > TIMEOUT_DURATION + sid for sid, details in connections.items() if now - details['updated_at'] > TIMEOUT_DURATION ] for sid in expired_sids: del connections[sid] if not connections: - log.debug(f"Cleaning up model {model_id} from usage pool") + log.debug(f'Cleaning up model {model_id} from usage pool') del USAGE_POOL[model_id] else: USAGE_POOL[model_id] = connections @@ -255,7 +237,7 @@ async def periodic_usage_pool_cleanup(): app = socketio.ASGIApp( sio, - socketio_path="/ws/socket.io", + socketio_path='/ws/socket.io', ) @@ -268,14 +250,14 @@ def get_models_in_use(): def get_user_id_from_session_pool(sid): user = SESSION_POOL.get(sid) if user: - return user["id"] + return user['id'] return None def get_session_ids_from_room(room): """Get all session IDs from a specific room.""" active_session_ids = sio.manager.get_participants( - namespace="/", + namespace='/', room=room, ) return [session_id[0] for session_id in active_session_ids] @@ -287,7 +269,7 @@ def get_user_ids_from_room(room): active_user_ids = list( set( [ - SESSION_POOL.get(session_id)["id"] + SESSION_POOL.get(session_id)['id'] for session_id in active_session_ids if SESSION_POOL.get(session_id) is not None ] @@ -307,9 +289,9 @@ async def emit_to_users(event: str, data: dict, user_ids: list[str]): """ try: for user_id in user_ids: - await sio.emit(event, data, room=f"user:{user_id}") + await sio.emit(event, data, room=f'user:{user_id}') except Exception as e: - log.debug(f"Failed to emit event {event} to users {user_ids}: {e}") + log.debug(f'Failed to emit event {event} to users {user_ids}: {e}') async def enter_room_for_users(room: str, user_ids: list[str]): @@ -321,163 +303,162 @@ async def enter_room_for_users(room: str, user_ids: list[str]): """ try: for user_id in user_ids: - session_ids = get_session_ids_from_room(f"user:{user_id}") + session_ids = get_session_ids_from_room(f'user:{user_id}') for sid in session_ids: await sio.enter_room(sid, room) except Exception as e: - log.debug(f"Failed to make users {user_ids} join room {room}: {e}") + log.debug(f'Failed to make users {user_ids} join room {room}: {e}') -@sio.on("usage") +@sio.on('usage') async def usage(sid, data): if sid in SESSION_POOL: - model_id = data["model"] + model_id = data['model'] # Record the timestamp for the last update current_time = int(time.time()) # Store the new usage data and task USAGE_POOL[model_id] = { **(USAGE_POOL[model_id] if model_id in USAGE_POOL else {}), - sid: {"updated_at": current_time}, + sid: {'updated_at': current_time}, } @sio.event async def connect(sid, environ, auth): user = None - if auth and "token" in auth: - data = decode_token(auth["token"]) + if auth and 'token' in auth: + data = decode_token(auth['token']) - if data is not None and "id" in data: - user = Users.get_user_by_id(data["id"]) + if data is not None and 'id' in data: + user = Users.get_user_by_id(data['id']) if user: SESSION_POOL[sid] = { **user.model_dump( exclude=[ - "profile_image_url", - "profile_banner_image_url", - "date_of_birth", - "bio", - "gender", + 'profile_image_url', + 'profile_banner_image_url', + 'date_of_birth', + 'bio', + 'gender', ] ), - "last_seen_at": int(time.time()), + 'last_seen_at': int(time.time()), } - await sio.enter_room(sid, f"user:{user.id}") + await sio.enter_room(sid, f'user:{user.id}') -@sio.on("user-join") +@sio.on('user-join') async def user_join(sid, data): - - auth = data["auth"] if "auth" in data else None - if not auth or "token" not in auth: + auth = data['auth'] if 'auth' in data else None + if not auth or 'token' not in auth: return - data = decode_token(auth["token"]) - if data is None or "id" not in data: + data = decode_token(auth['token']) + if data is None or 'id' not in data: return - user = Users.get_user_by_id(data["id"]) + user = Users.get_user_by_id(data['id']) if not user: return SESSION_POOL[sid] = { **user.model_dump( exclude=[ - "profile_image_url", - "profile_banner_image_url", - "date_of_birth", - "bio", - "gender", + 'profile_image_url', + 'profile_banner_image_url', + 'date_of_birth', + 'bio', + 'gender', ] ), - "last_seen_at": int(time.time()), + 'last_seen_at': int(time.time()), } - await sio.enter_room(sid, f"user:{user.id}") + await sio.enter_room(sid, f'user:{user.id}') # Join all the channels only if user has channels permission - if user.role == "admin" or has_permission(user.id, "features.channels"): + if user.role == 'admin' or has_permission(user.id, 'features.channels'): channels = Channels.get_channels_by_user_id(user.id) - log.debug(f"{channels=}") + log.debug(f'{channels=}') for channel in channels: - await sio.enter_room(sid, f"channel:{channel.id}") + await sio.enter_room(sid, f'channel:{channel.id}') - return {"id": user.id, "name": user.name} + return {'id': user.id, 'name': user.name} -@sio.on("heartbeat") +@sio.on('heartbeat') async def heartbeat(sid, data): user = SESSION_POOL.get(sid) if user: - SESSION_POOL[sid] = {**user, "last_seen_at": int(time.time())} - Users.update_last_active_by_id(user["id"]) + SESSION_POOL[sid] = {**user, 'last_seen_at': int(time.time())} + Users.update_last_active_by_id(user['id']) -@sio.on("join-channels") +@sio.on('join-channels') async def join_channel(sid, data): - auth = data["auth"] if "auth" in data else None - if not auth or "token" not in auth: + auth = data['auth'] if 'auth' in data else None + if not auth or 'token' not in auth: return - data = decode_token(auth["token"]) - if data is None or "id" not in data: + data = decode_token(auth['token']) + if data is None or 'id' not in data: return - user = Users.get_user_by_id(data["id"]) + user = Users.get_user_by_id(data['id']) if not user: return # Join all the channels only if user has channels permission - if user.role == "admin" or has_permission(user.id, "features.channels"): + if user.role == 'admin' or has_permission(user.id, 'features.channels'): channels = Channels.get_channels_by_user_id(user.id) - log.debug(f"{channels=}") + log.debug(f'{channels=}') for channel in channels: - await sio.enter_room(sid, f"channel:{channel.id}") + await sio.enter_room(sid, f'channel:{channel.id}') -@sio.on("join-note") +@sio.on('join-note') async def join_note(sid, data): - auth = data["auth"] if "auth" in data else None - if not auth or "token" not in auth: + auth = data['auth'] if 'auth' in data else None + if not auth or 'token' not in auth: return - token_data = decode_token(auth["token"]) - if token_data is None or "id" not in token_data: + token_data = decode_token(auth['token']) + if token_data is None or 'id' not in token_data: return - user = Users.get_user_by_id(token_data["id"]) + user = Users.get_user_by_id(token_data['id']) if not user: return - note = Notes.get_note_by_id(data["note_id"]) + note = Notes.get_note_by_id(data['note_id']) if not note: - log.error(f"Note {data['note_id']} not found for user {user.id}") + log.error(f'Note {data["note_id"]} not found for user {user.id}') return if ( - user.role != "admin" + user.role != 'admin' and user.id != note.user_id and not AccessGrants.has_access( user_id=user.id, - resource_type="note", + resource_type='note', resource_id=note.id, - permission="read", + permission='read', ) ): - log.error(f"User {user.id} does not have access to note {data['note_id']}") + log.error(f'User {user.id} does not have access to note {data["note_id"]}') return - log.debug(f"Joining note {note.id} for user {user.id}") - await sio.enter_room(sid, f"note:{note.id}") + log.debug(f'Joining note {note.id} for user {user.id}') + await sio.enter_room(sid, f'note:{note.id}') -@sio.on("events:channel") +@sio.on('events:channel') async def channel_events(sid, data): - room = f"channel:{data['channel_id']}" + room = f'channel:{data["channel_id"]}' participants = sio.manager.get_participants( - namespace="/", + namespace='/', room=room, ) @@ -485,27 +466,27 @@ async def channel_events(sid, data): if sid not in sids: return - event_data = data["data"] - event_type = event_data["type"] + event_data = data['data'] + event_type = event_data['type'] user = SESSION_POOL.get(sid) if not user: return - if event_type == "typing": + if event_type == 'typing': await sio.emit( - "events:channel", + 'events:channel', { - "channel_id": data["channel_id"], - "message_id": data.get("message_id", None), - "data": event_data, - "user": UserNameResponse(**user).model_dump(), + 'channel_id': data['channel_id'], + 'message_id': data.get('message_id', None), + 'data': event_data, + 'user': UserNameResponse(**user).model_dump(), }, room=room, ) - elif event_type == "last_read_at": - Channels.update_member_last_read_at(data["channel_id"], user["id"]) + elif event_type == 'last_read_at': + Channels.update_member_last_read_at(data['channel_id'], user['id']) def normalize_document_id(document_id: str) -> str: @@ -516,12 +497,12 @@ def normalize_document_id(document_id: str) -> str: We must rewrite underscore-prefixed IDs back to the colon form so that authorization checks (which key on "note:") always fire. """ - if document_id.startswith("note_"): - document_id = "note:" + document_id[5:] + if document_id.startswith('note_'): + document_id = 'note:' + document_id[5:] return document_id -@sio.on("ydoc:document:join") +@sio.on('ydoc:document:join') async def ydoc_document_join(sid, data): """Handle user joining a document""" user = SESSION_POOL.get(sid) @@ -529,41 +510,39 @@ async def ydoc_document_join(sid, data): return try: - document_id = normalize_document_id(data["document_id"]) + document_id = normalize_document_id(data['document_id']) - if document_id.startswith("note:"): - note_id = document_id.split(":")[1] + if document_id.startswith('note:'): + note_id = document_id.split(':')[1] note = Notes.get_note_by_id(note_id) if not note: - log.error(f"Note {note_id} not found") + log.error(f'Note {note_id} not found') return if ( - user.get("role") != "admin" - and user.get("id") != note.user_id + user.get('role') != 'admin' + and user.get('id') != note.user_id and not AccessGrants.has_access( - user_id=user.get("id"), - resource_type="note", + user_id=user.get('id'), + resource_type='note', resource_id=note.id, - permission="read", + permission='read', ) ): - log.error( - f"User {user.get('id')} does not have access to note {note_id}" - ) + log.error(f'User {user.get("id")} does not have access to note {note_id}') return - user_id = data.get("user_id", sid) - user_name = data.get("user_name", "Anonymous") - user_color = data.get("user_color", "#000000") + user_id = data.get('user_id', sid) + user_name = data.get('user_name', 'Anonymous') + user_color = data.get('user_color', '#000000') - log.info(f"User {user_id} joining document {document_id}") + log.info(f'User {user_id} joining document {document_id}') await YDOC_MANAGER.add_user(document_id=document_id, user_id=sid) # Join Socket.IO room - await sio.enter_room(sid, f"doc_{document_id}") + await sio.enter_room(sid, f'doc_{document_id}') - active_session_ids = get_session_ids_from_room(f"doc_{document_id}") + active_session_ids = get_session_ids_from_room(f'doc_{document_id}') # Get the Yjs document state ydoc = Y.Doc() @@ -574,78 +553,78 @@ async def ydoc_document_join(sid, data): # Encode the entire document state as an update state_update = ydoc.get_update() await sio.emit( - "ydoc:document:state", + 'ydoc:document:state', { - "document_id": document_id, - "state": list(state_update), # Convert bytes to list for JSON - "sessions": active_session_ids, + 'document_id': document_id, + 'state': list(state_update), # Convert bytes to list for JSON + 'sessions': active_session_ids, }, room=sid, ) # Notify other users about the new user await sio.emit( - "ydoc:user:joined", + 'ydoc:user:joined', { - "document_id": document_id, - "user_id": user_id, - "user_name": user_name, - "user_color": user_color, + 'document_id': document_id, + 'user_id': user_id, + 'user_name': user_name, + 'user_color': user_color, }, - room=f"doc_{document_id}", + room=f'doc_{document_id}', skip_sid=sid, ) - log.info(f"User {user_id} successfully joined document {document_id}") + log.info(f'User {user_id} successfully joined document {document_id}') except Exception as e: - log.error(f"Error in yjs_document_join: {e}") - await sio.emit("error", {"message": "Failed to join document"}, room=sid) + log.error(f'Error in yjs_document_join: {e}') + await sio.emit('error', {'message': 'Failed to join document'}, room=sid) async def document_save_handler(document_id, data, user): document_id = normalize_document_id(document_id) - if document_id.startswith("note:"): - note_id = document_id.split(":")[1] + if document_id.startswith('note:'): + note_id = document_id.split(':')[1] note = Notes.get_note_by_id(note_id) if not note: - log.error(f"Note {note_id} not found") + log.error(f'Note {note_id} not found') return if ( - user.get("role") != "admin" - and user.get("id") != note.user_id + user.get('role') != 'admin' + and user.get('id') != note.user_id and not AccessGrants.has_access( - user_id=user.get("id"), - resource_type="note", + user_id=user.get('id'), + resource_type='note', resource_id=note.id, - permission="read", + permission='read', ) ): - log.error(f"User {user.get('id')} does not have access to note {note_id}") + log.error(f'User {user.get("id")} does not have access to note {note_id}') return Notes.update_note_by_id(note_id, NoteUpdateForm(data=data)) -@sio.on("ydoc:document:state") +@sio.on('ydoc:document:state') async def yjs_document_state(sid, data): """Send the current state of the Yjs document to the user""" try: - document_id = data["document_id"] + document_id = data['document_id'] document_id = normalize_document_id(document_id) - room = f"doc_{document_id}" + room = f'doc_{document_id}' active_session_ids = get_session_ids_from_room(room) if sid not in active_session_ids: - log.warning(f"Session {sid} not in room {room}. Cannot send state.") + log.warning(f'Session {sid} not in room {room}. Cannot send state.') return if not await YDOC_MANAGER.document_exists(document_id): - log.warning(f"Document {document_id} not found") + log.warning(f'Document {document_id} not found') return # Get the Yjs document state @@ -658,31 +637,31 @@ async def yjs_document_state(sid, data): state_update = ydoc.get_update() await sio.emit( - "ydoc:document:state", + 'ydoc:document:state', { - "document_id": document_id, - "state": list(state_update), # Convert bytes to list for JSON - "sessions": active_session_ids, + 'document_id': document_id, + 'state': list(state_update), # Convert bytes to list for JSON + 'sessions': active_session_ids, }, room=sid, ) except Exception as e: - log.error(f"Error in yjs_document_state: {e}") + log.error(f'Error in yjs_document_state: {e}') -@sio.on("ydoc:document:update") +@sio.on('ydoc:document:update') async def yjs_document_update(sid, data): """Handle Yjs document updates""" try: - document_id = data["document_id"] + document_id = data['document_id'] document_id = normalize_document_id(document_id) # Verify the sender actually joined this document room - room = f"doc_{document_id}" + room = f'doc_{document_id}' active_session_ids = get_session_ids_from_room(room) if sid not in active_session_ids: - log.warning(f"Session {sid} not in room {room}. Rejecting update.") + log.warning(f'Session {sid} not in room {room}. Rejecting update.') return try: @@ -690,9 +669,9 @@ async def yjs_document_update(sid, data): except Exception: pass - user_id = data.get("user_id", sid) + user_id = data.get('user_id', sid) - update = data["update"] # List of bytes from frontend + update = data['update'] # List of bytes from frontend await YDOC_MANAGER.append_to_updates( document_id=document_id, @@ -701,14 +680,14 @@ async def yjs_document_update(sid, data): # Broadcast update to all other users in the document await sio.emit( - "ydoc:document:update", + 'ydoc:document:update', { - "document_id": document_id, - "user_id": user_id, - "update": update, - "socket_id": sid, # Add socket_id to match frontend filtering + 'document_id': document_id, + 'user_id': user_id, + 'update': update, + 'socket_id': sid, # Add socket_id to match frontend filtering }, - room=f"doc_{document_id}", + room=f'doc_{document_id}', skip_sid=sid, ) @@ -718,66 +697,63 @@ async def yjs_document_update(sid, data): async def debounced_save(): await asyncio.sleep(0.5) - await document_save_handler(document_id, data.get("data", {}), user) + await document_save_handler(document_id, data.get('data', {}), user) - if data.get("data"): + if data.get('data'): await create_task(REDIS, debounced_save(), document_id) except Exception as e: - log.error(f"Error in yjs_document_update: {e}") + log.error(f'Error in yjs_document_update: {e}') -@sio.on("ydoc:document:leave") +@sio.on('ydoc:document:leave') async def yjs_document_leave(sid, data): """Handle user leaving a document""" try: - document_id = data["document_id"] - user_id = data.get("user_id", sid) + document_id = data['document_id'] + user_id = data.get('user_id', sid) - log.info(f"User {user_id} leaving document {document_id}") + log.info(f'User {user_id} leaving document {document_id}') # Remove user from the document await YDOC_MANAGER.remove_user(document_id=document_id, user_id=sid) # Leave Socket.IO room - await sio.leave_room(sid, f"doc_{document_id}") + await sio.leave_room(sid, f'doc_{document_id}') # Notify other users await sio.emit( - "ydoc:user:left", - {"document_id": document_id, "user_id": user_id}, - room=f"doc_{document_id}", + 'ydoc:user:left', + {'document_id': document_id, 'user_id': user_id}, + room=f'doc_{document_id}', ) - if ( - await YDOC_MANAGER.document_exists(document_id) - and len(await YDOC_MANAGER.get_users(document_id)) == 0 - ): - log.info(f"Cleaning up document {document_id} as no users are left") + if await YDOC_MANAGER.document_exists(document_id) and len(await YDOC_MANAGER.get_users(document_id)) == 0: + log.info(f'Cleaning up document {document_id} as no users are left') await YDOC_MANAGER.clear_document(document_id) except Exception as e: - log.error(f"Error in yjs_document_leave: {e}") + log.error(f'Error in yjs_document_leave: {e}') -@sio.on("ydoc:awareness:update") +@sio.on('ydoc:awareness:update') async def yjs_awareness_update(sid, data): """Handle awareness updates (cursors, selections, etc.)""" try: - document_id = data["document_id"] - user_id = data.get("user_id", sid) - update = data["update"] + document_id = data['document_id'] + user_id = data.get('user_id', sid) + update = data['update'] # Broadcast awareness update to all other users in the document await sio.emit( - "ydoc:awareness:update", - {"document_id": document_id, "user_id": user_id, "update": update}, - room=f"doc_{document_id}", + 'ydoc:awareness:update', + {'document_id': document_id, 'user_id': user_id, 'update': update}, + room=f'doc_{document_id}', skip_sid=sid, ) except Exception as e: - log.error(f"Error in yjs_awareness_update: {e}") + log.error(f'Error in yjs_awareness_update: {e}') @sio.event @@ -804,132 +780,123 @@ async def disconnect(sid): def get_event_emitter(request_info, update_db=True): async def __event_emitter__(event_data): - user_id = request_info["user_id"] - chat_id = request_info["chat_id"] - message_id = request_info["message_id"] + user_id = request_info['user_id'] + chat_id = request_info['chat_id'] + message_id = request_info['message_id'] await sio.emit( - "events", + 'events', { - "chat_id": chat_id, - "message_id": message_id, - "data": event_data, + 'chat_id': chat_id, + 'message_id': message_id, + 'data': event_data, }, - room=f"user:{user_id}", + room=f'user:{user_id}', ) - if ( - update_db - and message_id - and not request_info.get("chat_id", "").startswith("local:") - ): + if update_db and message_id and not request_info.get('chat_id', '').startswith('local:'): + event_type = event_data.get('type') - event_type = event_data.get("type") - - if event_type == "status": + if event_type == 'status': await asyncio.to_thread( Chats.add_message_status_to_chat_by_id_and_message_id, - request_info["chat_id"], - request_info["message_id"], - event_data.get("data", {}), + request_info['chat_id'], + request_info['message_id'], + event_data.get('data', {}), ) - elif event_type == "message": + elif event_type == 'message': message = await asyncio.to_thread( Chats.get_message_by_id_and_message_id, - request_info["chat_id"], - request_info["message_id"], + request_info['chat_id'], + request_info['message_id'], ) if message: - content = message.get("content", "") - content += event_data.get("data", {}).get("content", "") + content = message.get('content', '') + content += event_data.get('data', {}).get('content', '') await asyncio.to_thread( Chats.upsert_message_to_chat_by_id_and_message_id, - request_info["chat_id"], - request_info["message_id"], + request_info['chat_id'], + request_info['message_id'], { - "content": content, + 'content': content, }, ) - elif event_type == "replace": - content = event_data.get("data", {}).get("content", "") + elif event_type == 'replace': + content = event_data.get('data', {}).get('content', '') await asyncio.to_thread( Chats.upsert_message_to_chat_by_id_and_message_id, - request_info["chat_id"], - request_info["message_id"], + request_info['chat_id'], + request_info['message_id'], { - "content": content, + 'content': content, }, ) - elif event_type == "embeds": + elif event_type == 'embeds': message = await asyncio.to_thread( Chats.get_message_by_id_and_message_id, - request_info["chat_id"], - request_info["message_id"], + request_info['chat_id'], + request_info['message_id'], ) - embeds = event_data.get("data", {}).get("embeds", []) - embeds.extend(message.get("embeds", [])) + embeds = event_data.get('data', {}).get('embeds', []) + embeds.extend(message.get('embeds', [])) await asyncio.to_thread( Chats.upsert_message_to_chat_by_id_and_message_id, - request_info["chat_id"], - request_info["message_id"], + request_info['chat_id'], + request_info['message_id'], { - "embeds": embeds, + 'embeds': embeds, }, ) - elif event_type == "files": + elif event_type == 'files': message = await asyncio.to_thread( Chats.get_message_by_id_and_message_id, - request_info["chat_id"], - request_info["message_id"], + request_info['chat_id'], + request_info['message_id'], ) - files = event_data.get("data", {}).get("files", []) - files.extend(message.get("files", [])) + files = event_data.get('data', {}).get('files', []) + files.extend(message.get('files', [])) await asyncio.to_thread( Chats.upsert_message_to_chat_by_id_and_message_id, - request_info["chat_id"], - request_info["message_id"], + request_info['chat_id'], + request_info['message_id'], { - "files": files, + 'files': files, }, ) - elif event_type in ("source", "citation"): - data = event_data.get("data", {}) - if data.get("type") is None: + elif event_type in ('source', 'citation'): + data = event_data.get('data', {}) + if data.get('type') is None: message = await asyncio.to_thread( Chats.get_message_by_id_and_message_id, - request_info["chat_id"], - request_info["message_id"], + request_info['chat_id'], + request_info['message_id'], ) - sources = message.get("sources", []) + sources = message.get('sources', []) sources.append(data) await asyncio.to_thread( Chats.upsert_message_to_chat_by_id_and_message_id, - request_info["chat_id"], - request_info["message_id"], + request_info['chat_id'], + request_info['message_id'], { - "sources": sources, + 'sources': sources, }, ) - if ( - "user_id" in request_info - and "chat_id" in request_info - and "message_id" in request_info - ): + if 'user_id' in request_info and 'chat_id' in request_info and 'message_id' in request_info: return __event_emitter__ else: return None @@ -938,22 +905,18 @@ def get_event_emitter(request_info, update_db=True): def get_event_call(request_info): async def __event_caller__(event_data): response = await sio.call( - "events", + 'events', { - "chat_id": request_info.get("chat_id", None), - "message_id": request_info.get("message_id", None), - "data": event_data, + 'chat_id': request_info.get('chat_id', None), + 'message_id': request_info.get('message_id', None), + 'data': event_data, }, - to=request_info["session_id"], + to=request_info['session_id'], timeout=WEBSOCKET_EVENT_CALLER_TIMEOUT, ) return response - if ( - "session_id" in request_info - and "chat_id" in request_info - and "message_id" in request_info - ): + if 'session_id' in request_info and 'chat_id' in request_info and 'message_id' in request_info: return __event_caller__ else: return None diff --git a/backend/open_webui/socket/utils.py b/backend/open_webui/socket/utils.py index c33af2e71d..682d779ccd 100644 --- a/backend/open_webui/socket/utils.py +++ b/backend/open_webui/socket/utils.py @@ -15,7 +15,6 @@ class RedisLock: redis_sentinels=[], redis_cluster=False, ): - self.lock_name = lock_name self.lock_id = str(uuid.uuid4()) self.timeout_secs = timeout_secs @@ -29,16 +28,12 @@ class RedisLock: def aquire_lock(self): # nx=True will only set this key if it _hasn't_ already been set - self.lock_obtained = self.redis.set( - self.lock_name, self.lock_id, nx=True, ex=self.timeout_secs - ) + self.lock_obtained = self.redis.set(self.lock_name, self.lock_id, nx=True, ex=self.timeout_secs) return self.lock_obtained def renew_lock(self): # xx=True will only set this key if it _has_ already been set - return self.redis.set( - self.lock_name, self.lock_id, xx=True, ex=self.timeout_secs - ) + return self.redis.set(self.lock_name, self.lock_id, xx=True, ex=self.timeout_secs) def release_lock(self): lock_value = self.redis.get(self.lock_name) @@ -106,7 +101,7 @@ class RedisDict: def update(self, other=None, **kwargs): if other is not None: - for k, v in other.items() if hasattr(other, "items") else other: + for k, v in other.items() if hasattr(other, 'items') else other: self[k] = v for k, v in kwargs.items(): self[k] = v @@ -123,7 +118,7 @@ class YdocManager: def __init__( self, redis=None, - redis_key_prefix: str = f"{REDIS_KEY_PREFIX}:ydoc:documents", + redis_key_prefix: str = f'{REDIS_KEY_PREFIX}:ydoc:documents', ): self._updates = {} self._users = {} @@ -131,9 +126,9 @@ class YdocManager: self._redis_key_prefix = redis_key_prefix async def append_to_updates(self, document_id: str, update: bytes): - document_id = document_id.replace(":", "_") + document_id = document_id.replace(':', '_') if self._redis: - redis_key = f"{self._redis_key_prefix}:{document_id}:updates" + redis_key = f'{self._redis_key_prefix}:{document_id}:updates' await self._redis.rpush(redis_key, json.dumps(list(update))) list_len = await self._redis.llen(redis_key) if list_len >= self.COMPACTION_THRESHOLD: @@ -147,7 +142,7 @@ class YdocManager: async def _compact_updates_redis(self, document_id: str): """Rolling compaction: squash oldest half into one snapshot.""" - redis_key = f"{self._redis_key_prefix}:{document_id}:updates" + redis_key = f'{self._redis_key_prefix}:{document_id}:updates' all_updates = await self._redis.lrange(redis_key, 0, -1) if len(all_updates) <= 1: return @@ -173,39 +168,39 @@ class YdocManager: self._updates[document_id] = [ydoc.get_update()] + updates[mid:] async def get_updates(self, document_id: str) -> List[bytes]: - document_id = document_id.replace(":", "_") + document_id = document_id.replace(':', '_') if self._redis: - redis_key = f"{self._redis_key_prefix}:{document_id}:updates" + redis_key = f'{self._redis_key_prefix}:{document_id}:updates' updates = await self._redis.lrange(redis_key, 0, -1) return [bytes(json.loads(update)) for update in updates] else: return self._updates.get(document_id, []) async def document_exists(self, document_id: str) -> bool: - document_id = document_id.replace(":", "_") + document_id = document_id.replace(':', '_') if self._redis: - redis_key = f"{self._redis_key_prefix}:{document_id}:updates" + redis_key = f'{self._redis_key_prefix}:{document_id}:updates' return await self._redis.exists(redis_key) > 0 else: return document_id in self._updates async def get_users(self, document_id: str) -> List[str]: - document_id = document_id.replace(":", "_") + document_id = document_id.replace(':', '_') if self._redis: - redis_key = f"{self._redis_key_prefix}:{document_id}:users" + redis_key = f'{self._redis_key_prefix}:{document_id}:users' users = await self._redis.smembers(redis_key) return list(users) else: return self._users.get(document_id, []) async def add_user(self, document_id: str, user_id: str): - document_id = document_id.replace(":", "_") + document_id = document_id.replace(':', '_') if self._redis: - redis_key = f"{self._redis_key_prefix}:{document_id}:users" + redis_key = f'{self._redis_key_prefix}:{document_id}:users' await self._redis.sadd(redis_key, user_id) else: if document_id not in self._users: @@ -213,10 +208,10 @@ class YdocManager: self._users[document_id].add(user_id) async def remove_user(self, document_id: str, user_id: str): - document_id = document_id.replace(":", "_") + document_id = document_id.replace(':', '_') if self._redis: - redis_key = f"{self._redis_key_prefix}:{document_id}:users" + redis_key = f'{self._redis_key_prefix}:{document_id}:users' await self._redis.srem(redis_key, user_id) else: if document_id in self._users and user_id in self._users[document_id]: @@ -225,15 +220,13 @@ class YdocManager: async def remove_user_from_all_documents(self, user_id: str): if self._redis: keys = [] - async for key in self._redis.scan_iter( - match=f"{self._redis_key_prefix}:*", count=100 - ): + async for key in self._redis.scan_iter(match=f'{self._redis_key_prefix}:*', count=100): keys.append(key) for key in keys: - if key.endswith(":users"): + if key.endswith(':users'): await self._redis.srem(key, user_id) - document_id = key.split(":")[-2] + document_id = key.split(':')[-2] if len(await self.get_users(document_id)) == 0: await self.clear_document(document_id) @@ -247,12 +240,12 @@ class YdocManager: await self.clear_document(document_id) async def clear_document(self, document_id: str): - document_id = document_id.replace(":", "_") + document_id = document_id.replace(':', '_') if self._redis: - redis_key = f"{self._redis_key_prefix}:{document_id}:updates" + redis_key = f'{self._redis_key_prefix}:{document_id}:updates' await self._redis.delete(redis_key) - redis_users_key = f"{self._redis_key_prefix}:{document_id}:users" + redis_users_key = f'{self._redis_key_prefix}:{document_id}:users' await self._redis.delete(redis_users_key) else: if document_id in self._updates: diff --git a/backend/open_webui/storage/provider.py b/backend/open_webui/storage/provider.py index 425d10c812..3c29462349 100644 --- a/backend/open_webui/storage/provider.py +++ b/backend/open_webui/storage/provider.py @@ -43,9 +43,7 @@ class StorageProvider(ABC): pass @abstractmethod - def upload_file( - self, file: BinaryIO, filename: str, tags: Dict[str, str] - ) -> Tuple[bytes, str]: + def upload_file(self, file: BinaryIO, filename: str, tags: Dict[str, str]) -> Tuple[bytes, str]: pass @abstractmethod @@ -59,14 +57,12 @@ class StorageProvider(ABC): class LocalStorageProvider(StorageProvider): @staticmethod - def upload_file( - file: BinaryIO, filename: str, tags: Dict[str, str] - ) -> Tuple[bytes, str]: + def upload_file(file: BinaryIO, filename: str, tags: Dict[str, str]) -> Tuple[bytes, str]: contents = file.read() if not contents: raise ValueError(ERROR_MESSAGES.EMPTY_CONTENT) - file_path = f"{UPLOAD_DIR}/{filename}" - with open(file_path, "wb") as f: + file_path = f'{UPLOAD_DIR}/{filename}' + with open(file_path, 'wb') as f: f.write(contents) return contents, file_path @@ -78,12 +74,12 @@ class LocalStorageProvider(StorageProvider): @staticmethod def delete_file(file_path: str) -> None: """Handles deletion of the file from local storage.""" - filename = file_path.split("/")[-1] - file_path = f"{UPLOAD_DIR}/{filename}" + filename = file_path.split('/')[-1] + file_path = f'{UPLOAD_DIR}/{filename}' if os.path.isfile(file_path): os.remove(file_path) else: - log.warning(f"File {file_path} not found in local storage.") + log.warning(f'File {file_path} not found in local storage.') @staticmethod def delete_all_files() -> None: @@ -97,27 +93,27 @@ class LocalStorageProvider(StorageProvider): elif os.path.isdir(file_path): shutil.rmtree(file_path) # Remove the directory except Exception as e: - log.exception(f"Failed to delete {file_path}. Reason: {e}") + log.exception(f'Failed to delete {file_path}. Reason: {e}') else: - log.warning(f"Directory {UPLOAD_DIR} not found in local storage.") + log.warning(f'Directory {UPLOAD_DIR} not found in local storage.') class S3StorageProvider(StorageProvider): def __init__(self): config = Config( s3={ - "use_accelerate_endpoint": S3_USE_ACCELERATE_ENDPOINT, - "addressing_style": S3_ADDRESSING_STYLE, + 'use_accelerate_endpoint': S3_USE_ACCELERATE_ENDPOINT, + 'addressing_style': S3_ADDRESSING_STYLE, }, # KIT change - see https://github.com/boto/boto3/issues/4400#issuecomment-2600742103∆ - request_checksum_calculation="when_required", - response_checksum_validation="when_required", + request_checksum_calculation='when_required', + response_checksum_validation='when_required', ) # If access key and secret are provided, use them for authentication if S3_ACCESS_KEY_ID and S3_SECRET_ACCESS_KEY: self.s3_client = boto3.client( - "s3", + 's3', region_name=S3_REGION_NAME, endpoint_url=S3_ENDPOINT_URL, aws_access_key_id=S3_ACCESS_KEY_ID, @@ -128,49 +124,40 @@ class S3StorageProvider(StorageProvider): # If no explicit credentials are provided, fall back to default AWS credentials # This supports workload identity (IAM roles for EC2, EKS, etc.) self.s3_client = boto3.client( - "s3", + 's3', region_name=S3_REGION_NAME, endpoint_url=S3_ENDPOINT_URL, config=config, ) self.bucket_name = S3_BUCKET_NAME - self.key_prefix = S3_KEY_PREFIX if S3_KEY_PREFIX else "" + self.key_prefix = S3_KEY_PREFIX if S3_KEY_PREFIX else '' @staticmethod def sanitize_tag_value(s: str) -> str: """Only include S3 allowed characters.""" - return re.sub(r"[^a-zA-Z0-9 äöüÄÖÜß\+\-=\._:/@]", "", s) + return re.sub(r'[^a-zA-Z0-9 äöüÄÖÜß\+\-=\._:/@]', '', s) - def upload_file( - self, file: BinaryIO, filename: str, tags: Dict[str, str] - ) -> Tuple[bytes, str]: + def upload_file(self, file: BinaryIO, filename: str, tags: Dict[str, str]) -> Tuple[bytes, str]: """Handles uploading of the file to S3 storage.""" _, file_path = LocalStorageProvider.upload_file(file, filename, tags) s3_key = os.path.join(self.key_prefix, filename) try: self.s3_client.upload_file(file_path, self.bucket_name, s3_key) if S3_ENABLE_TAGGING and tags: - sanitized_tags = { - self.sanitize_tag_value(k): self.sanitize_tag_value(v) - for k, v in tags.items() - } - tagging = { - "TagSet": [ - {"Key": k, "Value": v} for k, v in sanitized_tags.items() - ] - } + sanitized_tags = {self.sanitize_tag_value(k): self.sanitize_tag_value(v) for k, v in tags.items()} + tagging = {'TagSet': [{'Key': k, 'Value': v} for k, v in sanitized_tags.items()]} self.s3_client.put_object_tagging( Bucket=self.bucket_name, Key=s3_key, Tagging=tagging, ) return ( - open(file_path, "rb").read(), - f"s3://{self.bucket_name}/{s3_key}", + open(file_path, 'rb').read(), + f's3://{self.bucket_name}/{s3_key}', ) except ClientError as e: - raise RuntimeError(f"Error uploading file to S3: {e}") + raise RuntimeError(f'Error uploading file to S3: {e}') def get_file(self, file_path: str) -> str: """Handles downloading of the file from S3 storage.""" @@ -180,7 +167,7 @@ class S3StorageProvider(StorageProvider): self.s3_client.download_file(self.bucket_name, s3_key, local_file_path) return local_file_path except ClientError as e: - raise RuntimeError(f"Error downloading file from S3: {e}") + raise RuntimeError(f'Error downloading file from S3: {e}') def delete_file(self, file_path: str) -> None: """Handles deletion of the file from S3 storage.""" @@ -188,7 +175,7 @@ class S3StorageProvider(StorageProvider): s3_key = self._extract_s3_key(file_path) self.s3_client.delete_object(Bucket=self.bucket_name, Key=s3_key) except ClientError as e: - raise RuntimeError(f"Error deleting file from S3: {e}") + raise RuntimeError(f'Error deleting file from S3: {e}') # Always delete from local storage LocalStorageProvider.delete_file(file_path) @@ -197,27 +184,25 @@ class S3StorageProvider(StorageProvider): """Handles deletion of all files from S3 storage.""" try: response = self.s3_client.list_objects_v2(Bucket=self.bucket_name) - if "Contents" in response: - for content in response["Contents"]: + if 'Contents' in response: + for content in response['Contents']: # Skip objects that were not uploaded from open-webui in the first place - if not content["Key"].startswith(self.key_prefix): + if not content['Key'].startswith(self.key_prefix): continue - self.s3_client.delete_object( - Bucket=self.bucket_name, Key=content["Key"] - ) + self.s3_client.delete_object(Bucket=self.bucket_name, Key=content['Key']) except ClientError as e: - raise RuntimeError(f"Error deleting all files from S3: {e}") + raise RuntimeError(f'Error deleting all files from S3: {e}') # Always delete from local storage LocalStorageProvider.delete_all_files() # The s3 key is the name assigned to an object. It excludes the bucket name, but includes the internal path and the file name. def _extract_s3_key(self, full_file_path: str) -> str: - return "/".join(full_file_path.split("//")[1].split("/")[1:]) + return '/'.join(full_file_path.split('//')[1].split('/')[1:]) def _get_local_file_path(self, s3_key: str) -> str: - return f"{UPLOAD_DIR}/{s3_key.split('/')[-1]}" + return f'{UPLOAD_DIR}/{s3_key.split("/")[-1]}' class GCSStorageProvider(StorageProvider): @@ -235,38 +220,36 @@ class GCSStorageProvider(StorageProvider): self.gcs_client = storage.Client() self.bucket = self.gcs_client.bucket(GCS_BUCKET_NAME) - def upload_file( - self, file: BinaryIO, filename: str, tags: Dict[str, str] - ) -> Tuple[bytes, str]: + def upload_file(self, file: BinaryIO, filename: str, tags: Dict[str, str]) -> Tuple[bytes, str]: """Handles uploading of the file to GCS storage.""" contents, file_path = LocalStorageProvider.upload_file(file, filename, tags) try: blob = self.bucket.blob(filename) blob.upload_from_filename(file_path) - return contents, "gs://" + self.bucket_name + "/" + filename + return contents, 'gs://' + self.bucket_name + '/' + filename except GoogleCloudError as e: - raise RuntimeError(f"Error uploading file to GCS: {e}") + raise RuntimeError(f'Error uploading file to GCS: {e}') def get_file(self, file_path: str) -> str: """Handles downloading of the file from GCS storage.""" try: - filename = file_path.removeprefix("gs://").split("/")[1] - local_file_path = f"{UPLOAD_DIR}/{filename}" + filename = file_path.removeprefix('gs://').split('/')[1] + local_file_path = f'{UPLOAD_DIR}/{filename}' blob = self.bucket.get_blob(filename) blob.download_to_filename(local_file_path) return local_file_path except NotFound as e: - raise RuntimeError(f"Error downloading file from GCS: {e}") + raise RuntimeError(f'Error downloading file from GCS: {e}') def delete_file(self, file_path: str) -> None: """Handles deletion of the file from GCS storage.""" try: - filename = file_path.removeprefix("gs://").split("/")[1] + filename = file_path.removeprefix('gs://').split('/')[1] blob = self.bucket.get_blob(filename) blob.delete() except NotFound as e: - raise RuntimeError(f"Error deleting file from GCS: {e}") + raise RuntimeError(f'Error deleting file from GCS: {e}') # Always delete from local storage LocalStorageProvider.delete_file(file_path) @@ -280,7 +263,7 @@ class GCSStorageProvider(StorageProvider): blob.delete() except NotFound as e: - raise RuntimeError(f"Error deleting all files from GCS: {e}") + raise RuntimeError(f'Error deleting all files from GCS: {e}') # Always delete from local storage LocalStorageProvider.delete_all_files() @@ -294,51 +277,43 @@ class AzureStorageProvider(StorageProvider): if storage_key: # Configure using the Azure Storage Account Endpoint and Key - self.blob_service_client = BlobServiceClient( - account_url=self.endpoint, credential=storage_key - ) + self.blob_service_client = BlobServiceClient(account_url=self.endpoint, credential=storage_key) else: # Configure using the Azure Storage Account Endpoint and DefaultAzureCredential # If the key is not configured, then the DefaultAzureCredential will be used to support Managed Identity authentication - self.blob_service_client = BlobServiceClient( - account_url=self.endpoint, credential=DefaultAzureCredential() - ) - self.container_client = self.blob_service_client.get_container_client( - self.container_name - ) + self.blob_service_client = BlobServiceClient(account_url=self.endpoint, credential=DefaultAzureCredential()) + self.container_client = self.blob_service_client.get_container_client(self.container_name) - def upload_file( - self, file: BinaryIO, filename: str, tags: Dict[str, str] - ) -> Tuple[bytes, str]: + def upload_file(self, file: BinaryIO, filename: str, tags: Dict[str, str]) -> Tuple[bytes, str]: """Handles uploading of the file to Azure Blob Storage.""" contents, file_path = LocalStorageProvider.upload_file(file, filename, tags) try: blob_client = self.container_client.get_blob_client(filename) blob_client.upload_blob(contents, overwrite=True) - return contents, f"{self.endpoint}/{self.container_name}/{filename}" + return contents, f'{self.endpoint}/{self.container_name}/{filename}' except Exception as e: - raise RuntimeError(f"Error uploading file to Azure Blob Storage: {e}") + raise RuntimeError(f'Error uploading file to Azure Blob Storage: {e}') def get_file(self, file_path: str) -> str: """Handles downloading of the file from Azure Blob Storage.""" try: - filename = file_path.split("/")[-1] - local_file_path = f"{UPLOAD_DIR}/{filename}" + filename = file_path.split('/')[-1] + local_file_path = f'{UPLOAD_DIR}/{filename}' blob_client = self.container_client.get_blob_client(filename) - with open(local_file_path, "wb") as download_file: + with open(local_file_path, 'wb') as download_file: download_file.write(blob_client.download_blob().readall()) return local_file_path except ResourceNotFoundError as e: - raise RuntimeError(f"Error downloading file from Azure Blob Storage: {e}") + raise RuntimeError(f'Error downloading file from Azure Blob Storage: {e}') def delete_file(self, file_path: str) -> None: """Handles deletion of the file from Azure Blob Storage.""" try: - filename = file_path.split("/")[-1] + filename = file_path.split('/')[-1] blob_client = self.container_client.get_blob_client(filename) blob_client.delete_blob() except ResourceNotFoundError as e: - raise RuntimeError(f"Error deleting file from Azure Blob Storage: {e}") + raise RuntimeError(f'Error deleting file from Azure Blob Storage: {e}') # Always delete from local storage LocalStorageProvider.delete_file(file_path) @@ -350,23 +325,23 @@ class AzureStorageProvider(StorageProvider): for blob in blobs: self.container_client.delete_blob(blob.name) except Exception as e: - raise RuntimeError(f"Error deleting all files from Azure Blob Storage: {e}") + raise RuntimeError(f'Error deleting all files from Azure Blob Storage: {e}') # Always delete from local storage LocalStorageProvider.delete_all_files() def get_storage_provider(storage_provider: str): - if storage_provider == "local": + if storage_provider == 'local': Storage = LocalStorageProvider() - elif storage_provider == "s3": + elif storage_provider == 's3': Storage = S3StorageProvider() - elif storage_provider == "gcs": + elif storage_provider == 'gcs': Storage = GCSStorageProvider() - elif storage_provider == "azure": + elif storage_provider == 'azure': Storage = AzureStorageProvider() else: - raise RuntimeError(f"Unsupported storage provider: {storage_provider}") + raise RuntimeError(f'Unsupported storage provider: {storage_provider}') return Storage diff --git a/backend/open_webui/tasks.py b/backend/open_webui/tasks.py index 002f26c457..30754cfc48 100644 --- a/backend/open_webui/tasks.py +++ b/backend/open_webui/tasks.py @@ -17,9 +17,9 @@ tasks: Dict[str, asyncio.Task] = {} item_tasks = {} -REDIS_TASKS_KEY = f"{REDIS_KEY_PREFIX}:tasks" -REDIS_ITEM_TASKS_KEY = f"{REDIS_KEY_PREFIX}:tasks:item" -REDIS_PUBSUB_CHANNEL = f"{REDIS_KEY_PREFIX}:tasks:commands" +REDIS_TASKS_KEY = f'{REDIS_KEY_PREFIX}:tasks' +REDIS_ITEM_TASKS_KEY = f'{REDIS_KEY_PREFIX}:tasks:item' +REDIS_PUBSUB_CHANNEL = f'{REDIS_KEY_PREFIX}:tasks:commands' async def redis_task_command_listener(app): @@ -28,17 +28,17 @@ async def redis_task_command_listener(app): await pubsub.subscribe(REDIS_PUBSUB_CHANNEL) async for message in pubsub.listen(): - if message["type"] != "message": + if message['type'] != 'message': continue try: - command = json.loads(message["data"]) - if command.get("action") == "stop": - task_id = command.get("task_id") + command = json.loads(message['data']) + if command.get('action') == 'stop': + task_id = command.get('task_id') local_task = tasks.get(task_id) if local_task: local_task.cancel() except Exception as e: - log.exception(f"Error handling distributed task command: {e}") + log.exception(f'Error handling distributed task command: {e}') ### ------------------------------ @@ -48,9 +48,9 @@ async def redis_task_command_listener(app): async def redis_save_task(redis: Redis, task_id: str, item_id: Optional[str]): pipe = redis.pipeline() - pipe.hset(REDIS_TASKS_KEY, task_id, item_id or "") + pipe.hset(REDIS_TASKS_KEY, task_id, item_id or '') if item_id: - pipe.sadd(f"{REDIS_ITEM_TASKS_KEY}:{item_id}", task_id) + pipe.sadd(f'{REDIS_ITEM_TASKS_KEY}:{item_id}', task_id) await pipe.execute() @@ -58,11 +58,11 @@ async def redis_cleanup_task(redis: Redis, task_id: str, item_id: Optional[str]) pipe = redis.pipeline() pipe.hdel(REDIS_TASKS_KEY, task_id) if item_id: - pipe.srem(f"{REDIS_ITEM_TASKS_KEY}:{item_id}", task_id) + pipe.srem(f'{REDIS_ITEM_TASKS_KEY}:{item_id}', task_id) await pipe.execute() # Remove the set key entirely if no tasks remain for this item - if await redis.scard(f"{REDIS_ITEM_TASKS_KEY}:{item_id}") == 0: - await redis.delete(f"{REDIS_ITEM_TASKS_KEY}:{item_id}") + if await redis.scard(f'{REDIS_ITEM_TASKS_KEY}:{item_id}') == 0: + await redis.delete(f'{REDIS_ITEM_TASKS_KEY}:{item_id}') else: await pipe.execute() @@ -72,15 +72,15 @@ async def redis_list_tasks(redis: Redis) -> List[str]: async def redis_list_item_tasks(redis: Redis, item_id: str) -> List[str]: - return list(await redis.smembers(f"{REDIS_ITEM_TASKS_KEY}:{item_id}")) + return list(await redis.smembers(f'{REDIS_ITEM_TASKS_KEY}:{item_id}')) async def redis_send_command(redis: Redis, command: dict): command_json = json.dumps(command) # RedisCluster doesn't expose publish() directly, but the # PUBLISH command broadcasts across all cluster nodes server-side. - if hasattr(redis, "nodes_manager"): - await redis.execute_command("PUBLISH", REDIS_PUBSUB_CHANNEL, command_json) + if hasattr(redis, 'nodes_manager'): + await redis.execute_command('PUBLISH', REDIS_PUBSUB_CHANNEL, command_json) else: await redis.publish(REDIS_PUBSUB_CHANNEL, command_json) @@ -109,9 +109,7 @@ async def create_task(redis, coroutine, id=None): task = asyncio.create_task(coroutine) # Create the task # Add a done callback for cleanup - task.add_done_callback( - lambda t: asyncio.create_task(cleanup_task(redis, task_id, id)) - ) + task.add_done_callback(lambda t: asyncio.create_task(cleanup_task(redis, task_id, id))) tasks[task_id] = task # If an ID is provided, associate the task with that ID @@ -155,30 +153,30 @@ async def stop_task(redis, task_id: str): await redis_send_command( redis, { - "action": "stop", - "task_id": task_id, + 'action': 'stop', + 'task_id': task_id, }, ) # Always clean Redis directly — hdel/srem are idempotent, safe even # if the done_callback on the owning process also fires cleanup. await redis_cleanup_task(redis, task_id, item_id or None) - return {"status": True, "message": f"Task {task_id} stopped."} + return {'status': True, 'message': f'Task {task_id} stopped.'} task = tasks.pop(task_id, None) if not task: - return {"status": False, "message": f"Task with ID {task_id} not found."} + return {'status': False, 'message': f'Task with ID {task_id} not found.'} task.cancel() # Request task cancellation try: await task # Wait for the task to handle the cancellation except asyncio.CancelledError: # Task successfully canceled - return {"status": True, "message": f"Task {task_id} successfully stopped."} + return {'status': True, 'message': f'Task {task_id} successfully stopped.'} if task.cancelled() or task.done(): - return {"status": True, "message": f"Task {task_id} successfully cancelled."} + return {'status': True, 'message': f'Task {task_id} successfully cancelled.'} - return {"status": True, "message": f"Cancellation requested for {task_id}."} + return {'status': True, 'message': f'Cancellation requested for {task_id}.'} async def stop_item_tasks(redis: Redis, item_id: str): @@ -187,14 +185,14 @@ async def stop_item_tasks(redis: Redis, item_id: str): """ task_ids = await list_task_ids_by_item_id(redis, item_id) if not task_ids: - return {"status": True, "message": f"No tasks found for item {item_id}."} + return {'status': True, 'message': f'No tasks found for item {item_id}.'} for task_id in task_ids: result = await stop_task(redis, task_id) - if not result["status"]: + if not result['status']: return result # Return the first failure - return {"status": True, "message": f"All tasks for item {item_id} stopped."} + return {'status': True, 'message': f'All tasks for item {item_id} stopped.'} async def has_active_tasks(redis, chat_id: str) -> bool: diff --git a/backend/open_webui/test/apps/webui/routers/test_auths.py b/backend/open_webui/test/apps/webui/routers/test_auths.py index f0f69e26d2..9f9ae9bc5c 100644 --- a/backend/open_webui/test/apps/webui/routers/test_auths.py +++ b/backend/open_webui/test/apps/webui/routers/test_auths.py @@ -3,7 +3,7 @@ from test.util.mock_user import mock_webui_user class TestAuths(AbstractPostgresTest): - BASE_PATH = "/api/v1/auths" + BASE_PATH = '/api/v1/auths' def setup_class(cls): super().setup_class() @@ -15,171 +15,167 @@ class TestAuths(AbstractPostgresTest): def test_get_session_user(self): with mock_webui_user(): - response = self.fast_api_client.get(self.create_url("")) + response = self.fast_api_client.get(self.create_url('')) assert response.status_code == 200 assert response.json() == { - "id": "1", - "name": "John Doe", - "email": "john.doe@openwebui.com", - "role": "user", - "profile_image_url": "/user.png", + 'id': '1', + 'name': 'John Doe', + 'email': 'john.doe@openwebui.com', + 'role': 'user', + 'profile_image_url': '/user.png', } def test_update_profile(self): from open_webui.utils.auth import get_password_hash user = self.auths.insert_new_auth( - email="john.doe@openwebui.com", - password=get_password_hash("old_password"), - name="John Doe", - profile_image_url="/user.png", - role="user", + email='john.doe@openwebui.com', + password=get_password_hash('old_password'), + name='John Doe', + profile_image_url='/user.png', + role='user', ) with mock_webui_user(id=user.id): response = self.fast_api_client.post( - self.create_url("/update/profile"), - json={"name": "John Doe 2", "profile_image_url": "/user2.png"}, + self.create_url('/update/profile'), + json={'name': 'John Doe 2', 'profile_image_url': '/user2.png'}, ) assert response.status_code == 200 db_user = self.users.get_user_by_id(user.id) - assert db_user.name == "John Doe 2" - assert db_user.profile_image_url == "/user2.png" + assert db_user.name == 'John Doe 2' + assert db_user.profile_image_url == '/user2.png' def test_update_password(self): from open_webui.utils.auth import get_password_hash user = self.auths.insert_new_auth( - email="john.doe@openwebui.com", - password=get_password_hash("old_password"), - name="John Doe", - profile_image_url="/user.png", - role="user", + email='john.doe@openwebui.com', + password=get_password_hash('old_password'), + name='John Doe', + profile_image_url='/user.png', + role='user', ) with mock_webui_user(id=user.id): response = self.fast_api_client.post( - self.create_url("/update/password"), - json={"password": "old_password", "new_password": "new_password"}, + self.create_url('/update/password'), + json={'password': 'old_password', 'new_password': 'new_password'}, ) assert response.status_code == 200 - old_auth = self.auths.authenticate_user( - "john.doe@openwebui.com", "old_password" - ) + old_auth = self.auths.authenticate_user('john.doe@openwebui.com', 'old_password') assert old_auth is None - new_auth = self.auths.authenticate_user( - "john.doe@openwebui.com", "new_password" - ) + new_auth = self.auths.authenticate_user('john.doe@openwebui.com', 'new_password') assert new_auth is not None def test_signin(self): from open_webui.utils.auth import get_password_hash user = self.auths.insert_new_auth( - email="john.doe@openwebui.com", - password=get_password_hash("password"), - name="John Doe", - profile_image_url="/user.png", - role="user", + email='john.doe@openwebui.com', + password=get_password_hash('password'), + name='John Doe', + profile_image_url='/user.png', + role='user', ) response = self.fast_api_client.post( - self.create_url("/signin"), - json={"email": "john.doe@openwebui.com", "password": "password"}, + self.create_url('/signin'), + json={'email': 'john.doe@openwebui.com', 'password': 'password'}, ) assert response.status_code == 200 data = response.json() - assert data["id"] == user.id - assert data["name"] == "John Doe" - assert data["email"] == "john.doe@openwebui.com" - assert data["role"] == "user" - assert data["profile_image_url"] == "/user.png" - assert data["token"] is not None and len(data["token"]) > 0 - assert data["token_type"] == "Bearer" + assert data['id'] == user.id + assert data['name'] == 'John Doe' + assert data['email'] == 'john.doe@openwebui.com' + assert data['role'] == 'user' + assert data['profile_image_url'] == '/user.png' + assert data['token'] is not None and len(data['token']) > 0 + assert data['token_type'] == 'Bearer' def test_signup(self): response = self.fast_api_client.post( - self.create_url("/signup"), + self.create_url('/signup'), json={ - "name": "John Doe", - "email": "john.doe@openwebui.com", - "password": "password", + 'name': 'John Doe', + 'email': 'john.doe@openwebui.com', + 'password': 'password', }, ) assert response.status_code == 200 data = response.json() - assert data["id"] is not None and len(data["id"]) > 0 - assert data["name"] == "John Doe" - assert data["email"] == "john.doe@openwebui.com" - assert data["role"] in ["admin", "user", "pending"] - assert data["profile_image_url"] == "/user.png" - assert data["token"] is not None and len(data["token"]) > 0 - assert data["token_type"] == "Bearer" + assert data['id'] is not None and len(data['id']) > 0 + assert data['name'] == 'John Doe' + assert data['email'] == 'john.doe@openwebui.com' + assert data['role'] in ['admin', 'user', 'pending'] + assert data['profile_image_url'] == '/user.png' + assert data['token'] is not None and len(data['token']) > 0 + assert data['token_type'] == 'Bearer' def test_add_user(self): with mock_webui_user(): response = self.fast_api_client.post( - self.create_url("/add"), + self.create_url('/add'), json={ - "name": "John Doe 2", - "email": "john.doe2@openwebui.com", - "password": "password2", - "role": "admin", + 'name': 'John Doe 2', + 'email': 'john.doe2@openwebui.com', + 'password': 'password2', + 'role': 'admin', }, ) assert response.status_code == 200 data = response.json() - assert data["id"] is not None and len(data["id"]) > 0 - assert data["name"] == "John Doe 2" - assert data["email"] == "john.doe2@openwebui.com" - assert data["role"] == "admin" - assert data["profile_image_url"] == "/user.png" - assert data["token"] is not None and len(data["token"]) > 0 - assert data["token_type"] == "Bearer" + assert data['id'] is not None and len(data['id']) > 0 + assert data['name'] == 'John Doe 2' + assert data['email'] == 'john.doe2@openwebui.com' + assert data['role'] == 'admin' + assert data['profile_image_url'] == '/user.png' + assert data['token'] is not None and len(data['token']) > 0 + assert data['token_type'] == 'Bearer' def test_get_admin_details(self): self.auths.insert_new_auth( - email="john.doe@openwebui.com", - password="password", - name="John Doe", - profile_image_url="/user.png", - role="admin", + email='john.doe@openwebui.com', + password='password', + name='John Doe', + profile_image_url='/user.png', + role='admin', ) with mock_webui_user(): - response = self.fast_api_client.get(self.create_url("/admin/details")) + response = self.fast_api_client.get(self.create_url('/admin/details')) assert response.status_code == 200 assert response.json() == { - "name": "John Doe", - "email": "john.doe@openwebui.com", + 'name': 'John Doe', + 'email': 'john.doe@openwebui.com', } def test_create_api_key_(self): user = self.auths.insert_new_auth( - email="john.doe@openwebui.com", - password="password", - name="John Doe", - profile_image_url="/user.png", - role="admin", + email='john.doe@openwebui.com', + password='password', + name='John Doe', + profile_image_url='/user.png', + role='admin', ) with mock_webui_user(id=user.id): - response = self.fast_api_client.post(self.create_url("/api_key")) + response = self.fast_api_client.post(self.create_url('/api_key')) assert response.status_code == 200 data = response.json() - assert data["api_key"] is not None - assert len(data["api_key"]) > 0 + assert data['api_key'] is not None + assert len(data['api_key']) > 0 def test_delete_api_key(self): user = self.auths.insert_new_auth( - email="john.doe@openwebui.com", - password="password", - name="John Doe", - profile_image_url="/user.png", - role="admin", + email='john.doe@openwebui.com', + password='password', + name='John Doe', + profile_image_url='/user.png', + role='admin', ) - self.users.update_user_api_key_by_id(user.id, "abc") + self.users.update_user_api_key_by_id(user.id, 'abc') with mock_webui_user(id=user.id): - response = self.fast_api_client.delete(self.create_url("/api_key")) + response = self.fast_api_client.delete(self.create_url('/api_key')) assert response.status_code == 200 assert response.json() == True db_user = self.users.get_user_by_id(user.id) @@ -187,14 +183,14 @@ class TestAuths(AbstractPostgresTest): def test_get_api_key(self): user = self.auths.insert_new_auth( - email="john.doe@openwebui.com", - password="password", - name="John Doe", - profile_image_url="/user.png", - role="admin", + email='john.doe@openwebui.com', + password='password', + name='John Doe', + profile_image_url='/user.png', + role='admin', ) - self.users.update_user_api_key_by_id(user.id, "abc") + self.users.update_user_api_key_by_id(user.id, 'abc') with mock_webui_user(id=user.id): - response = self.fast_api_client.get(self.create_url("/api_key")) + response = self.fast_api_client.get(self.create_url('/api_key')) assert response.status_code == 200 - assert response.json() == {"api_key": "abc"} + assert response.json() == {'api_key': 'abc'} diff --git a/backend/open_webui/test/apps/webui/routers/test_models.py b/backend/open_webui/test/apps/webui/routers/test_models.py index c16ca9d073..6a7e26a519 100644 --- a/backend/open_webui/test/apps/webui/routers/test_models.py +++ b/backend/open_webui/test/apps/webui/routers/test_models.py @@ -3,7 +3,7 @@ from test.util.mock_user import mock_webui_user class TestModels(AbstractPostgresTest): - BASE_PATH = "/api/v1/models" + BASE_PATH = '/api/v1/models' def setup_class(cls): super().setup_class() @@ -12,50 +12,46 @@ class TestModels(AbstractPostgresTest): cls.models = Model def test_models(self): - with mock_webui_user(id="2"): - response = self.fast_api_client.get(self.create_url("/")) + with mock_webui_user(id='2'): + response = self.fast_api_client.get(self.create_url('/')) assert response.status_code == 200 assert len(response.json()) == 0 - with mock_webui_user(id="2"): + with mock_webui_user(id='2'): response = self.fast_api_client.post( - self.create_url("/add"), + self.create_url('/add'), json={ - "id": "my-model", - "base_model_id": "base-model-id", - "name": "Hello World", - "meta": { - "profile_image_url": "/static/favicon.png", - "description": "description", - "capabilities": None, - "model_config": {}, + 'id': 'my-model', + 'base_model_id': 'base-model-id', + 'name': 'Hello World', + 'meta': { + 'profile_image_url': '/static/favicon.png', + 'description': 'description', + 'capabilities': None, + 'model_config': {}, }, - "params": {}, + 'params': {}, }, ) assert response.status_code == 200 - with mock_webui_user(id="2"): - response = self.fast_api_client.get(self.create_url("/")) + with mock_webui_user(id='2'): + response = self.fast_api_client.get(self.create_url('/')) assert response.status_code == 200 assert len(response.json()) == 1 - with mock_webui_user(id="2"): - response = self.fast_api_client.get( - self.create_url(query_params={"id": "my-model"}) - ) + with mock_webui_user(id='2'): + response = self.fast_api_client.get(self.create_url(query_params={'id': 'my-model'})) assert response.status_code == 200 data = response.json()[0] - assert data["id"] == "my-model" - assert data["name"] == "Hello World" + assert data['id'] == 'my-model' + assert data['name'] == 'Hello World' - with mock_webui_user(id="2"): - response = self.fast_api_client.delete( - self.create_url("/delete?id=my-model") - ) + with mock_webui_user(id='2'): + response = self.fast_api_client.delete(self.create_url('/delete?id=my-model')) assert response.status_code == 200 - with mock_webui_user(id="2"): - response = self.fast_api_client.get(self.create_url("/")) + with mock_webui_user(id='2'): + response = self.fast_api_client.get(self.create_url('/')) assert response.status_code == 200 assert len(response.json()) == 0 diff --git a/backend/open_webui/test/apps/webui/routers/test_users.py b/backend/open_webui/test/apps/webui/routers/test_users.py index 3108729710..ad64df3508 100644 --- a/backend/open_webui/test/apps/webui/routers/test_users.py +++ b/backend/open_webui/test/apps/webui/routers/test_users.py @@ -3,17 +3,17 @@ from test.util.mock_user import mock_webui_user def _get_user_by_id(data, param): - return next((item for item in data if item["id"] == param), None) + return next((item for item in data if item['id'] == param), None) def _assert_user(data, id, **kwargs): user = _get_user_by_id(data, id) assert user is not None comparison_data = { - "name": f"user {id}", - "email": f"user{id}@openwebui.com", - "profile_image_url": f"/api/v1/users/{id}/profile/image", - "role": "user", + 'name': f'user {id}', + 'email': f'user{id}@openwebui.com', + 'profile_image_url': f'/api/v1/users/{id}/profile/image', + 'role': 'user', **kwargs, } for key, value in comparison_data.items(): @@ -21,7 +21,7 @@ def _assert_user(data, id, **kwargs): class TestUsers(AbstractPostgresTest): - BASE_PATH = "/api/v1/users" + BASE_PATH = '/api/v1/users' def setup_class(cls): super().setup_class() @@ -32,136 +32,134 @@ class TestUsers(AbstractPostgresTest): def setup_method(self): super().setup_method() self.users.insert_new_user( - id="1", - name="user 1", - email="user1@openwebui.com", - profile_image_url="/user1.png", - role="user", + id='1', + name='user 1', + email='user1@openwebui.com', + profile_image_url='/user1.png', + role='user', ) self.users.insert_new_user( - id="2", - name="user 2", - email="user2@openwebui.com", - profile_image_url="/user2.png", - role="user", + id='2', + name='user 2', + email='user2@openwebui.com', + profile_image_url='/user2.png', + role='user', ) def test_users(self): # Get all users - with mock_webui_user(id="3"): - response = self.fast_api_client.get(self.create_url("")) + with mock_webui_user(id='3'): + response = self.fast_api_client.get(self.create_url('')) assert response.status_code == 200 assert len(response.json()) == 2 data = response.json() - _assert_user(data, "1") - _assert_user(data, "2") + _assert_user(data, '1') + _assert_user(data, '2') # update role - with mock_webui_user(id="3"): - response = self.fast_api_client.post( - self.create_url("/update/role"), json={"id": "2", "role": "admin"} - ) + with mock_webui_user(id='3'): + response = self.fast_api_client.post(self.create_url('/update/role'), json={'id': '2', 'role': 'admin'}) assert response.status_code == 200 - _assert_user([response.json()], "2", role="admin") + _assert_user([response.json()], '2', role='admin') # Get all users - with mock_webui_user(id="3"): - response = self.fast_api_client.get(self.create_url("")) + with mock_webui_user(id='3'): + response = self.fast_api_client.get(self.create_url('')) assert response.status_code == 200 assert len(response.json()) == 2 data = response.json() - _assert_user(data, "1") - _assert_user(data, "2", role="admin") + _assert_user(data, '1') + _assert_user(data, '2', role='admin') # Get (empty) user settings - with mock_webui_user(id="2"): - response = self.fast_api_client.get(self.create_url("/user/settings")) + with mock_webui_user(id='2'): + response = self.fast_api_client.get(self.create_url('/user/settings')) assert response.status_code == 200 assert response.json() is None # Update user settings - with mock_webui_user(id="2"): + with mock_webui_user(id='2'): response = self.fast_api_client.post( - self.create_url("/user/settings/update"), + self.create_url('/user/settings/update'), json={ - "ui": {"attr1": "value1", "attr2": "value2"}, - "model_config": {"attr3": "value3", "attr4": "value4"}, + 'ui': {'attr1': 'value1', 'attr2': 'value2'}, + 'model_config': {'attr3': 'value3', 'attr4': 'value4'}, }, ) assert response.status_code == 200 # Get user settings - with mock_webui_user(id="2"): - response = self.fast_api_client.get(self.create_url("/user/settings")) + with mock_webui_user(id='2'): + response = self.fast_api_client.get(self.create_url('/user/settings')) assert response.status_code == 200 assert response.json() == { - "ui": {"attr1": "value1", "attr2": "value2"}, - "model_config": {"attr3": "value3", "attr4": "value4"}, + 'ui': {'attr1': 'value1', 'attr2': 'value2'}, + 'model_config': {'attr3': 'value3', 'attr4': 'value4'}, } # Get (empty) user info - with mock_webui_user(id="1"): - response = self.fast_api_client.get(self.create_url("/user/info")) + with mock_webui_user(id='1'): + response = self.fast_api_client.get(self.create_url('/user/info')) assert response.status_code == 200 assert response.json() is None # Update user info - with mock_webui_user(id="1"): + with mock_webui_user(id='1'): response = self.fast_api_client.post( - self.create_url("/user/info/update"), - json={"attr1": "value1", "attr2": "value2"}, + self.create_url('/user/info/update'), + json={'attr1': 'value1', 'attr2': 'value2'}, ) assert response.status_code == 200 # Get user info - with mock_webui_user(id="1"): - response = self.fast_api_client.get(self.create_url("/user/info")) + with mock_webui_user(id='1'): + response = self.fast_api_client.get(self.create_url('/user/info')) assert response.status_code == 200 - assert response.json() == {"attr1": "value1", "attr2": "value2"} + assert response.json() == {'attr1': 'value1', 'attr2': 'value2'} # Get user by id - with mock_webui_user(id="1"): - response = self.fast_api_client.get(self.create_url("/2")) + with mock_webui_user(id='1'): + response = self.fast_api_client.get(self.create_url('/2')) assert response.status_code == 200 - assert response.json() == {"name": "user 2", "profile_image_url": "/user2.png"} + assert response.json() == {'name': 'user 2', 'profile_image_url': '/user2.png'} # Update user by id - with mock_webui_user(id="1"): + with mock_webui_user(id='1'): response = self.fast_api_client.post( - self.create_url("/2/update"), + self.create_url('/2/update'), json={ - "name": "user 2 updated", - "email": "user2-updated@openwebui.com", - "profile_image_url": "/user2-updated.png", + 'name': 'user 2 updated', + 'email': 'user2-updated@openwebui.com', + 'profile_image_url': '/user2-updated.png', }, ) assert response.status_code == 200 # Get all users - with mock_webui_user(id="3"): - response = self.fast_api_client.get(self.create_url("")) + with mock_webui_user(id='3'): + response = self.fast_api_client.get(self.create_url('')) assert response.status_code == 200 assert len(response.json()) == 2 data = response.json() - _assert_user(data, "1") + _assert_user(data, '1') _assert_user( data, - "2", - role="admin", - name="user 2 updated", - email="user2-updated@openwebui.com", - profile_image_url=f"/api/v1/users/2/profile/image", + '2', + role='admin', + name='user 2 updated', + email='user2-updated@openwebui.com', + profile_image_url=f'/api/v1/users/2/profile/image', ) # Delete user by id - with mock_webui_user(id="1"): - response = self.fast_api_client.delete(self.create_url("/2")) + with mock_webui_user(id='1'): + response = self.fast_api_client.delete(self.create_url('/2')) assert response.status_code == 200 # Get all users - with mock_webui_user(id="3"): - response = self.fast_api_client.get(self.create_url("")) + with mock_webui_user(id='3'): + response = self.fast_api_client.get(self.create_url('')) assert response.status_code == 200 assert len(response.json()) == 1 data = response.json() - _assert_user(data, "1") + _assert_user(data, '1') diff --git a/backend/open_webui/test/apps/webui/storage/test_provider.py b/backend/open_webui/test/apps/webui/storage/test_provider.py index 3c874592fe..806b072d87 100644 --- a/backend/open_webui/test/apps/webui/storage/test_provider.py +++ b/backend/open_webui/test/apps/webui/storage/test_provider.py @@ -13,9 +13,9 @@ from unittest.mock import MagicMock def mock_upload_dir(monkeypatch, tmp_path): """Fixture to monkey-patch the UPLOAD_DIR and create a temporary directory.""" - directory = tmp_path / "uploads" + directory = tmp_path / 'uploads' directory.mkdir() - monkeypatch.setattr(provider, "UPLOAD_DIR", str(directory)) + monkeypatch.setattr(provider, 'UPLOAD_DIR', str(directory)) return directory @@ -29,16 +29,16 @@ def test_imports(): def test_get_storage_provider(): - Storage = provider.get_storage_provider("local") + Storage = provider.get_storage_provider('local') assert isinstance(Storage, provider.LocalStorageProvider) - Storage = provider.get_storage_provider("s3") + Storage = provider.get_storage_provider('s3') assert isinstance(Storage, provider.S3StorageProvider) - Storage = provider.get_storage_provider("gcs") + Storage = provider.get_storage_provider('gcs') assert isinstance(Storage, provider.GCSStorageProvider) - Storage = provider.get_storage_provider("azure") + Storage = provider.get_storage_provider('azure') assert isinstance(Storage, provider.AzureStorageProvider) with pytest.raises(RuntimeError): - provider.get_storage_provider("invalid") + provider.get_storage_provider('invalid') def test_class_instantiation(): @@ -58,10 +58,10 @@ def test_class_instantiation(): class TestLocalStorageProvider: Storage = provider.LocalStorageProvider() - file_content = b"test content" + file_content = b'test content' file_bytesio = io.BytesIO(file_content) - filename = "test.txt" - filename_extra = "test_exyta.txt" + filename = 'test.txt' + filename_extra = 'test_exyta.txt' file_bytesio_empty = io.BytesIO() def test_upload_file(self, monkeypatch, tmp_path): @@ -99,14 +99,13 @@ class TestLocalStorageProvider: @mock_aws class TestS3StorageProvider: - def __init__(self): self.Storage = provider.S3StorageProvider() - self.Storage.bucket_name = "my-bucket" - self.s3_client = boto3.resource("s3", region_name="us-east-1") - self.file_content = b"test content" - self.filename = "test.txt" - self.filename_extra = "test_exyta.txt" + self.Storage.bucket_name = 'my-bucket' + self.s3_client = boto3.resource('s3', region_name='us-east-1') + self.file_content = b'test content' + self.filename = 'test.txt' + self.filename_extra = 'test_exyta.txt' self.file_bytesio_empty = io.BytesIO() super().__init__() @@ -116,25 +115,21 @@ class TestS3StorageProvider: with pytest.raises(Exception): self.Storage.upload_file(io.BytesIO(self.file_content), self.filename) self.s3_client.create_bucket(Bucket=self.Storage.bucket_name) - contents, s3_file_path = self.Storage.upload_file( - io.BytesIO(self.file_content), self.filename - ) + contents, s3_file_path = self.Storage.upload_file(io.BytesIO(self.file_content), self.filename) object = self.s3_client.Object(self.Storage.bucket_name, self.filename) - assert self.file_content == object.get()["Body"].read() + assert self.file_content == object.get()['Body'].read() # local checks assert (upload_dir / self.filename).exists() assert (upload_dir / self.filename).read_bytes() == self.file_content assert contents == self.file_content - assert s3_file_path == "s3://" + self.Storage.bucket_name + "/" + self.filename + assert s3_file_path == 's3://' + self.Storage.bucket_name + '/' + self.filename with pytest.raises(ValueError): self.Storage.upload_file(self.file_bytesio_empty, self.filename) def test_get_file(self, monkeypatch, tmp_path): upload_dir = mock_upload_dir(monkeypatch, tmp_path) self.s3_client.create_bucket(Bucket=self.Storage.bucket_name) - contents, s3_file_path = self.Storage.upload_file( - io.BytesIO(self.file_content), self.filename - ) + contents, s3_file_path = self.Storage.upload_file(io.BytesIO(self.file_content), self.filename) file_path = self.Storage.get_file(s3_file_path) assert file_path == str(upload_dir / self.filename) assert (upload_dir / self.filename).exists() @@ -142,17 +137,15 @@ class TestS3StorageProvider: def test_delete_file(self, monkeypatch, tmp_path): upload_dir = mock_upload_dir(monkeypatch, tmp_path) self.s3_client.create_bucket(Bucket=self.Storage.bucket_name) - contents, s3_file_path = self.Storage.upload_file( - io.BytesIO(self.file_content), self.filename - ) + contents, s3_file_path = self.Storage.upload_file(io.BytesIO(self.file_content), self.filename) assert (upload_dir / self.filename).exists() self.Storage.delete_file(s3_file_path) assert not (upload_dir / self.filename).exists() with pytest.raises(ClientError) as exc: self.s3_client.Object(self.Storage.bucket_name, self.filename).load() - error = exc.value.response["Error"] - assert error["Code"] == "404" - assert error["Message"] == "Not Found" + error = exc.value.response['Error'] + assert error['Code'] == '404' + assert error['Message'] == 'Not Found' def test_delete_all_files(self, monkeypatch, tmp_path): upload_dir = mock_upload_dir(monkeypatch, tmp_path) @@ -160,12 +153,12 @@ class TestS3StorageProvider: self.s3_client.create_bucket(Bucket=self.Storage.bucket_name) self.Storage.upload_file(io.BytesIO(self.file_content), self.filename) object = self.s3_client.Object(self.Storage.bucket_name, self.filename) - assert self.file_content == object.get()["Body"].read() + assert self.file_content == object.get()['Body'].read() assert (upload_dir / self.filename).exists() assert (upload_dir / self.filename).read_bytes() == self.file_content self.Storage.upload_file(io.BytesIO(self.file_content), self.filename_extra) object = self.s3_client.Object(self.Storage.bucket_name, self.filename_extra) - assert self.file_content == object.get()["Body"].read() + assert self.file_content == object.get()['Body'].read() assert (upload_dir / self.filename).exists() assert (upload_dir / self.filename).read_bytes() == self.file_content @@ -173,15 +166,15 @@ class TestS3StorageProvider: assert not (upload_dir / self.filename).exists() with pytest.raises(ClientError) as exc: self.s3_client.Object(self.Storage.bucket_name, self.filename).load() - error = exc.value.response["Error"] - assert error["Code"] == "404" - assert error["Message"] == "Not Found" + error = exc.value.response['Error'] + assert error['Code'] == '404' + assert error['Message'] == 'Not Found' assert not (upload_dir / self.filename_extra).exists() with pytest.raises(ClientError) as exc: self.s3_client.Object(self.Storage.bucket_name, self.filename_extra).load() - error = exc.value.response["Error"] - assert error["Code"] == "404" - assert error["Message"] == "Not Found" + error = exc.value.response['Error'] + assert error['Code'] == '404' + assert error['Message'] == 'Not Found' self.Storage.delete_all_files() assert not (upload_dir / self.filename).exists() @@ -190,8 +183,8 @@ class TestS3StorageProvider: def test_init_without_credentials(self, monkeypatch): """Test that S3StorageProvider can initialize without explicit credentials.""" # Temporarily unset the environment variables - monkeypatch.setattr(provider, "S3_ACCESS_KEY_ID", None) - monkeypatch.setattr(provider, "S3_SECRET_ACCESS_KEY", None) + monkeypatch.setattr(provider, 'S3_ACCESS_KEY_ID', None) + monkeypatch.setattr(provider, 'S3_SECRET_ACCESS_KEY', None) # Should not raise an exception storage = provider.S3StorageProvider() @@ -201,19 +194,19 @@ class TestS3StorageProvider: class TestGCSStorageProvider: Storage = provider.GCSStorageProvider() - Storage.bucket_name = "my-bucket" - file_content = b"test content" - filename = "test.txt" - filename_extra = "test_exyta.txt" + Storage.bucket_name = 'my-bucket' + file_content = b'test content' + filename = 'test.txt' + filename_extra = 'test_exyta.txt' file_bytesio_empty = io.BytesIO() - @pytest.fixture(scope="class") + @pytest.fixture(scope='class') def setup(self): - host, port = "localhost", 9023 + host, port = 'localhost', 9023 server = create_server(host, port, in_memory=True) server.start() - os.environ["STORAGE_EMULATOR_HOST"] = f"http://{host}:{port}" + os.environ['STORAGE_EMULATOR_HOST'] = f'http://{host}:{port}' gcs_client = storage.Client() bucket = gcs_client.bucket(self.Storage.bucket_name) @@ -227,36 +220,30 @@ class TestGCSStorageProvider: upload_dir = mock_upload_dir(monkeypatch, tmp_path) # catch error if bucket does not exist with pytest.raises(Exception): - self.Storage.bucket = monkeypatch(self.Storage, "bucket", None) + self.Storage.bucket = monkeypatch(self.Storage, 'bucket', None) self.Storage.upload_file(io.BytesIO(self.file_content), self.filename) - contents, gcs_file_path = self.Storage.upload_file( - io.BytesIO(self.file_content), self.filename - ) + contents, gcs_file_path = self.Storage.upload_file(io.BytesIO(self.file_content), self.filename) object = self.Storage.bucket.get_blob(self.filename) assert self.file_content == object.download_as_bytes() # local checks assert (upload_dir / self.filename).exists() assert (upload_dir / self.filename).read_bytes() == self.file_content assert contents == self.file_content - assert gcs_file_path == "gs://" + self.Storage.bucket_name + "/" + self.filename + assert gcs_file_path == 'gs://' + self.Storage.bucket_name + '/' + self.filename # test error if file is empty with pytest.raises(ValueError): self.Storage.upload_file(self.file_bytesio_empty, self.filename) def test_get_file(self, monkeypatch, tmp_path, setup): upload_dir = mock_upload_dir(monkeypatch, tmp_path) - contents, gcs_file_path = self.Storage.upload_file( - io.BytesIO(self.file_content), self.filename - ) + contents, gcs_file_path = self.Storage.upload_file(io.BytesIO(self.file_content), self.filename) file_path = self.Storage.get_file(gcs_file_path) assert file_path == str(upload_dir / self.filename) assert (upload_dir / self.filename).exists() def test_delete_file(self, monkeypatch, tmp_path, setup): upload_dir = mock_upload_dir(monkeypatch, tmp_path) - contents, gcs_file_path = self.Storage.upload_file( - io.BytesIO(self.file_content), self.filename - ) + contents, gcs_file_path = self.Storage.upload_file(io.BytesIO(self.file_content), self.filename) # ensure that local directory has the uploaded file as well assert (upload_dir / self.filename).exists() assert self.Storage.bucket.get_blob(self.filename).name == self.filename @@ -278,10 +265,7 @@ class TestGCSStorageProvider: object = self.Storage.bucket.get_blob(self.filename_extra) assert (upload_dir / self.filename_extra).exists() assert (upload_dir / self.filename_extra).read_bytes() == self.file_content - assert ( - self.Storage.bucket.get_blob(self.filename_extra).name - == self.filename_extra - ) + assert self.Storage.bucket.get_blob(self.filename_extra).name == self.filename_extra assert self.file_content == object.download_as_bytes() self.Storage.delete_all_files() @@ -295,7 +279,7 @@ class TestAzureStorageProvider: def __init__(self): super().__init__() - @pytest.fixture(scope="class") + @pytest.fixture(scope='class') def setup_storage(self, monkeypatch): # Create mock Blob Service Client and related clients mock_blob_service_client = MagicMock() @@ -303,32 +287,28 @@ class TestAzureStorageProvider: mock_blob_client = MagicMock() # Set up return values for the mock - mock_blob_service_client.get_container_client.return_value = ( - mock_container_client - ) + mock_blob_service_client.get_container_client.return_value = mock_container_client mock_container_client.get_blob_client.return_value = mock_blob_client # Monkeypatch the Azure classes to return our mocks monkeypatch.setattr( azure.storage.blob, - "BlobServiceClient", + 'BlobServiceClient', lambda *args, **kwargs: mock_blob_service_client, ) monkeypatch.setattr( azure.storage.blob, - "ContainerClient", + 'ContainerClient', lambda *args, **kwargs: mock_container_client, ) - monkeypatch.setattr( - azure.storage.blob, "BlobClient", lambda *args, **kwargs: mock_blob_client - ) + monkeypatch.setattr(azure.storage.blob, 'BlobClient', lambda *args, **kwargs: mock_blob_client) self.Storage = provider.AzureStorageProvider() - self.Storage.endpoint = "https://myaccount.blob.core.windows.net" - self.Storage.container_name = "my-container" - self.file_content = b"test content" - self.filename = "test.txt" - self.filename_extra = "test_extra.txt" + self.Storage.endpoint = 'https://myaccount.blob.core.windows.net' + self.Storage.container_name = 'my-container' + self.file_content = b'test content' + self.filename = 'test.txt' + self.filename_extra = 'test_extra.txt' self.file_bytesio_empty = io.BytesIO() # Apply mocks to the Storage instance @@ -339,18 +319,14 @@ class TestAzureStorageProvider: upload_dir = mock_upload_dir(monkeypatch, tmp_path) # Simulate an error when container does not exist - self.Storage.container_client.get_blob_client.side_effect = Exception( - "Container does not exist" - ) + self.Storage.container_client.get_blob_client.side_effect = Exception('Container does not exist') with pytest.raises(Exception): self.Storage.upload_file(io.BytesIO(self.file_content), self.filename) # Reset side effect and create container self.Storage.container_client.get_blob_client.side_effect = None self.Storage.create_container() - contents, azure_file_path = self.Storage.upload_file( - io.BytesIO(self.file_content), self.filename - ) + contents, azure_file_path = self.Storage.upload_file(io.BytesIO(self.file_content), self.filename) # Assertions self.Storage.container_client.get_blob_client.assert_called_with(self.filename) @@ -359,8 +335,7 @@ class TestAzureStorageProvider: ) assert contents == self.file_content assert ( - azure_file_path - == f"https://myaccount.blob.core.windows.net/{self.Storage.container_name}/{self.filename}" + azure_file_path == f'https://myaccount.blob.core.windows.net/{self.Storage.container_name}/{self.filename}' ) assert (upload_dir / self.filename).exists() assert (upload_dir / self.filename).read_bytes() == self.file_content @@ -375,11 +350,9 @@ class TestAzureStorageProvider: # Mock upload behavior self.Storage.upload_file(io.BytesIO(self.file_content), self.filename) # Mock blob download behavior - self.Storage.container_client.get_blob_client().download_blob().readall.return_value = ( - self.file_content - ) + self.Storage.container_client.get_blob_client().download_blob().readall.return_value = self.file_content - file_url = f"https://myaccount.blob.core.windows.net/{self.Storage.container_name}/{self.filename}" + file_url = f'https://myaccount.blob.core.windows.net/{self.Storage.container_name}/{self.filename}' file_path = self.Storage.get_file(file_url) assert file_path == str(upload_dir / self.filename) @@ -395,7 +368,7 @@ class TestAzureStorageProvider: # Mock deletion self.Storage.container_client.get_blob_client().delete_blob.return_value = None - file_url = f"https://myaccount.blob.core.windows.net/{self.Storage.container_name}/{self.filename}" + file_url = f'https://myaccount.blob.core.windows.net/{self.Storage.container_name}/{self.filename}' self.Storage.delete_file(file_url) self.Storage.container_client.get_blob_client().delete_blob.assert_called_once() @@ -411,8 +384,8 @@ class TestAzureStorageProvider: # Mock listing and deletion behavior self.Storage.container_client.list_blobs.return_value = [ - {"name": self.filename}, - {"name": self.filename_extra}, + {'name': self.filename}, + {'name': self.filename_extra}, ] self.Storage.container_client.get_blob_client().delete_blob.return_value = None @@ -426,10 +399,8 @@ class TestAzureStorageProvider: def test_get_file_not_found(self, monkeypatch): self.Storage.create_container() - file_url = f"https://myaccount.blob.core.windows.net/{self.Storage.container_name}/{self.filename}" + file_url = f'https://myaccount.blob.core.windows.net/{self.Storage.container_name}/{self.filename}' # Mock behavior to raise an error for missing blobs - self.Storage.container_client.get_blob_client().download_blob.side_effect = ( - Exception("Blob not found") - ) - with pytest.raises(Exception, match="Blob not found"): + self.Storage.container_client.get_blob_client().download_blob.side_effect = Exception('Blob not found') + with pytest.raises(Exception, match='Blob not found'): self.Storage.get_file(file_url) diff --git a/backend/open_webui/test/util/test_redis.py b/backend/open_webui/test/util/test_redis.py index 8c393ce9d9..036fb36362 100644 --- a/backend/open_webui/test/util/test_redis.py +++ b/backend/open_webui/test/util/test_redis.py @@ -16,84 +16,84 @@ class TestSentinelRedisProxy: def test_parse_redis_service_url_valid(self): """Test parsing valid Redis service URL""" - url = "redis://user:pass@mymaster:6379/0" + url = 'redis://user:pass@mymaster:6379/0' result = parse_redis_service_url(url) - assert result["username"] == "user" - assert result["password"] == "pass" - assert result["service"] == "mymaster" - assert result["port"] == 6379 - assert result["db"] == 0 + assert result['username'] == 'user' + assert result['password'] == 'pass' + assert result['service'] == 'mymaster' + assert result['port'] == 6379 + assert result['db'] == 0 def test_parse_redis_service_url_defaults(self): """Test parsing Redis service URL with defaults""" - url = "redis://mymaster" + url = 'redis://mymaster' result = parse_redis_service_url(url) - assert result["username"] is None - assert result["password"] is None - assert result["service"] == "mymaster" - assert result["port"] == 6379 - assert result["db"] == 0 + assert result['username'] is None + assert result['password'] is None + assert result['service'] == 'mymaster' + assert result['port'] == 6379 + assert result['db'] == 0 def test_parse_redis_service_url_invalid_scheme(self): """Test parsing invalid URL scheme""" - with pytest.raises(ValueError, match="Invalid Redis URL scheme"): - parse_redis_service_url("http://invalid") + with pytest.raises(ValueError, match='Invalid Redis URL scheme'): + parse_redis_service_url('http://invalid') def test_get_sentinels_from_env(self): """Test parsing sentinel hosts from environment""" - hosts = "sentinel1,sentinel2,sentinel3" - port = "26379" + hosts = 'sentinel1,sentinel2,sentinel3' + port = '26379' result = get_sentinels_from_env(hosts, port) - expected = [("sentinel1", 26379), ("sentinel2", 26379), ("sentinel3", 26379)] + expected = [('sentinel1', 26379), ('sentinel2', 26379), ('sentinel3', 26379)] assert result == expected def test_get_sentinels_from_env_empty(self): """Test empty sentinel hosts""" - result = get_sentinels_from_env(None, "26379") + result = get_sentinels_from_env(None, '26379') assert result == [] - @patch("redis.sentinel.Sentinel") + @patch('redis.sentinel.Sentinel') def test_sentinel_redis_proxy_sync_success(self, mock_sentinel_class): """Test successful sync operation with SentinelRedisProxy""" mock_sentinel = Mock() mock_master = Mock() - mock_master.get.return_value = "test_value" + mock_master.get.return_value = 'test_value' mock_sentinel.master_for.return_value = mock_master - proxy = SentinelRedisProxy(mock_sentinel, "mymaster", async_mode=False) + proxy = SentinelRedisProxy(mock_sentinel, 'mymaster', async_mode=False) # Test attribute access - get_method = proxy.__getattr__("get") - result = get_method("test_key") + get_method = proxy.__getattr__('get') + result = get_method('test_key') - assert result == "test_value" - mock_sentinel.master_for.assert_called_with("mymaster") - mock_master.get.assert_called_with("test_key") + assert result == 'test_value' + mock_sentinel.master_for.assert_called_with('mymaster') + mock_master.get.assert_called_with('test_key') - @patch("redis.sentinel.Sentinel") + @patch('redis.sentinel.Sentinel') @pytest.mark.asyncio async def test_sentinel_redis_proxy_async_success(self, mock_sentinel_class): """Test successful async operation with SentinelRedisProxy""" mock_sentinel = Mock() mock_master = Mock() - mock_master.get = AsyncMock(return_value="test_value") + mock_master.get = AsyncMock(return_value='test_value') mock_sentinel.master_for.return_value = mock_master - proxy = SentinelRedisProxy(mock_sentinel, "mymaster", async_mode=True) + proxy = SentinelRedisProxy(mock_sentinel, 'mymaster', async_mode=True) # Test async attribute access - get_method = proxy.__getattr__("get") - result = await get_method("test_key") + get_method = proxy.__getattr__('get') + result = await get_method('test_key') - assert result == "test_value" - mock_sentinel.master_for.assert_called_with("mymaster") - mock_master.get.assert_called_with("test_key") + assert result == 'test_value' + mock_sentinel.master_for.assert_called_with('mymaster') + mock_master.get.assert_called_with('test_key') - @patch("redis.sentinel.Sentinel") + @patch('redis.sentinel.Sentinel') def test_sentinel_redis_proxy_failover_retry(self, mock_sentinel_class): """Test retry mechanism during failover""" mock_sentinel = Mock() @@ -101,39 +101,39 @@ class TestSentinelRedisProxy: # First call fails, second succeeds mock_master.get.side_effect = [ - redis.exceptions.ConnectionError("Master down"), - "test_value", + redis.exceptions.ConnectionError('Master down'), + 'test_value', ] mock_sentinel.master_for.return_value = mock_master - proxy = SentinelRedisProxy(mock_sentinel, "mymaster", async_mode=False) + proxy = SentinelRedisProxy(mock_sentinel, 'mymaster', async_mode=False) - get_method = proxy.__getattr__("get") - result = get_method("test_key") + get_method = proxy.__getattr__('get') + result = get_method('test_key') - assert result == "test_value" + assert result == 'test_value' assert mock_master.get.call_count == 2 - @patch("redis.sentinel.Sentinel") + @patch('redis.sentinel.Sentinel') def test_sentinel_redis_proxy_max_retries_exceeded(self, mock_sentinel_class): """Test failure after max retries exceeded""" mock_sentinel = Mock() mock_master = Mock() # All calls fail - mock_master.get.side_effect = redis.exceptions.ConnectionError("Master down") + mock_master.get.side_effect = redis.exceptions.ConnectionError('Master down') mock_sentinel.master_for.return_value = mock_master - proxy = SentinelRedisProxy(mock_sentinel, "mymaster", async_mode=False) + proxy = SentinelRedisProxy(mock_sentinel, 'mymaster', async_mode=False) - get_method = proxy.__getattr__("get") + get_method = proxy.__getattr__('get') with pytest.raises(redis.exceptions.ConnectionError): - get_method("test_key") + get_method('test_key') assert mock_master.get.call_count == MAX_RETRY_COUNT - @patch("redis.sentinel.Sentinel") + @patch('redis.sentinel.Sentinel') def test_sentinel_redis_proxy_readonly_error_retry(self, mock_sentinel_class): """Test retry on ReadOnlyError""" mock_sentinel = Mock() @@ -141,20 +141,20 @@ class TestSentinelRedisProxy: # First call gets ReadOnlyError (old master), second succeeds (new master) mock_master.get.side_effect = [ - redis.exceptions.ReadOnlyError("Read only"), - "test_value", + redis.exceptions.ReadOnlyError('Read only'), + 'test_value', ] mock_sentinel.master_for.return_value = mock_master - proxy = SentinelRedisProxy(mock_sentinel, "mymaster", async_mode=False) + proxy = SentinelRedisProxy(mock_sentinel, 'mymaster', async_mode=False) - get_method = proxy.__getattr__("get") - result = get_method("test_key") + get_method = proxy.__getattr__('get') + result = get_method('test_key') - assert result == "test_value" + assert result == 'test_value' assert mock_master.get.call_count == 2 - @patch("redis.sentinel.Sentinel") + @patch('redis.sentinel.Sentinel') def test_sentinel_redis_proxy_factory_methods(self, mock_sentinel_class): """Test factory methods are passed through directly""" mock_sentinel = Mock() @@ -163,61 +163,53 @@ class TestSentinelRedisProxy: mock_master.pipeline.return_value = mock_pipeline mock_sentinel.master_for.return_value = mock_master - proxy = SentinelRedisProxy(mock_sentinel, "mymaster", async_mode=False) + proxy = SentinelRedisProxy(mock_sentinel, 'mymaster', async_mode=False) # Factory methods should be passed through without wrapping - pipeline_method = proxy.__getattr__("pipeline") + pipeline_method = proxy.__getattr__('pipeline') result = pipeline_method() assert result == mock_pipeline mock_master.pipeline.assert_called_once() - @patch("redis.sentinel.Sentinel") - @patch("redis.from_url") - def test_get_redis_connection_with_sentinel( - self, mock_from_url, mock_sentinel_class - ): + @patch('redis.sentinel.Sentinel') + @patch('redis.from_url') + def test_get_redis_connection_with_sentinel(self, mock_from_url, mock_sentinel_class): """Test getting Redis connection with Sentinel""" mock_sentinel = Mock() mock_sentinel_class.return_value = mock_sentinel - sentinels = [("sentinel1", 26379), ("sentinel2", 26379)] - redis_url = "redis://user:pass@mymaster:6379/0" + sentinels = [('sentinel1', 26379), ('sentinel2', 26379)] + redis_url = 'redis://user:pass@mymaster:6379/0' - result = get_redis_connection( - redis_url=redis_url, redis_sentinels=sentinels, async_mode=False - ) + result = get_redis_connection(redis_url=redis_url, redis_sentinels=sentinels, async_mode=False) assert isinstance(result, SentinelRedisProxy) mock_sentinel_class.assert_called_once() mock_from_url.assert_not_called() - @patch("redis.Redis.from_url") + @patch('redis.Redis.from_url') def test_get_redis_connection_without_sentinel(self, mock_from_url): """Test getting Redis connection without Sentinel""" mock_redis = Mock() mock_from_url.return_value = mock_redis - redis_url = "redis://localhost:6379/0" + redis_url = 'redis://localhost:6379/0' - result = get_redis_connection( - redis_url=redis_url, redis_sentinels=None, async_mode=False - ) + result = get_redis_connection(redis_url=redis_url, redis_sentinels=None, async_mode=False) assert result == mock_redis mock_from_url.assert_called_once_with(redis_url, decode_responses=True) - @patch("redis.asyncio.from_url") + @patch('redis.asyncio.from_url') def test_get_redis_connection_without_sentinel_async(self, mock_from_url): """Test getting async Redis connection without Sentinel""" mock_redis = Mock() mock_from_url.return_value = mock_redis - redis_url = "redis://localhost:6379/0" + redis_url = 'redis://localhost:6379/0' - result = get_redis_connection( - redis_url=redis_url, redis_sentinels=None, async_mode=True - ) + result = get_redis_connection(redis_url=redis_url, redis_sentinels=None, async_mode=True) assert result == mock_redis mock_from_url.assert_called_once_with(redis_url, decode_responses=True) @@ -226,7 +218,7 @@ class TestSentinelRedisProxy: class TestSentinelRedisProxyCommands: """Test Redis commands through SentinelRedisProxy""" - @patch("redis.sentinel.Sentinel") + @patch('redis.sentinel.Sentinel') def test_hash_commands_sync(self, mock_sentinel_class): """Test Redis hash commands in sync mode""" mock_sentinel = Mock() @@ -234,39 +226,39 @@ class TestSentinelRedisProxyCommands: # Mock hash command responses mock_master.hset.return_value = 1 - mock_master.hget.return_value = "test_value" - mock_master.hgetall.return_value = {"key1": "value1", "key2": "value2"} + mock_master.hget.return_value = 'test_value' + mock_master.hgetall.return_value = {'key1': 'value1', 'key2': 'value2'} mock_master.hdel.return_value = 1 mock_sentinel.master_for.return_value = mock_master - proxy = SentinelRedisProxy(mock_sentinel, "mymaster", async_mode=False) + proxy = SentinelRedisProxy(mock_sentinel, 'mymaster', async_mode=False) # Test hset - hset_method = proxy.__getattr__("hset") - result = hset_method("test_hash", "field1", "value1") + hset_method = proxy.__getattr__('hset') + result = hset_method('test_hash', 'field1', 'value1') assert result == 1 - mock_master.hset.assert_called_with("test_hash", "field1", "value1") + mock_master.hset.assert_called_with('test_hash', 'field1', 'value1') # Test hget - hget_method = proxy.__getattr__("hget") - result = hget_method("test_hash", "field1") - assert result == "test_value" - mock_master.hget.assert_called_with("test_hash", "field1") + hget_method = proxy.__getattr__('hget') + result = hget_method('test_hash', 'field1') + assert result == 'test_value' + mock_master.hget.assert_called_with('test_hash', 'field1') # Test hgetall - hgetall_method = proxy.__getattr__("hgetall") - result = hgetall_method("test_hash") - assert result == {"key1": "value1", "key2": "value2"} - mock_master.hgetall.assert_called_with("test_hash") + hgetall_method = proxy.__getattr__('hgetall') + result = hgetall_method('test_hash') + assert result == {'key1': 'value1', 'key2': 'value2'} + mock_master.hgetall.assert_called_with('test_hash') # Test hdel - hdel_method = proxy.__getattr__("hdel") - result = hdel_method("test_hash", "field1") + hdel_method = proxy.__getattr__('hdel') + result = hdel_method('test_hash', 'field1') assert result == 1 - mock_master.hdel.assert_called_with("test_hash", "field1") + mock_master.hdel.assert_called_with('test_hash', 'field1') - @patch("redis.sentinel.Sentinel") + @patch('redis.sentinel.Sentinel') @pytest.mark.asyncio async def test_hash_commands_async(self, mock_sentinel_class): """Test Redis hash commands in async mode""" @@ -275,34 +267,32 @@ class TestSentinelRedisProxyCommands: # Mock async hash command responses mock_master.hset = AsyncMock(return_value=1) - mock_master.hget = AsyncMock(return_value="test_value") - mock_master.hgetall = AsyncMock( - return_value={"key1": "value1", "key2": "value2"} - ) + mock_master.hget = AsyncMock(return_value='test_value') + mock_master.hgetall = AsyncMock(return_value={'key1': 'value1', 'key2': 'value2'}) mock_sentinel.master_for.return_value = mock_master - proxy = SentinelRedisProxy(mock_sentinel, "mymaster", async_mode=True) + proxy = SentinelRedisProxy(mock_sentinel, 'mymaster', async_mode=True) # Test hset - hset_method = proxy.__getattr__("hset") - result = await hset_method("test_hash", "field1", "value1") + hset_method = proxy.__getattr__('hset') + result = await hset_method('test_hash', 'field1', 'value1') assert result == 1 - mock_master.hset.assert_called_with("test_hash", "field1", "value1") + mock_master.hset.assert_called_with('test_hash', 'field1', 'value1') # Test hget - hget_method = proxy.__getattr__("hget") - result = await hget_method("test_hash", "field1") - assert result == "test_value" - mock_master.hget.assert_called_with("test_hash", "field1") + hget_method = proxy.__getattr__('hget') + result = await hget_method('test_hash', 'field1') + assert result == 'test_value' + mock_master.hget.assert_called_with('test_hash', 'field1') # Test hgetall - hgetall_method = proxy.__getattr__("hgetall") - result = await hgetall_method("test_hash") - assert result == {"key1": "value1", "key2": "value2"} - mock_master.hgetall.assert_called_with("test_hash") + hgetall_method = proxy.__getattr__('hgetall') + result = await hgetall_method('test_hash') + assert result == {'key1': 'value1', 'key2': 'value2'} + mock_master.hgetall.assert_called_with('test_hash') - @patch("redis.sentinel.Sentinel") + @patch('redis.sentinel.Sentinel') def test_string_commands_sync(self, mock_sentinel_class): """Test Redis string commands in sync mode""" mock_sentinel = Mock() @@ -310,39 +300,39 @@ class TestSentinelRedisProxyCommands: # Mock string command responses mock_master.set.return_value = True - mock_master.get.return_value = "test_value" + mock_master.get.return_value = 'test_value' mock_master.delete.return_value = 1 mock_master.exists.return_value = True mock_sentinel.master_for.return_value = mock_master - proxy = SentinelRedisProxy(mock_sentinel, "mymaster", async_mode=False) + proxy = SentinelRedisProxy(mock_sentinel, 'mymaster', async_mode=False) # Test set - set_method = proxy.__getattr__("set") - result = set_method("test_key", "test_value") + set_method = proxy.__getattr__('set') + result = set_method('test_key', 'test_value') assert result is True - mock_master.set.assert_called_with("test_key", "test_value") + mock_master.set.assert_called_with('test_key', 'test_value') # Test get - get_method = proxy.__getattr__("get") - result = get_method("test_key") - assert result == "test_value" - mock_master.get.assert_called_with("test_key") + get_method = proxy.__getattr__('get') + result = get_method('test_key') + assert result == 'test_value' + mock_master.get.assert_called_with('test_key') # Test delete - delete_method = proxy.__getattr__("delete") - result = delete_method("test_key") + delete_method = proxy.__getattr__('delete') + result = delete_method('test_key') assert result == 1 - mock_master.delete.assert_called_with("test_key") + mock_master.delete.assert_called_with('test_key') # Test exists - exists_method = proxy.__getattr__("exists") - result = exists_method("test_key") + exists_method = proxy.__getattr__('exists') + result = exists_method('test_key') assert result is True - mock_master.exists.assert_called_with("test_key") + mock_master.exists.assert_called_with('test_key') - @patch("redis.sentinel.Sentinel") + @patch('redis.sentinel.Sentinel') def test_list_commands_sync(self, mock_sentinel_class): """Test Redis list commands in sync mode""" mock_sentinel = Mock() @@ -350,39 +340,39 @@ class TestSentinelRedisProxyCommands: # Mock list command responses mock_master.lpush.return_value = 1 - mock_master.rpop.return_value = "test_value" + mock_master.rpop.return_value = 'test_value' mock_master.llen.return_value = 5 - mock_master.lrange.return_value = ["item1", "item2", "item3"] + mock_master.lrange.return_value = ['item1', 'item2', 'item3'] mock_sentinel.master_for.return_value = mock_master - proxy = SentinelRedisProxy(mock_sentinel, "mymaster", async_mode=False) + proxy = SentinelRedisProxy(mock_sentinel, 'mymaster', async_mode=False) # Test lpush - lpush_method = proxy.__getattr__("lpush") - result = lpush_method("test_list", "item1") + lpush_method = proxy.__getattr__('lpush') + result = lpush_method('test_list', 'item1') assert result == 1 - mock_master.lpush.assert_called_with("test_list", "item1") + mock_master.lpush.assert_called_with('test_list', 'item1') # Test rpop - rpop_method = proxy.__getattr__("rpop") - result = rpop_method("test_list") - assert result == "test_value" - mock_master.rpop.assert_called_with("test_list") + rpop_method = proxy.__getattr__('rpop') + result = rpop_method('test_list') + assert result == 'test_value' + mock_master.rpop.assert_called_with('test_list') # Test llen - llen_method = proxy.__getattr__("llen") - result = llen_method("test_list") + llen_method = proxy.__getattr__('llen') + result = llen_method('test_list') assert result == 5 - mock_master.llen.assert_called_with("test_list") + mock_master.llen.assert_called_with('test_list') # Test lrange - lrange_method = proxy.__getattr__("lrange") - result = lrange_method("test_list", 0, -1) - assert result == ["item1", "item2", "item3"] - mock_master.lrange.assert_called_with("test_list", 0, -1) + lrange_method = proxy.__getattr__('lrange') + result = lrange_method('test_list', 0, -1) + assert result == ['item1', 'item2', 'item3'] + mock_master.lrange.assert_called_with('test_list', 0, -1) - @patch("redis.sentinel.Sentinel") + @patch('redis.sentinel.Sentinel') def test_pubsub_commands_sync(self, mock_sentinel_class): """Test Redis pubsub commands in sync mode""" mock_sentinel = Mock() @@ -393,25 +383,25 @@ class TestSentinelRedisProxyCommands: mock_master.pubsub.return_value = mock_pubsub mock_master.publish.return_value = 1 mock_pubsub.subscribe.return_value = None - mock_pubsub.get_message.return_value = {"type": "message", "data": "test_data"} + mock_pubsub.get_message.return_value = {'type': 'message', 'data': 'test_data'} mock_sentinel.master_for.return_value = mock_master - proxy = SentinelRedisProxy(mock_sentinel, "mymaster", async_mode=False) + proxy = SentinelRedisProxy(mock_sentinel, 'mymaster', async_mode=False) # Test pubsub (factory method - should pass through) - pubsub_method = proxy.__getattr__("pubsub") + pubsub_method = proxy.__getattr__('pubsub') result = pubsub_method() assert result == mock_pubsub mock_master.pubsub.assert_called_once() # Test publish - publish_method = proxy.__getattr__("publish") - result = publish_method("test_channel", "test_message") + publish_method = proxy.__getattr__('publish') + result = publish_method('test_channel', 'test_message') assert result == 1 - mock_master.publish.assert_called_with("test_channel", "test_message") + mock_master.publish.assert_called_with('test_channel', 'test_message') - @patch("redis.sentinel.Sentinel") + @patch('redis.sentinel.Sentinel') def test_pipeline_commands_sync(self, mock_sentinel_class): """Test Redis pipeline commands in sync mode""" mock_sentinel = Mock() @@ -422,19 +412,19 @@ class TestSentinelRedisProxyCommands: mock_master.pipeline.return_value = mock_pipeline mock_pipeline.set.return_value = mock_pipeline mock_pipeline.get.return_value = mock_pipeline - mock_pipeline.execute.return_value = [True, "test_value"] + mock_pipeline.execute.return_value = [True, 'test_value'] mock_sentinel.master_for.return_value = mock_master - proxy = SentinelRedisProxy(mock_sentinel, "mymaster", async_mode=False) + proxy = SentinelRedisProxy(mock_sentinel, 'mymaster', async_mode=False) # Test pipeline (factory method - should pass through) - pipeline_method = proxy.__getattr__("pipeline") + pipeline_method = proxy.__getattr__('pipeline') result = pipeline_method() assert result == mock_pipeline mock_master.pipeline.assert_called_once() - @patch("redis.sentinel.Sentinel") + @patch('redis.sentinel.Sentinel') def test_commands_with_failover_retry(self, mock_sentinel_class): """Test Redis commands with failover retry mechanism""" mock_sentinel = Mock() @@ -442,27 +432,27 @@ class TestSentinelRedisProxyCommands: # First call fails with connection error, second succeeds mock_master.hget.side_effect = [ - redis.exceptions.ConnectionError("Connection failed"), - "recovered_value", + redis.exceptions.ConnectionError('Connection failed'), + 'recovered_value', ] mock_sentinel.master_for.return_value = mock_master - proxy = SentinelRedisProxy(mock_sentinel, "mymaster", async_mode=False) + proxy = SentinelRedisProxy(mock_sentinel, 'mymaster', async_mode=False) # Test hget with retry - hget_method = proxy.__getattr__("hget") - result = hget_method("test_hash", "field1") + hget_method = proxy.__getattr__('hget') + result = hget_method('test_hash', 'field1') - assert result == "recovered_value" + assert result == 'recovered_value' assert mock_master.hget.call_count == 2 # Verify both calls were made with same parameters - expected_calls = [(("test_hash", "field1"),), (("test_hash", "field1"),)] + expected_calls = [(('test_hash', 'field1'),), (('test_hash', 'field1'),)] actual_calls = [call.args for call in mock_master.hget.call_args_list] assert actual_calls == expected_calls - @patch("redis.sentinel.Sentinel") + @patch('redis.sentinel.Sentinel') def test_commands_with_readonly_error_retry(self, mock_sentinel_class): """Test Redis commands with ReadOnlyError retry mechanism""" mock_sentinel = Mock() @@ -470,32 +460,30 @@ class TestSentinelRedisProxyCommands: # First call fails with ReadOnlyError, second succeeds mock_master.hset.side_effect = [ - redis.exceptions.ReadOnlyError( - "READONLY You can't write against a read only replica" - ), + redis.exceptions.ReadOnlyError("READONLY You can't write against a read only replica"), 1, ] mock_sentinel.master_for.return_value = mock_master - proxy = SentinelRedisProxy(mock_sentinel, "mymaster", async_mode=False) + proxy = SentinelRedisProxy(mock_sentinel, 'mymaster', async_mode=False) # Test hset with retry - hset_method = proxy.__getattr__("hset") - result = hset_method("test_hash", "field1", "value1") + hset_method = proxy.__getattr__('hset') + result = hset_method('test_hash', 'field1', 'value1') assert result == 1 assert mock_master.hset.call_count == 2 # Verify both calls were made with same parameters expected_calls = [ - (("test_hash", "field1", "value1"),), - (("test_hash", "field1", "value1"),), + (('test_hash', 'field1', 'value1'),), + (('test_hash', 'field1', 'value1'),), ] actual_calls = [call.args for call in mock_master.hset.call_args_list] assert actual_calls == expected_calls - @patch("redis.sentinel.Sentinel") + @patch('redis.sentinel.Sentinel') @pytest.mark.asyncio async def test_async_commands_with_failover_retry(self, mock_sentinel_class): """Test async Redis commands with failover retry mechanism""" @@ -505,24 +493,24 @@ class TestSentinelRedisProxyCommands: # First call fails with connection error, second succeeds mock_master.hget = AsyncMock( side_effect=[ - redis.exceptions.ConnectionError("Connection failed"), - "recovered_value", + redis.exceptions.ConnectionError('Connection failed'), + 'recovered_value', ] ) mock_sentinel.master_for.return_value = mock_master - proxy = SentinelRedisProxy(mock_sentinel, "mymaster", async_mode=True) + proxy = SentinelRedisProxy(mock_sentinel, 'mymaster', async_mode=True) # Test async hget with retry - hget_method = proxy.__getattr__("hget") - result = await hget_method("test_hash", "field1") + hget_method = proxy.__getattr__('hget') + result = await hget_method('test_hash', 'field1') - assert result == "recovered_value" + assert result == 'recovered_value' assert mock_master.hget.call_count == 2 # Verify both calls were made with same parameters - expected_calls = [(("test_hash", "field1"),), (("test_hash", "field1"),)] + expected_calls = [(('test_hash', 'field1'),), (('test_hash', 'field1'),)] actual_calls = [call.args for call in mock_master.hget.call_args_list] assert actual_calls == expected_calls @@ -530,7 +518,7 @@ class TestSentinelRedisProxyCommands: class TestSentinelRedisProxyFactoryMethods: """Test Redis factory methods in async mode - these are special cases that remain sync""" - @patch("redis.sentinel.Sentinel") + @patch('redis.sentinel.Sentinel') @pytest.mark.asyncio async def test_pubsub_factory_method_async(self, mock_sentinel_class): """Test pubsub factory method in async mode - should pass through without wrapping""" @@ -542,10 +530,10 @@ class TestSentinelRedisProxyFactoryMethods: mock_master.pubsub.return_value = mock_pubsub mock_sentinel.master_for.return_value = mock_master - proxy = SentinelRedisProxy(mock_sentinel, "mymaster", async_mode=True) + proxy = SentinelRedisProxy(mock_sentinel, 'mymaster', async_mode=True) # Test pubsub factory method - should NOT be wrapped as async - pubsub_method = proxy.__getattr__("pubsub") + pubsub_method = proxy.__getattr__('pubsub') result = pubsub_method() assert result == mock_pubsub @@ -554,7 +542,7 @@ class TestSentinelRedisProxyFactoryMethods: # Verify it's not wrapped as async (no await needed) assert not inspect.iscoroutine(result) - @patch("redis.sentinel.Sentinel") + @patch('redis.sentinel.Sentinel') @pytest.mark.asyncio async def test_pipeline_factory_method_async(self, mock_sentinel_class): """Test pipeline factory method in async mode - should pass through without wrapping""" @@ -566,14 +554,14 @@ class TestSentinelRedisProxyFactoryMethods: mock_master.pipeline.return_value = mock_pipeline mock_pipeline.set.return_value = mock_pipeline mock_pipeline.get.return_value = mock_pipeline - mock_pipeline.execute.return_value = [True, "test_value"] + mock_pipeline.execute.return_value = [True, 'test_value'] mock_sentinel.master_for.return_value = mock_master - proxy = SentinelRedisProxy(mock_sentinel, "mymaster", async_mode=True) + proxy = SentinelRedisProxy(mock_sentinel, 'mymaster', async_mode=True) # Test pipeline factory method - should NOT be wrapped as async - pipeline_method = proxy.__getattr__("pipeline") + pipeline_method = proxy.__getattr__('pipeline') result = pipeline_method() assert result == mock_pipeline @@ -583,10 +571,10 @@ class TestSentinelRedisProxyFactoryMethods: assert not inspect.iscoroutine(result) # Test pipeline usage (these should also be sync) - pipeline_result = result.set("key", "value").get("key").execute() - assert pipeline_result == [True, "test_value"] + pipeline_result = result.set('key', 'value').get('key').execute() + assert pipeline_result == [True, 'test_value'] - @patch("redis.sentinel.Sentinel") + @patch('redis.sentinel.Sentinel') @pytest.mark.asyncio async def test_factory_methods_vs_regular_commands_async(self, mock_sentinel_class): """Test that factory methods behave differently from regular commands in async mode""" @@ -596,19 +584,19 @@ class TestSentinelRedisProxyFactoryMethods: # Mock both factory method and regular command mock_pubsub = Mock() mock_master.pubsub.return_value = mock_pubsub - mock_master.get = AsyncMock(return_value="test_value") + mock_master.get = AsyncMock(return_value='test_value') mock_sentinel.master_for.return_value = mock_master - proxy = SentinelRedisProxy(mock_sentinel, "mymaster", async_mode=True) + proxy = SentinelRedisProxy(mock_sentinel, 'mymaster', async_mode=True) # Test factory method - should NOT be wrapped - pubsub_method = proxy.__getattr__("pubsub") + pubsub_method = proxy.__getattr__('pubsub') pubsub_result = pubsub_method() # Test regular command - should be wrapped as async - get_method = proxy.__getattr__("get") - get_result = get_method("test_key") + get_method = proxy.__getattr__('get') + get_result = get_method('test_key') # Factory method returns directly assert pubsub_result == mock_pubsub @@ -619,9 +607,9 @@ class TestSentinelRedisProxyFactoryMethods: # Regular command needs await actual_value = await get_result - assert actual_value == "test_value" + assert actual_value == 'test_value' - @patch("redis.sentinel.Sentinel") + @patch('redis.sentinel.Sentinel') @pytest.mark.asyncio async def test_factory_methods_with_failover_async(self, mock_sentinel_class): """Test factory methods with failover in async mode""" @@ -631,16 +619,16 @@ class TestSentinelRedisProxyFactoryMethods: # First call fails, second succeeds mock_pubsub = Mock() mock_master.pubsub.side_effect = [ - redis.exceptions.ConnectionError("Connection failed"), + redis.exceptions.ConnectionError('Connection failed'), mock_pubsub, ] mock_sentinel.master_for.return_value = mock_master - proxy = SentinelRedisProxy(mock_sentinel, "mymaster", async_mode=True) + proxy = SentinelRedisProxy(mock_sentinel, 'mymaster', async_mode=True) # Test pubsub factory method with failover - pubsub_method = proxy.__getattr__("pubsub") + pubsub_method = proxy.__getattr__('pubsub') result = pubsub_method() assert result == mock_pubsub @@ -649,7 +637,7 @@ class TestSentinelRedisProxyFactoryMethods: # Verify it's still not wrapped as async after retry assert not inspect.iscoroutine(result) - @patch("redis.sentinel.Sentinel") + @patch('redis.sentinel.Sentinel') @pytest.mark.asyncio async def test_monitor_factory_method_async(self, mock_sentinel_class): """Test monitor factory method in async mode - should pass through without wrapping""" @@ -661,10 +649,10 @@ class TestSentinelRedisProxyFactoryMethods: mock_master.monitor.return_value = mock_monitor mock_sentinel.master_for.return_value = mock_master - proxy = SentinelRedisProxy(mock_sentinel, "mymaster", async_mode=True) + proxy = SentinelRedisProxy(mock_sentinel, 'mymaster', async_mode=True) # Test monitor factory method - should NOT be wrapped as async - monitor_method = proxy.__getattr__("monitor") + monitor_method = proxy.__getattr__('monitor') result = monitor_method() assert result == mock_monitor @@ -673,7 +661,7 @@ class TestSentinelRedisProxyFactoryMethods: # Verify it's not wrapped as async (no await needed) assert not inspect.iscoroutine(result) - @patch("redis.sentinel.Sentinel") + @patch('redis.sentinel.Sentinel') @pytest.mark.asyncio async def test_client_factory_method_async(self, mock_sentinel_class): """Test client factory method in async mode - should pass through without wrapping""" @@ -685,10 +673,10 @@ class TestSentinelRedisProxyFactoryMethods: mock_master.client.return_value = mock_client mock_sentinel.master_for.return_value = mock_master - proxy = SentinelRedisProxy(mock_sentinel, "mymaster", async_mode=True) + proxy = SentinelRedisProxy(mock_sentinel, 'mymaster', async_mode=True) # Test client factory method - should NOT be wrapped as async - client_method = proxy.__getattr__("client") + client_method = proxy.__getattr__('client') result = client_method() assert result == mock_client @@ -697,7 +685,7 @@ class TestSentinelRedisProxyFactoryMethods: # Verify it's not wrapped as async (no await needed) assert not inspect.iscoroutine(result) - @patch("redis.sentinel.Sentinel") + @patch('redis.sentinel.Sentinel') @pytest.mark.asyncio async def test_transaction_factory_method_async(self, mock_sentinel_class): """Test transaction factory method in async mode - should pass through without wrapping""" @@ -709,10 +697,10 @@ class TestSentinelRedisProxyFactoryMethods: mock_master.transaction.return_value = mock_transaction mock_sentinel.master_for.return_value = mock_master - proxy = SentinelRedisProxy(mock_sentinel, "mymaster", async_mode=True) + proxy = SentinelRedisProxy(mock_sentinel, 'mymaster', async_mode=True) # Test transaction factory method - should NOT be wrapped as async - transaction_method = proxy.__getattr__("transaction") + transaction_method = proxy.__getattr__('transaction') result = transaction_method() assert result == mock_transaction @@ -721,7 +709,7 @@ class TestSentinelRedisProxyFactoryMethods: # Verify it's not wrapped as async (no await needed) assert not inspect.iscoroutine(result) - @patch("redis.sentinel.Sentinel") + @patch('redis.sentinel.Sentinel') @pytest.mark.asyncio async def test_all_factory_methods_async(self, mock_sentinel_class): """Test all factory methods in async mode - comprehensive test""" @@ -730,11 +718,11 @@ class TestSentinelRedisProxyFactoryMethods: # Mock all factory methods mock_objects = { - "pipeline": Mock(), - "pubsub": Mock(), - "monitor": Mock(), - "client": Mock(), - "transaction": Mock(), + 'pipeline': Mock(), + 'pubsub': Mock(), + 'monitor': Mock(), + 'client': Mock(), + 'transaction': Mock(), } for method_name, mock_obj in mock_objects.items(): @@ -742,7 +730,7 @@ class TestSentinelRedisProxyFactoryMethods: mock_sentinel.master_for.return_value = mock_master - proxy = SentinelRedisProxy(mock_sentinel, "mymaster", async_mode=True) + proxy = SentinelRedisProxy(mock_sentinel, 'mymaster', async_mode=True) # Test all factory methods for method_name, expected_obj in mock_objects.items(): @@ -756,7 +744,7 @@ class TestSentinelRedisProxyFactoryMethods: # Reset mock for next iteration getattr(mock_master, method_name).reset_mock() - @patch("redis.sentinel.Sentinel") + @patch('redis.sentinel.Sentinel') @pytest.mark.asyncio async def test_mixed_factory_and_regular_commands_async(self, mock_sentinel_class): """Test using both factory methods and regular commands in async mode""" @@ -768,26 +756,26 @@ class TestSentinelRedisProxyFactoryMethods: mock_master.pipeline.return_value = mock_pipeline mock_pipeline.set.return_value = mock_pipeline mock_pipeline.get.return_value = mock_pipeline - mock_pipeline.execute.return_value = [True, "pipeline_value"] + mock_pipeline.execute.return_value = [True, 'pipeline_value'] - mock_master.get = AsyncMock(return_value="regular_value") + mock_master.get = AsyncMock(return_value='regular_value') mock_sentinel.master_for.return_value = mock_master - proxy = SentinelRedisProxy(mock_sentinel, "mymaster", async_mode=True) + proxy = SentinelRedisProxy(mock_sentinel, 'mymaster', async_mode=True) # Use factory method (sync) - pipeline = proxy.__getattr__("pipeline")() - pipeline_result = pipeline.set("key1", "value1").get("key1").execute() + pipeline = proxy.__getattr__('pipeline')() + pipeline_result = pipeline.set('key1', 'value1').get('key1').execute() # Use regular command (async) - get_method = proxy.__getattr__("get") - regular_result = await get_method("key2") + get_method = proxy.__getattr__('get') + regular_result = await get_method('key2') # Verify both work correctly - assert pipeline_result == [True, "pipeline_value"] - assert regular_result == "regular_value" + assert pipeline_result == [True, 'pipeline_value'] + assert regular_result == 'regular_value' # Verify calls mock_master.pipeline.assert_called_once() - mock_master.get.assert_called_with("key2") + mock_master.get.assert_called_with('key2') diff --git a/backend/open_webui/tools/builtin.py b/backend/open_webui/tools/builtin.py index e556d8a355..73589810d6 100644 --- a/backend/open_webui/tools/builtin.py +++ b/backend/open_webui/tools/builtin.py @@ -64,14 +64,14 @@ async def get_current_timestamp( now = datetime.datetime.now(datetime.timezone.utc) return json.dumps( { - "current_timestamp": int(now.timestamp()), - "current_iso": now.isoformat(), + 'current_timestamp': int(now.timestamp()), + 'current_iso': now.isoformat(), }, ensure_ascii=False, ) except Exception as e: - log.exception(f"get_current_timestamp error: {e}") - return json.dumps({"error": str(e)}) + log.exception(f'get_current_timestamp error: {e}') + return json.dumps({'error': str(e)}) async def calculate_timestamp( @@ -112,10 +112,10 @@ async def calculate_timestamp( return json.dumps( { - "current_timestamp": current_ts, - "current_iso": now.isoformat(), - "calculated_timestamp": adjusted_ts, - "calculated_iso": adjusted.isoformat(), + 'current_timestamp': current_ts, + 'current_iso': now.isoformat(), + 'calculated_timestamp': adjusted_ts, + 'calculated_iso': adjusted.isoformat(), }, ensure_ascii=False, ) @@ -130,16 +130,16 @@ async def calculate_timestamp( adjusted_ts = int(adjusted.timestamp()) return json.dumps( { - "current_timestamp": current_ts, - "current_iso": now.isoformat(), - "calculated_timestamp": adjusted_ts, - "calculated_iso": adjusted.isoformat(), + 'current_timestamp': current_ts, + 'current_iso': now.isoformat(), + 'calculated_timestamp': adjusted_ts, + 'calculated_iso': adjusted.isoformat(), }, ensure_ascii=False, ) except Exception as e: - log.exception(f"calculate_timestamp error: {e}") - return json.dumps({"error": str(e)}) + log.exception(f'calculate_timestamp error: {e}') + return json.dumps({'error': str(e)}) # ============================================================================= @@ -162,14 +162,18 @@ async def search_web( :return: JSON with search results containing title, link, and snippet for each result """ if __request__ is None: - return json.dumps({"error": "Request context not available"}) + return json.dumps({'error': 'Request context not available'}) try: engine = __request__.app.state.config.WEB_SEARCH_ENGINE user = UserModel(**__user__) if __user__ else None # Enforce maximum result count from config to prevent abuse - count = count if count < __request__.app.state.config.WEB_SEARCH_RESULT_COUNT else __request__.app.state.config.WEB_SEARCH_RESULT_COUNT + count = ( + count + if count < __request__.app.state.config.WEB_SEARCH_RESULT_COUNT + else __request__.app.state.config.WEB_SEARCH_RESULT_COUNT + ) results = await asyncio.to_thread(_search_web, __request__, engine, query, user) @@ -177,12 +181,12 @@ async def search_web( results = results[:count] if results else [] return json.dumps( - [{"title": r.title, "link": r.link, "snippet": r.snippet} for r in results], + [{'title': r.title, 'link': r.link, 'snippet': r.snippet} for r in results], ensure_ascii=False, ) except Exception as e: - log.exception(f"search_web error: {e}") - return json.dumps({"error": str(e)}) + log.exception(f'search_web error: {e}') + return json.dumps({'error': str(e)}) async def fetch_url( @@ -197,20 +201,20 @@ async def fetch_url( :return: The extracted text content from the page """ if __request__ is None: - return json.dumps({"error": "Request context not available"}) + return json.dumps({'error': 'Request context not available'}) try: content, _ = await asyncio.to_thread(get_content_from_url, __request__, url) # Truncate if configured (WEB_FETCH_MAX_CONTENT_LENGTH) - max_length = getattr(__request__.app.state.config, "WEB_FETCH_MAX_CONTENT_LENGTH", None) + max_length = getattr(__request__.app.state.config, 'WEB_FETCH_MAX_CONTENT_LENGTH', None) if max_length and max_length > 0 and len(content) > max_length: - content = content[:max_length] + "\n\n[Content truncated...]" + content = content[:max_length] + '\n\n[Content truncated...]' return content except Exception as e: - log.exception(f"fetch_url error: {e}") - return json.dumps({"error": str(e)}) + log.exception(f'fetch_url error: {e}') + return json.dumps({'error': str(e)}) # ============================================================================= @@ -233,7 +237,7 @@ async def generate_image( :return: Confirmation that the image was generated, or an error message """ if __request__ is None: - return json.dumps({"error": "Request context not available"}) + return json.dumps({'error': 'Request context not available'}) try: user = UserModel(**__user__) if __user__ else None @@ -245,7 +249,7 @@ async def generate_image( ) # Prepare file entries for the images - image_files = [{"type": "image", "url": img["url"]} for img in images] + image_files = [{'type': 'image', 'url': img['url']} for img in images] # Persist files to DB if chat context is available if __chat_id__ and __message_id__ and images: @@ -261,26 +265,26 @@ async def generate_image( if __event_emitter__ and image_files: await __event_emitter__( { - "type": "chat:message:files", - "data": { - "files": image_files, + 'type': 'chat:message:files', + 'data': { + 'files': image_files, }, } ) # Return a message indicating the image is already displayed return json.dumps( { - "status": "success", - "message": "The image has been successfully generated and is already visible to the user in the chat. You do not need to display or embed the image again - just acknowledge that it has been created.", - "images": images, + 'status': 'success', + 'message': 'The image has been successfully generated and is already visible to the user in the chat. You do not need to display or embed the image again - just acknowledge that it has been created.', + 'images': images, }, ensure_ascii=False, ) - return json.dumps({"status": "success", "images": images}, ensure_ascii=False) + return json.dumps({'status': 'success', 'images': images}, ensure_ascii=False) except Exception as e: - log.exception(f"generate_image error: {e}") - return json.dumps({"error": str(e)}) + log.exception(f'generate_image error: {e}') + return json.dumps({'error': str(e)}) async def edit_image( @@ -300,7 +304,7 @@ async def edit_image( :return: Confirmation that the images were edited, or an error message """ if __request__ is None: - return json.dumps({"error": "Request context not available"}) + return json.dumps({'error': 'Request context not available'}) try: user = UserModel(**__user__) if __user__ else None @@ -312,7 +316,7 @@ async def edit_image( ) # Prepare file entries for the images - image_files = [{"type": "image", "url": img["url"]} for img in images] + image_files = [{'type': 'image', 'url': img['url']} for img in images] # Persist files to DB if chat context is available if __chat_id__ and __message_id__ and images: @@ -328,26 +332,26 @@ async def edit_image( if __event_emitter__ and image_files: await __event_emitter__( { - "type": "chat:message:files", - "data": { - "files": image_files, + 'type': 'chat:message:files', + 'data': { + 'files': image_files, }, } ) # Return a message indicating the image is already displayed return json.dumps( { - "status": "success", - "message": "The edited image has been successfully generated and is already visible to the user in the chat. You do not need to display or embed the image again - just acknowledge that it has been created.", - "images": images, + 'status': 'success', + 'message': 'The edited image has been successfully generated and is already visible to the user in the chat. You do not need to display or embed the image again - just acknowledge that it has been created.', + 'images': images, }, ensure_ascii=False, ) - return json.dumps({"status": "success", "images": images}, ensure_ascii=False) + return json.dumps({'status': 'success', 'images': images}, ensure_ascii=False) except Exception as e: - log.exception(f"edit_image error: {e}") - return json.dumps({"error": str(e)}) + log.exception(f'edit_image error: {e}') + return json.dumps({'error': str(e)}) # ============================================================================= @@ -376,7 +380,7 @@ async def execute_code( from uuid import uuid4 if __request__ is None: - return json.dumps({"error": "Request context not available"}) + return json.dumps({'error': 'Request context not available'}) try: # Sanitize code (strips ANSI codes and markdown fences) @@ -406,47 +410,39 @@ async def execute_code( builtins.__import__ = restricted_import """) - code = blocking_code + "\n" + code + code = blocking_code + '\n' + code - engine = getattr( - __request__.app.state.config, "CODE_INTERPRETER_ENGINE", "pyodide" - ) - if engine == "pyodide": + engine = getattr(__request__.app.state.config, 'CODE_INTERPRETER_ENGINE', 'pyodide') + if engine == 'pyodide': # Execute via frontend pyodide using bidirectional event call if __event_call__ is None: return json.dumps( - { - "error": "Event call not available. WebSocket connection required for pyodide execution." - } + {'error': 'Event call not available. WebSocket connection required for pyodide execution.'} ) output = await __event_call__( { - "type": "execute:python", - "data": { - "id": str(uuid4()), - "code": code, - "session_id": ( - __metadata__.get("session_id") if __metadata__ else None - ), - "files": ( - __metadata__.get("files", []) if __metadata__ else [] - ), + 'type': 'execute:python', + 'data': { + 'id': str(uuid4()), + 'code': code, + 'session_id': (__metadata__.get('session_id') if __metadata__ else None), + 'files': (__metadata__.get('files', []) if __metadata__ else []), }, } ) # Parse the output - pyodide returns dict with stdout, stderr, result if isinstance(output, dict): - stdout = output.get("stdout", "") - stderr = output.get("stderr", "") - result = output.get("result", "") + stdout = output.get('stdout', '') + stderr = output.get('stderr', '') + result = output.get('result', '') else: - stdout = "" - stderr = "" - result = str(output) if output else "" + stdout = '' + stderr = '' + result = str(output) if output else '' - elif engine == "jupyter": + elif engine == 'jupyter': from open_webui.utils.code_interpreter import execute_code_jupyter output = await execute_code_jupyter( @@ -454,39 +450,37 @@ async def execute_code( code, ( __request__.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH_TOKEN - if __request__.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH - == "token" + if __request__.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH == 'token' else None ), ( __request__.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD - if __request__.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH - == "password" + if __request__.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH == 'password' else None ), __request__.app.state.config.CODE_INTERPRETER_JUPYTER_TIMEOUT, ) - stdout = output.get("stdout", "") - stderr = output.get("stderr", "") - result = output.get("result", "") + stdout = output.get('stdout', '') + stderr = output.get('stderr', '') + result = output.get('result', '') else: - return json.dumps({"error": f"Unknown code interpreter engine: {engine}"}) + return json.dumps({'error': f'Unknown code interpreter engine: {engine}'}) # Handle image outputs (base64 encoded) - replace with uploaded URLs # Get actual user object for image upload (upload_image requires user.id attribute) - if __user__ and __user__.get("id"): + if __user__ and __user__.get('id'): from open_webui.models.users import Users from open_webui.utils.files import get_image_url_from_base64 - user = Users.get_user_by_id(__user__["id"]) + user = Users.get_user_by_id(__user__['id']) # Extract and upload images from stdout if stdout and isinstance(stdout, str): - stdout_lines = stdout.split("\n") + stdout_lines = stdout.split('\n') for idx, line in enumerate(stdout_lines): - if "data:image/png;base64" in line: + if 'data:image/png;base64' in line: image_url = get_image_url_from_base64( __request__, line, @@ -494,14 +488,14 @@ async def execute_code( user, ) if image_url: - stdout_lines[idx] = f"![Output Image]({image_url})" - stdout = "\n".join(stdout_lines) + stdout_lines[idx] = f'![Output Image]({image_url})' + stdout = '\n'.join(stdout_lines) # Extract and upload images from result if result and isinstance(result, str): - result_lines = result.split("\n") + result_lines = result.split('\n') for idx, line in enumerate(result_lines): - if "data:image/png;base64" in line: + if 'data:image/png;base64' in line: image_url = get_image_url_from_base64( __request__, line, @@ -509,20 +503,20 @@ async def execute_code( user, ) if image_url: - result_lines[idx] = f"![Output Image]({image_url})" - result = "\n".join(result_lines) + result_lines[idx] = f'![Output Image]({image_url})' + result = '\n'.join(result_lines) response = { - "status": "success", - "stdout": stdout, - "stderr": stderr, - "result": result, + 'status': 'success', + 'stdout': stdout, + 'stderr': stderr, + 'result': result, } return json.dumps(response, ensure_ascii=False) except Exception as e: - log.exception(f"execute_code error: {e}") - return json.dumps({"error": str(e)}) + log.exception(f'execute_code error: {e}') + return json.dumps({'error': str(e)}) # ============================================================================= @@ -544,7 +538,7 @@ async def search_memories( :return: JSON with matching memories and their dates """ if __request__ is None: - return json.dumps({"error": "Request context not available"}) + return json.dumps({'error': 'Request context not available'}) try: user = UserModel(**__user__) if __user__ else None @@ -555,27 +549,25 @@ async def search_memories( user, ) - if results and hasattr(results, "documents") and results.documents: + if results and hasattr(results, 'documents') and results.documents: memories = [] for doc_idx, doc in enumerate(results.documents[0]): memory_id = None if results.ids and results.ids[0]: memory_id = results.ids[0][doc_idx] - created_at = "Unknown" - if results.metadatas and results.metadatas[0][doc_idx].get( - "created_at" - ): + created_at = 'Unknown' + if results.metadatas and results.metadatas[0][doc_idx].get('created_at'): created_at = time.strftime( - "%Y-%m-%d", - time.localtime(results.metadatas[0][doc_idx]["created_at"]), + '%Y-%m-%d', + time.localtime(results.metadatas[0][doc_idx]['created_at']), ) - memories.append({"id": memory_id, "date": created_at, "content": doc}) + memories.append({'id': memory_id, 'date': created_at, 'content': doc}) return json.dumps(memories, ensure_ascii=False) else: return json.dumps([]) except Exception as e: - log.exception(f"search_memories error: {e}") - return json.dumps({"error": str(e)}) + log.exception(f'search_memories error: {e}') + return json.dumps({'error': str(e)}) async def add_memory( @@ -590,7 +582,7 @@ async def add_memory( :return: Confirmation that the memory was stored """ if __request__ is None: - return json.dumps({"error": "Request context not available"}) + return json.dumps({'error': 'Request context not available'}) try: user = UserModel(**__user__) if __user__ else None @@ -601,10 +593,10 @@ async def add_memory( user, ) - return json.dumps({"status": "success", "id": memory.id}, ensure_ascii=False) + return json.dumps({'status': 'success', 'id': memory.id}, ensure_ascii=False) except Exception as e: - log.exception(f"add_memory error: {e}") - return json.dumps({"error": str(e)}) + log.exception(f'add_memory error: {e}') + return json.dumps({'error': str(e)}) async def replace_memory_content( @@ -621,7 +613,7 @@ async def replace_memory_content( :return: Confirmation that the memory was updated """ if __request__ is None: - return json.dumps({"error": "Request context not available"}) + return json.dumps({'error': 'Request context not available'}) try: user = UserModel(**__user__) if __user__ else None @@ -634,12 +626,12 @@ async def replace_memory_content( ) return json.dumps( - {"status": "success", "id": memory.id, "content": memory.content}, + {'status': 'success', 'id': memory.id, 'content': memory.content}, ensure_ascii=False, ) except Exception as e: - log.exception(f"replace_memory_content error: {e}") - return json.dumps({"error": str(e)}) + log.exception(f'replace_memory_content error: {e}') + return json.dumps({'error': str(e)}) async def delete_memory( @@ -654,7 +646,7 @@ async def delete_memory( :return: Confirmation that the memory was deleted """ if __request__ is None: - return json.dumps({"error": "Request context not available"}) + return json.dumps({'error': 'Request context not available'}) try: user = UserModel(**__user__) if __user__ else None @@ -662,18 +654,16 @@ async def delete_memory( result = Memories.delete_memory_by_id_and_user_id(memory_id, user.id) if result: - VECTOR_DB_CLIENT.delete( - collection_name=f"user-memory-{user.id}", ids=[memory_id] - ) + VECTOR_DB_CLIENT.delete(collection_name=f'user-memory-{user.id}', ids=[memory_id]) return json.dumps( - {"status": "success", "message": f"Memory {memory_id} deleted"}, + {'status': 'success', 'message': f'Memory {memory_id} deleted'}, ensure_ascii=False, ) else: - return json.dumps({"error": "Memory not found or access denied"}) + return json.dumps({'error': 'Memory not found or access denied'}) except Exception as e: - log.exception(f"delete_memory error: {e}") - return json.dumps({"error": str(e)}) + log.exception(f'delete_memory error: {e}') + return json.dumps({'error': str(e)}) async def list_memories( @@ -686,7 +676,7 @@ async def list_memories( :return: JSON list of all memories with id, content, and dates """ if __request__ is None: - return json.dumps({"error": "Request context not available"}) + return json.dumps({'error': 'Request context not available'}) try: user = UserModel(**__user__) if __user__ else None @@ -696,14 +686,10 @@ async def list_memories( if memories: result = [ { - "id": m.id, - "content": m.content, - "created_at": time.strftime( - "%Y-%m-%d %H:%M", time.localtime(m.created_at) - ), - "updated_at": time.strftime( - "%Y-%m-%d %H:%M", time.localtime(m.updated_at) - ), + 'id': m.id, + 'content': m.content, + 'created_at': time.strftime('%Y-%m-%d %H:%M', time.localtime(m.created_at)), + 'updated_at': time.strftime('%Y-%m-%d %H:%M', time.localtime(m.updated_at)), } for m in memories ] @@ -711,8 +697,8 @@ async def list_memories( else: return json.dumps([]) except Exception as e: - log.exception(f"list_memories error: {e}") - return json.dumps({"error": str(e)}) + log.exception(f'list_memories error: {e}') + return json.dumps({'error': str(e)}) # ============================================================================= @@ -738,22 +724,22 @@ async def search_notes( :return: JSON with matching notes containing id, title, and content snippet """ if __request__ is None: - return json.dumps({"error": "Request context not available"}) + return json.dumps({'error': 'Request context not available'}) if not __user__: - return json.dumps({"error": "User context not available"}) + return json.dumps({'error': 'User context not available'}) try: - user_id = __user__.get("id") + user_id = __user__.get('id') user_group_ids = [group.id for group in Groups.get_groups_by_member_id(user_id)] result = Notes.search_notes( user_id=user_id, filter={ - "query": query, - "user_id": user_id, - "group_ids": user_group_ids, - "permission": "read", + 'query': query, + 'user_id': user_id, + 'group_ids': user_group_ids, + 'permission': 'read', }, skip=0, limit=count * 3, # Fetch more for filtering @@ -772,9 +758,9 @@ async def search_notes( continue # Extract a snippet from the markdown content - content_snippet = "" - if note.data and note.data.get("content", {}).get("md"): - md_content = note.data["content"]["md"] + content_snippet = '' + if note.data and note.data.get('content', {}).get('md'): + md_content = note.data['content']['md'] lower_content = md_content.lower() lower_query = query.lower() idx = lower_content.find(lower_query) @@ -782,21 +768,17 @@ async def search_notes( start = max(0, idx - 50) end = min(len(md_content), idx + len(query) + 100) content_snippet = ( - ("..." if start > 0 else "") - + md_content[start:end] - + ("..." if end < len(md_content) else "") + ('...' if start > 0 else '') + md_content[start:end] + ('...' if end < len(md_content) else '') ) else: - content_snippet = md_content[:150] + ( - "..." if len(md_content) > 150 else "" - ) + content_snippet = md_content[:150] + ('...' if len(md_content) > 150 else '') notes.append( { - "id": note.id, - "title": note.title, - "snippet": content_snippet, - "updated_at": note.updated_at, + 'id': note.id, + 'title': note.title, + 'snippet': content_snippet, + 'updated_at': note.updated_at, } ) @@ -805,8 +787,8 @@ async def search_notes( return json.dumps(notes, ensure_ascii=False) except Exception as e: - log.exception(f"search_notes error: {e}") - return json.dumps({"error": str(e)}) + log.exception(f'search_notes error: {e}') + return json.dumps({'error': str(e)}) async def view_note( @@ -821,50 +803,50 @@ async def view_note( :return: JSON with the note's id, title, and full markdown content """ if __request__ is None: - return json.dumps({"error": "Request context not available"}) + return json.dumps({'error': 'Request context not available'}) if not __user__: - return json.dumps({"error": "User context not available"}) + return json.dumps({'error': 'User context not available'}) try: note = Notes.get_note_by_id(note_id) if not note: - return json.dumps({"error": "Note not found"}) + return json.dumps({'error': 'Note not found'}) # Check access permission - user_id = __user__.get("id") + user_id = __user__.get('id') user_group_ids = [group.id for group in Groups.get_groups_by_member_id(user_id)] from open_webui.models.access_grants import AccessGrants if note.user_id != user_id and not AccessGrants.has_access( user_id=user_id, - resource_type="note", + resource_type='note', resource_id=note.id, - permission="read", + permission='read', user_group_ids=set(user_group_ids), ): - return json.dumps({"error": "Access denied"}) + return json.dumps({'error': 'Access denied'}) # Extract markdown content - content = "" - if note.data and note.data.get("content", {}).get("md"): - content = note.data["content"]["md"] + content = '' + if note.data and note.data.get('content', {}).get('md'): + content = note.data['content']['md'] return json.dumps( { - "id": note.id, - "title": note.title, - "content": content, - "updated_at": note.updated_at, - "created_at": note.created_at, + 'id': note.id, + 'title': note.title, + 'content': content, + 'updated_at': note.updated_at, + 'created_at': note.created_at, }, ensure_ascii=False, ) except Exception as e: - log.exception(f"view_note error: {e}") - return json.dumps({"error": str(e)}) + log.exception(f'view_note error: {e}') + return json.dumps({'error': str(e)}) async def write_note( @@ -881,39 +863,39 @@ async def write_note( :return: JSON with success status and new note id """ if __request__ is None: - return json.dumps({"error": "Request context not available"}) + return json.dumps({'error': 'Request context not available'}) if not __user__: - return json.dumps({"error": "User context not available"}) + return json.dumps({'error': 'User context not available'}) try: from open_webui.models.notes import NoteForm - user_id = __user__.get("id") + user_id = __user__.get('id') form = NoteForm( title=title, - data={"content": {"md": content}}, + data={'content': {'md': content}}, access_grants=[], # Private by default - only owner can access ) new_note = Notes.insert_new_note(user_id, form) if not new_note: - return json.dumps({"error": "Failed to create note"}) + return json.dumps({'error': 'Failed to create note'}) return json.dumps( { - "status": "success", - "id": new_note.id, - "title": new_note.title, - "created_at": new_note.created_at, + 'status': 'success', + 'id': new_note.id, + 'title': new_note.title, + 'created_at': new_note.created_at, }, ensure_ascii=False, ) except Exception as e: - log.exception(f"write_note error: {e}") - return json.dumps({"error": str(e)}) + log.exception(f'write_note error: {e}') + return json.dumps({'error': str(e)}) async def replace_note_content( @@ -932,10 +914,10 @@ async def replace_note_content( :return: JSON with success status and updated note info """ if __request__ is None: - return json.dumps({"error": "Request context not available"}) + return json.dumps({'error': 'Request context not available'}) if not __user__: - return json.dumps({"error": "User context not available"}) + return json.dumps({'error': 'User context not available'}) try: from open_webui.models.notes import NoteUpdateForm @@ -943,46 +925,46 @@ async def replace_note_content( note = Notes.get_note_by_id(note_id) if not note: - return json.dumps({"error": "Note not found"}) + return json.dumps({'error': 'Note not found'}) # Check write permission - user_id = __user__.get("id") + user_id = __user__.get('id') user_group_ids = [group.id for group in Groups.get_groups_by_member_id(user_id)] from open_webui.models.access_grants import AccessGrants if note.user_id != user_id and not AccessGrants.has_access( user_id=user_id, - resource_type="note", + resource_type='note', resource_id=note.id, - permission="write", + permission='write', user_group_ids=set(user_group_ids), ): - return json.dumps({"error": "Write access denied"}) + return json.dumps({'error': 'Write access denied'}) # Build update form - update_data = {"data": {"content": {"md": content}}} + update_data = {'data': {'content': {'md': content}}} if title: - update_data["title"] = title + update_data['title'] = title form = NoteUpdateForm(**update_data) updated_note = Notes.update_note_by_id(note_id, form) if not updated_note: - return json.dumps({"error": "Failed to update note"}) + return json.dumps({'error': 'Failed to update note'}) return json.dumps( { - "status": "success", - "id": updated_note.id, - "title": updated_note.title, - "updated_at": updated_note.updated_at, + 'status': 'success', + 'id': updated_note.id, + 'title': updated_note.title, + 'updated_at': updated_note.updated_at, }, ensure_ascii=False, ) except Exception as e: - log.exception(f"replace_note_content error: {e}") - return json.dumps({"error": str(e)}) + log.exception(f'replace_note_content error: {e}') + return json.dumps({'error': str(e)}) # ============================================================================= @@ -1009,13 +991,13 @@ async def search_chats( :return: JSON with matching chats containing id, title, updated_at, and content snippet """ if __request__ is None: - return json.dumps({"error": "Request context not available"}) + return json.dumps({'error': 'Request context not available'}) if not __user__: - return json.dumps({"error": "User context not available"}) + return json.dumps({'error': 'User context not available'}) try: - user_id = __user__.get("id") + user_id = __user__.get('id') chats = Chats.get_chats_by_user_id_and_search_text( user_id=user_id, @@ -1038,32 +1020,28 @@ async def search_chats( continue # Find a matching message snippet - snippet = "" - messages = chat.chat.get("history", {}).get("messages", {}) + snippet = '' + messages = chat.chat.get('history', {}).get('messages', {}) lower_query = query.lower() for msg_id, msg in messages.items(): - content = msg.get("content", "") + content = msg.get('content', '') if isinstance(content, str) and lower_query in content.lower(): idx = content.lower().find(lower_query) start = max(0, idx - 50) end = min(len(content), idx + len(query) + 100) - snippet = ( - ("..." if start > 0 else "") - + content[start:end] - + ("..." if end < len(content) else "") - ) + snippet = ('...' if start > 0 else '') + content[start:end] + ('...' if end < len(content) else '') break if not snippet and lower_query in chat.title.lower(): - snippet = f"Title match: {chat.title}" + snippet = f'Title match: {chat.title}' results.append( { - "id": chat.id, - "title": chat.title, - "snippet": snippet, - "updated_at": chat.updated_at, + 'id': chat.id, + 'title': chat.title, + 'snippet': snippet, + 'updated_at': chat.updated_at, } ) @@ -1072,8 +1050,8 @@ async def search_chats( return json.dumps(results, ensure_ascii=False) except Exception as e: - log.exception(f"search_chats error: {e}") - return json.dumps({"error": str(e)}) + log.exception(f'search_chats error: {e}') + return json.dumps({'error': str(e)}) async def view_chat( @@ -1088,26 +1066,26 @@ async def view_chat( :return: JSON with the chat's id, title, and messages """ if __request__ is None: - return json.dumps({"error": "Request context not available"}) + return json.dumps({'error': 'Request context not available'}) if not __user__: - return json.dumps({"error": "User context not available"}) + return json.dumps({'error': 'User context not available'}) try: - user_id = __user__.get("id") + user_id = __user__.get('id') chat = Chats.get_chat_by_id_and_user_id(chat_id, user_id) if not chat: - return json.dumps({"error": "Chat not found or access denied"}) + return json.dumps({'error': 'Chat not found or access denied'}) # Extract messages from history messages = [] - history = chat.chat.get("history", {}) - msg_dict = history.get("messages", {}) + history = chat.chat.get('history', {}) + msg_dict = history.get('messages', {}) # Build message chain from currentId - current_id = history.get("currentId") + current_id = history.get('currentId') visited = set() while current_id and current_id not in visited: @@ -1116,28 +1094,28 @@ async def view_chat( if msg: messages.append( { - "role": msg.get("role", ""), - "content": msg.get("content", ""), + 'role': msg.get('role', ''), + 'content': msg.get('content', ''), } ) - current_id = msg.get("parentId") if msg else None + current_id = msg.get('parentId') if msg else None # Reverse to get chronological order messages.reverse() return json.dumps( { - "id": chat.id, - "title": chat.title, - "messages": messages, - "updated_at": chat.updated_at, - "created_at": chat.created_at, + 'id': chat.id, + 'title': chat.title, + 'messages': messages, + 'updated_at': chat.updated_at, + 'created_at': chat.created_at, }, ensure_ascii=False, ) except Exception as e: - log.exception(f"view_chat error: {e}") - return json.dumps({"error": str(e)}) + log.exception(f'view_chat error: {e}') + return json.dumps({'error': str(e)}) # ============================================================================= @@ -1159,13 +1137,13 @@ async def search_channels( :return: JSON with matching channels containing id, name, description, and type """ if __request__ is None: - return json.dumps({"error": "Request context not available"}) + return json.dumps({'error': 'Request context not available'}) if not __user__: - return json.dumps({"error": "User context not available"}) + return json.dumps({'error': 'User context not available'}) try: - user_id = __user__.get("id") + user_id = __user__.get('id') # Get all channels the user has access to all_channels = Channels.get_channels_by_user_id(user_id) @@ -1176,15 +1154,15 @@ async def search_channels( for channel in all_channels: name_match = lower_query in channel.name.lower() if channel.name else False - desc_match = lower_query in (channel.description or "").lower() + desc_match = lower_query in (channel.description or '').lower() if name_match or desc_match: matching_channels.append( { - "id": channel.id, - "name": channel.name, - "description": channel.description or "", - "type": channel.type or "public", + 'id': channel.id, + 'name': channel.name, + 'description': channel.description or '', + 'type': channel.type or 'public', } ) @@ -1193,8 +1171,8 @@ async def search_channels( return json.dumps(matching_channels, ensure_ascii=False) except Exception as e: - log.exception(f"search_channels error: {e}") - return json.dumps({"error": str(e)}) + log.exception(f'search_channels error: {e}') + return json.dumps({'error': str(e)}) async def search_channel_messages( @@ -1215,13 +1193,13 @@ async def search_channel_messages( :return: JSON with matching messages containing channel info, message content, and thread context """ if __request__ is None: - return json.dumps({"error": "Request context not available"}) + return json.dumps({'error': 'Request context not available'}) if not __user__: - return json.dumps({"error": "User context not available"}) + return json.dumps({'error': 'User context not available'}) try: - user_id = __user__.get("id") + user_id = __user__.get('id') # Get all channels the user has access to user_channels = Channels.get_channels_by_user_id(user_id) @@ -1249,36 +1227,32 @@ async def search_channel_messages( channel = channel_map.get(msg.channel_id) # Extract snippet around the match - content = msg.content or "" + content = msg.content or '' lower_query = query.lower() idx = content.lower().find(lower_query) if idx != -1: start = max(0, idx - 50) end = min(len(content), idx + len(query) + 100) - snippet = ( - ("..." if start > 0 else "") - + content[start:end] - + ("..." if end < len(content) else "") - ) + snippet = ('...' if start > 0 else '') + content[start:end] + ('...' if end < len(content) else '') else: - snippet = content[:150] + ("..." if len(content) > 150 else "") + snippet = content[:150] + ('...' if len(content) > 150 else '') results.append( { - "channel_id": msg.channel_id, - "channel_name": channel.name if channel else "Unknown", - "message_id": msg.id, - "content_snippet": snippet, - "is_thread_reply": msg.parent_id is not None, - "parent_id": msg.parent_id, - "created_at": msg.created_at, + 'channel_id': msg.channel_id, + 'channel_name': channel.name if channel else 'Unknown', + 'message_id': msg.id, + 'content_snippet': snippet, + 'is_thread_reply': msg.parent_id is not None, + 'parent_id': msg.parent_id, + 'created_at': msg.created_at, } ) return json.dumps(results, ensure_ascii=False) except Exception as e: - log.exception(f"search_channel_messages error: {e}") - return json.dumps({"error": str(e)}) + log.exception(f'search_channel_messages error: {e}') + return json.dumps({'error': str(e)}) async def view_channel_message( @@ -1293,53 +1267,53 @@ async def view_channel_message( :return: JSON with the message content, channel info, and thread replies if any """ if __request__ is None: - return json.dumps({"error": "Request context not available"}) + return json.dumps({'error': 'Request context not available'}) if not __user__: - return json.dumps({"error": "User context not available"}) + return json.dumps({'error': 'User context not available'}) try: - user_id = __user__.get("id") + user_id = __user__.get('id') message = Messages.get_message_by_id(message_id) if not message: - return json.dumps({"error": "Message not found"}) + return json.dumps({'error': 'Message not found'}) # Verify user has access to the channel channel = Channels.get_channel_by_id(message.channel_id) if not channel: - return json.dumps({"error": "Channel not found"}) + return json.dumps({'error': 'Channel not found'}) # Check if user has access to the channel user_channels = Channels.get_channels_by_user_id(user_id) channel_ids = [c.id for c in user_channels] if message.channel_id not in channel_ids: - return json.dumps({"error": "Access denied"}) + return json.dumps({'error': 'Access denied'}) # Build response with thread information result = { - "id": message.id, - "channel_id": message.channel_id, - "channel_name": channel.name, - "content": message.content, - "user_id": message.user_id, - "is_thread_reply": message.parent_id is not None, - "parent_id": message.parent_id, - "reply_count": message.reply_count, - "created_at": message.created_at, - "updated_at": message.updated_at, + 'id': message.id, + 'channel_id': message.channel_id, + 'channel_name': channel.name, + 'content': message.content, + 'user_id': message.user_id, + 'is_thread_reply': message.parent_id is not None, + 'parent_id': message.parent_id, + 'reply_count': message.reply_count, + 'created_at': message.created_at, + 'updated_at': message.updated_at, } # Include user info if available if message.user: - result["user_name"] = message.user.name + result['user_name'] = message.user.name return json.dumps(result, ensure_ascii=False) except Exception as e: - log.exception(f"view_channel_message error: {e}") - return json.dumps({"error": str(e)}) + log.exception(f'view_channel_message error: {e}') + return json.dumps({'error': str(e)}) async def view_channel_thread( @@ -1354,30 +1328,30 @@ async def view_channel_thread( :return: JSON with the parent message and all thread replies in chronological order """ if __request__ is None: - return json.dumps({"error": "Request context not available"}) + return json.dumps({'error': 'Request context not available'}) if not __user__: - return json.dumps({"error": "User context not available"}) + return json.dumps({'error': 'User context not available'}) try: - user_id = __user__.get("id") + user_id = __user__.get('id') # Get the parent message parent_message = Messages.get_message_by_id(parent_message_id) if not parent_message: - return json.dumps({"error": "Message not found"}) + return json.dumps({'error': 'Message not found'}) # Verify user has access to the channel channel = Channels.get_channel_by_id(parent_message.channel_id) if not channel: - return json.dumps({"error": "Channel not found"}) + return json.dumps({'error': 'Channel not found'}) user_channels = Channels.get_channels_by_user_id(user_id) channel_ids = [c.id for c in user_channels] if parent_message.channel_id not in channel_ids: - return json.dumps({"error": "Access denied"}) + return json.dumps({'error': 'Access denied'}) # Get all thread replies thread_replies = Messages.get_thread_replies_by_message_id(parent_message_id) @@ -1388,12 +1362,12 @@ async def view_channel_thread( # Add parent message first messages.append( { - "id": parent_message.id, - "content": parent_message.content, - "user_id": parent_message.user_id, - "user_name": parent_message.user.name if parent_message.user else None, - "is_parent": True, - "created_at": parent_message.created_at, + 'id': parent_message.id, + 'content': parent_message.content, + 'user_id': parent_message.user_id, + 'user_name': parent_message.user.name if parent_message.user else None, + 'is_parent': True, + 'created_at': parent_message.created_at, } ) @@ -1401,29 +1375,29 @@ async def view_channel_thread( for reply in reversed(thread_replies): messages.append( { - "id": reply.id, - "content": reply.content, - "user_id": reply.user_id, - "user_name": reply.user.name if reply.user else None, - "is_parent": False, - "reply_to_id": reply.reply_to_id, - "created_at": reply.created_at, + 'id': reply.id, + 'content': reply.content, + 'user_id': reply.user_id, + 'user_name': reply.user.name if reply.user else None, + 'is_parent': False, + 'reply_to_id': reply.reply_to_id, + 'created_at': reply.created_at, } ) return json.dumps( { - "channel_id": parent_message.channel_id, - "channel_name": channel.name, - "thread_id": parent_message_id, - "message_count": len(messages), - "messages": messages, + 'channel_id': parent_message.channel_id, + 'channel_name': channel.name, + 'thread_id': parent_message_id, + 'message_count': len(messages), + 'messages': messages, }, ensure_ascii=False, ) except Exception as e: - log.exception(f"view_channel_thread error: {e}") - return json.dumps({"error": str(e)}) + log.exception(f'view_channel_thread error: {e}') + return json.dumps({'error': str(e)}) # ============================================================================= @@ -1445,23 +1419,23 @@ async def list_knowledge_bases( :return: JSON with KBs containing id, name, description, and file_count """ if __request__ is None: - return json.dumps({"error": "Request context not available"}) + return json.dumps({'error': 'Request context not available'}) if not __user__: - return json.dumps({"error": "User context not available"}) + return json.dumps({'error': 'User context not available'}) try: from open_webui.models.knowledge import Knowledges - user_id = __user__.get("id") + user_id = __user__.get('id') user_group_ids = [group.id for group in Groups.get_groups_by_member_id(user_id)] result = Knowledges.search_knowledge_bases( user_id, filter={ - "query": "", - "user_id": user_id, - "group_ids": user_group_ids, + 'query': '', + 'user_id': user_id, + 'group_ids': user_group_ids, }, skip=skip, limit=count, @@ -1474,18 +1448,18 @@ async def list_knowledge_bases( knowledge_bases.append( { - "id": knowledge_base.id, - "name": knowledge_base.name, - "description": knowledge_base.description or "", - "file_count": file_count, - "updated_at": knowledge_base.updated_at, + 'id': knowledge_base.id, + 'name': knowledge_base.name, + 'description': knowledge_base.description or '', + 'file_count': file_count, + 'updated_at': knowledge_base.updated_at, } ) return json.dumps(knowledge_bases, ensure_ascii=False) except Exception as e: - log.exception(f"list_knowledge_bases error: {e}") - return json.dumps({"error": str(e)}) + log.exception(f'list_knowledge_bases error: {e}') + return json.dumps({'error': str(e)}) async def search_knowledge_bases( @@ -1504,23 +1478,23 @@ async def search_knowledge_bases( :return: JSON with matching KBs containing id, name, description, and file_count """ if __request__ is None: - return json.dumps({"error": "Request context not available"}) + return json.dumps({'error': 'Request context not available'}) if not __user__: - return json.dumps({"error": "User context not available"}) + return json.dumps({'error': 'User context not available'}) try: from open_webui.models.knowledge import Knowledges - user_id = __user__.get("id") + user_id = __user__.get('id') user_group_ids = [group.id for group in Groups.get_groups_by_member_id(user_id)] result = Knowledges.search_knowledge_bases( user_id, filter={ - "query": query, - "user_id": user_id, - "group_ids": user_group_ids, + 'query': query, + 'user_id': user_id, + 'group_ids': user_group_ids, }, skip=skip, limit=count, @@ -1533,18 +1507,18 @@ async def search_knowledge_bases( knowledge_bases.append( { - "id": knowledge_base.id, - "name": knowledge_base.name, - "description": knowledge_base.description or "", - "file_count": file_count, - "updated_at": knowledge_base.updated_at, + 'id': knowledge_base.id, + 'name': knowledge_base.name, + 'description': knowledge_base.description or '', + 'file_count': file_count, + 'updated_at': knowledge_base.updated_at, } ) return json.dumps(knowledge_bases, ensure_ascii=False) except Exception as e: - log.exception(f"search_knowledge_bases error: {e}") - return json.dumps({"error": str(e)}) + log.exception(f'search_knowledge_bases error: {e}') + return json.dumps({'error': str(e)}) async def search_knowledge_files( @@ -1565,31 +1539,31 @@ async def search_knowledge_files( :return: JSON with matching files containing id, filename, and updated_at """ if __request__ is None: - return json.dumps({"error": "Request context not available"}) + return json.dumps({'error': 'Request context not available'}) if not __user__: - return json.dumps({"error": "User context not available"}) + return json.dumps({'error': 'User context not available'}) try: from open_webui.models.knowledge import Knowledges - user_id = __user__.get("id") + user_id = __user__.get('id') user_group_ids = [group.id for group in Groups.get_groups_by_member_id(user_id)] if knowledge_id: result = Knowledges.search_files_by_id( knowledge_id=knowledge_id, user_id=user_id, - filter={"query": query}, + filter={'query': query}, skip=skip, limit=count, ) else: result = Knowledges.search_knowledge_files( filter={ - "query": query, - "user_id": user_id, - "group_ids": user_group_ids, + 'query': query, + 'user_id': user_id, + 'group_ids': user_group_ids, }, skip=skip, limit=count, @@ -1598,19 +1572,19 @@ async def search_knowledge_files( files = [] for file in result.items: file_info = { - "id": file.id, - "filename": file.filename, - "updated_at": file.updated_at, + 'id': file.id, + 'filename': file.filename, + 'updated_at': file.updated_at, } - if hasattr(file, "collection") and file.collection: - file_info["knowledge_id"] = file.collection.get("id", "") - file_info["knowledge_name"] = file.collection.get("name", "") + if hasattr(file, 'collection') and file.collection: + file_info['knowledge_id'] = file.collection.get('id', '') + file_info['knowledge_name'] = file.collection.get('name', '') files.append(file_info) return json.dumps(files, ensure_ascii=False) except Exception as e: - log.exception(f"search_knowledge_files error: {e}") - return json.dumps({"error": str(e)}) + log.exception(f'search_knowledge_files error: {e}') + return json.dumps({'error': str(e)}) async def view_file( @@ -1626,54 +1600,53 @@ async def view_file( :return: JSON with the file's id, filename, and full text content """ if __request__ is None: - return json.dumps({"error": "Request context not available"}) + return json.dumps({'error': 'Request context not available'}) if not __user__: - return json.dumps({"error": "User context not available"}) + return json.dumps({'error': 'User context not available'}) try: from open_webui.models.files import Files from open_webui.utils.access_control.files import has_access_to_file - user_id = __user__.get("id") - user_role = __user__.get("role", "user") + user_id = __user__.get('id') + user_role = __user__.get('role', 'user') file = Files.get_file_by_id(file_id) if not file: - return json.dumps({"error": "File not found"}) + return json.dumps({'error': 'File not found'}) if ( file.user_id != user_id - and user_role != "admin" + and user_role != 'admin' and not any( - item.get("type") == "file" and item.get("id") == file_id - for item in (__model_knowledge__ or []) + item.get('type') == 'file' and item.get('id') == file_id for item in (__model_knowledge__ or []) ) and not has_access_to_file( file_id=file_id, - access_type="read", + access_type='read', user=UserModel(**__user__), ) ): - return json.dumps({"error": "File not found"}) + return json.dumps({'error': 'File not found'}) - content = "" + content = '' if file.data: - content = file.data.get("content", "") + content = file.data.get('content', '') return json.dumps( { - "id": file.id, - "filename": file.filename, - "content": content, - "updated_at": file.updated_at, - "created_at": file.created_at, + 'id': file.id, + 'filename': file.filename, + 'content': content, + 'updated_at': file.updated_at, + 'created_at': file.created_at, }, ensure_ascii=False, ) except Exception as e: - log.exception(f"view_file error: {e}") - return json.dumps({"error": str(e)}) + log.exception(f'view_file error: {e}') + return json.dumps({'error': str(e)}) async def view_knowledge_file( @@ -1688,23 +1661,23 @@ async def view_knowledge_file( :return: JSON with the file's id, filename, and full text content """ if __request__ is None: - return json.dumps({"error": "Request context not available"}) + return json.dumps({'error': 'Request context not available'}) if not __user__: - return json.dumps({"error": "User context not available"}) + return json.dumps({'error': 'User context not available'}) try: from open_webui.models.files import Files from open_webui.models.knowledge import Knowledges from open_webui.models.access_grants import AccessGrants - user_id = __user__.get("id") - user_role = __user__.get("role", "user") + user_id = __user__.get('id') + user_role = __user__.get('role', 'user') user_group_ids = [group.id for group in Groups.get_groups_by_member_id(user_id)] file = Files.get_file_by_id(file_id) if not file: - return json.dumps({"error": "File not found"}) + return json.dumps({'error': 'File not found'}) # Check access via any KB containing this file knowledges = Knowledges.get_knowledges_by_file_id(file_id) @@ -1713,43 +1686,43 @@ async def view_knowledge_file( for knowledge_base in knowledges: if ( - user_role == "admin" + user_role == 'admin' or knowledge_base.user_id == user_id or AccessGrants.has_access( user_id=user_id, - resource_type="knowledge", + resource_type='knowledge', resource_id=knowledge_base.id, - permission="read", + permission='read', user_group_ids=set(user_group_ids), ) ): has_knowledge_access = True - knowledge_info = {"id": knowledge_base.id, "name": knowledge_base.name} + knowledge_info = {'id': knowledge_base.id, 'name': knowledge_base.name} break if not has_knowledge_access: - if file.user_id != user_id and user_role != "admin": - return json.dumps({"error": "Access denied"}) + if file.user_id != user_id and user_role != 'admin': + return json.dumps({'error': 'Access denied'}) - content = "" + content = '' if file.data: - content = file.data.get("content", "") + content = file.data.get('content', '') result = { - "id": file.id, - "filename": file.filename, - "content": content, - "updated_at": file.updated_at, - "created_at": file.created_at, + 'id': file.id, + 'filename': file.filename, + 'content': content, + 'updated_at': file.updated_at, + 'created_at': file.created_at, } if knowledge_info: - result["knowledge_id"] = knowledge_info["id"] - result["knowledge_name"] = knowledge_info["name"] + result['knowledge_id'] = knowledge_info['id'] + result['knowledge_name'] = knowledge_info['name'] return json.dumps(result, ensure_ascii=False) except Exception as e: - log.exception(f"view_knowledge_file error: {e}") - return json.dumps({"error": str(e)}) + log.exception(f'view_knowledge_file error: {e}') + return json.dumps({'error': str(e)}) async def query_knowledge_files( @@ -1770,10 +1743,10 @@ async def query_knowledge_files( :return: JSON with relevant chunks containing content, source filename, and relevance score """ if __request__ is None: - return json.dumps({"error": "Request context not available"}) + return json.dumps({'error': 'Request context not available'}) if not __user__: - return json.dumps({"error": "User context not available"}) + return json.dumps({'error': 'User context not available'}) # Coerce parameters from LLM tool calls (may come as strings) if isinstance(count, str): @@ -1784,7 +1757,7 @@ async def query_knowledge_files( # Handle knowledge_ids being string "None", "null", or empty if isinstance(knowledge_ids, str): - if knowledge_ids.lower() in ("none", "null", ""): + if knowledge_ids.lower() in ('none', 'null', ''): knowledge_ids = None else: # Try to parse as JSON array if it looks like one @@ -1801,13 +1774,13 @@ async def query_knowledge_files( from open_webui.retrieval.utils import query_collection from open_webui.models.access_grants import AccessGrants - user_id = __user__.get("id") - user_role = __user__.get("role", "user") + user_id = __user__.get('id') + user_role = __user__.get('role', 'user') user_group_ids = [group.id for group in Groups.get_groups_by_member_id(user_id)] embedding_function = __request__.app.state.EMBEDDING_FUNCTION if not embedding_function: - return json.dumps({"error": "Embedding function not configured"}) + return json.dumps({'error': 'Embedding function not configured'}) collection_names = [] note_results = [] # Notes aren't vectorized, handle separately @@ -1815,51 +1788,51 @@ async def query_knowledge_files( # If model has attached knowledge, use those if __model_knowledge__: for item in __model_knowledge__: - item_type = item.get("type") - item_id = item.get("id") + item_type = item.get('type') + item_id = item.get('id') - if item_type == "collection": + if item_type == 'collection': # Knowledge base - use KB ID as collection name knowledge = Knowledges.get_knowledge_by_id(item_id) if knowledge and ( - user_role == "admin" + user_role == 'admin' or knowledge.user_id == user_id or AccessGrants.has_access( user_id=user_id, - resource_type="knowledge", + resource_type='knowledge', resource_id=knowledge.id, - permission="read", + permission='read', user_group_ids=set(user_group_ids), ) ): collection_names.append(item_id) - elif item_type == "file": + elif item_type == 'file': # Individual file - use file-{id} as collection name file = Files.get_file_by_id(item_id) if file: - collection_names.append(f"file-{item_id}") + collection_names.append(f'file-{item_id}') - elif item_type == "note": + elif item_type == 'note': # Note - always return full content as context note = Notes.get_note_by_id(item_id) if note and ( - user_role == "admin" + user_role == 'admin' or note.user_id == user_id or AccessGrants.has_access( user_id=user_id, - resource_type="note", + resource_type='note', resource_id=note.id, - permission="read", + permission='read', ) ): - content = note.data.get("content", {}).get("md", "") + content = note.data.get('content', {}).get('md', '') note_results.append( { - "content": content, - "source": note.title, - "note_id": note.id, - "type": "note", + 'content': content, + 'source': note.title, + 'note_id': note.id, + 'type': 'note', } ) @@ -1868,13 +1841,13 @@ async def query_knowledge_files( for knowledge_id in knowledge_ids: knowledge = Knowledges.get_knowledge_by_id(knowledge_id) if knowledge and ( - user_role == "admin" + user_role == 'admin' or knowledge.user_id == user_id or AccessGrants.has_access( user_id=user_id, - resource_type="knowledge", + resource_type='knowledge', resource_id=knowledge.id, - permission="read", + permission='read', user_group_ids=set(user_group_ids), ) ): @@ -1884,9 +1857,9 @@ async def query_knowledge_files( result = Knowledges.search_knowledge_bases( user_id, filter={ - "query": "", - "user_id": user_id, - "group_ids": user_group_ids, + 'query': '', + 'user_id': user_id, + 'group_ids': user_group_ids, }, skip=0, limit=50, @@ -1907,21 +1880,19 @@ async def query_knowledge_files( k=count, ) - if query_results and "documents" in query_results: - documents = query_results.get("documents", [[]])[0] - metadatas = query_results.get("metadatas", [[]])[0] - distances = query_results.get("distances", [[]])[0] + if query_results and 'documents' in query_results: + documents = query_results.get('documents', [[]])[0] + metadatas = query_results.get('metadatas', [[]])[0] + distances = query_results.get('distances', [[]])[0] for idx, doc in enumerate(documents): chunk_info = { - "content": doc, - "source": metadatas[idx].get( - "source", metadatas[idx].get("name", "Unknown") - ), - "file_id": metadatas[idx].get("file_id", ""), + 'content': doc, + 'source': metadatas[idx].get('source', metadatas[idx].get('name', 'Unknown')), + 'file_id': metadatas[idx].get('file_id', ''), } if idx < len(distances): - chunk_info["distance"] = distances[idx] + chunk_info['distance'] = distances[idx] chunks.append(chunk_info) # Limit to requested count @@ -1929,8 +1900,8 @@ async def query_knowledge_files( return json.dumps(chunks, ensure_ascii=False) except Exception as e: - log.exception(f"query_knowledge_files error: {e}") - return json.dumps({"error": str(e)}) + log.exception(f'query_knowledge_files error: {e}') + return json.dumps({'error': str(e)}) async def query_knowledge_bases( @@ -1949,10 +1920,10 @@ async def query_knowledge_bases( :return: JSON with matching KBs (id, name, description, similarity) """ if __request__ is None: - return json.dumps({"error": "Request context not available"}) + return json.dumps({'error': 'Request context not available'}) if not __user__: - return json.dumps({"error": "User context not available"}) + return json.dumps({'error': 'User context not available'}) try: import heapq @@ -1960,7 +1931,7 @@ async def query_knowledge_bases( from open_webui.routers.knowledge import KNOWLEDGE_BASES_COLLECTION from open_webui.retrieval.vector.factory import VECTOR_DB_CLIENT - user_id = __user__.get("id") + user_id = __user__.get('id') user_group_ids = [group.id for group in Groups.get_groups_by_member_id(user_id)] query_embedding = await __request__.app.state.EMBEDDING_FUNCTION(query) @@ -1973,7 +1944,7 @@ async def query_knowledge_bases( while True: accessible_knowledge_bases = Knowledges.search_knowledge_bases( user_id, - filter={"user_id": user_id, "group_ids": user_group_ids}, + filter={'user_id': user_id, 'group_ids': user_group_ids}, skip=page_offset, limit=page_size, ) @@ -1986,17 +1957,13 @@ async def query_knowledge_bases( search_results = VECTOR_DB_CLIENT.search( collection_name=KNOWLEDGE_BASES_COLLECTION, vectors=[query_embedding], - filter={"knowledge_base_id": {"$in": accessible_ids}}, + filter={'knowledge_base_id': {'$in': accessible_ids}}, limit=count, ) if search_results and search_results.ids and search_results.ids[0]: result_ids = search_results.ids[0] - result_distances = ( - search_results.distances[0] - if search_results.distances - else [0] * len(result_ids) - ) + result_distances = search_results.distances[0] if search_results.distances else [0] * len(result_ids) for knowledge_base_id, distance in zip(result_ids, result_distances): if knowledge_base_id in seen_ids: @@ -2006,9 +1973,7 @@ async def query_knowledge_bases( if len(top_results_heap) < count: heapq.heappush(top_results_heap, (distance, knowledge_base_id)) elif distance > top_results_heap[0][0]: - heapq.heapreplace( - top_results_heap, (distance, knowledge_base_id) - ) + heapq.heapreplace(top_results_heap, (distance, knowledge_base_id)) page_offset += page_size if len(accessible_knowledge_bases.items) < page_size: @@ -2025,18 +1990,18 @@ async def query_knowledge_bases( if knowledge_base: matching_knowledge_bases.append( { - "id": knowledge_base.id, - "name": knowledge_base.name, - "description": knowledge_base.description or "", - "similarity": round(distance, 4), + 'id': knowledge_base.id, + 'name': knowledge_base.name, + 'description': knowledge_base.description or '', + 'similarity': round(distance, 4), } ) return json.dumps(matching_knowledge_bases, ensure_ascii=False) except Exception as e: - log.exception(f"query_knowledge_bases error: {e}") - return json.dumps({"error": str(e)}) + log.exception(f'query_knowledge_bases error: {e}') + return json.dumps({'error': str(e)}) # ============================================================================= @@ -2057,45 +2022,43 @@ async def view_skill( :return: The full skill instructions as markdown content """ if __request__ is None: - return json.dumps({"error": "Request context not available"}) + return json.dumps({'error': 'Request context not available'}) if not __user__: - return json.dumps({"error": "User context not available"}) + return json.dumps({'error': 'User context not available'}) try: from open_webui.models.skills import Skills from open_webui.models.access_grants import AccessGrants - user_id = __user__.get("id") + user_id = __user__.get('id') # Direct DB lookup by unique name skill = Skills.get_skill_by_name(name) if not skill or not skill.is_active: - return json.dumps({"error": f"Skill '{name}' not found"}) + return json.dumps({'error': f"Skill '{name}' not found"}) # Check user access - user_role = __user__.get("role", "user") - if user_role != "admin" and skill.user_id != user_id: - user_group_ids = [ - group.id for group in Groups.get_groups_by_member_id(user_id) - ] + user_role = __user__.get('role', 'user') + if user_role != 'admin' and skill.user_id != user_id: + user_group_ids = [group.id for group in Groups.get_groups_by_member_id(user_id)] if not AccessGrants.has_access( user_id=user_id, - resource_type="skill", + resource_type='skill', resource_id=skill.id, - permission="read", + permission='read', user_group_ids=set(user_group_ids), ): - return json.dumps({"error": "Access denied"}) + return json.dumps({'error': 'Access denied'}) return json.dumps( { - "name": skill.name, - "content": skill.content, + 'name': skill.name, + 'content': skill.content, }, ensure_ascii=False, ) except Exception as e: - log.exception(f"view_skill error: {e}") - return json.dumps({"error": str(e)}) + log.exception(f'view_skill error: {e}') + return json.dumps({'error': str(e)}) diff --git a/backend/open_webui/utils/access_control/__init__.py b/backend/open_webui/utils/access_control/__init__.py index b2fb9ece17..3ee394acc9 100644 --- a/backend/open_webui/utils/access_control/__init__.py +++ b/backend/open_webui/utils/access_control/__init__.py @@ -13,9 +13,7 @@ from open_webui.config import DEFAULT_USER_PERMISSIONS from sqlalchemy.orm import Session -def fill_missing_permissions( - permissions: dict[str, Any], default_permissions: dict[str, Any] -) -> dict[str, Any]: +def fill_missing_permissions(permissions: dict[str, Any], default_permissions: dict[str, Any]) -> dict[str, Any]: """ Recursively fills in missing properties in the permissions dictionary using the default permissions as a template. @@ -23,9 +21,7 @@ def fill_missing_permissions( for key, value in default_permissions.items(): if key not in permissions: permissions[key] = value - elif isinstance(value, dict) and isinstance( - permissions[key], dict - ): # Both are nested dictionaries + elif isinstance(value, dict) and isinstance(permissions[key], dict): # Both are nested dictionaries permissions[key] = fill_missing_permissions(permissions[key], value) return permissions @@ -42,9 +38,7 @@ def get_permissions( Permissions are nested in a dict with the permission key as the key and a boolean as the value. """ - def combine_permissions( - permissions: dict[str, Any], group_permissions: dict[str, Any] - ) -> dict[str, Any]: + def combine_permissions(permissions: dict[str, Any], group_permissions: dict[str, Any]) -> dict[str, Any]: """Combine permissions from multiple groups by taking the most permissive value.""" for key, value in group_permissions.items(): if isinstance(value, dict): @@ -55,9 +49,7 @@ def get_permissions( if key not in permissions: permissions[key] = value else: - permissions[key] = ( - permissions[key] or value - ) # Use the most permissive value (True > False) + permissions[key] = permissions[key] or value # Use the most permissive value (True > False) return permissions user_groups = Groups.get_groups_by_member_id(user_id, db=db) @@ -97,7 +89,7 @@ def has_permission( return bool(permissions) # Return the boolean at the final level - permission_hierarchy = permission_key.split(".") + permission_hierarchy = permission_key.split('.') # Retrieve user group permissions user_groups = Groups.get_groups_by_member_id(user_id, db=db) @@ -107,15 +99,13 @@ def has_permission( return True # Check default permissions afterward if the group permissions don't allow it - default_permissions = fill_missing_permissions( - default_permissions, DEFAULT_USER_PERMISSIONS - ) + default_permissions = fill_missing_permissions(default_permissions, DEFAULT_USER_PERMISSIONS) return get_permission(default_permissions, permission_hierarchy) def has_access( user_id: str, - permission: str = "read", + permission: str = 'read', access_grants: list | None = None, user_group_ids: set[str] | None = None, db: Session | None = None, @@ -141,19 +131,13 @@ def has_access( for grant in access_grants: if not isinstance(grant, dict): continue - if grant.get("permission") != permission: + if grant.get('permission') != permission: continue - principal_type = grant.get("principal_type") - principal_id = grant.get("principal_id") - if principal_type == "user" and ( - principal_id == "*" or principal_id == user_id - ): + principal_type = grant.get('principal_type') + principal_id = grant.get('principal_id') + if principal_type == 'user' and (principal_id == '*' or principal_id == user_id): return True - if ( - principal_type == "group" - and user_group_ids - and principal_id in user_group_ids - ): + if principal_type == 'group' and user_group_ids and principal_id in user_group_ids: return True return False @@ -174,19 +158,17 @@ def has_connection_access( """ from open_webui.config import BYPASS_ADMIN_ACCESS_CONTROL - if user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL: + if user.role == 'admin' and BYPASS_ADMIN_ACCESS_CONTROL: return True if user_group_ids is None: user_group_ids = {group.id for group in Groups.get_groups_by_member_id(user.id)} - access_grants = (connection.get("config") or {}).get("access_grants", []) - return has_access(user.id, "read", access_grants, user_group_ids) + access_grants = (connection.get('config') or {}).get('access_grants', []) + return has_access(user.id, 'read', access_grants, user_group_ids) -def migrate_access_control( - data: dict, ac_key: str = "access_control", grants_key: str = "access_grants" -) -> None: +def migrate_access_control(data: dict, ac_key: str = 'access_control', grants_key: str = 'access_grants') -> None: """ Auto-migrate a config dict in-place from legacy access_control dict to access_grants list. @@ -202,24 +184,24 @@ def migrate_access_control( grants: list[dict[str, str]] = [] if access_control and isinstance(access_control, dict): - for perm in ["read", "write"]: + for perm in ['read', 'write']: perm_data = access_control.get(perm, {}) if not perm_data: continue - for group_id in perm_data.get("group_ids", []): + for group_id in perm_data.get('group_ids', []): grants.append( { - "principal_type": "group", - "principal_id": group_id, - "permission": perm, + 'principal_type': 'group', + 'principal_id': group_id, + 'permission': perm, } ) - for uid in perm_data.get("user_ids", []): + for uid in perm_data.get('user_ids', []): grants.append( { - "principal_type": "user", - "principal_id": uid, - "permission": perm, + 'principal_type': 'user', + 'principal_id': uid, + 'permission': perm, } ) @@ -239,7 +221,7 @@ def filter_allowed_access_grants( Checks if the user has the required permissions to grant access to a resource. Returns the filtered list of access grants if permissions are missing. """ - if user_role == "admin" or not access_grants: + if user_role == 'admin' or not access_grants: return access_grants # Check if user can share publicly @@ -253,25 +235,17 @@ def filter_allowed_access_grants( grant for grant in access_grants if not ( - ( - grant.get("principal_type") - if isinstance(grant, dict) - else getattr(grant, "principal_type", None) - ) - == "user" - and ( - grant.get("principal_id") - if isinstance(grant, dict) - else getattr(grant, "principal_id", None) - ) - == "*" + (grant.get('principal_type') if isinstance(grant, dict) else getattr(grant, 'principal_type', None)) + == 'user' + and (grant.get('principal_id') if isinstance(grant, dict) else getattr(grant, 'principal_id', None)) + == '*' ) ] # Strip individual user sharing if user lacks permission if has_user_access_grant(access_grants) and not has_permission( user_id, - "access_grants.allow_users", + 'access_grants.allow_users', default_permissions, db=db, ): diff --git a/backend/open_webui/utils/access_control/files.py b/backend/open_webui/utils/access_control/files.py index c69f3dfa9d..a7e35fd506 100644 --- a/backend/open_webui/utils/access_control/files.py +++ b/backend/open_webui/utils/access_control/files.py @@ -31,7 +31,7 @@ def has_access_to_file( file.user_id == user.id separately before calling this. """ file = Files.get_file_by_id(file_id, db=db) - log.debug(f"Checking if user has {access_type} access to file") + log.debug(f'Checking if user has {access_type} access to file') if not file: return False @@ -41,13 +41,11 @@ def has_access_to_file( # Check if the file is associated with any knowledge bases the user has access to knowledge_bases = Knowledges.get_knowledges_by_file_id(file_id, db=db) - user_group_ids = { - group.id for group in Groups.get_groups_by_member_id(user.id, db=db) - } + user_group_ids = {group.id for group in Groups.get_groups_by_member_id(user.id, db=db)} for knowledge_base in knowledge_bases: if knowledge_base.user_id == user.id or AccessGrants.has_access( user_id=user.id, - resource_type="knowledge", + resource_type='knowledge', resource_id=knowledge_base.id, permission=access_type, user_group_ids=user_group_ids, @@ -55,18 +53,16 @@ def has_access_to_file( ): return True - knowledge_base_id = file.meta.get("collection_name") if file.meta else None + knowledge_base_id = file.meta.get('collection_name') if file.meta else None if knowledge_base_id: - knowledge_bases = Knowledges.get_knowledge_bases_by_user_id( - user.id, access_type, db=db - ) + knowledge_bases = Knowledges.get_knowledge_bases_by_user_id(user.id, access_type, db=db) for knowledge_base in knowledge_bases: if knowledge_base.id == knowledge_base_id: return True # Check if the file is associated with any channels the user has access to channels = Channels.get_channels_by_file_id_and_user_id(file_id, user.id, db=db) - if access_type == "read" and channels: + if access_type == 'read' and channels: return True # Check if the file is associated with any chats the user has access to @@ -77,13 +73,9 @@ def has_access_to_file( # Check if the file is directly attached to a shared workspace model for model in Models.get_models_by_user_id(user.id, permission=access_type, db=db): - knowledge_items = getattr(model.meta, "knowledge", None) or [] + knowledge_items = getattr(model.meta, 'knowledge', None) or [] for item in knowledge_items: - if ( - isinstance(item, dict) - and item.get("type") == "file" - and item.get("id") == file.id - ): + if isinstance(item, dict) and item.get('type') == 'file' and item.get('id') == file.id: return True return False diff --git a/backend/open_webui/utils/actions.py b/backend/open_webui/utils/actions.py index 0b4b817f0a..5c5712fa0f 100644 --- a/backend/open_webui/utils/actions.py +++ b/backend/open_webui/utils/actions.py @@ -21,70 +21,70 @@ log = logging.getLogger(__name__) async def chat_action(request: Request, action_id: str, form_data: dict, user: Any): - if "." in action_id: - action_id, sub_action_id = action_id.split(".") + if '.' in action_id: + action_id, sub_action_id = action_id.split('.') else: sub_action_id = None action = Functions.get_function_by_id(action_id) if not action: - raise Exception(f"Action not found: {action_id}") + raise Exception(f'Action not found: {action_id}') if not request.app.state.MODELS: await get_all_models(request, user=user) - if getattr(request.state, "direct", False) and hasattr(request.state, "model"): + if getattr(request.state, 'direct', False) and hasattr(request.state, 'model'): models = { - request.state.model["id"]: request.state.model, + request.state.model['id']: request.state.model, } else: models = request.app.state.MODELS data = form_data - model_id = data["model"] + model_id = data['model'] if model_id not in models: - raise Exception("Model not found") + raise Exception('Model not found') model = models[model_id] __event_emitter__ = get_event_emitter( { - "chat_id": data["chat_id"], - "message_id": data["id"], - "session_id": data["session_id"], - "user_id": user.id, + 'chat_id': data['chat_id'], + 'message_id': data['id'], + 'session_id': data['session_id'], + 'user_id': user.id, } ) __event_call__ = get_event_call( { - "chat_id": data["chat_id"], - "message_id": data["id"], - "session_id": data["session_id"], - "user_id": user.id, + 'chat_id': data['chat_id'], + 'message_id': data['id'], + 'session_id': data['session_id'], + 'user_id': user.id, } ) function_module, _, _ = get_function_module_from_cache(request, action_id) - if hasattr(function_module, "valves") and hasattr(function_module, "Valves"): + if hasattr(function_module, 'valves') and hasattr(function_module, 'Valves'): valves = Functions.get_function_valves_by_id(action_id) function_module.valves = function_module.Valves(**(valves if valves else {})) - if hasattr(function_module, "action"): + if hasattr(function_module, 'action'): try: action = function_module.action # Get the signature of the function sig = inspect.signature(action) - params = {"body": data} + params = {'body': data} # Extra parameters to be passed to the function extra_params = { - "__model__": model, - "__id__": sub_action_id if sub_action_id is not None else action_id, - "__event_emitter__": __event_emitter__, - "__event_call__": __event_call__, - "__request__": request, + '__model__': model, + '__id__': sub_action_id if sub_action_id is not None else action_id, + '__event_emitter__': __event_emitter__, + '__event_call__': __event_call__, + '__request__': request, } # Add extra params in contained in function signature @@ -92,20 +92,18 @@ async def chat_action(request: Request, action_id: str, form_data: dict, user: A if key in sig.parameters: params[key] = value - if "__user__" in sig.parameters: + if '__user__' in sig.parameters: __user__ = user.model_dump() if isinstance(user, UserModel) else {} try: - if hasattr(function_module, "UserValves"): - __user__["valves"] = function_module.UserValves( - **Functions.get_user_valves_by_id_and_user_id( - action_id, user.id - ) + if hasattr(function_module, 'UserValves'): + __user__['valves'] = function_module.UserValves( + **Functions.get_user_valves_by_id_and_user_id(action_id, user.id) ) except Exception as e: - log.exception(f"Failed to get user values: {e}") + log.exception(f'Failed to get user values: {e}') - params = {**params, "__user__": __user__} + params = {**params, '__user__': __user__} if inspect.iscoroutinefunction(action): data = await action(**params) @@ -117,15 +115,15 @@ async def chat_action(request: Request, action_id: str, form_data: dict, user: A request, action_id, data, - "action", + 'action', ) if action_embeds: await __event_emitter__( { - "type": "embeds", - "data": { - "embeds": action_embeds, + 'type': 'embeds', + 'data': { + 'embeds': action_embeds, }, } ) @@ -134,6 +132,6 @@ async def chat_action(request: Request, action_id: str, form_data: dict, user: A data = processed_result except Exception as e: - raise Exception(f"Error: {e}") + raise Exception(f'Error: {e}') return data diff --git a/backend/open_webui/utils/anthropic.py b/backend/open_webui/utils/anthropic.py index 736f7238bc..5ba4099fb4 100644 --- a/backend/open_webui/utils/anthropic.py +++ b/backend/open_webui/utils/anthropic.py @@ -16,7 +16,7 @@ log = logging.getLogger(__name__) def is_anthropic_url(url: str) -> bool: """Check if the URL is an Anthropic API endpoint.""" - return "api.anthropic.com" in url + return 'api.anthropic.com' in url async def get_anthropic_models(url: str, key: str, user: UserModel = None) -> dict: @@ -31,56 +31,56 @@ async def get_anthropic_models(url: str, key: str, user: UserModel = None) -> di try: async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session: headers = { - "x-api-key": key, - "anthropic-version": "2023-06-01", + 'x-api-key': key, + 'anthropic-version': '2023-06-01', } if ENABLE_FORWARD_USER_INFO_HEADERS and user: headers = include_user_info_headers(headers, user) while True: - params = {"limit": 1000} + params = {'limit': 1000} if after_id: - params["after_id"] = after_id + params['after_id'] = after_id async with session.get( - f"{url}/models", + f'{url}/models', headers=headers, params=params, ssl=AIOHTTP_CLIENT_SESSION_SSL, ) as response: if response.status != 200: - error_detail = f"HTTP Error: {response.status}" + error_detail = f'HTTP Error: {response.status}' try: res = await response.json() - if "error" in res: - error_detail = f"External Error: {res['error']}" + if 'error' in res: + error_detail = f'External Error: {res["error"]}' except Exception: pass - return {"object": "list", "data": [], "error": error_detail} + return {'object': 'list', 'data': [], 'error': error_detail} data = await response.json() - for model in data.get("data", []): + for model in data.get('data', []): all_models.append( { - "id": model.get("id"), - "object": "model", - "created": 0, - "owned_by": "anthropic", - "name": model.get("display_name", model.get("id")), + 'id': model.get('id'), + 'object': 'model', + 'created': 0, + 'owned_by': 'anthropic', + 'name': model.get('display_name', model.get('id')), } ) - if not data.get("has_more", False): + if not data.get('has_more', False): break - after_id = data.get("last_id") + after_id = data.get('last_id') except Exception as e: - log.error(f"Anthropic connection error: {e}") + log.error(f'Anthropic connection error: {e}') return None - return {"object": "list", "data": all_models} + return {'object': 'list', 'data': all_models} ############################## @@ -102,245 +102,241 @@ def convert_anthropic_to_openai_payload(anthropic_payload: dict) -> dict: openai_payload = {} # Model - openai_payload["model"] = anthropic_payload.get("model", "") + openai_payload['model'] = anthropic_payload.get('model', '') # Build messages list messages = [] # System prompt (Anthropic has it as top-level, OpenAI as a system message) - system = anthropic_payload.get("system") + system = anthropic_payload.get('system') if system: if isinstance(system, str): - messages.append({"role": "system", "content": system}) + messages.append({'role': 'system', 'content': system}) elif isinstance(system, list): # Anthropic supports system as list of content blocks text_parts = [] for block in system: - if isinstance(block, dict) and block.get("type") == "text": - text_parts.append(block.get("text", "")) + if isinstance(block, dict) and block.get('type') == 'text': + text_parts.append(block.get('text', '')) elif isinstance(block, str): text_parts.append(block) - messages.append({"role": "system", "content": "\n".join(text_parts)}) + messages.append({'role': 'system', 'content': '\n'.join(text_parts)}) # Convert messages - for msg in anthropic_payload.get("messages", []): - role = msg.get("role", "user") - content = msg.get("content") + for msg in anthropic_payload.get('messages', []): + role = msg.get('role', 'user') + content = msg.get('content') if isinstance(content, str): - messages.append({"role": role, "content": content}) + messages.append({'role': role, 'content': content}) elif isinstance(content, list): # Convert Anthropic content blocks to OpenAI format openai_content = [] tool_calls = [] for block in content: - block_type = block.get("type", "text") + block_type = block.get('type', 'text') - if block_type == "text": + if block_type == 'text': openai_content.append( { - "type": "text", - "text": block.get("text", ""), + 'type': 'text', + 'text': block.get('text', ''), } ) - elif block_type == "image": - source = block.get("source", {}) - if source.get("type") == "base64": - media_type = source.get("media_type", "image/png") - data = source.get("data", "") + elif block_type == 'image': + source = block.get('source', {}) + if source.get('type') == 'base64': + media_type = source.get('media_type', 'image/png') + data = source.get('data', '') openai_content.append( { - "type": "image_url", - "image_url": { - "url": f"data:{media_type};base64,{data}", + 'type': 'image_url', + 'image_url': { + 'url': f'data:{media_type};base64,{data}', }, } ) - elif source.get("type") == "url": + elif source.get('type') == 'url': openai_content.append( { - "type": "image_url", - "image_url": {"url": source.get("url", "")}, + 'type': 'image_url', + 'image_url': {'url': source.get('url', '')}, } ) - elif block_type == "tool_use": + elif block_type == 'tool_use': tool_calls.append( { - "id": block.get("id", ""), - "type": "function", - "function": { - "name": block.get("name", ""), - "arguments": ( - json.dumps(block.get("input", {})) - if isinstance(block.get("input"), dict) - else str(block.get("input", "{}")) + 'id': block.get('id', ''), + 'type': 'function', + 'function': { + 'name': block.get('name', ''), + 'arguments': ( + json.dumps(block.get('input', {})) + if isinstance(block.get('input'), dict) + else str(block.get('input', '{}')) ), }, } ) - elif block_type == "tool_result": + elif block_type == 'tool_result': # Tool results become separate tool messages in OpenAI format - tool_content = block.get("content", "") + tool_content = block.get('content', '') if isinstance(tool_content, list): tool_text_parts = [] for tc in tool_content: - if isinstance(tc, dict) and tc.get("type") == "text": - tool_text_parts.append(tc.get("text", "")) - tool_content = "\n".join(tool_text_parts) + if isinstance(tc, dict) and tc.get('type') == 'text': + tool_text_parts.append(tc.get('text', '')) + tool_content = '\n'.join(tool_text_parts) # Propagate error status if present - if block.get("is_error"): - tool_content = f"Error: {tool_content}" + if block.get('is_error'): + tool_content = f'Error: {tool_content}' messages.append( { - "role": "tool", - "tool_call_id": block.get("tool_use_id", ""), - "content": tool_content, + 'role': 'tool', + 'tool_call_id': block.get('tool_use_id', ''), + 'content': tool_content, } ) # Build the message if tool_calls: # Assistant message with tool calls - msg_dict = {"role": role} + msg_dict = {'role': role} if openai_content: # If there's only text, flatten it - if len(openai_content) == 1 and openai_content[0]["type"] == "text": - msg_dict["content"] = openai_content[0]["text"] + if len(openai_content) == 1 and openai_content[0]['type'] == 'text': + msg_dict['content'] = openai_content[0]['text'] else: - msg_dict["content"] = openai_content + msg_dict['content'] = openai_content else: - msg_dict["content"] = "" - msg_dict["tool_calls"] = tool_calls + msg_dict['content'] = '' + msg_dict['tool_calls'] = tool_calls messages.append(msg_dict) elif openai_content: # If there's only a single text block, flatten it to a string - if len(openai_content) == 1 and openai_content[0]["type"] == "text": - messages.append( - {"role": role, "content": openai_content[0]["text"]} - ) + if len(openai_content) == 1 and openai_content[0]['type'] == 'text': + messages.append({'role': role, 'content': openai_content[0]['text']}) else: - messages.append({"role": role, "content": openai_content}) + messages.append({'role': role, 'content': openai_content}) else: - messages.append({"role": role, "content": str(content) if content else ""}) + messages.append({'role': role, 'content': str(content) if content else ''}) - openai_payload["messages"] = messages + openai_payload['messages'] = messages # max_tokens - if "max_tokens" in anthropic_payload: - openai_payload["max_tokens"] = anthropic_payload["max_tokens"] + if 'max_tokens' in anthropic_payload: + openai_payload['max_tokens'] = anthropic_payload['max_tokens'] # Common parameters - for param in ("temperature", "top_p", "stop_sequences", "stream"): + for param in ('temperature', 'top_p', 'stop_sequences', 'stream'): if param in anthropic_payload: - if param == "stop_sequences": - openai_payload["stop"] = anthropic_payload[param] + if param == 'stop_sequences': + openai_payload['stop'] = anthropic_payload[param] else: openai_payload[param] = anthropic_payload[param] # Tools conversion: Anthropic → OpenAI - if "tools" in anthropic_payload: + if 'tools' in anthropic_payload: openai_tools = [] - for tool in anthropic_payload["tools"]: + for tool in anthropic_payload['tools']: openai_tools.append( { - "type": "function", - "function": { - "name": tool.get("name", ""), - "description": tool.get("description", ""), - "parameters": tool.get("input_schema", {}), + 'type': 'function', + 'function': { + 'name': tool.get('name', ''), + 'description': tool.get('description', ''), + 'parameters': tool.get('input_schema', {}), }, } ) - openai_payload["tools"] = openai_tools + openai_payload['tools'] = openai_tools # tool_choice - if "tool_choice" in anthropic_payload: - tc = anthropic_payload["tool_choice"] + if 'tool_choice' in anthropic_payload: + tc = anthropic_payload['tool_choice'] if isinstance(tc, dict): - tc_type = tc.get("type", "auto") - if tc_type == "auto": - openai_payload["tool_choice"] = "auto" - elif tc_type == "any": - openai_payload["tool_choice"] = "required" - elif tc_type == "tool": - openai_payload["tool_choice"] = { - "type": "function", - "function": {"name": tc.get("name", "")}, + tc_type = tc.get('type', 'auto') + if tc_type == 'auto': + openai_payload['tool_choice'] = 'auto' + elif tc_type == 'any': + openai_payload['tool_choice'] = 'required' + elif tc_type == 'tool': + openai_payload['tool_choice'] = { + 'type': 'function', + 'function': {'name': tc.get('name', '')}, } return openai_payload -def convert_openai_to_anthropic_response( - openai_response: dict, model: str = "" -) -> dict: +def convert_openai_to_anthropic_response(openai_response: dict, model: str = '') -> dict: """ Convert a non-streaming OpenAI Chat Completions response to Anthropic Messages format. """ import uuid as _uuid choice = {} - if openai_response.get("choices"): - choice = openai_response["choices"][0] + if openai_response.get('choices'): + choice = openai_response['choices'][0] - message = choice.get("message", {}) - finish_reason = choice.get("finish_reason", "stop") + message = choice.get('message', {}) + finish_reason = choice.get('finish_reason', 'stop') # Map finish_reason to stop_reason stop_reason_map = { - "stop": "end_turn", - "length": "max_tokens", - "tool_calls": "tool_use", - "content_filter": "end_turn", + 'stop': 'end_turn', + 'length': 'max_tokens', + 'tool_calls': 'tool_use', + 'content_filter': 'end_turn', } - stop_reason = stop_reason_map.get(finish_reason, "end_turn") + stop_reason = stop_reason_map.get(finish_reason, 'end_turn') # Build content blocks content = [] - msg_content = message.get("content") + msg_content = message.get('content') if msg_content: - content.append({"type": "text", "text": msg_content}) + content.append({'type': 'text', 'text': msg_content}) # Tool calls → tool_use blocks - tool_calls = message.get("tool_calls", []) + tool_calls = message.get('tool_calls', []) for tc in tool_calls: - func = tc.get("function", {}) + func = tc.get('function', {}) try: - tool_input = json.loads(func.get("arguments", "{}")) + tool_input = json.loads(func.get('arguments', '{}')) except (json.JSONDecodeError, TypeError): tool_input = {} content.append( { - "type": "tool_use", - "id": tc.get("id", f"toolu_{_uuid.uuid4().hex[:24]}"), - "name": func.get("name", ""), - "input": tool_input, + 'type': 'tool_use', + 'id': tc.get('id', f'toolu_{_uuid.uuid4().hex[:24]}'), + 'name': func.get('name', ''), + 'input': tool_input, } ) # Usage - openai_usage = openai_response.get("usage", {}) + openai_usage = openai_response.get('usage', {}) usage = { - "input_tokens": openai_usage.get("prompt_tokens", 0), - "output_tokens": openai_usage.get("completion_tokens", 0), + 'input_tokens': openai_usage.get('prompt_tokens', 0), + 'output_tokens': openai_usage.get('completion_tokens', 0), } return { - "id": openai_response.get("id", f"msg_{_uuid.uuid4().hex[:24]}"), - "type": "message", - "role": "assistant", - "content": content, - "model": model or openai_response.get("model", ""), - "stop_reason": stop_reason, - "stop_sequence": None, - "usage": usage, + 'id': openai_response.get('id', f'msg_{_uuid.uuid4().hex[:24]}'), + 'type': 'message', + 'role': 'assistant', + 'content': content, + 'model': model or openai_response.get('model', ''), + 'stop_reason': stop_reason, + 'stop_sequence': None, + 'usage': usage, } -async def openai_stream_to_anthropic_stream(openai_stream_generator, model: str = ""): +async def openai_stream_to_anthropic_stream(openai_stream_generator, model: str = ''): """ Convert an OpenAI SSE streaming response to Anthropic Messages SSE format. @@ -352,10 +348,10 @@ async def openai_stream_to_anthropic_stream(openai_stream_generator, model: str """ import uuid as _uuid - msg_id = f"msg_{_uuid.uuid4().hex[:24]}" + msg_id = f'msg_{_uuid.uuid4().hex[:24]}' input_tokens = 0 output_tokens = 0 - stop_reason = "end_turn" + stop_reason = 'end_turn' # Track content blocks with a running index. # Each text block or tool_use block gets its own index. @@ -369,35 +365,35 @@ async def openai_stream_to_anthropic_stream(openai_stream_generator, model: str # Emit message_start message_start = { - "type": "message_start", - "message": { - "id": msg_id, - "type": "message", - "role": "assistant", - "content": [], - "model": model, - "stop_reason": None, - "stop_sequence": None, - "usage": {"input_tokens": 0, "output_tokens": 0}, + 'type': 'message_start', + 'message': { + 'id': msg_id, + 'type': 'message', + 'role': 'assistant', + 'content': [], + 'model': model, + 'stop_reason': None, + 'stop_sequence': None, + 'usage': {'input_tokens': 0, 'output_tokens': 0}, }, } - yield f"event: message_start\ndata: {json.dumps(message_start)}\n\n".encode() + yield f'event: message_start\ndata: {json.dumps(message_start)}\n\n'.encode() try: async for chunk in openai_stream_generator: if isinstance(chunk, bytes): - chunk = chunk.decode("utf-8", errors="ignore") + chunk = chunk.decode('utf-8', errors='ignore') - for line in chunk.strip().split("\n"): + for line in chunk.strip().split('\n'): line = line.strip() - if not line or not line.startswith("data:"): + if not line or not line.startswith('data:'): continue data_str = line[5:].strip() - if data_str == "[DONE]": + if data_str == '[DONE]': continue - if data_str == "{}": + if data_str == '{}': continue try: @@ -405,62 +401,58 @@ async def openai_stream_to_anthropic_stream(openai_stream_generator, model: str except (json.JSONDecodeError, TypeError): continue - choices = data.get("choices", []) + choices = data.get('choices', []) if not choices: # Check for usage in the final chunk - if data.get("usage"): - input_tokens = data["usage"].get("prompt_tokens", input_tokens) - output_tokens = data["usage"].get( - "completion_tokens", output_tokens - ) + if data.get('usage'): + input_tokens = data['usage'].get('prompt_tokens', input_tokens) + output_tokens = data['usage'].get('completion_tokens', output_tokens) continue - delta = choices[0].get("delta", {}) - finish_reason = choices[0].get("finish_reason") + delta = choices[0].get('delta', {}) + finish_reason = choices[0].get('finish_reason') # Update usage if present - if data.get("usage"): - input_tokens = data["usage"].get("prompt_tokens", input_tokens) - output_tokens = data["usage"].get( - "completion_tokens", output_tokens - ) + if data.get('usage'): + input_tokens = data['usage'].get('prompt_tokens', input_tokens) + output_tokens = data['usage'].get('completion_tokens', output_tokens) # --- Handle text content --- - content = delta.get("content") + content = delta.get('content') if content is not None: if not text_block_open: # Start a new text content block block_start = { - "type": "content_block_start", - "index": current_block_index, - "content_block": {"type": "text", "text": ""}, + 'type': 'content_block_start', + 'index': current_block_index, + 'content_block': {'type': 'text', 'text': ''}, } - yield f"event: content_block_start\ndata: {json.dumps(block_start)}\n\n".encode() + yield f'event: content_block_start\ndata: {json.dumps(block_start)}\n\n'.encode() text_block_open = True # Send text delta block_delta = { - "type": "content_block_delta", - "index": current_block_index, - "delta": {"type": "text_delta", "text": content}, + 'type': 'content_block_delta', + 'index': current_block_index, + 'delta': {'type': 'text_delta', 'text': content}, } - yield f"event: content_block_delta\ndata: {json.dumps(block_delta)}\n\n".encode() + yield f'event: content_block_delta\ndata: {json.dumps(block_delta)}\n\n'.encode() # --- Handle tool calls --- - tool_calls = delta.get("tool_calls") + tool_calls = delta.get('tool_calls') if tool_calls: # Close text block if one is open (text comes before tools) if text_block_open: block_stop = { - "type": "content_block_stop", - "index": current_block_index, + 'type': 'content_block_stop', + 'index': current_block_index, } - yield f"event: content_block_stop\ndata: {json.dumps(block_stop)}\n\n".encode() + yield f'event: content_block_stop\ndata: {json.dumps(block_stop)}\n\n'.encode() text_block_open = False current_block_index += 1 for tc in tool_calls: - tc_index = tc.get("index", 0) + tc_index = tc.get('index', 0) if tc_index not in tool_call_started: # First time seeing this tool call — emit content_block_start @@ -468,67 +460,67 @@ async def openai_stream_to_anthropic_stream(openai_stream_generator, model: str tool_call_started[tc_index] = True # Extract tool call ID and name from the first chunk - tc_id = tc.get("id", f"toolu_{_uuid.uuid4().hex[:24]}") - tc_name = tc.get("function", {}).get("name", "") + tc_id = tc.get('id', f'toolu_{_uuid.uuid4().hex[:24]}') + tc_name = tc.get('function', {}).get('name', '') block_start = { - "type": "content_block_start", - "index": current_block_index, - "content_block": { - "type": "tool_use", - "id": tc_id, - "name": tc_name, - "input": {}, + 'type': 'content_block_start', + 'index': current_block_index, + 'content_block': { + 'type': 'tool_use', + 'id': tc_id, + 'name': tc_name, + 'input': {}, }, } - yield f"event: content_block_start\ndata: {json.dumps(block_start)}\n\n".encode() + yield f'event: content_block_start\ndata: {json.dumps(block_start)}\n\n'.encode() current_block_index += 1 # Emit argument chunks as input_json_delta - args_chunk = tc.get("function", {}).get("arguments", "") + args_chunk = tc.get('function', {}).get('arguments', '') if args_chunk: block_delta = { - "type": "content_block_delta", - "index": tool_call_blocks[tc_index], - "delta": { - "type": "input_json_delta", - "partial_json": args_chunk, + 'type': 'content_block_delta', + 'index': tool_call_blocks[tc_index], + 'delta': { + 'type': 'input_json_delta', + 'partial_json': args_chunk, }, } - yield f"event: content_block_delta\ndata: {json.dumps(block_delta)}\n\n".encode() + yield f'event: content_block_delta\ndata: {json.dumps(block_delta)}\n\n'.encode() # --- Handle finish reason --- if finish_reason is not None: stop_reason_map = { - "stop": "end_turn", - "length": "max_tokens", - "tool_calls": "tool_use", + 'stop': 'end_turn', + 'length': 'max_tokens', + 'tool_calls': 'tool_use', } - stop_reason = stop_reason_map.get(finish_reason, "end_turn") + stop_reason = stop_reason_map.get(finish_reason, 'end_turn') except Exception as e: - log.error(f"Error in Anthropic stream conversion: {e}") + log.error(f'Error in Anthropic stream conversion: {e}') # Close any open text block if text_block_open: - block_stop = {"type": "content_block_stop", "index": current_block_index} - yield f"event: content_block_stop\ndata: {json.dumps(block_stop)}\n\n".encode() + block_stop = {'type': 'content_block_stop', 'index': current_block_index} + yield f'event: content_block_stop\ndata: {json.dumps(block_stop)}\n\n'.encode() # Close any open tool call blocks for tc_index, block_index in tool_call_blocks.items(): - block_stop = {"type": "content_block_stop", "index": block_index} - yield f"event: content_block_stop\ndata: {json.dumps(block_stop)}\n\n".encode() + block_stop = {'type': 'content_block_stop', 'index': block_index} + yield f'event: content_block_stop\ndata: {json.dumps(block_stop)}\n\n'.encode() # Emit message_delta with stop reason message_delta = { - "type": "message_delta", - "delta": { - "stop_reason": stop_reason, - "stop_sequence": None, + 'type': 'message_delta', + 'delta': { + 'stop_reason': stop_reason, + 'stop_sequence': None, }, - "usage": {"output_tokens": output_tokens}, + 'usage': {'output_tokens': output_tokens}, } - yield f"event: message_delta\ndata: {json.dumps(message_delta)}\n\n".encode() + yield f'event: message_delta\ndata: {json.dumps(message_delta)}\n\n'.encode() # Emit message_stop - yield f"event: message_stop\ndata: {json.dumps({'type': 'message_stop'})}\n\n".encode() + yield f'event: message_stop\ndata: {json.dumps({"type": "message_stop"})}\n\n'.encode() diff --git a/backend/open_webui/utils/audit.py b/backend/open_webui/utils/audit.py index bf3c33774c..1200d813af 100644 --- a/backend/open_webui/utils/audit.py +++ b/backend/open_webui/utils/audit.py @@ -50,10 +50,10 @@ class AuditLogEntry: class AuditLevel(str, Enum): - NONE = "NONE" - METADATA = "METADATA" - REQUEST = "REQUEST" - REQUEST_RESPONSE = "REQUEST_RESPONSE" + NONE = 'NONE' + METADATA = 'METADATA' + REQUEST = 'REQUEST' + REQUEST_RESPONSE = 'REQUEST_RESPONSE' class AuditLogger: @@ -64,25 +64,24 @@ class AuditLogger: logger (Logger): An instance of Loguru’s logger. """ - def __init__(self, logger: "Logger"): + def __init__(self, logger: 'Logger'): self.logger = logger.bind(auditable=True) def write( self, audit_entry: AuditLogEntry, *, - log_level: str = "INFO", + log_level: str = 'INFO', extra: Optional[dict] = None, ): - entry = asdict(audit_entry) if extra: - entry["extra"] = extra + entry['extra'] = extra self.logger.log( log_level, - "", + '', **entry, ) @@ -106,15 +105,11 @@ class AuditContext: def add_request_chunk(self, chunk: bytes): if len(self.request_body) < self.max_body_size: - self.request_body.extend( - chunk[: self.max_body_size - len(self.request_body)] - ) + self.request_body.extend(chunk[: self.max_body_size - len(self.request_body)]) def add_response_chunk(self, chunk: bytes): if len(self.response_body) < self.max_body_size: - self.response_body.extend( - chunk[: self.max_body_size - len(self.response_body)] - ) + self.response_body.extend(chunk[: self.max_body_size - len(self.response_body)]) class AuditLoggingMiddleware: @@ -122,7 +117,7 @@ class AuditLoggingMiddleware: ASGI middleware that intercepts HTTP requests and responses to perform audit logging. It captures request/response bodies (depending on audit level), headers, HTTP methods, and user information, then logs a structured audit entry at the end of the request cycle. """ - AUDITED_METHODS = {"PUT", "PATCH", "DELETE", "POST"} + AUDITED_METHODS = {'PUT', 'PATCH', 'DELETE', 'POST'} def __init__( self, @@ -142,8 +137,8 @@ class AuditLoggingMiddleware: if self.included_paths and self.excluded_paths: logger.warning( - "Both AUDIT_INCLUDED_PATHS and AUDIT_EXCLUDED_PATHS are set. " - "AUDIT_INCLUDED_PATHS (whitelist) takes precedence." + 'Both AUDIT_INCLUDED_PATHS and AUDIT_EXCLUDED_PATHS are set. ' + 'AUDIT_INCLUDED_PATHS (whitelist) takes precedence.' ) async def __call__( @@ -152,7 +147,7 @@ class AuditLoggingMiddleware: receive: ASGIReceiveCallable, send: ASGISendCallable, ) -> None: - if scope["type"] != "http": + if scope['type'] != 'http': return await self.app(scope, receive, send) request = Request(scope=cast(MutableMapping, scope)) @@ -185,9 +180,7 @@ class AuditLoggingMiddleware: await self.app(scope, receive_wrapper, send_wrapper) @asynccontextmanager - async def _audit_context( - self, request: Request - ) -> AsyncGenerator[AuditContext, None]: + async def _audit_context(self, request: Request) -> AsyncGenerator[AuditContext, None]: """ async context manager that ensures that an audit log entry is recorded after the request is processed. """ @@ -198,29 +191,24 @@ class AuditLoggingMiddleware: await self._log_audit_entry(request, context) async def _get_authenticated_user(self, request: Request) -> Optional[UserModel]: - auth_header = request.headers.get("Authorization") + auth_header = request.headers.get('Authorization') try: - user = await get_current_user( - request, None, None, get_http_authorization_cred(auth_header) - ) + user = await get_current_user(request, None, None, get_http_authorization_cred(auth_header)) return user except Exception as e: - logger.debug(f"Failed to get authenticated user: {str(e)}") + logger.debug(f'Failed to get authenticated user: {str(e)}') return None def _should_skip_auditing(self, request: Request) -> bool: - if ( - request.method not in {"POST", "PUT", "PATCH", "DELETE"} - or AUDIT_LOG_LEVEL == "NONE" - ): + if request.method not in {'POST', 'PUT', 'PATCH', 'DELETE'} or AUDIT_LOG_LEVEL == 'NONE': return True ALWAYS_LOG_ENDPOINTS = { - "/api/v1/auths/signin", - "/api/v1/auths/signout", - "/api/v1/auths/signup", + '/api/v1/auths/signin', + '/api/v1/auths/signout', + '/api/v1/auths/signup', } path = request.url.path.lower() for endpoint in ALWAYS_LOG_ENDPOINTS: @@ -229,55 +217,47 @@ class AuditLoggingMiddleware: # Skip logging if the request is not authenticated # Check both Authorization header (API keys) and token cookie (browser sessions) - if not request.headers.get("authorization") and not request.cookies.get( - "token" - ): + if not request.headers.get('authorization') and not request.cookies.get('token'): return True # Whitelist mode: only log paths that match included_paths if self.included_paths: - pattern = re.compile( - r"^/api(?:/v1)?/(" + "|".join(self.included_paths) + r")\b" - ) + pattern = re.compile(r'^/api(?:/v1)?/(' + '|'.join(self.included_paths) + r')\b') if not pattern.match(request.url.path): return True # Skip: path not in whitelist return False # Do NOT skip: path is in whitelist # Blacklist mode: skip paths that match excluded_paths - pattern = re.compile( - r"^/api(?:/v1)?/(" + "|".join(self.excluded_paths) + r")\b" - ) + pattern = re.compile(r'^/api(?:/v1)?/(' + '|'.join(self.excluded_paths) + r')\b') if pattern.match(request.url.path): return True return False async def _capture_request(self, message: ASGIReceiveEvent, context: AuditContext): - if message["type"] == "http.request": - body = message.get("body", b"") + if message['type'] == 'http.request': + body = message.get('body', b'') context.add_request_chunk(body) async def _capture_response(self, message: ASGISendEvent, context: AuditContext): - if message["type"] == "http.response.start": - context.metadata["response_status_code"] = message["status"] + if message['type'] == 'http.response.start': + context.metadata['response_status_code'] = message['status'] - elif message["type"] == "http.response.body": - body = message.get("body", b"") + elif message['type'] == 'http.response.body': + body = message.get('body', b'') context.add_response_chunk(body) async def _log_audit_entry(self, request: Request, context: AuditContext): try: user = await self._get_authenticated_user(request) - user = ( - user.model_dump(include={"id", "name", "email", "role"}) if user else {} - ) + user = user.model_dump(include={'id', 'name', 'email', 'role'}) if user else {} - request_body = context.request_body.decode("utf-8", errors="replace") - response_body = context.response_body.decode("utf-8", errors="replace") + request_body = context.request_body.decode('utf-8', errors='replace') + response_body = context.response_body.decode('utf-8', errors='replace') # Redact sensitive information - if "password" in request_body: + if 'password' in request_body: request_body = re.sub( r'"password":\s*"(.*?)"', '"password": "********"', @@ -290,13 +270,13 @@ class AuditLoggingMiddleware: audit_level=self.audit_level.value, verb=request.method, request_uri=str(request.url), - response_status_code=context.metadata.get("response_status_code", None), + response_status_code=context.metadata.get('response_status_code', None), source_ip=request.client.host if request.client else None, - user_agent=request.headers.get("user-agent"), + user_agent=request.headers.get('user-agent'), request_object=request_body, response_object=response_body, ) self.audit_logger.write(entry) except Exception as e: - logger.error(f"Failed to log audit entry: {str(e)}") + logger.error(f'Failed to log audit entry: {str(e)}') diff --git a/backend/open_webui/utils/auth.py b/backend/open_webui/utils/auth.py index a12c6db881..1a78f3143f 100644 --- a/backend/open_webui/utils/auth.py +++ b/backend/open_webui/utils/auth.py @@ -49,7 +49,7 @@ from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer log = logging.getLogger(__name__) SESSION_SECRET = WEBUI_SECRET_KEY -ALGORITHM = "HS256" +ALGORITHM = 'HS256' ############## # Auth Utils @@ -74,62 +74,60 @@ def verify_signature(payload: str, signature: str) -> bool: def override_static(path: str, content: str): # Ensure path is safe - if "/" in path or ".." in path: - log.error(f"Invalid path: {path}") + if '/' in path or '..' in path: + log.error(f'Invalid path: {path}') return file_path = os.path.join(STATIC_DIR, path) os.makedirs(os.path.dirname(file_path), exist_ok=True) - with open(file_path, "wb") as f: + with open(file_path, 'wb') as f: f.write(base64.b64decode(content)) # Convert Base64 back to raw binary def get_license_data(app, key): def data_handler(data): for k, v in data.items(): - if k == "resources": + if k == 'resources': for p, c in v.items(): - globals().get("override_static", lambda a, b: None)(p, c) - elif k == "count": - setattr(app.state, "USER_COUNT", v) - elif k == "name": - setattr(app.state, "WEBUI_NAME", v) - elif k == "metadata": - setattr(app.state, "LICENSE_METADATA", v) + globals().get('override_static', lambda a, b: None)(p, c) + elif k == 'count': + setattr(app.state, 'USER_COUNT', v) + elif k == 'name': + setattr(app.state, 'WEBUI_NAME', v) + elif k == 'metadata': + setattr(app.state, 'LICENSE_METADATA', v) def handler(u): res = requests.post( - f"{u}/api/v1/license/", - json={"key": key, "version": "1"}, + f'{u}/api/v1/license/', + json={'key': key, 'version': '1'}, timeout=5, ) - if getattr(res, "ok", False): - payload = getattr(res, "json", lambda: {})() + if getattr(res, 'ok', False): + payload = getattr(res, 'json', lambda: {})() data_handler(payload) return True else: - log.error( - f"License: retrieval issue: {getattr(res, 'text', 'unknown error')}" - ) + log.error(f'License: retrieval issue: {getattr(res, "text", "unknown error")}') if key: us = [ - "https://api.openwebui.com", - "https://licenses.api.openwebui.com", + 'https://api.openwebui.com', + 'https://licenses.api.openwebui.com', ] try: for u in us: if handler(u): return True except Exception as ex: - log.exception(f"License: Uncaught Exception: {ex}") + log.exception(f'License: Uncaught Exception: {ex}') try: if LICENSE_BLOB: nl = 12 - kb = hashlib.sha256((key.replace("-", "").upper()).encode()).digest() + kb = hashlib.sha256((key.replace('-', '').upper()).encode()).digest() def nt(b): return b[:nl], b[nl:] @@ -139,19 +137,19 @@ def get_license_data(app, key): aesgcm = AESGCM(kb) p = json.loads(aesgcm.decrypt(ln, lt, None)) - pk.verify(base64.b64decode(p["s"]), p["p"].encode()) + pk.verify(base64.b64decode(p['s']), p['p'].encode()) - pb = base64.b64decode(p["p"]) + pb = base64.b64decode(p['p']) pn, pt = nt(pb) data = json.loads(aesgcm.decrypt(pn, pt, None).decode()) - if not data.get("exp") and data.get("exp") < datetime.now().date(): + if not data.get('exp') and data.get('exp') < datetime.now().date(): return False data_handler(data) return True except Exception as e: - log.error(f"License: {e}") + log.error(f'License: {e}') return False @@ -161,12 +159,12 @@ bearer_security = HTTPBearer(auto_error=False) def get_password_hash(password: str) -> str: """Hash a password using bcrypt""" - return bcrypt.hashpw(password.encode("utf-8"), bcrypt.gensalt()).decode("utf-8") + return bcrypt.hashpw(password.encode('utf-8'), bcrypt.gensalt()).decode('utf-8') def validate_password(password: str) -> bool: # The password passed to bcrypt must be 72 bytes or fewer. If it is longer, it will be truncated before hashing. - if len(password.encode("utf-8")) > 72: + if len(password.encode('utf-8')) > 72: raise Exception( ERROR_MESSAGES.PASSWORD_TOO_LONG, ) @@ -182,8 +180,8 @@ def verify_password(plain_password: str, hashed_password: str) -> bool: """Verify a password against its hash""" return ( bcrypt.checkpw( - plain_password.encode("utf-8"), - hashed_password.encode("utf-8"), + plain_password.encode('utf-8'), + hashed_password.encode('utf-8'), ) if hashed_password else None @@ -195,10 +193,10 @@ def create_token(data: dict, expires_delta: Union[timedelta, None] = None) -> st if expires_delta: expire = datetime.now(UTC) + expires_delta - payload.update({"exp": expire}) + payload.update({'exp': expire}) jti = str(uuid.uuid4()) - payload.update({"jti": jti}) + payload.update({'jti': jti}) encoded_jwt = jwt.encode(payload, SESSION_SECRET, algorithm=ALGORITHM) return encoded_jwt @@ -215,12 +213,10 @@ def decode_token(token: str) -> Optional[dict]: async def is_valid_token(request, decoded) -> bool: # Require Redis to check revoked tokens if request.app.state.redis: - jti = decoded.get("jti") + jti = decoded.get('jti') if jti: - revoked = await request.app.state.redis.get( - f"{REDIS_KEY_PREFIX}:auth:token:{jti}:revoked" - ) + revoked = await request.app.state.redis.get(f'{REDIS_KEY_PREFIX}:auth:token:{jti}:revoked') if revoked: return False @@ -236,37 +232,35 @@ async def invalidate_token(request, token): # Require Redis to store revoked tokens if request.app.state.redis: - jti = decoded.get("jti") - exp = decoded.get("exp") + jti = decoded.get('jti') + exp = decoded.get('exp') if jti and exp: - ttl = exp - int( - datetime.now(UTC).timestamp() - ) # Calculate time-to-live for the token + ttl = exp - int(datetime.now(UTC).timestamp()) # Calculate time-to-live for the token if ttl > 0: # Store the revoked token in Redis with an expiration time await request.app.state.redis.set( - f"{REDIS_KEY_PREFIX}:auth:token:{jti}:revoked", - "1", + f'{REDIS_KEY_PREFIX}:auth:token:{jti}:revoked', + '1', ex=ttl, ) def extract_token_from_auth_header(auth_header: str): - return auth_header[len("Bearer ") :] + return auth_header[len('Bearer ') :] def create_api_key(): - key = str(uuid.uuid4()).replace("-", "") - return f"sk-{key}" + key = str(uuid.uuid4()).replace('-', '') + return f'sk-{key}' def get_http_authorization_cred(auth_header: Optional[str]): if not auth_header: return None try: - scheme, credentials = auth_header.split(" ") + scheme, credentials = auth_header.split(' ') return HTTPAuthorizationCredentials(scheme=scheme, credentials=credentials) except Exception: return None @@ -287,27 +281,27 @@ async def get_current_user( if auth_token is not None: token = auth_token.credentials - if token is None and "token" in request.cookies: - token = request.cookies.get("token") + if token is None and 'token' in request.cookies: + token = request.cookies.get('token') # Fallback to request.state.token (set by middleware, e.g. for x-api-key) - if token is None and hasattr(request.state, "token") and request.state.token: + if token is None and hasattr(request.state, 'token') and request.state.token: token = request.state.token.credentials if token is None: - raise HTTPException(status_code=401, detail="Not authenticated") + raise HTTPException(status_code=401, detail='Not authenticated') # auth by api key - if token.startswith("sk-"): + if token.startswith('sk-'): user = get_current_user_by_api_key(request, token) # Add user info to current span current_span = trace.get_current_span() if current_span: - current_span.set_attribute("client.user.id", user.id) - current_span.set_attribute("client.user.email", user.email) - current_span.set_attribute("client.user.role", user.role) - current_span.set_attribute("client.auth.type", "api_key") + current_span.set_attribute('client.user.id', user.id) + current_span.set_attribute('client.user.email', user.email) + current_span.set_attribute('client.user.role', user.role) + current_span.set_attribute('client.auth.type', 'api_key') return user @@ -318,17 +312,17 @@ async def get_current_user( except Exception as e: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, - detail="Invalid token", + detail='Invalid token', ) - if data is not None and "id" in data: - if data.get("jti") and not await is_valid_token(request, data): + if data is not None and 'id' in data: + if data.get('jti') and not await is_valid_token(request, data): raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, - detail="Invalid token", + detail='Invalid token', ) - user = Users.get_user_by_id(data["id"]) + user = Users.get_user_by_id(data['id']) if user is None: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, @@ -336,22 +330,20 @@ async def get_current_user( ) else: if WEBUI_AUTH_TRUSTED_EMAIL_HEADER: - trusted_email = request.headers.get( - WEBUI_AUTH_TRUSTED_EMAIL_HEADER, "" - ).lower() + trusted_email = request.headers.get(WEBUI_AUTH_TRUSTED_EMAIL_HEADER, '').lower() if trusted_email and user.email != trusted_email: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, - detail="User mismatch. Please sign in again.", + detail='User mismatch. Please sign in again.', ) # Add user info to current span current_span = trace.get_current_span() if current_span: - current_span.set_attribute("client.user.id", user.id) - current_span.set_attribute("client.user.email", user.email) - current_span.set_attribute("client.user.role", user.role) - current_span.set_attribute("client.auth.type", "jwt") + current_span.set_attribute('client.user.id', user.id) + current_span.set_attribute('client.user.email', user.email) + current_span.set_attribute('client.user.role', user.role) + current_span.set_attribute('client.auth.type', 'jwt') # Refresh the user's last active timestamp asynchronously # to prevent blocking the request @@ -365,15 +357,15 @@ async def get_current_user( ) except Exception as e: # Delete the token cookie - if request.cookies.get("token"): - response.delete_cookie("token") + if request.cookies.get('token'): + response.delete_cookie('token') - if request.cookies.get("oauth_id_token"): - response.delete_cookie("oauth_id_token") + if request.cookies.get('oauth_id_token'): + response.delete_cookie('oauth_id_token') # Delete OAuth session if present - if request.cookies.get("oauth_session_id"): - response.delete_cookie("oauth_session_id") + if request.cookies.get('oauth_session_id'): + response.delete_cookie('oauth_session_id') raise e @@ -389,31 +381,29 @@ def get_current_user_by_api_key(request, api_key: str): ) if not request.state.enable_api_keys or ( - user.role != "admin" + user.role != 'admin' and not has_permission( user.id, - "features.api_keys", + 'features.api_keys', request.app.state.config.USER_PERMISSIONS, ) ): - raise HTTPException( - status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.API_KEY_NOT_ALLOWED - ) + raise HTTPException(status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.API_KEY_NOT_ALLOWED) # Add user info to current span current_span = trace.get_current_span() if current_span: - current_span.set_attribute("client.user.id", user.id) - current_span.set_attribute("client.user.email", user.email) - current_span.set_attribute("client.user.role", user.role) - current_span.set_attribute("client.auth.type", "api_key") + current_span.set_attribute('client.user.id', user.id) + current_span.set_attribute('client.user.email', user.email) + current_span.set_attribute('client.user.role', user.role) + current_span.set_attribute('client.auth.type', 'api_key') Users.update_last_active_by_id(user.id) return user def get_verified_user(user=Depends(get_current_user)): - if user.role not in {"user", "admin"}: + if user.role not in {'user', 'admin'}: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.ACCESS_PROHIBITED, @@ -422,7 +412,7 @@ def get_verified_user(user=Depends(get_current_user)): def get_admin_user(user=Depends(get_current_user)): - if user.role != "admin": + if user.role != 'admin': raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.ACCESS_PROHIBITED, @@ -430,7 +420,7 @@ def get_admin_user(user=Depends(get_current_user)): return user -def create_admin_user(email: str, password: str, name: str = "Admin"): +def create_admin_user(email: str, password: str, name: str = 'Admin'): """ Create an admin user from environment variables. Used for headless/automated deployments. @@ -441,24 +431,24 @@ def create_admin_user(email: str, password: str, name: str = "Admin"): return None if Users.has_users(): - log.debug("Users already exist, skipping admin creation") + log.debug('Users already exist, skipping admin creation') return None - log.info(f"Creating admin account from environment variables: {email}") + log.info(f'Creating admin account from environment variables: {email}') try: hashed = get_password_hash(password) user = Auths.insert_new_auth( email=email.lower(), password=hashed, name=name, - role="admin", + role='admin', ) if user: - log.info(f"Admin account created successfully: {email}") + log.info(f'Admin account created successfully: {email}') return user else: - log.error("Failed to create admin account from environment variables") + log.error('Failed to create admin account from environment variables') return None except Exception as e: - log.error(f"Error creating admin account: {e}") + log.error(f'Error creating admin account: {e}') return None diff --git a/backend/open_webui/utils/channels.py b/backend/open_webui/utils/channels.py index 312b5ea24c..6f85dfae1e 100644 --- a/backend/open_webui/utils/channels.py +++ b/backend/open_webui/utils/channels.py @@ -1,16 +1,16 @@ import re -def extract_mentions(message: str, triggerChar: str = "@"): +def extract_mentions(message: str, triggerChar: str = '@'): # Escape triggerChar in case it's a regex special character triggerChar = re.escape(triggerChar) - pattern = rf"<{triggerChar}([A-Z]):([^|>]+)" + pattern = rf'<{triggerChar}([A-Z]):([^|>]+)' matches = re.findall(pattern, message) - return [{"id_type": id_type, "id": id_value} for id_type, id_value in matches] + return [{'id_type': id_type, 'id': id_value} for id_type, id_value in matches] -def replace_mentions(message: str, triggerChar: str = "@", use_label: bool = True): +def replace_mentions(message: str, triggerChar: str = '@', use_label: bool = True): """ Replace mentions in the message with either their label (after the pipe `|`) or their id if no label exists. @@ -27,5 +27,5 @@ def replace_mentions(message: str, triggerChar: str = "@", use_label: bool = Tru return label if use_label and label else id_value # Regex captures: idType, id, optional label - pattern = rf"<{triggerChar}([A-Z]):([^|>]+)(?:\|([^>]+))?>" + pattern = rf'<{triggerChar}([A-Z]):([^|>]+)(?:\|([^>]+))?>' return re.sub(pattern, replacer, message) diff --git a/backend/open_webui/utils/chat.py b/backend/open_webui/utils/chat.py index 753ee56636..79a7991eca 100644 --- a/backend/open_webui/utils/chat.py +++ b/backend/open_webui/utils/chat.py @@ -62,20 +62,20 @@ async def generate_direct_chat_completion( user: Any, models: dict, ): - log.info("generate_direct_chat_completion") + log.info('generate_direct_chat_completion') - metadata = form_data.pop("metadata", {}) + metadata = form_data.pop('metadata', {}) - user_id = metadata.get("user_id") - session_id = metadata.get("session_id") + user_id = metadata.get('user_id') + session_id = metadata.get('session_id') request_id = str(uuid.uuid4()) # Generate a unique request ID event_caller = get_event_call(metadata) - channel = f"{user_id}:{session_id}:{request_id}" - logging.info(f"WebSocket channel: {channel}") + channel = f'{user_id}:{session_id}:{request_id}' + logging.info(f'WebSocket channel: {channel}') - if form_data.get("stream"): + if form_data.get('stream'): q = asyncio.Queue() async def message_listener(sid, data): @@ -90,19 +90,19 @@ async def generate_direct_chat_completion( # Start processing chat completion in background res = await event_caller( { - "type": "request:chat:completion", - "data": { - "form_data": form_data, - "model": models[form_data["model"]], - "channel": channel, - "session_id": session_id, + 'type': 'request:chat:completion', + 'data': { + 'form_data': form_data, + 'model': models[form_data['model']], + 'channel': channel, + 'session_id': session_id, }, } ) - log.info(f"res: {res}") + log.info(f'res: {res}') - if res.get("status", False): + if res.get('status', False): # Define a generator to stream responses async def event_generator(): nonlocal q @@ -110,47 +110,45 @@ async def generate_direct_chat_completion( while True: data = await q.get() # Wait for new messages if isinstance(data, dict): - if "done" in data and data["done"]: + if 'done' in data and data['done']: break # Stop streaming when 'done' is received - yield f"data: {json.dumps(data)}\n\n" + yield f'data: {json.dumps(data)}\n\n' elif isinstance(data, str): - if "data:" in data: - yield f"{data}\n\n" + if 'data:' in data: + yield f'{data}\n\n' else: - yield f"data: {data}\n\n" + yield f'data: {data}\n\n' except Exception as e: - log.debug(f"Error in event generator: {e}") + log.debug(f'Error in event generator: {e}') pass # Define a background task to run the event generator async def background(): try: - del sio.handlers["/"][channel] + del sio.handlers['/'][channel] except Exception as e: pass # Return the streaming response - return StreamingResponse( - event_generator(), media_type="text/event-stream", background=background - ) + return StreamingResponse(event_generator(), media_type='text/event-stream', background=background) else: raise Exception(str(res)) else: res = await event_caller( { - "type": "request:chat:completion", - "data": { - "form_data": form_data, - "model": models[form_data["model"]], - "channel": channel, - "session_id": session_id, + 'type': 'request:chat:completion', + 'data': { + 'form_data': form_data, + 'model': models[form_data['model']], + 'channel': channel, + 'session_id': session_id, }, } ) - if "error" in res and res["error"]: - raise Exception(res["error"]) + if 'error' in res and res['error']: + raise Exception(res['error']) return res @@ -162,7 +160,7 @@ async def generate_chat_completion( bypass_filter: bool = False, bypass_system_prompt: bool = False, ): - log.debug(f"generate_chat_completion: {form_data}") + log.debug(f'generate_chat_completion: {form_data}') if BYPASS_MODEL_ACCESS_CONTROL: bypass_filter = True @@ -170,49 +168,47 @@ async def generate_chat_completion( # handlers (openai/ollama) can read it without exposing it as a query param. request.state.bypass_filter = bypass_filter - if hasattr(request.state, "metadata"): - if "metadata" not in form_data: - form_data["metadata"] = request.state.metadata + if hasattr(request.state, 'metadata'): + if 'metadata' not in form_data: + form_data['metadata'] = request.state.metadata else: - form_data["metadata"] = { - **form_data["metadata"], + form_data['metadata'] = { + **form_data['metadata'], **request.state.metadata, } - if getattr(request.state, "direct", False) and hasattr(request.state, "model"): + if getattr(request.state, 'direct', False) and hasattr(request.state, 'model'): models = { - request.state.model["id"]: request.state.model, + request.state.model['id']: request.state.model, } - log.debug(f"direct connection to model: {models}") + log.debug(f'direct connection to model: {models}') else: models = request.app.state.MODELS - model_id = form_data["model"] + model_id = form_data['model'] if model_id not in models: - raise Exception("Model not found") + raise Exception('Model not found') model = models[model_id] - if getattr(request.state, "direct", False): - return await generate_direct_chat_completion( - request, form_data, user=user, models=models - ) + if getattr(request.state, 'direct', False): + return await generate_direct_chat_completion(request, form_data, user=user, models=models) else: # Check if user has access to the model - if not bypass_filter and user.role == "user": + if not bypass_filter and user.role == 'user': try: check_model_access(user, model) except Exception as e: raise e - if model.get("owned_by") == "arena": - model_ids = model.get("info", {}).get("meta", {}).get("model_ids") - filter_mode = model.get("info", {}).get("meta", {}).get("filter_mode") - if model_ids and filter_mode == "exclude": + if model.get('owned_by') == 'arena': + model_ids = model.get('info', {}).get('meta', {}).get('model_ids') + filter_mode = model.get('info', {}).get('meta', {}).get('filter_mode') + if model_ids and filter_mode == 'exclude': model_ids = [ - model["id"] + model['id'] for model in list(request.app.state.MODELS.values()) - if model.get("owned_by") != "arena" and model["id"] not in model_ids + if model.get('owned_by') != 'arena' and model['id'] not in model_ids ] selected_model_id = None @@ -220,18 +216,16 @@ async def generate_chat_completion( selected_model_id = random.choice(model_ids) else: model_ids = [ - model["id"] - for model in list(request.app.state.MODELS.values()) - if model.get("owned_by") != "arena" + model['id'] for model in list(request.app.state.MODELS.values()) if model.get('owned_by') != 'arena' ] selected_model_id = random.choice(model_ids) - form_data["model"] = selected_model_id + form_data['model'] = selected_model_id - if form_data.get("stream") == True: + if form_data.get('stream') == True: async def stream_wrapper(stream): - yield f"data: {json.dumps({'selected_model_id': selected_model_id})}\n\n" + yield f'data: {json.dumps({"selected_model_id": selected_model_id})}\n\n' async for chunk in stream: yield chunk @@ -244,7 +238,7 @@ async def generate_chat_completion( ) return StreamingResponse( stream_wrapper(response.body_iterator), - media_type="text/event-stream", + media_type='text/event-stream', background=response.background, ) else: @@ -258,15 +252,13 @@ async def generate_chat_completion( bypass_system_prompt=bypass_system_prompt, ) ), - "selected_model_id": selected_model_id, + 'selected_model_id': selected_model_id, } - if model.get("pipe"): + if model.get('pipe'): # Below does not require bypass_filter because this is the only route the uses this function and it is already bypassing the filter - return await generate_function_chat_completion( - request, form_data, user=user, models=models - ) - if model.get("owned_by") == "ollama": + return await generate_function_chat_completion(request, form_data, user=user, models=models) + if model.get('owned_by') == 'ollama': # Using /ollama/api/chat endpoint form_data = convert_payload_openai_to_ollama(form_data) response = await generate_ollama_chat_completion( @@ -275,8 +267,8 @@ async def generate_chat_completion( user=user, bypass_system_prompt=bypass_system_prompt, ) - if form_data.get("stream"): - response.headers["content-type"] = "text/event-stream" + if form_data.get('stream'): + response.headers['content-type'] = 'text/event-stream' return StreamingResponse( convert_streaming_response_ollama_to_openai(response), headers=dict(response.headers), @@ -300,55 +292,53 @@ async def chat_completed(request: Request, form_data: dict, user: Any): if not request.app.state.MODELS: await get_all_models(request, user=user) - if getattr(request.state, "direct", False) and hasattr(request.state, "model"): + if getattr(request.state, 'direct', False) and hasattr(request.state, 'model'): models = { - request.state.model["id"]: request.state.model, + request.state.model['id']: request.state.model, } else: models = request.app.state.MODELS data = form_data - model_id = data["model"] + model_id = data['model'] if model_id not in models: - raise Exception("Model not found") + raise Exception('Model not found') model = models[model_id] try: data = await process_pipeline_outlet_filter(request, data, user, models) except Exception as e: - raise Exception(f"Error: {e}") + raise Exception(f'Error: {e}') metadata = { - "chat_id": data["chat_id"], - "message_id": data["id"], - "filter_ids": data.get("filter_ids", []), - "session_id": data["session_id"], - "user_id": user.id, + 'chat_id': data['chat_id'], + 'message_id': data['id'], + 'filter_ids': data.get('filter_ids', []), + 'session_id': data['session_id'], + 'user_id': user.id, } extra_params = { - "__event_emitter__": get_event_emitter(metadata), - "__event_call__": get_event_call(metadata), - "__user__": user.model_dump() if isinstance(user, UserModel) else {}, - "__metadata__": metadata, - "__request__": request, - "__model__": model, + '__event_emitter__': get_event_emitter(metadata), + '__event_call__': get_event_call(metadata), + '__user__': user.model_dump() if isinstance(user, UserModel) else {}, + '__metadata__': metadata, + '__request__': request, + '__model__': model, } try: - filter_ids = get_sorted_filter_ids( - request, model, metadata.get("filter_ids", []) - ) + filter_ids = get_sorted_filter_ids(request, model, metadata.get('filter_ids', [])) filter_functions = Functions.get_functions_by_ids(filter_ids) result, _ = await process_filter_functions( request=request, filter_functions=filter_functions, - filter_type="outlet", + filter_type='outlet', form_data=data, extra_params=extra_params, ) return result except Exception as e: - raise Exception(f"Error: {e}") + raise Exception(f'Error: {e}') diff --git a/backend/open_webui/utils/code_interpreter.py b/backend/open_webui/utils/code_interpreter.py index a5de56a6c1..3e30c419ae 100644 --- a/backend/open_webui/utils/code_interpreter.py +++ b/backend/open_webui/utils/code_interpreter.py @@ -16,9 +16,9 @@ class ResultModel(BaseModel): Execute Code Result Model """ - stdout: Optional[str] = "" - stderr: Optional[str] = "" - result: Optional[str] = "" + stdout: Optional[str] = '' + stderr: Optional[str] = '' + result: Optional[str] = '' class JupyterCodeExecuter: @@ -30,8 +30,8 @@ class JupyterCodeExecuter: self, base_url: str, code: str, - token: str = "", - password: str = "", + token: str = '', + password: str = '', timeout: int = 60, ): """ @@ -46,9 +46,9 @@ class JupyterCodeExecuter: self.token = token self.password = password self.timeout = timeout - self.kernel_id = "" - if self.base_url[-1] != "/": - self.base_url += "/" + self.kernel_id = '' + if self.base_url[-1] != '/': + self.base_url += '/' self.session = aiohttp.ClientSession(trust_env=True, base_url=self.base_url) self.params = {} self.result = ResultModel() @@ -59,12 +59,10 @@ class JupyterCodeExecuter: async def __aexit__(self, exc_type, exc_val, exc_tb): if self.kernel_id: try: - async with self.session.delete( - f"api/kernels/{self.kernel_id}", params=self.params - ) as response: + async with self.session.delete(f'api/kernels/{self.kernel_id}', params=self.params) as response: response.raise_for_status() except Exception as err: - logger.exception("close kernel failed, %s", err) + logger.exception('close kernel failed, %s', err) await self.session.close() async def run(self) -> ResultModel: @@ -73,23 +71,23 @@ class JupyterCodeExecuter: await self.init_kernel() await self.execute_code() except Exception as err: - logger.exception("execute code failed, %s", err) - self.result.stderr = f"Error: {err}" + logger.exception('execute code failed, %s', err) + self.result.stderr = f'Error: {err}' return self.result async def sign_in(self) -> None: # password authentication if self.password and not self.token: - async with self.session.get("login") as response: + async with self.session.get('login') as response: response.raise_for_status() - xsrf_token = response.cookies["_xsrf"].value + xsrf_token = response.cookies['_xsrf'].value if not xsrf_token: - raise ValueError("_xsrf token not found") + raise ValueError('_xsrf token not found') self.session.cookie_jar.update_cookies(response.cookies) - self.session.headers.update({"X-XSRFToken": xsrf_token}) + self.session.headers.update({'X-XSRFToken': xsrf_token}) async with self.session.post( - "login", - data={"_xsrf": xsrf_token, "password": self.password}, + 'login', + data={'_xsrf': xsrf_token, 'password': self.password}, allow_redirects=False, ) as response: response.raise_for_status() @@ -97,27 +95,22 @@ class JupyterCodeExecuter: # token authentication if self.token: - self.params.update({"token": self.token}) + self.params.update({'token': self.token}) async def init_kernel(self) -> None: - async with self.session.post(url="api/kernels", params=self.params) as response: + async with self.session.post(url='api/kernels', params=self.params) as response: response.raise_for_status() kernel_data = await response.json() - self.kernel_id = kernel_data["id"] + self.kernel_id = kernel_data['id'] def init_ws(self) -> (str, dict): - ws_base = self.base_url.replace("http", "ws", 1) - ws_params = "?" + "&".join([f"{key}={val}" for key, val in self.params.items()]) - websocket_url = f"{ws_base}api/kernels/{self.kernel_id}/channels{ws_params if len(ws_params) > 1 else ''}" + ws_base = self.base_url.replace('http', 'ws', 1) + ws_params = '?' + '&'.join([f'{key}={val}' for key, val in self.params.items()]) + websocket_url = f'{ws_base}api/kernels/{self.kernel_id}/channels{ws_params if len(ws_params) > 1 else ""}' ws_headers = {} if self.password and not self.token: ws_headers = { - "Cookie": "; ".join( - [ - f"{cookie.key}={cookie.value}" - for cookie in self.session.cookie_jar - ] - ), + 'Cookie': '; '.join([f'{cookie.key}={cookie.value}' for cookie in self.session.cookie_jar]), **self.session.headers, } return websocket_url, ws_headers @@ -126,9 +119,7 @@ class JupyterCodeExecuter: # initialize ws websocket_url, ws_headers = self.init_ws() # execute - async with websockets.connect( - websocket_url, additional_headers=ws_headers - ) as ws: + async with websockets.connect(websocket_url, additional_headers=ws_headers) as ws: await self.execute_in_jupyter(ws) async def execute_in_jupyter(self, ws) -> None: @@ -137,71 +128,69 @@ class JupyterCodeExecuter: await ws.send( json.dumps( { - "header": { - "msg_id": msg_id, - "msg_type": "execute_request", - "username": "user", - "session": uuid.uuid4().hex, - "date": "", - "version": "5.3", + 'header': { + 'msg_id': msg_id, + 'msg_type': 'execute_request', + 'username': 'user', + 'session': uuid.uuid4().hex, + 'date': '', + 'version': '5.3', }, - "parent_header": {}, - "metadata": {}, - "content": { - "code": self.code, - "silent": False, - "store_history": True, - "user_expressions": {}, - "allow_stdin": False, - "stop_on_error": True, + 'parent_header': {}, + 'metadata': {}, + 'content': { + 'code': self.code, + 'silent': False, + 'store_history': True, + 'user_expressions': {}, + 'allow_stdin': False, + 'stop_on_error': True, }, - "channel": "shell", + 'channel': 'shell', } ) ) # parse message - stdout, stderr, result = "", "", [] + stdout, stderr, result = '', '', [] while True: try: # wait for message message = await asyncio.wait_for(ws.recv(), self.timeout) message_data = json.loads(message) # msg id not match, skip - if message_data.get("parent_header", {}).get("msg_id") != msg_id: + if message_data.get('parent_header', {}).get('msg_id') != msg_id: continue # check message type - msg_type = message_data.get("msg_type") + msg_type = message_data.get('msg_type') match msg_type: - case "stream": - if message_data["content"]["name"] == "stdout": - stdout += message_data["content"]["text"] - elif message_data["content"]["name"] == "stderr": - stderr += message_data["content"]["text"] - case "execute_result" | "display_data": - data = message_data["content"]["data"] - if "image/png" in data: - result.append(f"data:image/png;base64,{data['image/png']}") - elif "text/plain" in data: - result.append(data["text/plain"]) - case "error": - stderr += "\n".join(message_data["content"]["traceback"]) - case "status": - if message_data["content"]["execution_state"] == "idle": + case 'stream': + if message_data['content']['name'] == 'stdout': + stdout += message_data['content']['text'] + elif message_data['content']['name'] == 'stderr': + stderr += message_data['content']['text'] + case 'execute_result' | 'display_data': + data = message_data['content']['data'] + if 'image/png' in data: + result.append(f'data:image/png;base64,{data["image/png"]}') + elif 'text/plain' in data: + result.append(data['text/plain']) + case 'error': + stderr += '\n'.join(message_data['content']['traceback']) + case 'status': + if message_data['content']['execution_state'] == 'idle': break except asyncio.TimeoutError: - stderr += "\nExecution timed out." + stderr += '\nExecution timed out.' break self.result.stdout = stdout.strip() self.result.stderr = stderr.strip() - self.result.result = "\n".join(result).strip() if result else "" + self.result.result = '\n'.join(result).strip() if result else '' async def execute_code_jupyter( - base_url: str, code: str, token: str = "", password: str = "", timeout: int = 60 + base_url: str, code: str, token: str = '', password: str = '', timeout: int = 60 ) -> dict: - async with JupyterCodeExecuter( - base_url, code, token, password, timeout - ) as executor: + async with JupyterCodeExecuter(base_url, code, token, password, timeout) as executor: result = await executor.run() return result.model_dump() diff --git a/backend/open_webui/utils/embeddings.py b/backend/open_webui/utils/embeddings.py index a2dc080cb5..251b5edf7e 100644 --- a/backend/open_webui/utils/embeddings.py +++ b/backend/open_webui/utils/embeddings.py @@ -43,35 +43,35 @@ async def generate_embeddings( bypass_filter = True # Attach extra metadata from request.state if present - if hasattr(request.state, "metadata"): - if "metadata" not in form_data: - form_data["metadata"] = request.state.metadata + if hasattr(request.state, 'metadata'): + if 'metadata' not in form_data: + form_data['metadata'] = request.state.metadata else: - form_data["metadata"] = { - **form_data["metadata"], + form_data['metadata'] = { + **form_data['metadata'], **request.state.metadata, } # If "direct" flag present, use only that model - if getattr(request.state, "direct", False) and hasattr(request.state, "model"): + if getattr(request.state, 'direct', False) and hasattr(request.state, 'model'): models = { - request.state.model["id"]: request.state.model, + request.state.model['id']: request.state.model, } else: models = request.app.state.MODELS - model_id = form_data.get("model") + model_id = form_data.get('model') if model_id not in models: - raise Exception("Model not found") + raise Exception('Model not found') model = models[model_id] # Access filtering - if not getattr(request.state, "direct", False): - if not bypass_filter and user.role == "user": + if not getattr(request.state, 'direct', False): + if not bypass_filter and user.role == 'user': check_model_access(user, model) # Ollama backend — use /api/embed which supports batch input natively - if model.get("owned_by") == "ollama": + if model.get('owned_by') == 'ollama': ollama_payload = convert_embed_payload_openai_to_ollama(form_data) response = await ollama_embed( request=request, diff --git a/backend/open_webui/utils/files.py b/backend/open_webui/utils/files.py index af8818d59b..3bb918e8da 100644 --- a/backend/open_webui/utils/files.py +++ b/backend/open_webui/utils/files.py @@ -27,22 +27,22 @@ import re import requests -BASE64_IMAGE_URL_PREFIX = re.compile(r"data:image/\w+;base64,", re.IGNORECASE) -MARKDOWN_IMAGE_URL_PATTERN = re.compile(r"!\[(.*?)\]\((.+?)\)", re.IGNORECASE) +BASE64_IMAGE_URL_PREFIX = re.compile(r'data:image/\w+;base64,', re.IGNORECASE) +MARKDOWN_IMAGE_URL_PATTERN = re.compile(r'!\[(.*?)\]\((.+?)\)', re.IGNORECASE) def get_image_base64_from_url(url: str) -> Optional[str]: try: - if url.startswith("http"): + if url.startswith('http'): # Validate URL to prevent SSRF attacks against local/private networks validate_url(url) # Download the image from the URL response = requests.get(url) response.raise_for_status() image_data = response.content - encoded_string = base64.b64encode(image_data).decode("utf-8") - content_type = response.headers.get("Content-Type", "image/png") - return f"data:{content_type};base64,{encoded_string}" + encoded_string = base64.b64encode(image_data).decode('utf-8') + content_type = response.headers.get('Content-Type', 'image/png') + return f'data:{content_type};base64,{encoded_string}' else: file = Files.get_file_by_id(url) @@ -53,10 +53,10 @@ def get_image_base64_from_url(url: str) -> Optional[str]: file_path = Path(file_path) if file_path.is_file(): - with open(file_path, "rb") as image_file: - encoded_string = base64.b64encode(image_file.read()).decode("utf-8") + with open(file_path, 'rb') as image_file: + encoded_string = base64.b64encode(image_file.read()).decode('utf-8') content_type, _ = mimetypes.guess_type(file_path.name) - return f"data:{content_type};base64,{encoded_string}" + return f'data:{content_type};base64,{encoded_string}' else: return None @@ -66,7 +66,7 @@ def get_image_base64_from_url(url: str) -> Optional[str]: def get_image_url_from_base64(request, base64_image_string, metadata, user): if BASE64_IMAGE_URL_PREFIX.match(base64_image_string): - image_url = "" + image_url = '' # Extract base64 image data from the line image_data, content_type = get_image_data(base64_image_string) if image_data is not None: @@ -89,7 +89,7 @@ def convert_markdown_base64_images(request, content: str, metadata, user): if len(base64_string) > MIN_REPLACEMENT_URL_LENGTH: url = get_image_url_from_base64(request, base64_string, metadata, user) if url: - return f"![{match.group(1)}]({url})" + return f'![{match.group(1)}]({url})' return match.group(0) return MARKDOWN_IMAGE_URL_PATTERN.sub(replace, content) @@ -97,18 +97,16 @@ def convert_markdown_base64_images(request, content: str, metadata, user): def load_b64_audio_data(b64_str): try: - if "," in b64_str: - header, b64_data = b64_str.split(",", 1) + if ',' in b64_str: + header, b64_data = b64_str.split(',', 1) else: b64_data = b64_str - header = "data:audio/wav;base64" + header = 'data:audio/wav;base64' audio_data = base64.b64decode(b64_data) - content_type = ( - header.split(";")[0].split(":")[1] if ";" in header else "audio/wav" - ) + content_type = header.split(';')[0].split(':')[1] if ';' in header else 'audio/wav' return audio_data, content_type except Exception as e: - print(f"Error decoding base64 audio data: {e}") + print(f'Error decoding base64 audio data: {e}') return None, None @@ -116,9 +114,9 @@ def upload_audio(request, audio_data, content_type, metadata, user): audio_format = mimetypes.guess_extension(content_type) file = UploadFile( file=io.BytesIO(audio_data), - filename=f"generated-{audio_format}", # will be converted to a unique ID on upload_file + filename=f'generated-{audio_format}', # will be converted to a unique ID on upload_file headers={ - "content-type": content_type, + 'content-type': content_type, }, ) file_item = upload_file_handler( @@ -128,13 +126,13 @@ def upload_audio(request, audio_data, content_type, metadata, user): process=False, user=user, ) - url = request.app.url_path_for("get_file_content_by_id", id=file_item.id) + url = request.app.url_path_for('get_file_content_by_id', id=file_item.id) return url def get_audio_url_from_base64(request, base64_audio_string, metadata, user): - if "data:audio/wav;base64" in base64_audio_string: - audio_url = "" + if 'data:audio/wav;base64' in base64_audio_string: + audio_url = '' # Extract base64 audio data from the line audio_data, content_type = load_b64_audio_data(base64_audio_string) if audio_data is not None: @@ -150,9 +148,9 @@ def get_audio_url_from_base64(request, base64_audio_string, metadata, user): def get_file_url_from_base64(request, base64_file_string, metadata, user): - if "data:image/png;base64" in base64_file_string: + if 'data:image/png;base64' in base64_file_string: return get_image_url_from_base64(request, base64_file_string, metadata, user) - elif "data:audio/wav;base64" in base64_file_string: + elif 'data:audio/wav;base64' in base64_file_string: return get_audio_url_from_base64(request, base64_file_string, metadata, user) return None @@ -170,10 +168,10 @@ def get_image_base64_from_file_id(id: str) -> Optional[str]: if file_path.is_file(): import base64 - with open(file_path, "rb") as image_file: - encoded_string = base64.b64encode(image_file.read()).decode("utf-8") + with open(file_path, 'rb') as image_file: + encoded_string = base64.b64encode(image_file.read()).decode('utf-8') content_type, _ = mimetypes.guess_type(file_path.name) - return f"data:{content_type};base64,{encoded_string}" + return f'data:{content_type};base64,{encoded_string}' else: return None except Exception as e: diff --git a/backend/open_webui/utils/filter.py b/backend/open_webui/utils/filter.py index 9c71f0d651..7f3f4e8ee2 100644 --- a/backend/open_webui/utils/filter.py +++ b/backend/open_webui/utils/filter.py @@ -14,9 +14,7 @@ def get_function_module(request, function_id, load_from_db=True): """ Get the function module by its ID. """ - function_module, _, _ = get_function_module_from_cache( - request, function_id, load_from_db - ) + function_module, _, _ = get_function_module_from_cache(request, function_id, load_from_db) return function_module @@ -24,34 +22,29 @@ def get_sorted_filter_ids(request, model: dict, enabled_filter_ids: list = None) def get_priority(function_id): try: function_module = get_function_module(request, function_id) - if function_module and hasattr(function_module, "Valves"): + if function_module and hasattr(function_module, 'Valves'): valves_db = Functions.get_function_valves_by_id(function_id) valves = function_module.Valves(**(valves_db if valves_db else {})) - return getattr(valves, "priority", 0) + return getattr(valves, 'priority', 0) except Exception: pass return 0 filter_ids = [function.id for function in Functions.get_global_filter_functions()] - if "info" in model and "meta" in model["info"]: - filter_ids.extend(model["info"]["meta"].get("filterIds", [])) + if 'info' in model and 'meta' in model['info']: + filter_ids.extend(model['info']['meta'].get('filterIds', [])) filter_ids = list(set(filter_ids)) - active_filter_ids = [ - function.id - for function in Functions.get_functions_by_type("filter", active_only=True) - ] + active_filter_ids = [function.id for function in Functions.get_functions_by_type('filter', active_only=True)] def get_active_status(filter_id): function_module = get_function_module(request, filter_id) - if getattr(function_module, "toggle", None): + if getattr(function_module, 'toggle', None): return filter_id in (enabled_filter_ids or []) return True - active_filter_ids = [ - filter_id for filter_id in active_filter_ids if get_active_status(filter_id) - ] + active_filter_ids = [filter_id for filter_id in active_filter_ids if get_active_status(filter_id)] filter_ids = [fid for fid in filter_ids if fid in active_filter_ids] filter_ids.sort(key=lambda fid: (get_priority(fid), fid)) @@ -59,9 +52,7 @@ def get_sorted_filter_ids(request, model: dict, enabled_filter_ids: list = None) return filter_ids -async def process_filter_functions( - request, filter_functions, filter_type, form_data, extra_params -): +async def process_filter_functions(request, filter_functions, filter_type, form_data, extra_params): skip_files = None for function in filter_functions: @@ -70,53 +61,47 @@ async def process_filter_functions( if not filter: continue - function_module = get_function_module( - request, filter_id, load_from_db=(filter_type != "stream") - ) + function_module = get_function_module(request, filter_id, load_from_db=(filter_type != 'stream')) # Prepare handler function handler = getattr(function_module, filter_type, None) if not handler: continue # Check if the function has a file_handler variable - if filter_type == "inlet" and hasattr(function_module, "file_handler"): + if filter_type == 'inlet' and hasattr(function_module, 'file_handler'): skip_files = function_module.file_handler # Apply valves to the function - if hasattr(function_module, "valves") and hasattr(function_module, "Valves"): + if hasattr(function_module, 'valves') and hasattr(function_module, 'Valves'): valves = Functions.get_function_valves_by_id(filter_id) - function_module.valves = function_module.Valves( - **(valves if valves else {}) - ) + function_module.valves = function_module.Valves(**(valves if valves else {})) try: # Prepare parameters sig = inspect.signature(handler) - params = {"body": form_data} - if filter_type == "stream": - params = {"event": form_data} + params = {'body': form_data} + if filter_type == 'stream': + params = {'event': form_data} params = params | { k: v for k, v in { **extra_params, - "__id__": filter_id, + '__id__': filter_id, }.items() if k in sig.parameters } # Handle user parameters - if "__user__" in sig.parameters: - if hasattr(function_module, "UserValves"): + if '__user__' in sig.parameters: + if hasattr(function_module, 'UserValves'): try: - params["__user__"]["valves"] = function_module.UserValves( - **Functions.get_user_valves_by_id_and_user_id( - filter_id, params["__user__"]["id"] - ) + params['__user__']['valves'] = function_module.UserValves( + **Functions.get_user_valves_by_id_and_user_id(filter_id, params['__user__']['id']) ) except Exception as e: - log.exception(f"Failed to get user values: {e}") + log.exception(f'Failed to get user values: {e}') # Execute handler if inspect.iscoroutinefunction(handler): @@ -125,14 +110,14 @@ async def process_filter_functions( form_data = handler(**params) except Exception as e: - log.debug(f"Error in {filter_type} handler {filter_id}: {e}") + log.debug(f'Error in {filter_type} handler {filter_id}: {e}') raise e # Handle file cleanup for inlet if skip_files: - if "files" in form_data.get("metadata", {}): - del form_data["metadata"]["files"] - if "files" in form_data: - del form_data["files"] + if 'files' in form_data.get('metadata', {}): + del form_data['metadata']['files'] + if 'files' in form_data: + del form_data['files'] return form_data, {} diff --git a/backend/open_webui/utils/groups.py b/backend/open_webui/utils/groups.py index 26fc5d8434..90c4593cec 100644 --- a/backend/open_webui/utils/groups.py +++ b/backend/open_webui/utils/groups.py @@ -20,6 +20,4 @@ def apply_default_group_assignment( try: Groups.add_users_to_group(default_group_id, [user_id], db=db) except Exception as e: - log.error( - f"Failed to add user {user_id} to default group {default_group_id}: {e}" - ) + log.error(f'Failed to add user {user_id} to default group {default_group_id}: {e}') diff --git a/backend/open_webui/utils/headers.py b/backend/open_webui/utils/headers.py index f0b13c00d3..0baee5edb9 100644 --- a/backend/open_webui/utils/headers.py +++ b/backend/open_webui/utils/headers.py @@ -11,7 +11,7 @@ from open_webui.env import ( def include_user_info_headers(headers, user): return { **headers, - FORWARD_USER_INFO_HEADER_USER_NAME: quote(user.name, safe=" "), + FORWARD_USER_INFO_HEADER_USER_NAME: quote(user.name, safe=' '), FORWARD_USER_INFO_HEADER_USER_ID: user.id, FORWARD_USER_INFO_HEADER_USER_EMAIL: user.email, FORWARD_USER_INFO_HEADER_USER_ROLE: user.role, diff --git a/backend/open_webui/utils/images/comfyui.py b/backend/open_webui/utils/images/comfyui.py index 3c402cbc17..497808c22d 100644 --- a/backend/open_webui/utils/images/comfyui.py +++ b/backend/open_webui/utils/images/comfyui.py @@ -13,99 +13,97 @@ from pydantic import BaseModel log = logging.getLogger(__name__) -default_headers = {"User-Agent": "Mozilla/5.0"} +default_headers = {'User-Agent': 'Mozilla/5.0'} def queue_prompt(prompt, client_id, base_url, api_key): - log.info("queue_prompt") - p = {"prompt": prompt, "client_id": client_id} - data = json.dumps(p).encode("utf-8") - log.debug(f"queue_prompt data: {data}") + log.info('queue_prompt') + p = {'prompt': prompt, 'client_id': client_id} + data = json.dumps(p).encode('utf-8') + log.debug(f'queue_prompt data: {data}') try: req = urllib.request.Request( - f"{base_url}/prompt", + f'{base_url}/prompt', data=data, - headers={**default_headers, "Authorization": f"Bearer {api_key}"}, + headers={**default_headers, 'Authorization': f'Bearer {api_key}'}, ) response = urllib.request.urlopen(req).read() return json.loads(response) except Exception as e: - log.exception(f"Error while queuing prompt: {e}") + log.exception(f'Error while queuing prompt: {e}') raise e def get_image(filename, subfolder, folder_type, base_url, api_key): - log.info("get_image") - data = {"filename": filename, "subfolder": subfolder, "type": folder_type} + log.info('get_image') + data = {'filename': filename, 'subfolder': subfolder, 'type': folder_type} url_values = urllib.parse.urlencode(data) req = urllib.request.Request( - f"{base_url}/view?{url_values}", - headers={**default_headers, "Authorization": f"Bearer {api_key}"}, + f'{base_url}/view?{url_values}', + headers={**default_headers, 'Authorization': f'Bearer {api_key}'}, ) with urllib.request.urlopen(req) as response: return response.read() def get_image_url(filename, subfolder, folder_type, base_url): - log.info("get_image") - data = {"filename": filename, "subfolder": subfolder, "type": folder_type} + log.info('get_image') + data = {'filename': filename, 'subfolder': subfolder, 'type': folder_type} url_values = urllib.parse.urlencode(data) - return f"{base_url}/view?{url_values}" + return f'{base_url}/view?{url_values}' def get_history(prompt_id, base_url, api_key): - log.info("get_history") + log.info('get_history') req = urllib.request.Request( - f"{base_url}/history/{prompt_id}", - headers={**default_headers, "Authorization": f"Bearer {api_key}"}, + f'{base_url}/history/{prompt_id}', + headers={**default_headers, 'Authorization': f'Bearer {api_key}'}, ) with urllib.request.urlopen(req) as response: return json.loads(response.read()) def get_images(ws, workflow, client_id, base_url, api_key): - prompt_id = queue_prompt(workflow, client_id, base_url, api_key)["prompt_id"] + prompt_id = queue_prompt(workflow, client_id, base_url, api_key)['prompt_id'] output_images = [] while True: out = ws.recv() if isinstance(out, str): message = json.loads(out) - if message["type"] == "executing": - data = message["data"] - if data["node"] is None and data["prompt_id"] == prompt_id: + if message['type'] == 'executing': + data = message['data'] + if data['node'] is None and data['prompt_id'] == prompt_id: break # Execution is done else: continue # previews are binary data history = get_history(prompt_id, base_url, api_key)[prompt_id] - for node_id in history["outputs"]: - node_output = history["outputs"][node_id] - if node_id in workflow and workflow[node_id].get("class_type") in [ - "SaveImage", - "PreviewImage", + for node_id in history['outputs']: + node_output = history['outputs'][node_id] + if node_id in workflow and workflow[node_id].get('class_type') in [ + 'SaveImage', + 'PreviewImage', ]: - if "images" in node_output: - for image in node_output["images"]: - url = get_image_url( - image["filename"], image["subfolder"], image["type"], base_url - ) - output_images.append({"url": url}) - return {"data": output_images} + if 'images' in node_output: + for image in node_output['images']: + url = get_image_url(image['filename'], image['subfolder'], image['type'], base_url) + output_images.append({'url': url}) + return {'data': output_images} async def comfyui_upload_image(image_file_item, base_url, api_key): - url = f"{base_url}/api/upload/image" + url = f'{base_url}/api/upload/image' headers = {} if api_key: - headers["Authorization"] = f"Bearer {api_key}" + headers['Authorization'] = f'Bearer {api_key}' _, (filename, file_bytes, mime_type) = image_file_item form = aiohttp.FormData() - form.add_field("image", file_bytes, filename=filename, content_type=mime_type) - form.add_field("type", "input") # required by ComfyUI + form.add_field('image', file_bytes, filename=filename, content_type=mime_type) + form.add_field('type', 'input') # required by ComfyUI async with aiohttp.ClientSession() as session: async with session.post(url, data=form, headers=headers) as resp: @@ -116,7 +114,7 @@ async def comfyui_upload_image(image_file_item, base_url, api_key): class ComfyUINodeInput(BaseModel): type: Optional[str] = None node_ids: list[str] = [] - key: Optional[str] = "text" + key: Optional[str] = 'text' value: Optional[str] = None @@ -138,76 +136,56 @@ class ComfyUICreateImageForm(BaseModel): seed: Optional[int] = None -async def comfyui_create_image( - model: str, payload: ComfyUICreateImageForm, client_id, base_url, api_key -): - ws_url = base_url.replace("http://", "ws://").replace("https://", "wss://") +async def comfyui_create_image(model: str, payload: ComfyUICreateImageForm, client_id, base_url, api_key): + ws_url = base_url.replace('http://', 'ws://').replace('https://', 'wss://') workflow = json.loads(payload.workflow.workflow) for node in payload.workflow.nodes: if node.type: - if node.type == "model": + if node.type == 'model': for node_id in node.node_ids: - workflow[node_id]["inputs"][node.key] = model - elif node.type == "prompt": + workflow[node_id]['inputs'][node.key] = model + elif node.type == 'prompt': for node_id in node.node_ids: - workflow[node_id]["inputs"][ - node.key if node.key else "text" - ] = payload.prompt - elif node.type == "negative_prompt": + workflow[node_id]['inputs'][node.key if node.key else 'text'] = payload.prompt + elif node.type == 'negative_prompt': for node_id in node.node_ids: - workflow[node_id]["inputs"][ - node.key if node.key else "text" - ] = payload.negative_prompt - elif node.type == "width": + workflow[node_id]['inputs'][node.key if node.key else 'text'] = payload.negative_prompt + elif node.type == 'width': for node_id in node.node_ids: - workflow[node_id]["inputs"][ - node.key if node.key else "width" - ] = payload.width - elif node.type == "height": + workflow[node_id]['inputs'][node.key if node.key else 'width'] = payload.width + elif node.type == 'height': for node_id in node.node_ids: - workflow[node_id]["inputs"][ - node.key if node.key else "height" - ] = payload.height - elif node.type == "n": + workflow[node_id]['inputs'][node.key if node.key else 'height'] = payload.height + elif node.type == 'n': for node_id in node.node_ids: - workflow[node_id]["inputs"][ - node.key if node.key else "batch_size" - ] = payload.n - elif node.type == "steps": + workflow[node_id]['inputs'][node.key if node.key else 'batch_size'] = payload.n + elif node.type == 'steps': for node_id in node.node_ids: - workflow[node_id]["inputs"][ - node.key if node.key else "steps" - ] = payload.steps - elif node.type == "seed": - seed = ( - payload.seed - if payload.seed - else random.randint(0, 1125899906842624) - ) + workflow[node_id]['inputs'][node.key if node.key else 'steps'] = payload.steps + elif node.type == 'seed': + seed = payload.seed if payload.seed else random.randint(0, 1125899906842624) for node_id in node.node_ids: - workflow[node_id]["inputs"][node.key] = seed + workflow[node_id]['inputs'][node.key] = seed else: for node_id in node.node_ids: - workflow[node_id]["inputs"][node.key] = node.value + workflow[node_id]['inputs'][node.key] = node.value try: ws = websocket.WebSocket() - headers = {"Authorization": f"Bearer {api_key}"} - ws.connect(f"{ws_url}/ws?clientId={client_id}", header=headers) - log.info("WebSocket connection established.") + headers = {'Authorization': f'Bearer {api_key}'} + ws.connect(f'{ws_url}/ws?clientId={client_id}', header=headers) + log.info('WebSocket connection established.') except Exception as e: - log.exception(f"Failed to connect to WebSocket server: {e}") + log.exception(f'Failed to connect to WebSocket server: {e}') return None try: - log.info("Sending workflow to WebSocket server.") - log.info(f"Workflow: {workflow}") - images = await asyncio.to_thread( - get_images, ws, workflow, client_id, base_url, api_key - ) + log.info('Sending workflow to WebSocket server.') + log.info(f'Workflow: {workflow}') + images = await asyncio.to_thread(get_images, ws, workflow, client_id, base_url, api_key) except Exception as e: - log.exception(f"Error while receiving images: {e}") + log.exception(f'Error while receiving images: {e}') images = None ws.close() @@ -228,85 +206,65 @@ class ComfyUIEditImageForm(BaseModel): seed: Optional[int] = None -async def comfyui_edit_image( - model: str, payload: ComfyUIEditImageForm, client_id, base_url, api_key -): - ws_url = base_url.replace("http://", "ws://").replace("https://", "wss://") +async def comfyui_edit_image(model: str, payload: ComfyUIEditImageForm, client_id, base_url, api_key): + ws_url = base_url.replace('http://', 'ws://').replace('https://', 'wss://') workflow = json.loads(payload.workflow.workflow) for node in payload.workflow.nodes: if node.type: - if node.type == "model": + if node.type == 'model': for node_id in node.node_ids: - workflow[node_id]["inputs"][node.key] = model - elif node.type == "image": + workflow[node_id]['inputs'][node.key] = model + elif node.type == 'image': if isinstance(payload.image, list): # check if multiple images are provided for idx, node_id in enumerate(node.node_ids): if idx < len(payload.image): - workflow[node_id]["inputs"][node.key] = payload.image[idx] + workflow[node_id]['inputs'][node.key] = payload.image[idx] else: for node_id in node.node_ids: - workflow[node_id]["inputs"][node.key] = payload.image - elif node.type == "prompt": + workflow[node_id]['inputs'][node.key] = payload.image + elif node.type == 'prompt': for node_id in node.node_ids: - workflow[node_id]["inputs"][ - node.key if node.key else "text" - ] = payload.prompt - elif node.type == "negative_prompt": + workflow[node_id]['inputs'][node.key if node.key else 'text'] = payload.prompt + elif node.type == 'negative_prompt': for node_id in node.node_ids: - workflow[node_id]["inputs"][ - node.key if node.key else "text" - ] = payload.negative_prompt - elif node.type == "width": + workflow[node_id]['inputs'][node.key if node.key else 'text'] = payload.negative_prompt + elif node.type == 'width': for node_id in node.node_ids: - workflow[node_id]["inputs"][ - node.key if node.key else "width" - ] = payload.width - elif node.type == "height": + workflow[node_id]['inputs'][node.key if node.key else 'width'] = payload.width + elif node.type == 'height': for node_id in node.node_ids: - workflow[node_id]["inputs"][ - node.key if node.key else "height" - ] = payload.height - elif node.type == "n": + workflow[node_id]['inputs'][node.key if node.key else 'height'] = payload.height + elif node.type == 'n': for node_id in node.node_ids: - workflow[node_id]["inputs"][ - node.key if node.key else "batch_size" - ] = payload.n - elif node.type == "steps": + workflow[node_id]['inputs'][node.key if node.key else 'batch_size'] = payload.n + elif node.type == 'steps': for node_id in node.node_ids: - workflow[node_id]["inputs"][ - node.key if node.key else "steps" - ] = payload.steps - elif node.type == "seed": - seed = ( - payload.seed - if payload.seed - else random.randint(0, 1125899906842624) - ) + workflow[node_id]['inputs'][node.key if node.key else 'steps'] = payload.steps + elif node.type == 'seed': + seed = payload.seed if payload.seed else random.randint(0, 1125899906842624) for node_id in node.node_ids: - workflow[node_id]["inputs"][node.key] = seed + workflow[node_id]['inputs'][node.key] = seed else: for node_id in node.node_ids: - workflow[node_id]["inputs"][node.key] = node.value + workflow[node_id]['inputs'][node.key] = node.value try: ws = websocket.WebSocket() - headers = {"Authorization": f"Bearer {api_key}"} - ws.connect(f"{ws_url}/ws?clientId={client_id}", header=headers) - log.info("WebSocket connection established.") + headers = {'Authorization': f'Bearer {api_key}'} + ws.connect(f'{ws_url}/ws?clientId={client_id}', header=headers) + log.info('WebSocket connection established.') except Exception as e: - log.exception(f"Failed to connect to WebSocket server: {e}") + log.exception(f'Failed to connect to WebSocket server: {e}') return None try: - log.info("Sending workflow to WebSocket server.") - log.info(f"Workflow: {workflow}") - images = await asyncio.to_thread( - get_images, ws, workflow, client_id, base_url, api_key - ) + log.info('Sending workflow to WebSocket server.') + log.info(f'Workflow: {workflow}') + images = await asyncio.to_thread(get_images, ws, workflow, client_id, base_url, api_key) except Exception as e: - log.exception(f"Error while receiving images: {e}") + log.exception(f'Error while receiving images: {e}') images = None ws.close() diff --git a/backend/open_webui/utils/logger.py b/backend/open_webui/utils/logger.py index 26a525fc0b..5cc34fe923 100644 --- a/backend/open_webui/utils/logger.py +++ b/backend/open_webui/utils/logger.py @@ -23,7 +23,7 @@ if TYPE_CHECKING: from loguru import Message, Record -def stdout_format(record: "Record") -> str: +def stdout_format(record: 'Record') -> str: """ Generates a formatted string for log records that are output to the console. This format includes a timestamp, log level, source location (module, function, and line), the log message, and any extra data (serialized as JSON). @@ -32,39 +32,39 @@ def stdout_format(record: "Record") -> str: Returns: str: A formatted log string intended for stdout. """ - if record["extra"]: - record["extra"]["extra_json"] = json.dumps(record["extra"]) - extra_format = " - {extra[extra_json]}" + if record['extra']: + record['extra']['extra_json'] = json.dumps(record['extra']) + extra_format = ' - {extra[extra_json]}' else: - extra_format = "" + extra_format = '' return ( - "{time:YYYY-MM-DD HH:mm:ss.SSS} | " - "{level: <8} | " - "{name}:{function}:{line} - " - "{message}" + extra_format + "\n{exception}" + '{time:YYYY-MM-DD HH:mm:ss.SSS} | ' + '{level: <8} | ' + '{name}:{function}:{line} - ' + '{message}' + extra_format + '\n{exception}' ) -def _json_sink(message: "Message") -> None: +def _json_sink(message: 'Message') -> None: """Write log records as single-line JSON to stdout. Used as a Loguru sink when LOG_FORMAT is set to "json". """ record = message.record log_entry = { - "ts": record["time"].strftime("%Y-%m-%dT%H:%M:%S.%f")[:-3] + "Z", - "level": _LEVEL_MAP.get(record["level"].name, record["level"].name.lower()), - "msg": record["message"], - "caller": f"{record['name']}:{record['function']}:{record['line']}", + 'ts': record['time'].strftime('%Y-%m-%dT%H:%M:%S.%f')[:-3] + 'Z', + 'level': _LEVEL_MAP.get(record['level'].name, record['level'].name.lower()), + 'msg': record['message'], + 'caller': f'{record["name"]}:{record["function"]}:{record["line"]}', } - if record["extra"]: - log_entry["extra"] = record["extra"] + if record['extra']: + log_entry['extra'] = record['extra'] - if record["exception"] is not None: - log_entry["error"] = "".join(record["exception"].format_exception()).rstrip() + if record['exception'] is not None: + log_entry['error'] = ''.join(record['exception'].format_exception()).rstrip() - sys.stdout.write(json.dumps(log_entry, ensure_ascii=False, default=str) + "\n") + sys.stdout.write(json.dumps(log_entry, ensure_ascii=False, default=str) + '\n') sys.stdout.flush() @@ -90,9 +90,7 @@ class InterceptHandler(logging.Handler): frame = frame.f_back depth += 1 - logger.opt(depth=depth, exception=record.exc_info).bind( - **self._get_extras() - ).log(level, record.getMessage()) + logger.opt(depth=depth, exception=record.exc_info).bind(**self._get_extras()).log(level, record.getMessage()) if ENABLE_OTEL and ENABLE_OTEL_LOGS: from open_webui.utils.telemetry.logs import otel_handler @@ -105,12 +103,12 @@ class InterceptHandler(logging.Handler): extras = {} context = trace.get_current_span().get_span_context() if context.is_valid: - extras["trace_id"] = trace.format_trace_id(context.trace_id) - extras["span_id"] = trace.format_span_id(context.span_id) + extras['trace_id'] = trace.format_trace_id(context.trace_id) + extras['span_id'] = trace.format_span_id(context.span_id) return extras -def file_format(record: "Record"): +def file_format(record: 'Record'): """ Formats audit log records into a structured JSON string for file output. @@ -121,22 +119,22 @@ def file_format(record: "Record"): """ audit_data = { - "id": record["extra"].get("id", ""), - "timestamp": int(record["time"].timestamp()), - "user": record["extra"].get("user", dict()), - "audit_level": record["extra"].get("audit_level", ""), - "verb": record["extra"].get("verb", ""), - "request_uri": record["extra"].get("request_uri", ""), - "response_status_code": record["extra"].get("response_status_code", 0), - "source_ip": record["extra"].get("source_ip", ""), - "user_agent": record["extra"].get("user_agent", ""), - "request_object": record["extra"].get("request_object", b""), - "response_object": record["extra"].get("response_object", b""), - "extra": record["extra"].get("extra", {}), + 'id': record['extra'].get('id', ''), + 'timestamp': int(record['time'].timestamp()), + 'user': record['extra'].get('user', dict()), + 'audit_level': record['extra'].get('audit_level', ''), + 'verb': record['extra'].get('verb', ''), + 'request_uri': record['extra'].get('request_uri', ''), + 'response_status_code': record['extra'].get('response_status_code', 0), + 'source_ip': record['extra'].get('source_ip', ''), + 'user_agent': record['extra'].get('user_agent', ''), + 'request_object': record['extra'].get('request_object', b''), + 'response_object': record['extra'].get('response_object', b''), + 'extra': record['extra'].get('extra', {}), } - record["extra"]["file_extra"] = json.dumps(audit_data, default=str) - return "{extra[file_extra]}\n" + record['extra']['file_extra'] = json.dumps(audit_data, default=str) + return '{extra[file_extra]}\n' def start_logger(): @@ -152,10 +150,8 @@ def start_logger(): """ logger.remove() - audit_filter = lambda record: ( - True if ENABLE_AUDIT_STDOUT else "auditable" not in record["extra"] - ) - if LOG_FORMAT == "json": + audit_filter = lambda record: (True if ENABLE_AUDIT_STDOUT else 'auditable' not in record['extra']) + if LOG_FORMAT == 'json': logger.add( _json_sink, level=GLOBAL_LOG_LEVEL, @@ -168,24 +164,22 @@ def start_logger(): format=stdout_format, filter=audit_filter, ) - if AUDIT_LOG_LEVEL != "NONE" and ENABLE_AUDIT_LOGS_FILE: + if AUDIT_LOG_LEVEL != 'NONE' and ENABLE_AUDIT_LOGS_FILE: try: logger.add( AUDIT_LOGS_FILE_PATH, - level="INFO", + level='INFO', rotation=AUDIT_LOG_FILE_ROTATION_SIZE, - compression="zip", + compression='zip', format=file_format, - filter=lambda record: record["extra"].get("auditable") is True, + filter=lambda record: record['extra'].get('auditable') is True, ) except Exception as e: - logger.error(f"Failed to initialize audit log file handler: {str(e)}") + logger.error(f'Failed to initialize audit log file handler: {str(e)}') - logging.basicConfig( - handlers=[InterceptHandler()], level=GLOBAL_LOG_LEVEL, force=True - ) + logging.basicConfig(handlers=[InterceptHandler()], level=GLOBAL_LOG_LEVEL, force=True) - for uvicorn_logger_name in ["uvicorn", "uvicorn.error"]: + for uvicorn_logger_name in ['uvicorn', 'uvicorn.error']: uvicorn_logger = logging.getLogger(uvicorn_logger_name) uvicorn_logger.setLevel(GLOBAL_LOG_LEVEL) uvicorn_logger.handlers = [] @@ -195,4 +189,4 @@ def start_logger(): uvicorn_logger.setLevel(GLOBAL_LOG_LEVEL) uvicorn_logger.handlers = [InterceptHandler()] - logger.info(f"GLOBAL_LOG_LEVEL: {GLOBAL_LOG_LEVEL}") + logger.info(f'GLOBAL_LOG_LEVEL: {GLOBAL_LOG_LEVEL}') diff --git a/backend/open_webui/utils/mcp/client.py b/backend/open_webui/utils/mcp/client.py index 1004536e4d..fbabb390aa 100644 --- a/backend/open_webui/utils/mcp/client.py +++ b/backend/open_webui/utils/mcp/client.py @@ -20,15 +20,15 @@ def create_insecure_httpx_client(headers=None, timeout=None, auth=None): after construction does not affect the underlying transport's SSL context. """ kwargs = { - "follow_redirects": True, - "verify": False, + 'follow_redirects': True, + 'verify': False, } if timeout is not None: - kwargs["timeout"] = timeout + kwargs['timeout'] = timeout if headers is not None: - kwargs["headers"] = headers + kwargs['headers'] = headers if auth is not None: - kwargs["auth"] = auth + kwargs['auth'] = auth return httpx.AsyncClient(**kwargs) @@ -52,13 +52,9 @@ class MCPClient: transport = await exit_stack.enter_async_context(self._streams_context) read_stream, write_stream, _ = transport - self._session_context = ClientSession( - read_stream, write_stream - ) # pylint: disable=W0201 + self._session_context = ClientSession(read_stream, write_stream) # pylint: disable=W0201 - self.session = await exit_stack.enter_async_context( - self._session_context - ) + self.session = await exit_stack.enter_async_context(self._session_context) with anyio.fail_after(10): await self.session.initialize() self.exit_stack = exit_stack.pop_all() @@ -68,7 +64,7 @@ class MCPClient: async def list_tool_specs(self) -> Optional[dict]: if not self.session: - raise RuntimeError("MCP client is not connected.") + raise RuntimeError('MCP client is not connected.') result = await self.session.list_tools() tools = result.tools @@ -81,26 +77,22 @@ class MCPClient: inputSchema = tool.inputSchema # TODO: handle outputSchema if needed - outputSchema = getattr(tool, "outputSchema", None) + outputSchema = getattr(tool, 'outputSchema', None) - tool_specs.append( - {"name": name, "description": description, "parameters": inputSchema} - ) + tool_specs.append({'name': name, 'description': description, 'parameters': inputSchema}) return tool_specs - async def call_tool( - self, function_name: str, function_args: dict - ) -> Optional[dict]: + async def call_tool(self, function_name: str, function_args: dict) -> Optional[dict]: if not self.session: - raise RuntimeError("MCP client is not connected.") + raise RuntimeError('MCP client is not connected.') result = await self.session.call_tool(function_name, function_args) if not result: - raise Exception("No result returned from MCP tool call.") + raise Exception('No result returned from MCP tool call.') - result_dict = result.model_dump(mode="json") - result_content = result_dict.get("content", {}) + result_dict = result.model_dump(mode='json') + result_content = result_dict.get('content', {}) if result.isError: raise Exception(result_content) @@ -109,24 +101,24 @@ class MCPClient: async def list_resources(self, cursor: Optional[str] = None) -> Optional[dict]: if not self.session: - raise RuntimeError("MCP client is not connected.") + raise RuntimeError('MCP client is not connected.') result = await self.session.list_resources(cursor=cursor) if not result: - raise Exception("No result returned from MCP list_resources call.") + raise Exception('No result returned from MCP list_resources call.') result_dict = result.model_dump() - resources = result_dict.get("resources", []) + resources = result_dict.get('resources', []) return resources async def read_resource(self, uri: str) -> Optional[dict]: if not self.session: - raise RuntimeError("MCP client is not connected.") + raise RuntimeError('MCP client is not connected.') result = await self.session.read_resource(uri) if not result: - raise Exception("No result returned from MCP read_resource call.") + raise Exception('No result returned from MCP read_resource call.') result_dict = result.model_dump() return result_dict diff --git a/backend/open_webui/utils/middleware.py b/backend/open_webui/utils/middleware.py index f0a79ef4b1..ae1b557da7 100644 --- a/backend/open_webui/utils/middleware.py +++ b/backend/open_webui/utils/middleware.py @@ -144,22 +144,22 @@ log = logging.getLogger(__name__) DEFAULT_REASONING_TAGS = [ - ("", ""), - ("", ""), - ("", ""), - ("", ""), - ("", ""), - ("", ""), - ("<|begin_of_thought|>", "<|end_of_thought|>"), - ("◁think▷", "◁/think▷"), + ('', ''), + ('', ''), + ('', ''), + ('', ''), + ('', ''), + ('', ''), + ('<|begin_of_thought|>', '<|end_of_thought|>'), + ('◁think▷', '◁/think▷'), ] -DEFAULT_SOLUTION_TAGS = [("<|begin_of_solution|>", "<|end_of_solution|>")] -DEFAULT_CODE_INTERPRETER_TAGS = [("", "")] +DEFAULT_SOLUTION_TAGS = [('<|begin_of_solution|>', '<|end_of_solution|>')] +DEFAULT_CODE_INTERPRETER_TAGS = [('', '')] def output_id(prefix: str) -> str: """Generate OR-style ID: prefix + 24-char hex UUID.""" - return f"{prefix}_{uuid4().hex[:24]}" + return f'{prefix}_{uuid4().hex[:24]}' def _split_tool_calls( @@ -196,7 +196,7 @@ def _split_tool_calls( expanded = [] for tool_call in tool_calls: - arguments = tool_call.get("function", {}).get("arguments", "") + arguments = tool_call.get('function', {}).get('arguments', '') split_arguments = split_json_objects(arguments) if len(split_arguments) <= 1: @@ -204,15 +204,15 @@ def _split_tool_calls( else: for argument in split_arguments: cloned = copy.deepcopy(tool_call) - cloned["id"] = f"call_{uuid4().hex[:24]}" - cloned["function"]["arguments"] = argument + cloned['id'] = f'call_{uuid4().hex[:24]}' + cloned['function']['arguments'] = argument expanded.append(cloned) return expanded def get_citation_source_from_tool_result( - tool_name: str, tool_params: dict, tool_result: str, tool_id: str = "" + tool_name: str, tool_params: dict, tool_result: str, tool_id: str = '' ) -> list[dict]: """ Parse a tool's result and convert it to source dicts for citation display. @@ -224,15 +224,15 @@ def get_citation_source_from_tool_result( Returns a list of sources (usually one, but query_knowledge_files may return multiple). """ - _EXPECTS_LIST = {"search_web", "query_knowledge_files"} - _EXPECTS_DICT = {"view_knowledge_file"} + _EXPECTS_LIST = {'search_web', 'query_knowledge_files'} + _EXPECTS_DICT = {'view_knowledge_file'} try: try: tool_result = json.loads(tool_result) except (json.JSONDecodeError, TypeError): pass # keep tool_result as-is (e.g. fetch_url returns plain text) - if isinstance(tool_result, dict) and "error" in tool_result: + if isinstance(tool_result, dict) and 'error' in tool_result: return [] # Validate tool_result type based on what the branch expects @@ -241,83 +241,79 @@ def get_citation_source_from_tool_result( elif tool_name in _EXPECTS_DICT and not isinstance(tool_result, dict): return [] - if tool_name == "search_web": + if tool_name == 'search_web': # Parse JSON array: [{"title": "...", "link": "...", "snippet": "..."}] results = tool_result documents = [] metadata = [] for result in results: - title = result.get("title", "") - link = result.get("link", "") - snippet = result.get("snippet", "") + title = result.get('title', '') + link = result.get('link', '') + snippet = result.get('snippet', '') - documents.append(f"{title}\n{snippet}") + documents.append(f'{title}\n{snippet}') metadata.append( { - "source": link, - "name": title, - "url": link, + 'source': link, + 'name': title, + 'url': link, } ) return [ { - "source": {"name": "search_web", "id": "search_web"}, - "document": documents, - "metadata": metadata, + 'source': {'name': 'search_web', 'id': 'search_web'}, + 'document': documents, + 'metadata': metadata, } ] - elif tool_name == "view_knowledge_file": + elif tool_name == 'view_knowledge_file': file_data = tool_result - filename = file_data.get("filename", "Unknown File") - file_id = file_data.get("id", "") - knowledge_name = file_data.get("knowledge_name", "") + filename = file_data.get('filename', 'Unknown File') + file_id = file_data.get('id', '') + knowledge_name = file_data.get('knowledge_name', '') return [ { - "source": { - "id": file_id, - "name": filename, - "type": "file", + 'source': { + 'id': file_id, + 'name': filename, + 'type': 'file', }, - "document": [file_data.get("content", "")], - "metadata": [ + 'document': [file_data.get('content', '')], + 'metadata': [ { - "file_id": file_id, - "name": filename, - "source": filename, - **( - {"knowledge_name": knowledge_name} - if knowledge_name - else {} - ), + 'file_id': file_id, + 'name': filename, + 'source': filename, + **({'knowledge_name': knowledge_name} if knowledge_name else {}), } ], } ] - elif tool_name == "fetch_url": - url = tool_params.get("url", "") + elif tool_name == 'fetch_url': + url = tool_params.get('url', '') content = tool_result if isinstance(tool_result, str) else str(tool_result) - snippet = content[:500] + ("..." if len(content) > 500 else "") + snippet = content[:500] + ('...' if len(content) > 500 else '') return [ { - "source": {"name": url or "fetch_url", "id": url or "fetch_url"}, - "document": [snippet], - "metadata": [ + 'source': {'name': url or 'fetch_url', 'id': url or 'fetch_url'}, + 'document': [snippet], + 'metadata': [ { - "source": url, - "name": url, - "url": url, + 'source': url, + 'name': url, + 'url': url, } ], } ] - elif tool_name == "query_knowledge_files": + elif tool_name == 'query_knowledge_files': chunks = tool_result # Group chunks by source for better citation display @@ -325,33 +321,33 @@ def get_citation_source_from_tool_result( sources_by_file = {} for chunk in chunks: - source_name = chunk.get("source", "Unknown") - file_id = chunk.get("file_id", "") - note_id = chunk.get("note_id", "") - chunk_type = chunk.get("type", "file") - content = chunk.get("content", "") + source_name = chunk.get('source', 'Unknown') + file_id = chunk.get('file_id', '') + note_id = chunk.get('note_id', '') + chunk_type = chunk.get('type', 'file') + content = chunk.get('content', '') # Use file_id or note_id as the key key = file_id or note_id or source_name if key not in sources_by_file: sources_by_file[key] = { - "source": { - "id": file_id or note_id, - "name": source_name, - "type": chunk_type, + 'source': { + 'id': file_id or note_id, + 'name': source_name, + 'type': chunk_type, }, - "document": [], - "metadata": [], + 'document': [], + 'metadata': [], } - sources_by_file[key]["document"].append(content) - sources_by_file[key]["metadata"].append( + sources_by_file[key]['document'].append(content) + sources_by_file[key]['metadata'].append( { - "file_id": file_id, - "name": source_name, - "source": source_name, - **({"note_id": note_id} if note_id else {}), + 'file_id': file_id, + 'name': source_name, + 'source': source_name, + **({'note_id': note_id} if note_id else {}), } ) @@ -366,36 +362,34 @@ def get_citation_source_from_tool_result( # Fallback for other tools return [ { - "source": { - "name": tool_name, - "type": "tool", - "id": tool_id or tool_name, + 'source': { + 'name': tool_name, + 'type': 'tool', + 'id': tool_id or tool_name, }, - "document": [str(tool_result)], - "metadata": [{"source": tool_name, "name": tool_name}], + 'document': [str(tool_result)], + 'metadata': [{'source': tool_name, 'name': tool_name}], } ] except Exception as e: - log.exception(f"Error parsing tool result for {tool_name}: {e}") + log.exception(f'Error parsing tool result for {tool_name}: {e}') return [ { - "source": {"name": tool_name, "type": "tool"}, - "document": [str(tool_result)], - "metadata": [{"source": tool_name}], + 'source': {'name': tool_name, 'type': 'tool'}, + 'document': [str(tool_result)], + 'metadata': [{'source': tool_name}], } ] def split_content_and_whitespace(content): content_stripped = content.rstrip() - original_whitespace = ( - content[len(content_stripped) :] if len(content) > len(content_stripped) else "" - ) + original_whitespace = content[len(content_stripped) :] if len(content) > len(content_stripped) else '' return content_stripped, original_whitespace def is_opening_code_block(content): - backtick_segments = content.split("```") + backtick_segments = content.split('```') # Even number of segments means the last backticks are opening a new block return len(backtick_segments) > 1 and len(backtick_segments) % 2 == 0 @@ -405,128 +399,119 @@ def serialize_output(output: list) -> str: Convert OR-aligned output items to HTML for display. For LLM consumption, use convert_output_to_messages() instead. """ - content = "" + content = '' # First pass: collect function_call_output items by call_id for lookup tool_outputs = {} for item in output: - if item.get("type") == "function_call_output": - tool_outputs[item.get("call_id")] = item + if item.get('type') == 'function_call_output': + tool_outputs[item.get('call_id')] = item # Second pass: render items in order for idx, item in enumerate(output): - item_type = item.get("type", "") + item_type = item.get('type', '') - if item_type == "message": - for content_part in item.get("content", []): - if "text" in content_part: - text = content_part.get("text", "").strip() + if item_type == 'message': + for content_part in item.get('content', []): + if 'text' in content_part: + text = content_part.get('text', '').strip() if text: - content = f"{content}{text}\n" + content = f'{content}{text}\n' - elif item_type == "function_call": + elif item_type == 'function_call': # Render tool call inline with its result (if available) - if content and not content.endswith("\n"): - content += "\n" + if content and not content.endswith('\n'): + content += '\n' - call_id = item.get("call_id", "") - name = item.get("name", "") - arguments = item.get("arguments", "") + call_id = item.get('call_id', '') + name = item.get('name', '') + arguments = item.get('arguments', '') result_item = tool_outputs.get(call_id) if result_item: - result_text = "" - for result_output in result_item.get("output", []): - if "text" in result_output: - output_text = result_output.get("text", "") - result_text += ( - str(output_text) - if not isinstance(output_text, str) - else output_text - ) - files = result_item.get("files") - embeds = result_item.get("embeds", "") + result_text = '' + for result_output in result_item.get('output', []): + if 'text' in result_output: + output_text = result_output.get('text', '') + result_text += str(output_text) if not isinstance(output_text, str) else output_text + files = result_item.get('files') + embeds = result_item.get('embeds', '') content += f'
\nTool Executed\n
\n' else: content += f'
\nExecuting...\n
\n' - elif item_type == "function_call_output": + elif item_type == 'function_call_output': # Already handled inline with function_call above pass - elif item_type == "reasoning": - reasoning_content = "" + elif item_type == 'reasoning': + reasoning_content = '' # Check for 'summary' (new structure) or 'content' (legacy/fallback) - source_list = item.get("summary", []) or item.get("content", []) + source_list = item.get('summary', []) or item.get('content', []) for content_part in source_list: - if "text" in content_part: - reasoning_content += content_part.get("text", "") - elif "summary" in content_part: # Handle potential nested logic if any + if 'text' in content_part: + reasoning_content += content_part.get('text', '') + elif 'summary' in content_part: # Handle potential nested logic if any pass reasoning_content = reasoning_content.strip() - duration = item.get("duration") - status = item.get("status", "in_progress") + duration = item.get('duration') + status = item.get('status', 'in_progress') # Infer completion: if this reasoning item is NOT the last item, # render as done (a subsequent item means reasoning is complete) is_last_item = idx == len(output) - 1 - if content and not content.endswith("\n"): - content += "\n" + if content and not content.endswith('\n'): + content += '\n' display = html.escape( - "\n".join( - (f"> {line}" if not line.startswith(">") else line) - for line in reasoning_content.splitlines() + '\n'.join( + (f'> {line}' if not line.startswith('>') else line) for line in reasoning_content.splitlines() ) ) - if status == "completed" or duration is not None or not is_last_item: + if status == 'completed' or duration is not None or not is_last_item: content = f'{content}
\nThought for {duration or 0} seconds\n{display}\n
\n' else: content = f'{content}
\nThinking…\n{display}\n
\n' - elif item_type == "open_webui:code_interpreter": - content_stripped, original_whitespace = split_content_and_whitespace( - content - ) + elif item_type == 'open_webui:code_interpreter': + content_stripped, original_whitespace = split_content_and_whitespace(content) if is_opening_code_block(content_stripped): - content = content_stripped.rstrip("`").rstrip() + original_whitespace + content = content_stripped.rstrip('`').rstrip() + original_whitespace else: content = content_stripped + original_whitespace - if content and not content.endswith("\n"): - content += "\n" + if content and not content.endswith('\n'): + content += '\n' # Render the code_interpreter item as a
block # so the frontend Collapsible renders "Analyzing..."/"Analyzed". - code = item.get("code", "").strip() - lang = item.get("lang", "python") - status = item.get("status", "in_progress") - duration = item.get("duration") + code = item.get('code', '').strip() + lang = item.get('lang', 'python') + status = item.get('status', 'in_progress') + duration = item.get('duration') is_last_item = idx == len(output) - 1 # Build inner content: code block - display = "" + display = '' if code: - display = f"```{lang}\n{code}\n```" + display = f'```{lang}\n{code}\n```' # Build output attribute as HTML-escaped JSON for CodeBlock.svelte - ci_output = item.get("output") - output_attr = "" + ci_output = item.get('output') + output_attr = '' if ci_output: if isinstance(ci_output, dict): output_json = json.dumps(ci_output, ensure_ascii=False) else: - output_json = json.dumps( - {"result": str(ci_output)}, ensure_ascii=False - ) + output_json = json.dumps({'result': str(ci_output)}, ensure_ascii=False) output_attr = f' output="{html.escape(output_json)}"' - if status == "completed" or duration is not None or not is_last_item: + if status == 'completed' or duration is not None or not is_last_item: content += f'
\nAnalyzed\n{display}\n
\n' else: content += f'
\nAnalyzing…\n{display}\n
\n' @@ -575,19 +560,19 @@ def handle_responses_streaming_event( # Note: treating current_output as immutable, but avoiding full deepcopy for perf. # We will shallow copy only if we need to modify the list structure or items. - event_type = data.get("type", "") + event_type = data.get('type', '') - if event_type == "response.output_item.added": - item = data.get("item", {}) + if event_type == 'response.output_item.added': + item = data.get('item', {}) if item: new_output = list(current_output) new_output.append(item) return new_output, None return current_output, None - elif event_type == "response.content_part.added": - part = data.get("part", {}) - output_index = data.get("output_index", len(current_output) - 1) + elif event_type == 'response.content_part.added': + part = data.get('part', {}) + output_index = data.get('output_index', len(current_output) - 1) if current_output and 0 <= output_index < len(current_output): new_output = list(current_output) @@ -595,83 +580,83 @@ def handle_responses_streaming_event( item = new_output[output_index].copy() new_output[output_index] = item - if "content" not in item: - item["content"] = [] + if 'content' not in item: + item['content'] = [] else: # Copy content list - item["content"] = list(item["content"]) + item['content'] = list(item['content']) - if item.get("type") == "reasoning": + if item.get('type') == 'reasoning': # Reasoning items should not have content parts pass else: - item["content"].append(part) + item['content'].append(part) return new_output, None return current_output, None - elif event_type == "response.reasoning_summary_part.added": - part = data.get("part", {}) - output_index = data.get("output_index", len(current_output) - 1) + elif event_type == 'response.reasoning_summary_part.added': + part = data.get('part', {}) + output_index = data.get('output_index', len(current_output) - 1) if current_output and 0 <= output_index < len(current_output): new_output = list(current_output) item = new_output[output_index].copy() new_output[output_index] = item - if "summary" not in item: - item["summary"] = [] + if 'summary' not in item: + item['summary'] = [] else: - item["summary"] = list(item["summary"]) + item['summary'] = list(item['summary']) - item["summary"].append(part) + item['summary'].append(part) return new_output, None return current_output, None - elif event_type.startswith("response.") and event_type.endswith(".delta"): + elif event_type.startswith('response.') and event_type.endswith('.delta'): # Generic Delta Handling - parts = event_type.split(".") + parts = event_type.split('.') if len(parts) >= 3: delta_type = parts[1] - delta = data.get("delta", "") + delta = data.get('delta', '') - output_index = data.get("output_index", len(current_output) - 1) + output_index = data.get('output_index', len(current_output) - 1) if current_output and 0 <= output_index < len(current_output): new_output = list(current_output) item = new_output[output_index].copy() new_output[output_index] = item - item_type = item.get("type", "") + item_type = item.get('type', '') # Determine target field and object based on delta_type and item_type - if delta_type == "function_call_arguments": - key = "arguments" - if item_type == "function_call": + if delta_type == 'function_call_arguments': + key = 'arguments' + if item_type == 'function_call': # Function call args are usually strings - item[key] = item.get(key, "") + str(delta) + item[key] = item.get(key, '') + str(delta) else: # Generic handling, refined by item type below pass - if item_type == "message": + if item_type == 'message': # Message items: "text"/"output_text" -> "text" # "reasoning_text" -> Skipped (should use reasoning item) - if delta_type in ["text", "output_text"]: - key = "text" - elif delta_type in ["reasoning_text", "reasoning_summary_text"]: + if delta_type in ['text', 'output_text']: + key = 'text' + elif delta_type in ['reasoning_text', 'reasoning_summary_text']: # Skip reasoning updates for message items return new_output, None else: key = delta_type - content_index = data.get("content_index", 0) - if "content" not in item: - item["content"] = [] + content_index = data.get('content_index', 0) + if 'content' not in item: + item['content'] = [] else: - item["content"] = list(item["content"]) - content_list = item["content"] + item['content'] = list(item['content']) + content_list = item['content'] while len(content_list) <= content_index: - content_list.append({"type": "text", "text": ""}) + content_list.append({'type': 'text', 'text': ''}) # Copy the part to mutate it part = content_list[content_index].copy() @@ -680,55 +665,53 @@ def handle_responses_streaming_event( current_val = part.get(key) if current_val is None: # Initialize based on delta type - current_val = {} if isinstance(delta, dict) else "" + current_val = {} if isinstance(delta, dict) else '' part[key] = deep_merge(current_val, delta) - elif item_type == "reasoning": + elif item_type == 'reasoning': # Reasoning items: "reasoning_text"/"reasoning_summary_text" -> "text" # "text"/"output_text" -> Skipped (should use message item) - if delta_type == "reasoning_summary_text": + if delta_type == 'reasoning_summary_text': # Summary updates -> item['summary'] - key = "text" - summary_index = data.get("summary_index", 0) - if "summary" not in item: - item["summary"] = [] + key = 'text' + summary_index = data.get('summary_index', 0) + if 'summary' not in item: + item['summary'] = [] else: - item["summary"] = list(item["summary"]) - summary_list = item["summary"] + item['summary'] = list(item['summary']) + summary_list = item['summary'] while len(summary_list) <= summary_index: - summary_list.append( - {"type": "summary_text", "text": ""} - ) + summary_list.append({'type': 'summary_text', 'text': ''}) part = summary_list[summary_index].copy() summary_list[summary_index] = part - target_val = part.get(key, "") + target_val = part.get(key, '') part[key] = deep_merge(target_val, delta) - elif delta_type == "reasoning_text": + elif delta_type == 'reasoning_text': # Reasoning body updates -> item['content'] - key = "text" - content_index = data.get("content_index", 0) - if "content" not in item: - item["content"] = [] + key = 'text' + content_index = data.get('content_index', 0) + if 'content' not in item: + item['content'] = [] else: - item["content"] = list(item["content"]) - content_list = item["content"] + item['content'] = list(item['content']) + content_list = item['content'] while len(content_list) <= content_index: # Reasoning content parts default to text - content_list.append({"type": "text", "text": ""}) + content_list.append({'type': 'text', 'text': ''}) part = content_list[content_index].copy() content_list[content_index] = part - target_val = part.get(key, "") + target_val = part.get(key, '') part[key] = deep_merge(target_val, delta) - elif delta_type in ["text", "output_text"]: + elif delta_type in ['text', 'output_text']: return new_output, None else: # Fallback just in case other deltas target reasoning? @@ -736,109 +719,104 @@ def handle_responses_streaming_event( else: # Fallback for other item types - if delta_type in ["text", "output_text"]: - key = "text" + if delta_type in ['text', 'output_text']: + key = 'text' else: key = delta_type current_val = item.get(key) if current_val is None: - current_val = {} if isinstance(delta, dict) else "" + current_val = {} if isinstance(delta, dict) else '' item[key] = deep_merge(current_val, delta) return new_output, None - elif event_type.startswith("response.") and event_type.endswith(".done"): + elif event_type.startswith('response.') and event_type.endswith('.done'): # Delta Events: response.content_part.done, response.text.done, etc. - parts = event_type.split(".") + parts = event_type.split('.') if len(parts) >= 3: type_name = parts[1] # 1. Handle specific Delta "done" signals - if type_name == "content_part": + if type_name == 'content_part': # "Signaling that no further changes will occur to a content part" # If payloads contains the full part, we could update it. # Usually purely signaling in standard implementation, but we check payload. - part = data.get("part") - output_index = data.get("output_index", len(current_output) - 1) + part = data.get('part') + output_index = data.get('output_index', len(current_output) - 1) if part and current_output and 0 <= output_index < len(current_output): new_output = list(current_output) item = new_output[output_index].copy() new_output[output_index] = item - if "content" in item: - item["content"] = list(item["content"]) - content_index = data.get( - "content_index", len(item["content"]) - 1 - ) - if 0 <= content_index < len(item["content"]): - item["content"][content_index] = part + if 'content' in item: + item['content'] = list(item['content']) + content_index = data.get('content_index', len(item['content']) - 1) + if 0 <= content_index < len(item['content']): + item['content'][content_index] = part return new_output, {} return current_output, None - elif type_name == "reasoning_summary_part": - part = data.get("part") - output_index = data.get("output_index", len(current_output) - 1) + elif type_name == 'reasoning_summary_part': + part = data.get('part') + output_index = data.get('output_index', len(current_output) - 1) if part and current_output and 0 <= output_index < len(current_output): new_output = list(current_output) item = new_output[output_index].copy() new_output[output_index] = item - if "summary" in item: - item["summary"] = list(item["summary"]) - summary_index = data.get( - "summary_index", len(item["summary"]) - 1 - ) - if 0 <= summary_index < len(item["summary"]): - item["summary"][summary_index] = part + if 'summary' in item: + item['summary'] = list(item['summary']) + summary_index = data.get('summary_index', len(item['summary']) - 1) + if 0 <= summary_index < len(item['summary']): + item['summary'][summary_index] = part return new_output, {} return current_output, None # 2. Skip Output Item done (handled specifically below) - if type_name == "output_item": + if type_name == 'output_item': pass # 3. Generic Field Done (text.done, audio.done) - elif type_name not in ["completed", "failed"]: - output_index = data.get("output_index", len(current_output) - 1) + elif type_name not in ['completed', 'failed']: + output_index = data.get('output_index', len(current_output) - 1) if current_output and 0 <= output_index < len(current_output): - key = ( - "text" + 'text' if type_name in [ - "text", - "output_text", - "reasoning_text", - "reasoning_summary_text", + 'text', + 'output_text', + 'reasoning_text', + 'reasoning_summary_text', ] else type_name ) - if type_name == "function_call_arguments": - key = "arguments" + if type_name == 'function_call_arguments': + key = 'arguments' if key in data: final_value = data[key] new_output = list(current_output) item = new_output[output_index].copy() new_output[output_index] = item - item_type = item.get("type", "") + item_type = item.get('type', '') - if type_name == "function_call_arguments": - if item_type == "function_call": - item["arguments"] = final_value - elif item_type == "message": - content_index = data.get("content_index", 0) - if "content" in item: - item["content"] = list(item["content"]) - if len(item["content"]) > content_index: - part = item["content"][content_index].copy() - item["content"][content_index] = part + if type_name == 'function_call_arguments': + if item_type == 'function_call': + item['arguments'] = final_value + elif item_type == 'message': + content_index = data.get('content_index', 0) + if 'content' in item: + item['content'] = list(item['content']) + if len(item['content']) > content_index: + part = item['content'][content_index].copy() + item['content'][content_index] = part part[key] = final_value - elif item_type == "reasoning": - item["status"] = "completed" + elif item_type == 'reasoning': + item['status'] = 'completed' else: item[key] = final_value @@ -846,10 +824,10 @@ def handle_responses_streaming_event( return current_output, None - elif event_type == "response.output_item.done": + elif event_type == 'response.output_item.done': # Delta Event: Output item complete - item = data.get("item") - output_index = data.get("output_index", len(current_output) - 1) + item = data.get('item') + output_index = data.get('output_index', len(current_output) - 1) new_output = list(current_output) if item and 0 <= output_index < len(current_output): @@ -858,60 +836,53 @@ def handle_responses_streaming_event( new_output.append(item) return new_output, {} - elif event_type == "response.completed": + elif event_type == 'response.completed': # State Machine Event: Completed - response_data = data.get("response", {}) - final_output = response_data.get("output") + response_data = data.get('response', {}) + final_output = response_data.get('output') new_output = final_output if final_output is not None else current_output # Ensure reasoning items are marked as completed in the final output if new_output: for item in new_output: - if ( - item.get("type") == "reasoning" - and item.get("status") != "completed" - ): - item["status"] = "completed" + if item.get('type') == 'reasoning' and item.get('status') != 'completed': + item['status'] = 'completed' - return new_output, {"usage": response_data.get("usage"), "done": True} + return new_output, {'usage': response_data.get('usage'), 'done': True} - elif event_type == "response.in_progress": + elif event_type == 'response.in_progress': # State Machine Event: In Progress # We could extract metadata if needed, but for now just acknowledge iteration return current_output, None - elif event_type == "response.failed": + elif event_type == 'response.failed': # State Machine Event: Failed - error = data.get("response", {}).get("error", {}) - return current_output, {"error": error} + error = data.get('response', {}).get('error', {}) + return current_output, {'error': error} else: return current_output, None -def get_source_context( - sources: list, source_ids: dict = None, include_content: bool = True -) -> str: +def get_source_context(sources: list, source_ids: dict = None, include_content: bool = True) -> str: """ Build tag context string from citation sources. """ - context_string = "" + context_string = '' if source_ids is None: source_ids = {} for source in sources: - for doc, meta in zip(source.get("document", []), source.get("metadata", [])): - source_id = ( - meta.get("source") or source.get("source", {}).get("id") or "N/A" - ) + for doc, meta in zip(source.get('document', []), source.get('metadata', [])): + source_id = meta.get('source') or source.get('source', {}).get('id') or 'N/A' if source_id not in source_ids: source_ids[source_id] = len(source_ids) + 1 - src_name = source.get("source", {}).get("name") - body = doc if include_content else "" + src_name = source.get('source', {}).get('name') + body = doc if include_content else '' context_string += ( f'{body}\n" + + (f' name="{src_name}"' if src_name else '') + + f'>{body}\n' ) return context_string @@ -964,40 +935,40 @@ def process_tool_result( user=None, ): tool_result_embeds = [] - EXTERNAL_TOOL_TYPES = ("external", "action", "terminal") + EXTERNAL_TOOL_TYPES = ('external', 'action', 'terminal') if isinstance(tool_result, HTMLResponse): - content_disposition = tool_result.headers.get("Content-Disposition", "") - if "inline" in content_disposition: - content = tool_result.body.decode("utf-8", "replace") + content_disposition = tool_result.headers.get('Content-Disposition', '') + if 'inline' in content_disposition: + content = tool_result.body.decode('utf-8', 'replace') tool_result_embeds.append(content) if 200 <= tool_result.status_code < 300: tool_result = { - "status": "success", - "code": "ui_component", - "message": f"{tool_function_name}: Embedded UI result is active and visible to the user.", + 'status': 'success', + 'code': 'ui_component', + 'message': f'{tool_function_name}: Embedded UI result is active and visible to the user.', } elif 400 <= tool_result.status_code < 500: tool_result = { - "status": "error", - "code": "ui_component", - "message": f"{tool_function_name}: Client error {tool_result.status_code} from embedded UI result.", + 'status': 'error', + 'code': 'ui_component', + 'message': f'{tool_function_name}: Client error {tool_result.status_code} from embedded UI result.', } elif 500 <= tool_result.status_code < 600: tool_result = { - "status": "error", - "code": "ui_component", - "message": f"{tool_function_name}: Server error {tool_result.status_code} from embedded UI result.", + 'status': 'error', + 'code': 'ui_component', + 'message': f'{tool_function_name}: Server error {tool_result.status_code} from embedded UI result.', } else: tool_result = { - "status": "error", - "code": "ui_component", - "message": f"{tool_function_name}: Unexpected status code {tool_result.status_code} from embedded UI result.", + 'status': 'error', + 'code': 'ui_component', + 'message': f'{tool_function_name}: Unexpected status code {tool_result.status_code} from embedded UI result.', } else: - tool_result = tool_result.body.decode("utf-8", "replace") + tool_result = tool_result.body.decode('utf-8', 'replace') elif (tool_type in EXTERNAL_TOOL_TYPES and isinstance(tool_result, tuple)) or ( direct_tool and isinstance(tool_result, list) and len(tool_result) == 2 @@ -1013,84 +984,84 @@ def process_tool_result( if tool_response_headers and isinstance(tool_response_headers, dict): content_disposition = tool_response_headers.get( - "Content-Disposition", - tool_response_headers.get("content-disposition", ""), + 'Content-Disposition', + tool_response_headers.get('content-disposition', ''), ) - if "inline" in content_disposition: + if 'inline' in content_disposition: content_type = tool_response_headers.get( - "Content-Type", - tool_response_headers.get("content-type", ""), + 'Content-Type', + tool_response_headers.get('content-type', ''), ) location = tool_response_headers.get( - "Location", - tool_response_headers.get("location", ""), + 'Location', + tool_response_headers.get('location', ''), ) - if "text/html" in content_type: + if 'text/html' in content_type: # Display as iframe embed tool_result_embeds.append(tool_result) tool_result = { - "status": "success", - "code": "ui_component", - "message": f"{tool_function_name}: Embedded UI result is active and visible to the user.", + 'status': 'success', + 'code': 'ui_component', + 'message': f'{tool_function_name}: Embedded UI result is active and visible to the user.', } elif location: tool_result_embeds.append(location) tool_result = { - "status": "success", - "code": "ui_component", - "message": f"{tool_function_name}: Embedded UI result is active and visible to the user.", + 'status': 'success', + 'code': 'ui_component', + 'message': f'{tool_function_name}: Embedded UI result is active and visible to the user.', } tool_result_files = [] if isinstance(tool_result, list): - if tool_type == "mcp": # MCP + if tool_type == 'mcp': # MCP tool_response = [] for item in tool_result: if isinstance(item, dict): - if item.get("type") == "text": - text = item.get("text", "") + if item.get('type') == 'text': + text = item.get('text', '') if isinstance(text, str): try: text = json.loads(text) except json.JSONDecodeError: pass tool_response.append(text) - elif item.get("type") in ["image", "audio"]: + elif item.get('type') in ['image', 'audio']: file_url = get_file_url_from_base64( request, - f"data:{item.get('mimeType')};base64,{item.get('data', item.get('blob', ''))}", + f'data:{item.get("mimeType")};base64,{item.get("data", item.get("blob", ""))}', { - "chat_id": metadata.get("chat_id", None), - "message_id": metadata.get("message_id", None), - "session_id": metadata.get("session_id", None), - "result": item, + 'chat_id': metadata.get('chat_id', None), + 'message_id': metadata.get('message_id', None), + 'session_id': metadata.get('session_id', None), + 'result': item, }, user, ) tool_result_files.append( { - "type": item.get("type", "data"), - "url": file_url, + 'type': item.get('type', 'data'), + 'url': file_url, } ) tool_result = tool_response[0] if len(tool_response) == 1 else tool_response else: # OpenAPI for item in tool_result: - if isinstance(item, str) and item.startswith("data:"): + if isinstance(item, str) and item.startswith('data:'): tool_result_files.append( { - "type": "data", - "content": item, + 'type': 'data', + 'content': item, } ) tool_result.remove(item) if isinstance(tool_result, list): - tool_result = {"results": tool_result} + tool_result = {'results': tool_result} if isinstance(tool_result, dict) or isinstance(tool_result, list): tool_result = json.dumps(tool_result, indent=2, ensure_ascii=False) @@ -1101,11 +1072,7 @@ def process_tool_result( if tool_result is not None and not isinstance(tool_result, str): if isinstance(tool_result, tuple): # execute_tool_server returns (data, headers); unpack the data part - tool_result = ( - json.dumps(tool_result[0], indent=2, ensure_ascii=False) - if len(tool_result) > 0 - else "" - ) + tool_result = json.dumps(tool_result[0], indent=2, ensure_ascii=False) if len(tool_result) > 0 else '' else: tool_result = str(tool_result) @@ -1127,8 +1094,8 @@ async def terminal_event_handler( if not event_emitter: return - if tool_function_name == "display_file": - path = tool_function_params.get("path", "") + if tool_function_name == 'display_file': + path = tool_function_params.get('path', '') if not path: return # Only emit if the file actually exists @@ -1138,30 +1105,30 @@ async def terminal_event_handler( parsed = json.loads(parsed) except (json.JSONDecodeError, TypeError): pass - if isinstance(parsed, dict) and parsed.get("exists") is False: + if isinstance(parsed, dict) and parsed.get('exists') is False: return await event_emitter( { - "type": f"terminal:{tool_function_name}", - "data": {"path": path}, + 'type': f'terminal:{tool_function_name}', + 'data': {'path': path}, } ) - elif tool_function_name in ("write_file", "replace_file_content"): - path = tool_function_params.get("path", "") + elif tool_function_name in ('write_file', 'replace_file_content'): + path = tool_function_params.get('path', '') if not path: return await event_emitter( { - "type": f"terminal:{tool_function_name}", - "data": {"path": path}, + 'type': f'terminal:{tool_function_name}', + 'data': {'path': path}, } ) - elif tool_function_name == "run_command": + elif tool_function_name == 'run_command': await event_emitter( { - "type": "terminal:run_command", - "data": {}, + 'type': 'terminal:run_command', + 'data': {}, } ) @@ -1171,53 +1138,48 @@ async def chat_completion_tools_handler( ) -> tuple[dict, dict]: async def get_content_from_response(response) -> Optional[str]: content = None - if hasattr(response, "body_iterator"): + if hasattr(response, 'body_iterator'): async for chunk in response.body_iterator: - data = json.loads(chunk.decode("utf-8", "replace")) - content = data["choices"][0]["message"]["content"] + data = json.loads(chunk.decode('utf-8', 'replace')) + content = data['choices'][0]['message']['content'] # Cleanup any remaining background tasks if necessary if response.background is not None: await response.background() else: - content = response["choices"][0]["message"]["content"] + content = response['choices'][0]['message']['content'] return content def get_tools_function_calling_payload(messages, task_model_id, content): user_message = get_last_user_message(messages) - if user_message and messages and messages[-1]["role"] == "user": + if user_message and messages and messages[-1]['role'] == 'user': # Remove the last user message to avoid duplication messages = messages[:-1] recent_messages = messages[-4:] if len(messages) > 4 else messages - chat_history = "\n".join( - f"{message['role'].upper()}: \"\"\"{get_content_from_message(message)}\"\"\"" - for message in recent_messages + chat_history = '\n'.join( + f'{message["role"].upper()}: """{get_content_from_message(message)}"""' for message in recent_messages ) - prompt = ( - f"History:\n{chat_history}\nQuery: {user_message}" - if chat_history - else f"Query: {user_message}" - ) + prompt = f'History:\n{chat_history}\nQuery: {user_message}' if chat_history else f'Query: {user_message}' return { - "model": task_model_id, - "messages": [ - {"role": "system", "content": content}, - {"role": "user", "content": prompt}, + 'model': task_model_id, + 'messages': [ + {'role': 'system', 'content': content}, + {'role': 'user', 'content': prompt}, ], - "stream": False, - "metadata": {"task": str(TASKS.FUNCTION_CALLING)}, + 'stream': False, + 'metadata': {'task': str(TASKS.FUNCTION_CALLING)}, } - event_caller = extra_params["__event_call__"] - event_emitter = extra_params["__event_emitter__"] - metadata = extra_params["__metadata__"] + event_caller = extra_params['__event_call__'] + event_emitter = extra_params['__event_emitter__'] + metadata = extra_params['__metadata__'] task_model_id = get_task_model_id( - body["model"], + body['model'], request.app.state.config.TASK_MODEL, request.app.state.config.TASK_MODEL_EXTERNAL, models, @@ -1226,97 +1188,85 @@ async def chat_completion_tools_handler( skip_files = False sources = [] - specs = [tool["spec"] for tool in tools.values()] + specs = [tool['spec'] for tool in tools.values()] tools_specs = json.dumps(specs, ensure_ascii=False) - if request.app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE != "": + if request.app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE != '': template = request.app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE else: template = DEFAULT_TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE - tools_function_calling_prompt = tools_function_calling_generation_template( - template, tools_specs - ) - payload = get_tools_function_calling_payload( - body["messages"], task_model_id, tools_function_calling_prompt - ) + tools_function_calling_prompt = tools_function_calling_generation_template(template, tools_specs) + payload = get_tools_function_calling_payload(body['messages'], task_model_id, tools_function_calling_prompt) try: response = await generate_chat_completion(request, form_data=payload, user=user) - log.debug(f"{response=}") + log.debug(f'{response=}') content = await get_content_from_response(response) - log.debug(f"{content=}") + log.debug(f'{content=}') if not content: return body, {} try: - content = content[content.find("{") : content.rfind("}") + 1] + content = content[content.find('{') : content.rfind('}') + 1] if not content: - raise Exception("No JSON object found in the response") + raise Exception('No JSON object found in the response') result = json.loads(content) async def tool_call_handler(tool_call): nonlocal skip_files - log.debug(f"{tool_call=}") + log.debug(f'{tool_call=}') - tool_function_name = tool_call.get("name", None) + tool_function_name = tool_call.get('name', None) if tool_function_name not in tools: return body, {} - tool_function_params = tool_call.get("parameters", {}) + tool_function_params = tool_call.get('parameters', {}) tool = None - tool_type = "" + tool_type = '' direct_tool = False try: tool = tools[tool_function_name] - tool_type = tool.get("type", "") - direct_tool = tool.get("direct", False) + tool_type = tool.get('type', '') + direct_tool = tool.get('direct', False) - spec = tool.get("spec", {}) - allowed_params = ( - spec.get("parameters", {}).get("properties", {}).keys() - ) - tool_function_params = { - k: v - for k, v in tool_function_params.items() - if k in allowed_params - } + spec = tool.get('spec', {}) + allowed_params = spec.get('parameters', {}).get('properties', {}).keys() + tool_function_params = {k: v for k, v in tool_function_params.items() if k in allowed_params} - if tool.get("direct", False): + if tool.get('direct', False): tool_result = await event_caller( { - "type": "execute:tool", - "data": { - "id": str(uuid4()), - "name": tool_function_name, - "params": tool_function_params, - "server": tool.get("server", {}), - "session_id": metadata.get("session_id", None), + 'type': 'execute:tool', + 'data': { + 'id': str(uuid4()), + 'name': tool_function_name, + 'params': tool_function_params, + 'server': tool.get('server', {}), + 'session_id': metadata.get('session_id', None), }, } ) else: - tool_function = tool["callable"] + tool_function = tool['callable'] tool_result = await tool_function(**tool_function_params) except Exception as e: tool_result = str(e) - tool_result, tool_result_files, tool_result_embeds = ( - process_tool_result( - request, - tool_function_name, - tool_result, - tool_type, - direct_tool, - metadata, - user, - ) + tool_result, tool_result_files, tool_result_embeds = process_tool_result( + request, + tool_function_name, + tool_result, + tool_type, + direct_tool, + metadata, + user, ) if event_emitter: @@ -1330,9 +1280,9 @@ async def chat_completion_tools_handler( if tool_result_files: await event_emitter( { - "type": "files", - "data": { - "files": tool_result_files, + 'type': 'files', + 'data': { + 'files': tool_result_files, }, } ) @@ -1340,79 +1290,69 @@ async def chat_completion_tools_handler( if tool_result_embeds: await event_emitter( { - "type": "embeds", - "data": { - "embeds": tool_result_embeds, + 'type': 'embeds', + 'data': { + 'embeds': tool_result_embeds, }, } ) if tool_result: tool = tools[tool_function_name] - tool_id = tool.get("tool_id", "") + tool_id = tool.get('tool_id', '') - tool_name = ( - f"{tool_id}/{tool_function_name}" - if tool_id - else f"{tool_function_name}" - ) + tool_name = f'{tool_id}/{tool_function_name}' if tool_id else f'{tool_function_name}' # Citation is enabled for this tool sources.append( { - "source": { - "name": (f"{tool_name}"), + 'source': { + 'name': (f'{tool_name}'), }, - "document": [str(tool_result)], - "metadata": [ + 'document': [str(tool_result)], + 'metadata': [ { - "source": (f"{tool_name}"), - "parameters": tool_function_params, + 'source': (f'{tool_name}'), + 'parameters': tool_function_params, } ], - "tool_result": True, + 'tool_result': True, } ) - if ( - tools[tool_function_name] - .get("metadata", {}) - .get("file_handler", False) - ): + if tools[tool_function_name].get('metadata', {}).get('file_handler', False): skip_files = True # check if "tool_calls" in result - if result.get("tool_calls"): - for tool_call in result.get("tool_calls"): + if result.get('tool_calls'): + for tool_call in result.get('tool_calls'): await tool_call_handler(tool_call) else: await tool_call_handler(result) except Exception as e: - log.debug(f"Error: {e}") + log.debug(f'Error: {e}') content = None except Exception as e: - log.debug(f"Error: {e}") + log.debug(f'Error: {e}') content = None - log.debug(f"tool_contexts: {sources}") + log.debug(f'tool_contexts: {sources}') - if skip_files and "files" in body.get("metadata", {}): - del body["metadata"]["files"] + if skip_files and 'files' in body.get('metadata', {}): + del body['metadata']['files'] - return body, {"sources": sources} + return body, {'sources': sources} -async def chat_memory_handler( - request: Request, form_data: dict, extra_params: dict, user -): +async def chat_memory_handler(request: Request, form_data: dict, extra_params: dict, user): try: results = await query_memory( request, QueryMemoryForm( **{ - "content": get_last_user_message(form_data["messages"]) or "", - "k": 3, + 'content': get_last_user_message(form_data['messages']) or '', + 'k': 3, } ), user, @@ -1421,43 +1361,39 @@ async def chat_memory_handler( log.debug(e) results = None - user_context = "" - if results and hasattr(results, "documents"): + user_context = '' + if results and hasattr(results, 'documents'): if results.documents and len(results.documents) > 0: for doc_idx, doc in enumerate(results.documents[0]): - created_at_date = "Unknown Date" + created_at_date = 'Unknown Date' - if results.metadatas[0][doc_idx].get("created_at"): - created_at_timestamp = results.metadatas[0][doc_idx]["created_at"] - created_at_date = time.strftime( - "%Y-%m-%d", time.localtime(created_at_timestamp) - ) + if results.metadatas[0][doc_idx].get('created_at'): + created_at_timestamp = results.metadatas[0][doc_idx]['created_at'] + created_at_date = time.strftime('%Y-%m-%d', time.localtime(created_at_timestamp)) - user_context += f"{doc_idx + 1}. [{created_at_date}] {doc}\n" + user_context += f'{doc_idx + 1}. [{created_at_date}] {doc}\n' - form_data["messages"] = add_or_update_system_message( - f"User Context:\n{user_context}\n", form_data["messages"], append=True + form_data['messages'] = add_or_update_system_message( + f'User Context:\n{user_context}\n', form_data['messages'], append=True ) return form_data -async def chat_web_search_handler( - request: Request, form_data: dict, extra_params: dict, user -): - event_emitter = extra_params["__event_emitter__"] +async def chat_web_search_handler(request: Request, form_data: dict, extra_params: dict, user): + event_emitter = extra_params['__event_emitter__'] await event_emitter( { - "type": "status", - "data": { - "action": "web_search", - "description": "Searching the web", - "done": False, + 'type': 'status', + 'data': { + 'action': 'web_search', + 'description': 'Searching the web', + 'done': False, }, } ) - messages = form_data["messages"] + messages = form_data['messages'] user_message = get_last_user_message(messages) queries = [] @@ -1465,27 +1401,27 @@ async def chat_web_search_handler( res = await generate_queries( request, { - "model": form_data["model"], - "messages": messages, - "prompt": user_message, - "type": "web_search", - "chat_id": extra_params.get("__chat_id__"), + 'model': form_data['model'], + 'messages': messages, + 'prompt': user_message, + 'type': 'web_search', + 'chat_id': extra_params.get('__chat_id__'), }, user, ) - response = res["choices"][0]["message"]["content"] + response = res['choices'][0]['message']['content'] try: - bracket_start = response.find("{") - bracket_end = response.rfind("}") + 1 + bracket_start = response.find('{') + bracket_end = response.rfind('}') + 1 if bracket_start == -1 or bracket_end == -1: - raise Exception("No JSON object found in the response") + raise Exception('No JSON object found in the response') response = response[bracket_start:bracket_end] queries = json.loads(response) - queries = queries.get("queries", []) + queries = queries.get('queries', []) except Exception as e: queries = [response] @@ -1497,18 +1433,18 @@ async def chat_web_search_handler( queries = [user_message] # Check if generated queries are empty - if len(queries) == 1 and queries[0].strip() == "": + if len(queries) == 1 and queries[0].strip() == '': queries = [user_message] # Check if queries are not found if len(queries) == 0: await event_emitter( { - "type": "status", - "data": { - "action": "web_search", - "description": "No search query generated", - "done": True, + 'type': 'status', + 'data': { + 'action': 'web_search', + 'description': 'No search query generated', + 'done': True, }, } ) @@ -1516,11 +1452,11 @@ async def chat_web_search_handler( await event_emitter( { - "type": "status", - "data": { - "action": "web_search_queries_generated", - "queries": queries, - "done": False, + 'type': 'status', + 'data': { + 'action': 'web_search_queries_generated', + 'queries': queries, + 'done': False, }, } ) @@ -1533,57 +1469,55 @@ async def chat_web_search_handler( ) if results: - files = form_data.get("files", []) + files = form_data.get('files', []) - if results.get("collection_names"): - for col_idx, collection_name in enumerate( - results.get("collection_names") - ): + if results.get('collection_names'): + for col_idx, collection_name in enumerate(results.get('collection_names')): files.append( { - "collection_name": collection_name, - "name": ", ".join(queries), - "type": "web_search", - "urls": results["filenames"], - "queries": queries, + 'collection_name': collection_name, + 'name': ', '.join(queries), + 'type': 'web_search', + 'urls': results['filenames'], + 'queries': queries, } ) - elif results.get("docs"): + elif results.get('docs'): # Invoked when bypass embedding and retrieval is set to True - docs = results["docs"] + docs = results['docs'] files.append( { - "docs": docs, - "name": ", ".join(queries), - "type": "web_search", - "urls": results["filenames"], - "queries": queries, + 'docs': docs, + 'name': ', '.join(queries), + 'type': 'web_search', + 'urls': results['filenames'], + 'queries': queries, } ) - form_data["files"] = files + form_data['files'] = files await event_emitter( { - "type": "status", - "data": { - "action": "web_search", - "description": "Searched {{count}} sites", - "urls": results["filenames"], - "items": results.get("items", []), - "done": True, + 'type': 'status', + 'data': { + 'action': 'web_search', + 'description': 'Searched {{count}} sites', + 'urls': results['filenames'], + 'items': results.get('items', []), + 'done': True, }, } ) else: await event_emitter( { - "type": "status", - "data": { - "action": "web_search", - "description": "No search results found", - "done": True, - "error": True, + 'type': 'status', + 'data': { + 'action': 'web_search', + 'description': 'No search results found', + 'done': True, + 'error': True, }, } ) @@ -1592,13 +1526,13 @@ async def chat_web_search_handler( log.exception(e) await event_emitter( { - "type": "status", - "data": { - "action": "web_search", - "description": "An error occurred while searching the web", - "queries": queries, - "done": True, - "error": True, + 'type': 'status', + 'data': { + 'action': 'web_search', + 'description': 'An error occurred while searching the web', + 'queries': queries, + 'done': True, + 'error': True, }, } ) @@ -1610,13 +1544,12 @@ def get_images_from_messages(message_list): images = [] for message in reversed(message_list): - message_images = [] - for file in message.get("files", []): - if file.get("type") == "image": - message_images.append(file.get("url")) - elif file.get("content_type", "").startswith("image/"): - message_images.append(file.get("url")) + for file in message.get('files', []): + if file.get('type') == 'image': + message_images.append(file.get('url')) + elif file.get('content_type', '').startswith('image/'): + message_images.append(file.get('url')) if message_images: images.append(message_images) @@ -1630,14 +1563,14 @@ def get_image_urls(delta_images, request, metadata, user) -> list[str]: image_urls = [] for img in delta_images: - if not isinstance(img, dict) or img.get("type") != "image_url": + if not isinstance(img, dict) or img.get('type') != 'image_url': continue - url = img.get("image_url", {}).get("url") + url = img.get('image_url', {}).get('url') if not url: continue - if url.startswith("data:image/png;base64"): + if url.startswith('data:image/png;base64'): url = get_image_url_from_base64(request, url, metadata, user) image_urls.append(url) @@ -1649,72 +1582,66 @@ def add_file_context(messages: list, chat_id: str, user) -> list: """ Add file URLs to messages for native function calling. """ - if not chat_id or chat_id.startswith("local:"): + if not chat_id or chat_id.startswith('local:'): return messages chat = Chats.get_chat_by_id_and_user_id(chat_id, user.id) if not chat: return messages - history = chat.chat.get("history", {}) - stored_messages = get_message_list( - history.get("messages", {}), history.get("currentId") - ) + history = chat.chat.get('history', {}) + stored_messages = get_message_list(history.get('messages', {}), history.get('currentId')) def format_file_tag(file): attrs = f'type="{file.get("type", "file")}" url="{file["url"]}"' - if file.get("content_type"): + if file.get('content_type'): attrs += f' content_type="{file["content_type"]}"' - if file.get("name"): + if file.get('name'): attrs += f' name="{file["name"]}"' - return f"" + return f'' for message, stored_message in zip(messages, stored_messages): files_with_urls = [ file - for file in stored_message.get("files", []) - if file.get("url") and not file.get("url").startswith("data:") + for file in stored_message.get('files', []) + if file.get('url') and not file.get('url').startswith('data:') ] if not files_with_urls: continue file_tags = [format_file_tag(file) for file in files_with_urls] - file_context = ( - "\n" + "\n".join(file_tags) + "\n\n\n" - ) + file_context = '\n' + '\n'.join(file_tags) + '\n\n\n' - content = message.get("content", "") + content = message.get('content', '') if isinstance(content, list): - message["content"] = [{"type": "text", "text": file_context}] + content + message['content'] = [{'type': 'text', 'text': file_context}] + content else: - message["content"] = file_context + content + message['content'] = file_context + content return messages -async def chat_image_generation_handler( - request: Request, form_data: dict, extra_params: dict, user -): - metadata = extra_params.get("__metadata__", {}) - chat_id = metadata.get("chat_id", None) - __event_emitter__ = extra_params.get("__event_emitter__", None) +async def chat_image_generation_handler(request: Request, form_data: dict, extra_params: dict, user): + metadata = extra_params.get('__metadata__', {}) + chat_id = metadata.get('chat_id', None) + __event_emitter__ = extra_params.get('__event_emitter__', None) if not chat_id or not isinstance(chat_id, str) or not __event_emitter__: return form_data - if chat_id.startswith("local:"): - message_list = form_data.get("messages", []) + if chat_id.startswith('local:'): + message_list = form_data.get('messages', []) else: chat = Chats.get_chat_by_id_and_user_id(chat_id, user.id) await __event_emitter__( { - "type": "status", - "data": {"description": "Creating image", "done": False}, + 'type': 'status', + 'data': {'description': 'Creating image', 'done': False}, } ) - messages_map = chat.chat.get("history", {}).get("messages", {}) - message_id = chat.chat.get("history", {}).get("currentId") + messages_map = chat.chat.get('history', {}).get('messages', {}) + message_id = chat.chat.get('history', {}).get('currentId') message_list = get_message_list(messages_map, message_id) user_message = get_last_user_message(message_list) @@ -1731,36 +1658,36 @@ async def chat_image_generation_handler( for image in images: input_images.append(image) - system_message_content = "" + system_message_content = '' if len(input_images) > 0 and request.app.state.config.ENABLE_IMAGE_EDIT: # Edit image(s) try: images = await image_edits( request=request, - form_data=EditImageForm(**{"prompt": prompt, "image": input_images}), + form_data=EditImageForm(**{'prompt': prompt, 'image': input_images}), metadata={ - "chat_id": metadata.get("chat_id", None), - "message_id": metadata.get("message_id", None), + 'chat_id': metadata.get('chat_id', None), + 'message_id': metadata.get('message_id', None), }, user=user, ) await __event_emitter__( { - "type": "status", - "data": {"description": "Image created", "done": True}, + 'type': 'status', + 'data': {'description': 'Image created', 'done': True}, } ) await __event_emitter__( { - "type": "files", - "data": { - "files": [ + 'type': 'files', + 'data': { + 'files': [ { - "type": "image", - "url": image["url"], + 'type': 'image', + 'url': image['url'], } for image in images ] @@ -1768,28 +1695,28 @@ async def chat_image_generation_handler( } ) - system_message_content = "The requested image has been edited and created and is now being shown to the user. Let them know that it has been generated." + system_message_content = 'The requested image has been edited and created and is now being shown to the user. Let them know that it has been generated.' except Exception as e: log.debug(e) - error_message = "" + error_message = '' if isinstance(e, HTTPException): if e.detail and isinstance(e.detail, dict): - error_message = e.detail.get("message", str(e.detail)) + error_message = e.detail.get('message', str(e.detail)) else: error_message = str(e.detail) await __event_emitter__( { - "type": "status", - "data": { - "description": f"An error occurred while generating an image", - "done": True, + 'type': 'status', + 'data': { + 'description': f'An error occurred while generating an image', + 'done': True, }, } ) - system_message_content = f"Image generation was attempted but failed. The system is currently unable to generate the image. Tell the user that the following error occurred: {error_message}" + system_message_content = f'Image generation was attempted but failed. The system is currently unable to generate the image. Tell the user that the following error occurred: {error_message}' else: # Create image(s) @@ -1798,25 +1725,25 @@ async def chat_image_generation_handler( res = await generate_image_prompt( request, { - "model": form_data["model"], - "messages": form_data["messages"], - "chat_id": metadata.get("chat_id"), + 'model': form_data['model'], + 'messages': form_data['messages'], + 'chat_id': metadata.get('chat_id'), }, user, ) - response = res["choices"][0]["message"]["content"] + response = res['choices'][0]['message']['content'] try: - bracket_start = response.find("{") - bracket_end = response.rfind("}") + 1 + bracket_start = response.find('{') + bracket_end = response.rfind('}') + 1 if bracket_start == -1 or bracket_end == -1: - raise Exception("No JSON object found in the response") + raise Exception('No JSON object found in the response') response = response[bracket_start:bracket_end] response = json.loads(response) - prompt = response.get("prompt", []) + prompt = response.get('prompt', []) except Exception as e: prompt = user_message @@ -1827,29 +1754,29 @@ async def chat_image_generation_handler( try: images = await image_generations( request=request, - form_data=CreateImageForm(**{"prompt": prompt}), + form_data=CreateImageForm(**{'prompt': prompt}), metadata={ - "chat_id": metadata.get("chat_id", None), - "message_id": metadata.get("message_id", None), + 'chat_id': metadata.get('chat_id', None), + 'message_id': metadata.get('message_id', None), }, user=user, ) await __event_emitter__( { - "type": "status", - "data": {"description": "Image created", "done": True}, + 'type': 'status', + 'data': {'description': 'Image created', 'done': True}, } ) await __event_emitter__( { - "type": "files", - "data": { - "files": [ + 'type': 'files', + 'data': { + 'files': [ { - "type": "image", - "url": image["url"], + 'type': 'image', + 'url': image['url'], } for image in images ] @@ -1857,33 +1784,31 @@ async def chat_image_generation_handler( } ) - system_message_content = "The requested image has been created by the system successfully and is now being shown to the user. Let the user know that the image they requested has been generated and is now shown in the chat." + system_message_content = 'The requested image has been created by the system successfully and is now being shown to the user. Let the user know that the image they requested has been generated and is now shown in the chat.' except Exception as e: log.debug(e) - error_message = "" + error_message = '' if isinstance(e, HTTPException): if e.detail and isinstance(e.detail, dict): - error_message = e.detail.get("message", str(e.detail)) + error_message = e.detail.get('message', str(e.detail)) else: error_message = str(e.detail) await __event_emitter__( { - "type": "status", - "data": { - "description": f"An error occurred while generating an image", - "done": True, + 'type': 'status', + 'data': { + 'description': f'An error occurred while generating an image', + 'done': True, }, } ) - system_message_content = f"Image generation was attempted but failed because of an error. The system is currently unable to generate the image. Tell the user that the following error occurred: {error_message}" + system_message_content = f'Image generation was attempted but failed because of an error. The system is currently unable to generate the image. Tell the user that the following error occurred: {error_message}' if system_message_content: - form_data["messages"] = add_or_update_system_message( - system_message_content, form_data["messages"] - ) + form_data['messages'] = add_or_update_system_message(system_message_content, form_data['messages']) return form_data @@ -1891,12 +1816,12 @@ async def chat_image_generation_handler( async def chat_completion_files_handler( request: Request, body: dict, extra_params: dict, user: UserModel ) -> tuple[dict, dict[str, list]]: - __event_emitter__ = extra_params["__event_emitter__"] + __event_emitter__ = extra_params['__event_emitter__'] sources = [] - if files := body.get("metadata", {}).get("files", None): + if files := body.get('metadata', {}).get('files', None): # Check if all files are in full context mode - all_full_context = all(item.get("context") == "full" for item in files) + all_full_context = all(item.get('context') == 'full' for item in files) queries = [] if not all_full_context: @@ -1904,44 +1829,44 @@ async def chat_completion_files_handler( queries_response = await generate_queries( request, { - "model": body["model"], - "messages": body["messages"], - "type": "retrieval", - "chat_id": body.get("metadata", {}).get("chat_id"), + 'model': body['model'], + 'messages': body['messages'], + 'type': 'retrieval', + 'chat_id': body.get('metadata', {}).get('chat_id'), }, user, ) - queries_response = queries_response["choices"][0]["message"]["content"] + queries_response = queries_response['choices'][0]['message']['content'] try: - bracket_start = queries_response.find("{") - bracket_end = queries_response.rfind("}") + 1 + bracket_start = queries_response.find('{') + bracket_end = queries_response.rfind('}') + 1 if bracket_start == -1 or bracket_end == -1: - raise Exception("No JSON object found in the response") + raise Exception('No JSON object found in the response') queries_response = queries_response[bracket_start:bracket_end] queries_response = json.loads(queries_response) except Exception as e: - queries_response = {"queries": [queries_response]} + queries_response = {'queries': [queries_response]} - queries = queries_response.get("queries", []) + queries = queries_response.get('queries', []) except Exception: pass await __event_emitter__( { - "type": "status", - "data": { - "action": "queries_generated", - "queries": queries, - "done": False, + 'type': 'status', + 'data': { + 'action': 'queries_generated', + 'queries': queries, + 'done': False, }, } ) if len(queries) == 0: - queries = [get_last_user_message(body["messages"])] + queries = [get_last_user_message(body['messages'])] try: # Directly await async get_sources_from_items (no thread needed - fully async now) @@ -1954,11 +1879,7 @@ async def chat_completion_files_handler( ), k=request.app.state.config.TOP_K, reranking_function=( - ( - lambda query, documents: request.app.state.RERANKING_FUNCTION( - query, documents, user=user - ) - ) + (lambda query, documents: request.app.state.RERANKING_FUNCTION(query, documents, user=user)) if request.app.state.RERANKING_FUNCTION else None ), @@ -1966,58 +1887,53 @@ async def chat_completion_files_handler( r=request.app.state.config.RELEVANCE_THRESHOLD, hybrid_bm25_weight=request.app.state.config.HYBRID_BM25_WEIGHT, hybrid_search=request.app.state.config.ENABLE_RAG_HYBRID_SEARCH, - full_context=all_full_context - or request.app.state.config.RAG_FULL_CONTEXT, + full_context=all_full_context or request.app.state.config.RAG_FULL_CONTEXT, user=user, ) except Exception as e: log.exception(e) - log.debug(f"rag_contexts:sources: {sources}") + log.debug(f'rag_contexts:sources: {sources}') unique_ids = set() for source in sources or []: if not source or len(source.keys()) == 0: continue - documents = source.get("document") or [] - metadatas = source.get("metadata") or [] - src_info = source.get("source") or {} + documents = source.get('document') or [] + metadatas = source.get('metadata') or [] + src_info = source.get('source') or {} for index, _ in enumerate(documents): metadata = metadatas[index] if index < len(metadatas) else None - _id = ( - (metadata or {}).get("source") - or (src_info or {}).get("id") - or "N/A" - ) + _id = (metadata or {}).get('source') or (src_info or {}).get('id') or 'N/A' unique_ids.add(_id) sources_count = len(unique_ids) await __event_emitter__( { - "type": "status", - "data": { - "action": "sources_retrieved", - "count": sources_count, - "done": True, + 'type': 'status', + 'data': { + 'action': 'sources_retrieved', + 'count': sources_count, + 'done': True, }, } ) - return body, {"sources": sources} + return body, {'sources': sources} def apply_params_to_form_data(form_data, model): - params = form_data.pop("params", {}) - custom_params = params.pop("custom_params", {}) + params = form_data.pop('params', {}) + custom_params = params.pop('custom_params', {}) open_webui_params = { - "stream_response": bool, - "stream_delta_chunk_size": int, - "function_calling": str, - "reasoning_tags": list, - "system": str, + 'stream_response': bool, + 'stream_delta_chunk_size': int, + 'function_calling': str, + 'reasoning_tags': list, + 'system': str, } for key in list(params.keys()): @@ -2038,62 +1954,60 @@ def apply_params_to_form_data(form_data, model): # If custom_params are provided, merge them into params params = deep_update(params, custom_params) - if model.get("owned_by") == "ollama": + if model.get('owned_by') == 'ollama': # Ollama specific parameters - form_data["options"] = params + form_data['options'] = params else: if isinstance(params, dict): for key, value in params.items(): if value is not None: form_data[key] = value - if "logit_bias" in params and params["logit_bias"] is not None: + if 'logit_bias' in params and params['logit_bias'] is not None: try: - logit_bias = convert_logit_bias_input_to_json(params["logit_bias"]) + logit_bias = convert_logit_bias_input_to_json(params['logit_bias']) if logit_bias: - form_data["logit_bias"] = json.loads(logit_bias) + form_data['logit_bias'] = json.loads(logit_bias) except Exception as e: - log.exception(f"Error parsing logit_bias: {e}") + log.exception(f'Error parsing logit_bias: {e}') return form_data async def convert_url_images_to_base64(form_data): - messages = form_data.get("messages", []) + messages = form_data.get('messages', []) for message in messages: - content = message.get("content") + content = message.get('content') if not isinstance(content, list): continue new_content = [] for item in content: - if not isinstance(item, dict) or item.get("type") != "image_url": + if not isinstance(item, dict) or item.get('type') != 'image_url': new_content.append(item) continue - image_url = item.get("image_url", {}).get("url", "") - if image_url.startswith("data:image/"): + image_url = item.get('image_url', {}).get('url', '') + if image_url.startswith('data:image/'): new_content.append(item) continue try: - base64_data = await asyncio.to_thread( - get_image_base64_from_url, image_url - ) + base64_data = await asyncio.to_thread(get_image_base64_from_url, image_url) new_content.append( { - "type": "image_url", - "image_url": {"url": base64_data}, + 'type': 'image_url', + 'image_url': {'url': base64_data}, } ) except Exception as e: - log.debug(f"Error converting image URL to base64: {e}") + log.debug(f'Error converting image URL to base64: {e}') new_content.append(item) - message["content"] = new_content + message['content'] = new_content return form_data @@ -2111,10 +2025,7 @@ def load_messages_from_db(chat_id: str, message_id: str) -> Optional[list[dict]] if not db_messages: return None - return [ - {k: v for k, v in msg.items() if k in ("role", "content", "output", "files")} - for msg in db_messages - ] + return [{k: v for k, v in msg.items() if k in ('role', 'content', 'output', 'files')} for msg in db_messages] def process_messages_with_output(messages: list[dict]) -> list[dict]: @@ -2127,15 +2038,15 @@ def process_messages_with_output(messages: list[dict]) -> list[dict]: processed = [] for message in messages: - if message.get("role") == "assistant" and message.get("output"): + if message.get('role') == 'assistant' and message.get('output'): # Use output items for clean OpenAI-format messages - output_messages = convert_output_to_messages(message["output"], raw=True) + output_messages = convert_output_to_messages(message['output'], raw=True) if output_messages: processed.extend(output_messages) continue # Strip 'output' field before adding (LLM shouldn't see it) - clean_message = {k: v for k, v in message.items() if k != "output"} + clean_message = {k: v for k, v in message.items() if k != 'output'} processed.append(clean_message) return processed @@ -2147,54 +2058,51 @@ async def process_chat_payload(request, form_data, user, metadata, model): # -> Chat Files form_data = apply_params_to_form_data(form_data, model) - log.debug(f"form_data: {form_data}") + log.debug(f'form_data: {form_data}') # Load messages from DB when available — DB preserves structured 'output' items # which the frontend strips, causing tool calls to be merged into content. - chat_id = metadata.get("chat_id") - parent_message_id = metadata.get("parent_message_id") + chat_id = metadata.get('chat_id') + parent_message_id = metadata.get('parent_message_id') - if chat_id and parent_message_id and not chat_id.startswith("local:"): + if chat_id and parent_message_id and not chat_id.startswith('local:'): db_messages = load_messages_from_db(chat_id, parent_message_id) if db_messages: - system_message = get_system_message(form_data.get("messages", [])) - form_data["messages"] = ( - [system_message, *db_messages] if system_message else db_messages - ) + system_message = get_system_message(form_data.get('messages', [])) + form_data['messages'] = [system_message, *db_messages] if system_message else db_messages # Inject image files into content as image_url parts (mirrors frontend logic) - for message in form_data["messages"]: + for message in form_data['messages']: image_files = [ f - for f in message.get("files", []) - if f.get("type") == "image" - or (f.get("content_type") or "").startswith("image/") + for f in message.get('files', []) + if f.get('type') == 'image' or (f.get('content_type') or '').startswith('image/') ] - if message.get("role") == "user" and image_files: - text_content = message.get("content", "") + if message.get('role') == 'user' and image_files: + text_content = message.get('content', '') if isinstance(text_content, str): - message["content"] = [ - {"type": "text", "text": text_content}, + message['content'] = [ + {'type': 'text', 'text': text_content}, *[ { - "type": "image_url", - "image_url": {"url": f["url"]}, + 'type': 'image_url', + 'image_url': {'url': f['url']}, } for f in image_files - if f.get("url") + if f.get('url') ], ] # Strip files field — it's been incorporated into content - message.pop("files", None) + message.pop('files', None) # Process messages with OR-aligned output items for clean LLM messages - form_data["messages"] = process_messages_with_output(form_data.get("messages", [])) + form_data['messages'] = process_messages_with_output(form_data.get('messages', [])) - system_message = get_system_message(form_data.get("messages", [])) + system_message = get_system_message(form_data.get('messages', [])) if system_message: # Chat Controls/User Settings try: form_data = apply_system_prompt_to_body( - system_message.get("content"), form_data, metadata, user, replace=True + system_message.get('content'), form_data, metadata, user, replace=True ) # Required to handle system prompt variables except Exception: pass @@ -2205,27 +2113,27 @@ async def process_chat_payload(request, form_data, user, metadata, model): event_caller = get_event_call(metadata) extra_params = { - "__event_emitter__": event_emitter, - "__event_call__": event_caller, - "__user__": user.model_dump() if isinstance(user, UserModel) else {}, - "__metadata__": metadata, - "__oauth_token__": await get_system_oauth_token(request, user), - "__request__": request, - "__model__": model, - "__chat_id__": metadata.get("chat_id"), - "__message_id__": metadata.get("message_id"), + '__event_emitter__': event_emitter, + '__event_call__': event_caller, + '__user__': user.model_dump() if isinstance(user, UserModel) else {}, + '__metadata__': metadata, + '__oauth_token__': await get_system_oauth_token(request, user), + '__request__': request, + '__model__': model, + '__chat_id__': metadata.get('chat_id'), + '__message_id__': metadata.get('message_id'), } # Initialize events to store additional event to be sent to the client # Initialize contexts and citation - if getattr(request.state, "direct", False) and hasattr(request.state, "model"): + if getattr(request.state, 'direct', False) and hasattr(request.state, 'model'): models = { - request.state.model["id"]: request.state.model, + request.state.model['id']: request.state.model, } else: models = request.app.state.MODELS task_model_id = get_task_model_id( - form_data["model"], + form_data['model'], request.app.state.config.TASK_MODEL, request.app.state.config.TASK_MODEL_EXTERNAL, models, @@ -2237,215 +2145,194 @@ async def process_chat_payload(request, form_data, user, metadata, model): # Folder "Project" handling # Check if the request has chat_id and is inside of a folder # Uses lightweight column query — only fetches folder_id, not the full chat JSON blob - chat_id = metadata.get("chat_id", None) + chat_id = metadata.get('chat_id', None) if chat_id and user: folder_id = Chats.get_chat_folder_id(chat_id, user.id) if folder_id: folder = Folders.get_folder_by_id_and_user_id(folder_id, user.id) if folder and folder.data: - if "system_prompt" in folder.data: - form_data = apply_system_prompt_to_body( - folder.data["system_prompt"], form_data, metadata, user - ) - if "files" in folder.data: - if metadata.get("params", {}).get("function_calling") != "native": - form_data["files"] = [ - *folder.data["files"], - *form_data.get("files", []), + if 'system_prompt' in folder.data: + form_data = apply_system_prompt_to_body(folder.data['system_prompt'], form_data, metadata, user) + if 'files' in folder.data: + if metadata.get('params', {}).get('function_calling') != 'native': + form_data['files'] = [ + *folder.data['files'], + *form_data.get('files', []), ] else: # Native FC: skip RAG injection, builtin tools # will read folder knowledge from metadata. - metadata["folder_knowledge"] = folder.data["files"] + metadata['folder_knowledge'] = folder.data['files'] # Model "Knowledge" handling - user_message = get_last_user_message(form_data["messages"]) - model_knowledge = model.get("info", {}).get("meta", {}).get("knowledge", False) + user_message = get_last_user_message(form_data['messages']) + model_knowledge = model.get('info', {}).get('meta', {}).get('knowledge', False) - if ( - model_knowledge - and metadata.get("params", {}).get("function_calling") != "native" - ): + if model_knowledge and metadata.get('params', {}).get('function_calling') != 'native': await event_emitter( { - "type": "status", - "data": { - "action": "knowledge_search", - "query": user_message, - "done": False, + 'type': 'status', + 'data': { + 'action': 'knowledge_search', + 'query': user_message, + 'done': False, }, } ) knowledge_files = [] for item in model_knowledge: - if item.get("collection_name"): + if item.get('collection_name'): knowledge_files.append( { - "id": item.get("collection_name"), - "name": item.get("name"), - "legacy": True, + 'id': item.get('collection_name'), + 'name': item.get('name'), + 'legacy': True, } ) - elif item.get("collection_names"): + elif item.get('collection_names'): knowledge_files.append( { - "name": item.get("name"), - "type": "collection", - "collection_names": item.get("collection_names"), - "legacy": True, + 'name': item.get('name'), + 'type': 'collection', + 'collection_names': item.get('collection_names'), + 'legacy': True, } ) else: knowledge_files.append(item) - files = form_data.get("files", []) + files = form_data.get('files', []) files.extend(knowledge_files) - form_data["files"] = files + form_data['files'] = files - variables = form_data.pop("variables", None) + variables = form_data.pop('variables', None) # Process the form_data through the pipeline try: - form_data = await process_pipeline_inlet_filter( - request, form_data, user, models - ) + form_data = await process_pipeline_inlet_filter(request, form_data, user, models) except Exception as e: raise e try: - filter_ids = get_sorted_filter_ids( - request, model, metadata.get("filter_ids", []) - ) + filter_ids = get_sorted_filter_ids(request, model, metadata.get('filter_ids', [])) filter_functions = Functions.get_functions_by_ids(filter_ids) form_data, flags = await process_filter_functions( request=request, filter_functions=filter_functions, - filter_type="inlet", + filter_type='inlet', form_data=form_data, extra_params=extra_params, ) except Exception as e: - raise Exception(f"{e}") + raise Exception(f'{e}') - features = form_data.pop("features", None) or {} - extra_params["__features__"] = features + features = form_data.pop('features', None) or {} + extra_params['__features__'] = features if features: - if "voice" in features and features["voice"]: + if 'voice' in features and features['voice']: if request.app.state.config.VOICE_MODE_PROMPT_TEMPLATE != None: - if request.app.state.config.VOICE_MODE_PROMPT_TEMPLATE != "": + if request.app.state.config.VOICE_MODE_PROMPT_TEMPLATE != '': template = request.app.state.config.VOICE_MODE_PROMPT_TEMPLATE else: template = DEFAULT_VOICE_MODE_PROMPT_TEMPLATE - form_data["messages"] = add_or_update_system_message( + form_data['messages'] = add_or_update_system_message( template, - form_data["messages"], + form_data['messages'], ) - if "memory" in features and features["memory"]: + if 'memory' in features and features['memory']: # Skip forced memory injection when native FC is enabled - model can use memory tools - if metadata.get("params", {}).get("function_calling") != "native": - form_data = await chat_memory_handler( - request, form_data, extra_params, user - ) + if metadata.get('params', {}).get('function_calling') != 'native': + form_data = await chat_memory_handler(request, form_data, extra_params, user) - if "web_search" in features and features["web_search"]: + if 'web_search' in features and features['web_search']: # Skip forced RAG web search when native FC is enabled - model can use web_search tool - if metadata.get("params", {}).get("function_calling") != "native": - form_data = await chat_web_search_handler( - request, form_data, extra_params, user - ) + if metadata.get('params', {}).get('function_calling') != 'native': + form_data = await chat_web_search_handler(request, form_data, extra_params, user) - if "image_generation" in features and features["image_generation"]: + if 'image_generation' in features and features['image_generation']: # Skip forced image generation when native FC is enabled - model can use generate_image tool - if metadata.get("params", {}).get("function_calling") != "native": - form_data = await chat_image_generation_handler( - request, form_data, extra_params, user - ) + if metadata.get('params', {}).get('function_calling') != 'native': + form_data = await chat_image_generation_handler(request, form_data, extra_params, user) - if "code_interpreter" in features and features["code_interpreter"]: - engine = getattr( - request.app.state.config, "CODE_INTERPRETER_ENGINE", "pyodide" - ) + if 'code_interpreter' in features and features['code_interpreter']: + engine = getattr(request.app.state.config, 'CODE_INTERPRETER_ENGINE', 'pyodide') # Skip XML-tag prompt injection when native FC is enabled — # execute_code will be injected as a builtin tool instead - if metadata.get("params", {}).get("function_calling") != "native": + if metadata.get('params', {}).get('function_calling') != 'native': prompt = ( request.app.state.config.CODE_INTERPRETER_PROMPT_TEMPLATE - if request.app.state.config.CODE_INTERPRETER_PROMPT_TEMPLATE != "" + if request.app.state.config.CODE_INTERPRETER_PROMPT_TEMPLATE != '' else DEFAULT_CODE_INTERPRETER_PROMPT ) # Append filesystem awareness only for pyodide engine - if engine != "jupyter": + if engine != 'jupyter': prompt += CODE_INTERPRETER_PYODIDE_PROMPT - form_data["messages"] = add_or_update_user_message( + form_data['messages'] = add_or_update_user_message( prompt, - form_data["messages"], + form_data['messages'], ) else: # Native FC: tool docstring can't be dynamic, so inject # filesystem context into messages for pyodide engine - if engine != "jupyter": - form_data["messages"] = add_or_update_user_message( + if engine != 'jupyter': + form_data['messages'] = add_or_update_user_message( CODE_INTERPRETER_PYODIDE_PROMPT, - form_data["messages"], + form_data['messages'], ) - tool_ids = form_data.pop("tool_ids", None) - terminal_id = form_data.pop("terminal_id", None) - files = form_data.pop("files", None) + tool_ids = form_data.pop('tool_ids', None) + terminal_id = form_data.pop('terminal_id', None) + files = form_data.pop('files', None) # Caller-provided OpenAI-style tools take precedence over server-side # tool resolution (tool_ids, MCP servers, builtin tools). - payload_tools = form_data.get("tools", None) + payload_tools = form_data.get('tools', None) # Skills - user_skill_ids = set(form_data.pop("skill_ids", None) or []) - model_skill_ids = set(model.get("info", {}).get("meta", {}).get("skillIds", [])) + user_skill_ids = set(form_data.pop('skill_ids', None) or []) + model_skill_ids = set(model.get('info', {}).get('meta', {}).get('skillIds', [])) all_skill_ids = user_skill_ids | model_skill_ids available_skills = [] if all_skill_ids: from open_webui.models.skills import Skills as SkillsModel - accessible_skill_ids = { - s.id for s in SkillsModel.get_skills_by_user_id(user.id, "read") - } + accessible_skill_ids = {s.id for s in SkillsModel.get_skills_by_user_id(user.id, 'read')} available_skills = [ s for sid in all_skill_ids - if sid in accessible_skill_ids - and (s := SkillsModel.get_skill_by_id(sid)) - and s.is_active + if sid in accessible_skill_ids and (s := SkillsModel.get_skill_by_id(sid)) and s.is_active ] - skill_descriptions = "" + skill_descriptions = '' for skill in available_skills: if skill.id in user_skill_ids: # User-selected: inject full content - form_data["messages"] = add_or_update_system_message( + form_data['messages'] = add_or_update_system_message( f'\n{skill.content}\n', - form_data["messages"], + form_data['messages'], append=True, ) else: # Model-attached: name+description only - skill_descriptions += f"\n{skill.name}\n{skill.description or ''}\n\n" + skill_descriptions += f'\n{skill.name}\n{skill.description or ""}\n\n' if skill_descriptions: - form_data["messages"] = add_or_update_system_message( - f"\n{skill_descriptions}", - form_data["messages"], + form_data['messages'] = add_or_update_system_message( + f'\n{skill_descriptions}', + form_data['messages'], append=True, ) - prompt = get_last_user_message(form_data["messages"]) + prompt = get_last_user_message(form_data['messages']) # TODO: re-enable URL extraction from prompt # urls = [] # if prompt and len(prompt or "") < 500 and (not files or len(files) == 0): @@ -2456,14 +2343,14 @@ async def process_chat_payload(request, form_data, user, metadata, model): files = [] for file_item in files: - if file_item.get("type", "file") == "folder": + if file_item.get('type', 'file') == 'folder': # Get folder files - folder_id = file_item.get("id", None) + folder_id = file_item.get('id', None) if folder_id: folder = Folders.get_folder_by_id_and_user_id(folder_id, user.id) - if folder and folder.data and "files" in folder.data: - files = [f for f in files if f.get("id", None) != folder_id] - files = [*files, *folder.data["files"]] + if folder and folder.data and 'files' in folder.data: + files = [f for f in files if f.get('id', None) != folder_id] + files = [*files, *folder.data['files']] # files = [*files, *[{"type": "url", "url": url, "name": url} for url in urls]] # Remove duplicate files based on their content @@ -2471,23 +2358,23 @@ async def process_chat_payload(request, form_data, user, metadata, model): metadata = { **metadata, - "tool_ids": tool_ids, - "terminal_id": terminal_id, - "files": files, + 'tool_ids': tool_ids, + 'terminal_id': terminal_id, + 'files': files, } - form_data["metadata"] = metadata + form_data['metadata'] = metadata # When the caller provides an explicit OpenAI-style `tools` array in the # request body, skip all server-side tool resolution and pass the caller's # tools through to the model unchanged. if not payload_tools: # Server side tools - tool_ids = metadata.get("tool_ids", None) + tool_ids = metadata.get('tool_ids', None) # Client side tools - direct_tool_servers = metadata.get("tool_servers", None) + direct_tool_servers = metadata.get('tool_servers', None) - log.debug(f"{tool_ids=}") - log.debug(f"{direct_tool_servers=}") + log.debug(f'{tool_ids=}') + log.debug(f'{direct_tool_servers=}') tools_dict = {} @@ -2496,70 +2383,57 @@ async def process_chat_payload(request, form_data, user, metadata, model): if tool_ids: for tool_id in tool_ids: - if tool_id.startswith("server:mcp:"): + if tool_id.startswith('server:mcp:'): try: - server_id = tool_id[len("server:mcp:") :] + server_id = tool_id[len('server:mcp:') :] mcp_server_connection = None - for ( - server_connection - ) in request.app.state.config.TOOL_SERVER_CONNECTIONS: + for server_connection in request.app.state.config.TOOL_SERVER_CONNECTIONS: if ( - server_connection.get("type", "") == "mcp" - and server_connection.get("info", {}).get("id") - == server_id + server_connection.get('type', '') == 'mcp' + and server_connection.get('info', {}).get('id') == server_id ): mcp_server_connection = server_connection break if not mcp_server_connection: - log.error(f"MCP server with id {server_id} not found") + log.error(f'MCP server with id {server_id} not found') continue # Check access control for MCP server if not has_connection_access(user, mcp_server_connection): - log.warning( - f"Access denied to MCP server {server_id} for user {user.id}" - ) + log.warning(f'Access denied to MCP server {server_id} for user {user.id}') continue - auth_type = mcp_server_connection.get("auth_type", "") + auth_type = mcp_server_connection.get('auth_type', '') headers = {} - if auth_type == "bearer": - headers["Authorization"] = ( - f"Bearer {mcp_server_connection.get('key', '')}" - ) - elif auth_type == "none": + if auth_type == 'bearer': + headers['Authorization'] = f'Bearer {mcp_server_connection.get("key", "")}' + elif auth_type == 'none': # No authentication pass - elif auth_type == "session": - headers["Authorization"] = ( - f"Bearer {request.state.token.credentials}" - ) - elif auth_type == "system_oauth": - oauth_token = extra_params.get("__oauth_token__", None) + elif auth_type == 'session': + headers['Authorization'] = f'Bearer {request.state.token.credentials}' + elif auth_type == 'system_oauth': + oauth_token = extra_params.get('__oauth_token__', None) if oauth_token: - headers["Authorization"] = ( - f"Bearer {oauth_token.get('access_token', '')}" - ) - elif auth_type == "oauth_2.1": + headers['Authorization'] = f'Bearer {oauth_token.get("access_token", "")}' + elif auth_type == 'oauth_2.1': try: - splits = server_id.split(":") + splits = server_id.split(':') server_id = splits[-1] if len(splits) > 1 else server_id oauth_token = await request.app.state.oauth_client_manager.get_oauth_token( - user.id, f"mcp:{server_id}" + user.id, f'mcp:{server_id}' ) if oauth_token: - headers["Authorization"] = ( - f"Bearer {oauth_token.get('access_token', '')}" - ) + headers['Authorization'] = f'Bearer {oauth_token.get("access_token", "")}' except Exception as e: - log.error(f"Error getting OAuth token: {e}") + log.error(f'Error getting OAuth token: {e}') oauth_token = None - connection_headers = mcp_server_connection.get("headers", None) + connection_headers = mcp_server_connection.get('headers', None) if connection_headers and isinstance(connection_headers, dict): for key, value in connection_headers.items(): headers[key] = value @@ -2567,29 +2441,23 @@ async def process_chat_payload(request, form_data, user, metadata, model): # Add user info headers if enabled if ENABLE_FORWARD_USER_INFO_HEADERS and user: headers = include_user_info_headers(headers, user) - if metadata and metadata.get("chat_id"): - headers[FORWARD_SESSION_INFO_HEADER_CHAT_ID] = ( - metadata.get("chat_id") - ) - if metadata and metadata.get("message_id"): - headers[FORWARD_SESSION_INFO_HEADER_MESSAGE_ID] = ( - metadata.get("message_id") - ) + if metadata and metadata.get('chat_id'): + headers[FORWARD_SESSION_INFO_HEADER_CHAT_ID] = metadata.get('chat_id') + if metadata and metadata.get('message_id'): + headers[FORWARD_SESSION_INFO_HEADER_MESSAGE_ID] = metadata.get('message_id') mcp_clients[server_id] = MCPClient() await mcp_clients[server_id].connect( - url=mcp_server_connection.get("url", ""), + url=mcp_server_connection.get('url', ''), headers=headers if headers else None, ) - function_name_filter_list = mcp_server_connection.get( - "config", {} - ).get("function_name_filter_list", "") + function_name_filter_list = mcp_server_connection.get('config', {}).get( + 'function_name_filter_list', '' + ) if isinstance(function_name_filter_list, str): - function_name_filter_list = function_name_filter_list.split( - "," - ) + function_name_filter_list = function_name_filter_list.split(',') tool_specs = await mcp_clients[server_id].list_tool_specs() for tool_spec in tool_specs: @@ -2604,37 +2472,29 @@ async def process_chat_payload(request, form_data, user, metadata, model): return tool_function if function_name_filter_list: - if not is_string_allowed( - tool_spec["name"], function_name_filter_list - ): + if not is_string_allowed(tool_spec['name'], function_name_filter_list): # Skip this function continue - tool_function = make_tool_function( - mcp_clients[server_id], tool_spec["name"] - ) + tool_function = make_tool_function(mcp_clients[server_id], tool_spec['name']) - mcp_tools_dict[f"{server_id}_{tool_spec['name']}"] = { - "spec": { + mcp_tools_dict[f'{server_id}_{tool_spec["name"]}'] = { + 'spec': { **tool_spec, - "name": f"{server_id}_{tool_spec['name']}", + 'name': f'{server_id}_{tool_spec["name"]}', }, - "callable": tool_function, - "type": "mcp", - "client": mcp_clients[server_id], - "direct": False, + 'callable': tool_function, + 'type': 'mcp', + 'client': mcp_clients[server_id], + 'direct': False, } except Exception as e: log.debug(e) if event_emitter: await event_emitter( { - "type": "chat:message:error", - "data": { - "error": { - "content": f"Failed to connect to MCP server '{server_id}'" - } - }, + 'type': 'chat:message:error', + 'data': {'error': {'content': f"Failed to connect to MCP server '{server_id}'"}}, } ) continue @@ -2645,9 +2505,9 @@ async def process_chat_payload(request, form_data, user, metadata, model): user, { **extra_params, - "__model__": models[task_model_id], - "__messages__": form_data["messages"], - "__files__": metadata.get("files", []), + '__model__': models[task_model_id], + '__messages__': form_data['messages'], + '__files__': metadata.get('files', []), }, ) @@ -2671,40 +2531,33 @@ async def process_chat_payload(request, form_data, user, metadata, model): if direct_tool_servers: for tool_server in direct_tool_servers: - tool_specs = tool_server.pop("specs", []) + tool_specs = tool_server.pop('specs', []) for tool in tool_specs: - tools_dict[tool["name"]] = { - "spec": tool, - "direct": True, - "server": tool_server, + tools_dict[tool['name']] = { + 'spec': tool, + 'direct': True, + 'server': tool_server, } if mcp_clients: - metadata["mcp_clients"] = mcp_clients + metadata['mcp_clients'] = mcp_clients # Inject builtin tools for native function calling based on enabled features and model capability # Check if builtin_tools capability is enabled for this model (defaults to True if not specified) - builtin_tools_enabled = ( - model.get("info", {}).get("meta", {}).get("capabilities") or {} - ).get("builtin_tools", True) - if ( - metadata.get("params", {}).get("function_calling") == "native" - and builtin_tools_enabled - ): + builtin_tools_enabled = (model.get('info', {}).get('meta', {}).get('capabilities') or {}).get( + 'builtin_tools', True + ) + if metadata.get('params', {}).get('function_calling') == 'native' and builtin_tools_enabled: # Add file context to user messages - chat_id = metadata.get("chat_id") - form_data["messages"] = add_file_context( - form_data.get("messages", []), chat_id, user - ) + chat_id = metadata.get('chat_id') + form_data['messages'] = add_file_context(form_data.get('messages', []), chat_id, user) builtin_tools = get_builtin_tools( request, { **extra_params, - "__event_emitter__": event_emitter, - "__skill_ids__": [ - s.id for s in available_skills if s.id not in user_skill_ids - ], + '__event_emitter__': event_emitter, + '__skill_ids__': [s.id for s in available_skills if s.id not in user_skill_ids], }, features, model, @@ -2714,12 +2567,11 @@ async def process_chat_payload(request, form_data, user, metadata, model): tools_dict[name] = tool_dict if tools_dict: - if metadata.get("params", {}).get("function_calling") == "native": + if metadata.get('params', {}).get('function_calling') == 'native': # If the function calling is native, then call the tools function calling handler - metadata["tools"] = tools_dict - form_data["tools"] = [ - {"type": "function", "function": tool.get("spec", {})} - for tool in tools_dict.values() + metadata['tools'] = tools_dict + form_data['tools'] = [ + {'type': 'function', 'function': tool.get('spec', {})} for tool in tools_dict.values() ] else: # If the function calling is not native, then call the tools function calling handler @@ -2727,60 +2579,51 @@ async def process_chat_payload(request, form_data, user, metadata, model): form_data, flags = await chat_completion_tools_handler( request, form_data, extra_params, user, models, tools_dict ) - sources.extend(flags.get("sources", [])) + sources.extend(flags.get('sources', [])) except Exception as e: log.exception(e) # Check if file context extraction is enabled for this model (default True) - file_context_enabled = ( - model.get("info", {}).get("meta", {}).get("capabilities") or {} - ).get("file_context", True) + file_context_enabled = (model.get('info', {}).get('meta', {}).get('capabilities') or {}).get('file_context', True) if file_context_enabled: try: - form_data, flags = await chat_completion_files_handler( - request, form_data, extra_params, user - ) - sources.extend(flags.get("sources", [])) + form_data, flags = await chat_completion_files_handler(request, form_data, extra_params, user) + sources.extend(flags.get('sources', [])) except Exception as e: log.exception(e) # Save the pre-RAG message state so the native tool call loop can # restore to the true original (before file-source injection) rather # than a snapshot that already has the RAG template baked in. - system_message = get_system_message(form_data["messages"]) - metadata["system_prompt"] = ( - get_content_from_message(system_message) if system_message else None - ) - metadata["user_prompt"] = get_last_user_message(form_data["messages"]) - metadata["sources"] = sources[:] if sources else [] + system_message = get_system_message(form_data['messages']) + metadata['system_prompt'] = get_content_from_message(system_message) if system_message else None + metadata['user_prompt'] = get_last_user_message(form_data['messages']) + metadata['sources'] = sources[:] if sources else [] # If context is not empty, insert it into the messages if sources and prompt: - form_data["messages"] = apply_source_context_to_messages( - request, form_data["messages"], sources, prompt - ) + form_data['messages'] = apply_source_context_to_messages(request, form_data['messages'], sources, prompt) # If there are citations, add them to the data_items sources = [ source for source in sources - if source.get("source", {}).get("name", "") - or source.get("source", {}).get("id", "") + if source.get('source', {}).get('name', '') or source.get('source', {}).get('id', '') ] if len(sources) > 0: - events.append({"sources": sources}) + events.append({'sources': sources}) if model_knowledge: await event_emitter( { - "type": "status", - "data": { - "action": "knowledge_search", - "query": user_message, - "done": True, - "hidden": True, + 'type': 'status', + 'data': { + 'action': 'knowledge_search', + 'query': user_message, + 'done': True, + 'hidden': True, }, } ) @@ -2792,32 +2635,30 @@ def get_event_emitter_and_caller(metadata): event_emitter = None event_caller = None if ( - "session_id" in metadata - and metadata["session_id"] - and "chat_id" in metadata - and metadata["chat_id"] - and "message_id" in metadata - and metadata["message_id"] + 'session_id' in metadata + and metadata['session_id'] + and 'chat_id' in metadata + and metadata['chat_id'] + and 'message_id' in metadata + and metadata['message_id'] ): event_emitter = get_event_emitter(metadata) event_caller = get_event_call(metadata) return event_emitter, event_caller -def build_chat_response_context( - request, form_data, user, model, metadata, tasks, events -): +def build_chat_response_context(request, form_data, user, model, metadata, tasks, events): event_emitter, event_caller = get_event_emitter_and_caller(metadata) return { - "request": request, - "form_data": form_data, - "user": user, - "model": model, - "metadata": metadata, - "tasks": tasks, - "events": events, - "event_emitter": event_emitter, - "event_caller": event_caller, + 'request': request, + 'form_data': form_data, + 'user': user, + 'model': model, + 'metadata': metadata, + 'tasks': tasks, + 'events': events, + 'event_emitter': event_emitter, + 'event_caller': event_caller, } @@ -2829,9 +2670,9 @@ def get_response_data(response): if isinstance(response, JSONResponse): if isinstance(response.body, bytes): try: - response_data = json.loads(response.body.decode("utf-8", "replace")) + response_data = json.loads(response.body.decode('utf-8', 'replace')) except json.JSONDecodeError: - response_data = {"error": {"detail": "Invalid JSON response"}} + response_data = {'error': {'detail': 'Invalid JSON response'}} else: response_data = response elif isinstance(response, dict): @@ -2873,32 +2714,32 @@ def build_response_object(response, response_data): async def get_system_oauth_token(request, user): oauth_token = None try: - if request.cookies.get("oauth_session_id", None): + if request.cookies.get('oauth_session_id', None): oauth_token = await request.app.state.oauth_manager.get_oauth_token( user.id, - request.cookies.get("oauth_session_id", None), + request.cookies.get('oauth_session_id', None), ) except Exception as e: - log.error(f"Error getting OAuth token: {e}") + log.error(f'Error getting OAuth token: {e}') return oauth_token async def background_tasks_handler(ctx): - request = ctx["request"] - form_data = ctx["form_data"] - user = ctx["user"] - metadata = ctx["metadata"] - tasks = ctx["tasks"] - event_emitter = ctx["event_emitter"] + request = ctx['request'] + form_data = ctx['form_data'] + user = ctx['user'] + metadata = ctx['metadata'] + tasks = ctx['tasks'] + event_emitter = ctx['event_emitter'] message = None messages = [] - if "chat_id" in metadata and not metadata["chat_id"].startswith("local:"): - messages_map = Chats.get_messages_map_by_chat_id(metadata["chat_id"]) - message = messages_map.get(metadata["message_id"]) if messages_map else None + if 'chat_id' in metadata and not metadata['chat_id'].startswith('local:'): + messages_map = Chats.get_messages_map_by_chat_id(metadata['chat_id']) + message = messages_map.get(metadata['message_id']) if messages_map else None - message_list = get_message_list(messages_map, metadata["message_id"]) + message_list = get_message_list(messages_map, metadata['message_id']) # Remove details tags and files from the messages. # as get_message_list creates a new list, it does not affect @@ -2906,17 +2747,17 @@ async def background_tasks_handler(ctx): messages = [] for message in message_list: - content = message.get("content", "") + content = message.get('content', '') if isinstance(content, list): for item in content: - if item.get("type") == "text": - content = item["text"] + if item.get('type') == 'text': + content = item['text'] break if isinstance(content, str): content = re.sub( - r"]*>.*?<\/details>|!\[.*?\]\(.*?\)", - "", + r']*>.*?<\/details>|!\[.*?\]\(.*?\)', + '', content, flags=re.S | re.I, ).strip() @@ -2924,141 +2765,128 @@ async def background_tasks_handler(ctx): messages.append( { **message, - "role": message.get( - "role", "assistant" - ), # Safe fallback for missing role - "content": content, + 'role': message.get('role', 'assistant'), # Safe fallback for missing role + 'content': content, } ) else: # Local temp chat, get the model and message from the form_data - message = get_last_user_message_item(form_data.get("messages", [])) - messages = form_data.get("messages", []) + message = get_last_user_message_item(form_data.get('messages', [])) + messages = form_data.get('messages', []) if message: - message["model"] = form_data.get("model") + message['model'] = form_data.get('model') - if message and "model" in message: + if message and 'model' in message: if tasks and messages: - if ( - TASKS.FOLLOW_UP_GENERATION in tasks - and tasks[TASKS.FOLLOW_UP_GENERATION] - ): + if TASKS.FOLLOW_UP_GENERATION in tasks and tasks[TASKS.FOLLOW_UP_GENERATION]: res = await generate_follow_ups( request, { - "model": message["model"], - "messages": messages, - "message_id": metadata["message_id"], - "chat_id": metadata["chat_id"], + 'model': message['model'], + 'messages': messages, + 'message_id': metadata['message_id'], + 'chat_id': metadata['chat_id'], }, user, ) if res and isinstance(res, dict): - if len(res.get("choices", [])) == 1: - response_message = res.get("choices", [])[0].get("message", {}) + if len(res.get('choices', [])) == 1: + response_message = res.get('choices', [])[0].get('message', {}) - follow_ups_string = response_message.get( - "content" - ) or response_message.get("reasoning_content", "") + follow_ups_string = response_message.get('content') or response_message.get( + 'reasoning_content', '' + ) else: - follow_ups_string = "" + follow_ups_string = '' follow_ups_string = follow_ups_string[ - follow_ups_string.find("{") : follow_ups_string.rfind("}") + 1 + follow_ups_string.find('{') : follow_ups_string.rfind('}') + 1 ] try: - follow_ups = json.loads(follow_ups_string).get("follow_ups", []) + follow_ups = json.loads(follow_ups_string).get('follow_ups', []) await event_emitter( { - "type": "chat:message:follow_ups", - "data": { - "follow_ups": follow_ups, + 'type': 'chat:message:follow_ups', + 'data': { + 'follow_ups': follow_ups, }, } ) - if not metadata.get("chat_id", "").startswith("local:"): + if not metadata.get('chat_id', '').startswith('local:'): Chats.upsert_message_to_chat_by_id_and_message_id( - metadata["chat_id"], - metadata["message_id"], + metadata['chat_id'], + metadata['message_id'], { - "followUps": follow_ups, + 'followUps': follow_ups, }, ) except Exception as e: pass - if not metadata.get("chat_id", "").startswith( - "local:" - ): # Only update titles and tags for non-temp chats + if not metadata.get('chat_id', '').startswith('local:'): # Only update titles and tags for non-temp chats if TASKS.TITLE_GENERATION in tasks: user_message = get_last_user_message(messages) if user_message and len(user_message) > 100: - user_message = user_message[:100] + "..." + user_message = user_message[:100] + '...' title = None if tasks[TASKS.TITLE_GENERATION]: res = await generate_title( request, { - "model": message["model"], - "messages": messages, - "chat_id": metadata["chat_id"], + 'model': message['model'], + 'messages': messages, + 'chat_id': metadata['chat_id'], }, user, ) if res and isinstance(res, dict): - if len(res.get("choices", [])) == 1: - response_message = res.get("choices", [])[0].get( - "message", {} - ) + if len(res.get('choices', [])) == 1: + response_message = res.get('choices', [])[0].get('message', {}) title_string = ( - response_message.get("content") + response_message.get('content') or response_message.get( - "reasoning_content", + 'reasoning_content', ) - or message.get("content", user_message) + or message.get('content', user_message) ) else: - title_string = "" + title_string = '' - title_string = title_string[ - title_string.find("{") : title_string.rfind("}") + 1 - ] + title_string = title_string[title_string.find('{') : title_string.rfind('}') + 1] try: - title = json.loads(title_string).get( - "title", user_message - ) + title = json.loads(title_string).get('title', user_message) except Exception as e: - title = "" + title = '' if not title: - title = messages[0].get("content", user_message) + title = messages[0].get('content', user_message) - Chats.update_chat_title_by_id(metadata["chat_id"], title) + Chats.update_chat_title_by_id(metadata['chat_id'], title) await event_emitter( { - "type": "chat:title", - "data": title, + 'type': 'chat:title', + 'data': title, } ) if title == None and len(messages) == 2: - title = messages[0].get("content", user_message) + title = messages[0].get('content', user_message) - Chats.update_chat_title_by_id(metadata["chat_id"], title) + Chats.update_chat_title_by_id(metadata['chat_id'], title) await event_emitter( { - "type": "chat:title", - "data": message.get("content", user_message), + 'type': 'chat:title', + 'data': message.get('content', user_message), } ) @@ -3066,39 +2894,33 @@ async def background_tasks_handler(ctx): res = await generate_chat_tags( request, { - "model": message["model"], - "messages": messages, - "chat_id": metadata["chat_id"], + 'model': message['model'], + 'messages': messages, + 'chat_id': metadata['chat_id'], }, user, ) if res and isinstance(res, dict): - if len(res.get("choices", [])) == 1: - response_message = res.get("choices", [])[0].get( - "message", {} + if len(res.get('choices', [])) == 1: + response_message = res.get('choices', [])[0].get('message', {}) + + tags_string = response_message.get('content') or response_message.get( + 'reasoning_content', '' ) - - tags_string = response_message.get( - "content" - ) or response_message.get("reasoning_content", "") else: - tags_string = "" + tags_string = '' - tags_string = tags_string[ - tags_string.find("{") : tags_string.rfind("}") + 1 - ] + tags_string = tags_string[tags_string.find('{') : tags_string.rfind('}') + 1] try: - tags = json.loads(tags_string).get("tags", []) - Chats.update_chat_tags_by_id( - metadata["chat_id"], tags, user - ) + tags = json.loads(tags_string).get('tags', []) + Chats.update_chat_tags_by_id(metadata['chat_id'], tags, user) await event_emitter( { - "type": "chat:tags", - "data": tags, + 'type': 'chat:tags', + 'data': tags, } ) except Exception as e: @@ -3106,13 +2928,13 @@ async def background_tasks_handler(ctx): async def non_streaming_chat_response_handler(response, ctx): - request = ctx["request"] + request = ctx['request'] - user = ctx["user"] - metadata = ctx["metadata"] - events = ctx["events"] + user = ctx['user'] + metadata = ctx['metadata'] + events = ctx['events'] - event_emitter = ctx["event_emitter"] + event_emitter = ctx['event_emitter'] response, response_data = get_response_data(response) if response_data is None: @@ -3120,89 +2942,89 @@ async def non_streaming_chat_response_handler(response, ctx): if event_emitter: try: - if "error" in response_data: - error = response_data.get("error") + if 'error' in response_data: + error = response_data.get('error') if isinstance(error, dict): - error = error.get("detail", error) + error = error.get('detail', error) else: error = str(error) Chats.upsert_message_to_chat_by_id_and_message_id( - metadata["chat_id"], - metadata["message_id"], + metadata['chat_id'], + metadata['message_id'], { - "error": {"content": error}, + 'error': {'content': error}, }, ) if isinstance(error, str) or isinstance(error, dict): await event_emitter( { - "type": "chat:message:error", - "data": {"error": {"content": error}}, + 'type': 'chat:message:error', + 'data': {'error': {'content': error}}, } ) - if "selected_model_id" in response_data: + if 'selected_model_id' in response_data: Chats.upsert_message_to_chat_by_id_and_message_id( - metadata["chat_id"], - metadata["message_id"], + metadata['chat_id'], + metadata['message_id'], { - "selectedModelId": response_data["selected_model_id"], + 'selectedModelId': response_data['selected_model_id'], }, ) - choices = response_data.get("choices", []) - if choices and choices[0].get("message", {}).get("content"): - content = response_data["choices"][0]["message"]["content"] + choices = response_data.get('choices', []) + if choices and choices[0].get('message', {}).get('content'): + content = response_data['choices'][0]['message']['content'] if content: await event_emitter( { - "type": "chat:completion", - "data": response_data, + 'type': 'chat:completion', + 'data': response_data, } ) - title = Chats.get_chat_title_by_id(metadata["chat_id"]) + title = Chats.get_chat_title_by_id(metadata['chat_id']) # Use output from backend if provided (OR-compliant backends), # otherwise generate from response content - response_output = response_data.get("output") + response_output = response_data.get('output') if not response_output: response_output = [ { - "type": "message", - "id": output_id("msg"), - "status": "completed", - "role": "assistant", - "content": [{"type": "output_text", "text": content}], + 'type': 'message', + 'id': output_id('msg'), + 'status': 'completed', + 'role': 'assistant', + 'content': [{'type': 'output_text', 'text': content}], } ] await event_emitter( { - "type": "chat:completion", - "data": { - "done": True, - "content": content, - "output": response_output, - "title": title, + 'type': 'chat:completion', + 'data': { + 'done': True, + 'content': content, + 'output': response_output, + 'title': title, }, } ) # Save message in the database - usage = normalize_usage(response_data.get("usage", {}) or {}) + usage = normalize_usage(response_data.get('usage', {}) or {}) Chats.upsert_message_to_chat_by_id_and_message_id( - metadata["chat_id"], - metadata["message_id"], + metadata['chat_id'], + metadata['message_id'], { - "role": "assistant", - "content": content, - "output": response_output, - **({"usage": usage} if usage else {}), + 'role': 'assistant', + 'content': content, + 'output': response_output, + **({'usage': usage} if usage else {}), }, ) @@ -3213,22 +3035,20 @@ async def non_streaming_chat_response_handler(response, ctx): await post_webhook( request.app.state.WEBUI_NAME, webhook_url, - f"{title} - {request.app.state.config.WEBUI_URL}/c/{metadata['chat_id']}\n\n{content}", + f'{title} - {request.app.state.config.WEBUI_URL}/c/{metadata["chat_id"]}\n\n{content}', { - "action": "chat", - "message": content, - "title": title, - "url": f"{request.app.state.config.WEBUI_URL}/c/{metadata['chat_id']}", + 'action': 'chat', + 'message': content, + 'title': title, + 'url': f'{request.app.state.config.WEBUI_URL}/c/{metadata["chat_id"]}', }, ) await background_tasks_handler(ctx) - response = build_response_object( - response, merge_events_into_response(response_data, events) - ) + response = build_response_object(response, merge_events_into_response(response_data, events)) except Exception as e: - log.debug(f"Error occurred while processing request: {e}") + log.debug(f'Error occurred while processing request: {e}') pass return response @@ -3240,40 +3060,38 @@ async def non_streaming_chat_response_handler(response, ctx): async def streaming_chat_response_handler(response, ctx): - request = ctx["request"] + request = ctx['request'] - form_data = ctx["form_data"] + form_data = ctx['form_data'] - user = ctx["user"] - model = ctx["model"] + user = ctx['user'] + model = ctx['model'] - metadata = ctx["metadata"] - events = ctx["events"] + metadata = ctx['metadata'] + events = ctx['events'] - event_emitter = ctx["event_emitter"] - event_caller = ctx["event_caller"] + event_emitter = ctx['event_emitter'] + event_caller = ctx['event_caller'] extra_params = { - "__event_emitter__": event_emitter, - "__event_call__": event_caller, - "__user__": user.model_dump() if isinstance(user, UserModel) else {}, - "__metadata__": metadata, - "__oauth_token__": await get_system_oauth_token(request, user), - "__request__": request, - "__model__": model, + '__event_emitter__': event_emitter, + '__event_call__': event_caller, + '__user__': user.model_dump() if isinstance(user, UserModel) else {}, + '__metadata__': metadata, + '__oauth_token__': await get_system_oauth_token(request, user), + '__request__': request, + '__model__': model, } filter_functions = [ Functions.get_function_by_id(filter_id) - for filter_id in get_sorted_filter_ids( - request, model, metadata.get("filter_ids", []) - ) + for filter_id in get_sorted_filter_ids(request, model, metadata.get('filter_ids', [])) ] # Standard streaming response handler if event_emitter and event_caller: task_id = str(uuid4()) # Create a unique task ID. - model_id = form_data.get("model", "") + model_id = form_data.get('model', '') # Handle as a background task async def response_handler(response, events): @@ -3300,46 +3118,43 @@ async def streaming_chat_response_handler(response, ctx): def get_last_text(out): """Get text from last message item, or empty string.""" - if out and out[-1].get("type") == "message": - parts = out[-1].get("content", []) - if parts and parts[-1].get("type") == "output_text": - return parts[-1].get("text", "") - return "" + if out and out[-1].get('type') == 'message': + parts = out[-1].get('content', []) + if parts and parts[-1].get('type') == 'output_text': + return parts[-1].get('text', '') + return '' def set_last_text(out, text): """Set text on last message item's output_text.""" - if out and out[-1].get("type") == "message": - parts = out[-1].get("content", []) - if parts and parts[-1].get("type") == "output_text": - parts[-1]["text"] = text + if out and out[-1].get('type') == 'message': + parts = out[-1].get('content', []) + if parts and parts[-1].get('type') == 'output_text': + parts[-1]['text'] = text # Map content_type to output item type output_type_map = { - "reasoning": "reasoning", - "solution": "message", # solution tags just produce text - "code_interpreter": "open_webui:code_interpreter", + 'reasoning': 'reasoning', + 'solution': 'message', # solution tags just produce text + 'code_interpreter': 'open_webui:code_interpreter', } output_item_type = output_type_map.get(content_type, content_type) - last_type = output[-1].get("type", "") if output else "" + last_type = output[-1].get('type', '') if output else '' - if last_type == "message": + if last_type == 'message': # Use the output item's own text for tag detection item_text = get_last_text(output) for start_tag, end_tag in tags: - - start_tag_pattern = rf"{re.escape(start_tag)}" - if start_tag.startswith("<") and start_tag.endswith(">"): - start_tag_pattern = ( - rf"<{re.escape(start_tag[1:-1])}(\s.*?)?>" - ) + start_tag_pattern = rf'{re.escape(start_tag)}' + if start_tag.startswith('<') and start_tag.endswith('>'): + start_tag_pattern = rf'<{re.escape(start_tag[1:-1])}(\s.*?)?>' match = re.search(start_tag_pattern, item_text) if match: try: - attr_content = match.group(1) if match.group(1) else "" + attr_content = match.group(1) if match.group(1) else '' except Exception: - attr_content = "" + attr_content = '' attributes = extract_attributes(attr_content) @@ -3351,102 +3166,90 @@ async def streaming_chat_response_handler(response, ctx): if not before_tag.strip(): # Remove empty message item - if output and output[-1].get("type") == "message": + if output and output[-1].get('type') == 'message': output.pop() # Append the new output item - if output_item_type == "reasoning": + if output_item_type == 'reasoning': output.append( { - "type": "reasoning", - "id": output_id("r"), - "status": "in_progress", - "start_tag": start_tag, - "end_tag": end_tag, - "attributes": attributes, - "content": [], - "summary": None, - "started_at": time.time(), + 'type': 'reasoning', + 'id': output_id('r'), + 'status': 'in_progress', + 'start_tag': start_tag, + 'end_tag': end_tag, + 'attributes': attributes, + 'content': [], + 'summary': None, + 'started_at': time.time(), } ) - elif output_item_type == "open_webui:code_interpreter": + elif output_item_type == 'open_webui:code_interpreter': output.append( { - "type": "open_webui:code_interpreter", - "id": output_id("ci"), - "status": "in_progress", - "start_tag": start_tag, - "end_tag": end_tag, - "attributes": attributes, - "lang": attributes.get("lang", "python"), - "code": "", - "output": None, - "started_at": time.time(), + 'type': 'open_webui:code_interpreter', + 'id': output_id('ci'), + 'status': 'in_progress', + 'start_tag': start_tag, + 'end_tag': end_tag, + 'attributes': attributes, + 'lang': attributes.get('lang', 'python'), + 'code': '', + 'output': None, + 'started_at': time.time(), } ) else: # solution or other text-producing tag output.append( { - "type": "message", - "id": output_id("msg"), - "status": "in_progress", - "role": "assistant", - "content": [ - {"type": "output_text", "text": ""} - ], - "_tag_type": content_type, - "start_tag": start_tag, - "end_tag": end_tag, - "attributes": attributes, - "started_at": time.time(), + 'type': 'message', + 'id': output_id('msg'), + 'status': 'in_progress', + 'role': 'assistant', + 'content': [{'type': 'output_text', 'text': ''}], + '_tag_type': content_type, + 'start_tag': start_tag, + 'end_tag': end_tag, + 'attributes': attributes, + 'started_at': time.time(), } ) if after_tag: # Set the after_tag content on the new item - if output_item_type == "reasoning": - output[-1]["content"] = [ - {"type": "output_text", "text": after_tag} - ] - elif output_item_type == "open_webui:code_interpreter": - output[-1]["code"] = after_tag + if output_item_type == 'reasoning': + output[-1]['content'] = [{'type': 'output_text', 'text': after_tag}] + elif output_item_type == 'open_webui:code_interpreter': + output[-1]['code'] = after_tag else: set_last_text(output, after_tag) - _, recursive_end = tag_output_handler( - content_type, tags, output - ) + _, recursive_end = tag_output_handler(content_type, tags, output) if recursive_end: end_flag = True break elif ( - (last_type == "reasoning" and content_type == "reasoning") - or ( - last_type == "open_webui:code_interpreter" - and content_type == "code_interpreter" - ) - or ( - last_type == "message" - and output[-1].get("_tag_type") == content_type - ) + (last_type == 'reasoning' and content_type == 'reasoning') + or (last_type == 'open_webui:code_interpreter' and content_type == 'code_interpreter') + or (last_type == 'message' and output[-1].get('_tag_type') == content_type) ): item = output[-1] - start_tag = item.get("start_tag", "") - end_tag = item.get("end_tag", "") + start_tag = item.get('start_tag', '') + end_tag = item.get('end_tag', '') - end_tag_pattern = rf"{re.escape(end_tag)}" + end_tag_pattern = rf'{re.escape(end_tag)}' # Get the block content from the item itself - if last_type == "reasoning": - parts = item.get("content", []) - block_content = "" - if parts and parts[-1].get("type") == "output_text": - block_content = parts[-1].get("text", "") - elif last_type == "open_webui:code_interpreter": - block_content = item.get("code", "") + if last_type == 'reasoning': + parts = item.get('content', []) + block_content = '' + if parts and parts[-1].get('type') == 'output_text': + block_content = parts[-1].get('text', '') + elif last_type == 'open_webui:code_interpreter': + block_content = item.get('code', '') else: block_content = get_last_text(output) @@ -3454,57 +3257,43 @@ async def streaming_chat_response_handler(response, ctx): end_flag = True # Strip start and end tags from content - start_tag_pattern = rf"{re.escape(start_tag)}" - if start_tag.startswith("<") and start_tag.endswith(">"): - start_tag_pattern = ( - rf"<{re.escape(start_tag[1:-1])}(\s.*?)?>" - ) - block_content = re.sub( - start_tag_pattern, "", block_content - ).strip() + start_tag_pattern = rf'{re.escape(start_tag)}' + if start_tag.startswith('<') and start_tag.endswith('>'): + start_tag_pattern = rf'<{re.escape(start_tag[1:-1])}(\s.*?)?>' + block_content = re.sub(start_tag_pattern, '', block_content).strip() end_tag_regex = re.compile(end_tag_pattern, re.DOTALL) split_content = end_tag_regex.split(block_content, maxsplit=1) - block_content = ( - split_content[0].strip() if split_content else "" - ) - leftover_content = ( - split_content[1].strip() if len(split_content) > 1 else "" - ) + block_content = split_content[0].strip() if split_content else '' + leftover_content = split_content[1].strip() if len(split_content) > 1 else '' if block_content: # Update the item with final content - if last_type == "reasoning": - item["content"] = [ - {"type": "output_text", "text": block_content} - ] - item["ended_at"] = time.time() - item["duration"] = int( - item["ended_at"] - item["started_at"] - ) - item["status"] = "completed" - elif last_type == "open_webui:code_interpreter": - item["code"] = block_content - item["ended_at"] = time.time() - item["duration"] = int( - item["ended_at"] - item["started_at"] - ) + if last_type == 'reasoning': + item['content'] = [{'type': 'output_text', 'text': block_content}] + item['ended_at'] = time.time() + item['duration'] = int(item['ended_at'] - item['started_at']) + item['status'] = 'completed' + elif last_type == 'open_webui:code_interpreter': + item['code'] = block_content + item['ended_at'] = time.time() + item['duration'] = int(item['ended_at'] - item['started_at']) else: set_last_text(output, block_content) - item["ended_at"] = time.time() + item['ended_at'] = time.time() # Reset by appending a new message item for leftover output.append( { - "type": "message", - "id": output_id("msg"), - "status": "in_progress", - "role": "assistant", - "content": [ + 'type': 'message', + 'id': output_id('msg'), + 'status': 'in_progress', + 'role': 'assistant', + 'content': [ { - "type": "output_text", - "text": leftover_content, + 'type': 'output_text', + 'text': leftover_content, } ], } @@ -3514,14 +3303,14 @@ async def streaming_chat_response_handler(response, ctx): output.pop() output.append( { - "type": "message", - "id": output_id("msg"), - "status": "in_progress", - "role": "assistant", - "content": [ + 'type': 'message', + 'id': output_id('msg'), + 'status': 'in_progress', + 'role': 'assistant', + 'content': [ { - "type": "output_text", - "text": leftover_content, + 'type': 'output_text', + 'text': leftover_content, } ], } @@ -3529,29 +3318,23 @@ async def streaming_chat_response_handler(response, ctx): return output, end_flag - message = Chats.get_message_by_id_and_message_id( - metadata["chat_id"], metadata["message_id"] - ) + message = Chats.get_message_by_id_and_message_id(metadata['chat_id'], metadata['message_id']) tool_calls = [] last_assistant_message = None try: - if form_data["messages"][-1]["role"] == "assistant": - last_assistant_message = get_last_assistant_message( - form_data["messages"] - ) + if form_data['messages'][-1]['role'] == 'assistant': + last_assistant_message = get_last_assistant_message(form_data['messages']) except Exception as e: pass content = ( - message.get("content", "") - if message - else last_assistant_message if last_assistant_message else "" + message.get('content', '') if message else last_assistant_message if last_assistant_message else '' ) # Initialize output: use existing from message if continuing, else create new - existing_output = message.get("output") if message else None + existing_output = message.get('output') if message else None if existing_output: output = existing_output else: @@ -3559,11 +3342,11 @@ async def streaming_chat_response_handler(response, ctx): if content: output = [ { - "type": "message", - "id": output_id("msg"), - "status": "in_progress", - "role": "assistant", - "content": [{"type": "output_text", "text": content}], + 'type': 'message', + 'id': output_id('msg'), + 'status': 'in_progress', + 'role': 'assistant', + 'content': [{'type': 'output_text', 'text': content}], } ] else: @@ -3571,21 +3354,14 @@ async def streaming_chat_response_handler(response, ctx): usage = None - reasoning_tags_param = metadata.get("params", {}).get("reasoning_tags") + reasoning_tags_param = metadata.get('params', {}).get('reasoning_tags') DETECT_REASONING_TAGS = reasoning_tags_param is not False - DETECT_CODE_INTERPRETER = metadata.get("features", {}).get( - "code_interpreter", False - ) + DETECT_CODE_INTERPRETER = metadata.get('features', {}).get('code_interpreter', False) reasoning_tags = [] if DETECT_REASONING_TAGS: - if ( - isinstance(reasoning_tags_param, list) - and len(reasoning_tags_param) == 2 - ): - reasoning_tags = [ - (reasoning_tags_param[0], reasoning_tags_param[1]) - ] + if isinstance(reasoning_tags_param, list) and len(reasoning_tags_param) == 2: + reasoning_tags = [(reasoning_tags_param[0], reasoning_tags_param[1])] else: reasoning_tags = DEFAULT_REASONING_TAGS @@ -3593,15 +3369,15 @@ async def streaming_chat_response_handler(response, ctx): for event in events: await event_emitter( { - "type": "chat:completion", - "data": event, + 'type': 'chat:completion', + 'data': event, } ) # Save message in the database Chats.upsert_message_to_chat_by_id_and_message_id( - metadata["chat_id"], - metadata["message_id"], + metadata['chat_id'], + metadata['message_id'], { **event, }, @@ -3617,10 +3393,7 @@ async def streaming_chat_response_handler(response, ctx): delta_count = 0 delta_chunk_size = max( CHAT_RESPONSE_STREAM_DELTA_CHUNK_SIZE, - int( - metadata.get("params", {}).get("stream_delta_chunk_size") - or 1 - ), + int(metadata.get('params', {}).get('stream_delta_chunk_size') or 1), ) last_delta_data = None @@ -3631,19 +3404,15 @@ async def streaming_chat_response_handler(response, ctx): if delta_count >= threshold and last_delta_data: await event_emitter( { - "type": "chat:completion", - "data": last_delta_data, + 'type': 'chat:completion', + 'data': last_delta_data, } ) delta_count = 0 last_delta_data = None async for line in response.body_iterator: - line = ( - line.decode("utf-8", "replace") - if isinstance(line, bytes) - else line - ) + line = line.decode('utf-8', 'replace') if isinstance(line, bytes) else line data = line # Skip empty lines @@ -3651,11 +3420,11 @@ async def streaming_chat_response_handler(response, ctx): continue # "data:" is the prefix for each event - if not data.startswith("data:"): + if not data.startswith('data:'): continue # Remove the prefix - data = data[len("data:") :].strip() + data = data[len('data:') :].strip() try: data = json.loads(data) @@ -3663,41 +3432,37 @@ async def streaming_chat_response_handler(response, ctx): data, _ = await process_filter_functions( request=request, filter_functions=filter_functions, - filter_type="stream", + filter_type='stream', form_data=data, - extra_params={"__body__": form_data, **extra_params}, + extra_params={'__body__': form_data, **extra_params}, ) if data: - if "event" in data and not getattr( - request.state, "direct", False - ): - await event_emitter(data.get("event", {})) + if 'event' in data and not getattr(request.state, 'direct', False): + await event_emitter(data.get('event', {})) - if "selected_model_id" in data: - model_id = data["selected_model_id"] + if 'selected_model_id' in data: + model_id = data['selected_model_id'] Chats.upsert_message_to_chat_by_id_and_message_id( - metadata["chat_id"], - metadata["message_id"], + metadata['chat_id'], + metadata['message_id'], { - "selectedModelId": model_id, + 'selectedModelId': model_id, }, ) await event_emitter( { - "type": "chat:completion", - "data": data, + 'type': 'chat:completion', + 'data': data, } ) # Check for Responses API events (type field starts with "response.") - elif data.get("type", "").startswith("response."): - output, response_metadata = ( - handle_responses_streaming_event(data, output) - ) + elif data.get('type', '').startswith('response.'): + output, response_metadata = handle_responses_streaming_event(data, output) processed_data = { - "output": output, - "content": serialize_output(output), + 'output': output, + 'content': serialize_output(output), } # print(data) @@ -3709,137 +3474,108 @@ async def streaming_chat_response_handler(response, ctx): await event_emitter( { - "type": "chat:completion", - "data": processed_data, + 'type': 'chat:completion', + 'data': processed_data, } ) continue else: - choices = data.get("choices", []) + choices = data.get('choices', []) # Normalize usage data to standard format - raw_usage = data.get("usage", {}) or {} - raw_usage.update( - data.get("timings", {}) - ) # llama.cpp + raw_usage = data.get('usage', {}) or {} + raw_usage.update(data.get('timings', {})) # llama.cpp if raw_usage: usage = normalize_usage(raw_usage) await event_emitter( { - "type": "chat:completion", - "data": { - "usage": usage, + 'type': 'chat:completion', + 'data': { + 'usage': usage, }, } ) if not choices: - error = data.get("error", {}) + error = data.get('error', {}) if error: await event_emitter( { - "type": "chat:completion", - "data": { - "error": error, + 'type': 'chat:completion', + 'data': { + 'error': error, }, } ) continue - delta = choices[0].get("delta", {}) + delta = choices[0].get('delta', {}) # Handle delta annotations - annotations = delta.get("annotations") + annotations = delta.get('annotations') if annotations: for annotation in annotations: if ( - annotation.get("type") == "url_citation" - and "url_citation" in annotation + annotation.get('type') == 'url_citation' + and 'url_citation' in annotation ): - url_citation = annotation[ - "url_citation" - ] + url_citation = annotation['url_citation'] - url = url_citation.get("url", "") - title = url_citation.get("title", url) + url = url_citation.get('url', '') + title = url_citation.get('title', url) await event_emitter( { - "type": "source", - "data": { - "source": { - "name": title, - "url": url, + 'type': 'source', + 'data': { + 'source': { + 'name': title, + 'url': url, }, - "document": [title], - "metadata": [ + 'document': [title], + 'metadata': [ { - "source": url, - "name": title, + 'source': url, + 'name': title, } ], }, } ) - delta_tool_calls = delta.get("tool_calls", None) + delta_tool_calls = delta.get('tool_calls', None) if delta_tool_calls: for delta_tool_call in delta_tool_calls: - tool_call_index = delta_tool_call.get( - "index" - ) + tool_call_index = delta_tool_call.get('index') if tool_call_index is not None: # Check if the tool call already exists current_response_tool_call = None - for ( - response_tool_call - ) in response_tool_calls: - if ( - response_tool_call.get("index") - == tool_call_index - ): - current_response_tool_call = ( - response_tool_call - ) + for response_tool_call in response_tool_calls: + if response_tool_call.get('index') == tool_call_index: + current_response_tool_call = response_tool_call break if current_response_tool_call is None: # Add the new tool call - delta_tool_call.setdefault( - "function", {} - ) - delta_tool_call[ - "function" - ].setdefault("name", "") - delta_tool_call[ - "function" - ].setdefault("arguments", "") - response_tool_calls.append( - delta_tool_call - ) + delta_tool_call.setdefault('function', {}) + delta_tool_call['function'].setdefault('name', '') + delta_tool_call['function'].setdefault('arguments', '') + response_tool_calls.append(delta_tool_call) else: # Update the existing tool call - delta_name = delta_tool_call.get( - "function", {} - ).get("name") - delta_arguments = ( - delta_tool_call.get( - "function", {} - ).get("arguments") + delta_name = delta_tool_call.get('function', {}).get('name') + delta_arguments = delta_tool_call.get('function', {}).get( + 'arguments' ) if delta_name: - current_response_tool_call[ - "function" - ]["name"] = delta_name + current_response_tool_call['function']['name'] = delta_name if delta_arguments: - current_response_tool_call[ - "function" - ][ - "arguments" - ] += delta_arguments + current_response_tool_call['function']['arguments'] += ( + delta_arguments + ) # Emit pending tool calls in real-time if response_tool_calls: @@ -3849,44 +3585,34 @@ async def streaming_chat_response_handler(response, ctx): # Build pending function_call output items for display pending_fc_items = [] for tc in response_tool_calls: - call_id = tc.get("id", "") - func = tc.get("function", {}) + call_id = tc.get('id', '') + func = tc.get('function', {}) pending_fc_items.append( { - "type": "function_call", - "id": call_id - or output_id("fc"), - "call_id": call_id, - "name": func.get("name", ""), - "arguments": func.get( - "arguments", "{}" - ), - "status": "in_progress", + 'type': 'function_call', + 'id': call_id or output_id('fc'), + 'call_id': call_id, + 'name': func.get('name', ''), + 'arguments': func.get('arguments', '{}'), + 'status': 'in_progress', } ) pending_output = output + pending_fc_items await event_emitter( { - "type": "chat:completion", - "data": { - "content": serialize_output( - pending_output - ), + 'type': 'chat:completion', + 'data': { + 'content': serialize_output(pending_output), }, } ) - image_urls = get_image_urls( - delta.get("images", []), request, metadata, user - ) + image_urls = get_image_urls(delta.get('images', []), request, metadata, user) if image_urls: - image_file_list = [ - {"type": "image", "url": url} - for url in image_urls - ] + image_file_list = [{'type': 'image', 'url': url} for url in image_urls] message_files = Chats.add_message_files_by_id_and_message_id( - metadata["chat_id"], - metadata["message_id"], + metadata['chat_id'], + metadata['message_id'], image_file_list, ) if message_files is None: @@ -3894,84 +3620,72 @@ async def streaming_chat_response_handler(response, ctx): await event_emitter( { - "type": "files", - "data": {"files": message_files}, + 'type': 'files', + 'data': {'files': message_files}, } ) - value = delta.get("content") + value = delta.get('content') reasoning_content = ( - delta.get("reasoning_content") - or delta.get("reasoning") - or delta.get("thinking") + delta.get('reasoning_content') + or delta.get('reasoning') + or delta.get('thinking') ) if reasoning_content: - if ( - not output - or output[-1].get("type") != "reasoning" - ): + if not output or output[-1].get('type') != 'reasoning': reasoning_item = { - "type": "reasoning", - "id": output_id("r"), - "status": "in_progress", - "start_tag": "", - "end_tag": "", - "attributes": { - "type": "reasoning_content" - }, - "content": [], - "summary": None, - "started_at": time.time(), + 'type': 'reasoning', + 'id': output_id('r'), + 'status': 'in_progress', + 'start_tag': '', + 'end_tag': '', + 'attributes': {'type': 'reasoning_content'}, + 'content': [], + 'summary': None, + 'started_at': time.time(), } output.append(reasoning_item) else: reasoning_item = output[-1] # Append to reasoning content - parts = reasoning_item.get("content", []) - if ( - parts - and parts[-1].get("type") == "output_text" - ): - parts[-1]["text"] += reasoning_content + parts = reasoning_item.get('content', []) + if parts and parts[-1].get('type') == 'output_text': + parts[-1]['text'] += reasoning_content else: - reasoning_item["content"] = [ + reasoning_item['content'] = [ { - "type": "output_text", - "text": reasoning_content, + 'type': 'output_text', + 'text': reasoning_content, } ] - data = {"content": serialize_output(output)} + data = {'content': serialize_output(output)} if value: if ( output - and output[-1].get("type") == "reasoning" - and output[-1] - .get("attributes", {}) - .get("type") - == "reasoning_content" + and output[-1].get('type') == 'reasoning' + and output[-1].get('attributes', {}).get('type') == 'reasoning_content' ): reasoning_item = output[-1] - reasoning_item["ended_at"] = time.time() - reasoning_item["duration"] = int( - reasoning_item["ended_at"] - - reasoning_item["started_at"] + reasoning_item['ended_at'] = time.time() + reasoning_item['duration'] = int( + reasoning_item['ended_at'] - reasoning_item['started_at'] ) - reasoning_item["status"] = "completed" + reasoning_item['status'] = 'completed' output.append( { - "type": "message", - "id": output_id("msg"), - "status": "in_progress", - "role": "assistant", - "content": [ + 'type': 'message', + 'id': output_id('msg'), + 'status': 'in_progress', + 'role': 'assistant', + 'content': [ { - "type": "output_text", - "text": "", + 'type': 'output_text', + 'text': '', } ], } @@ -3982,17 +3696,13 @@ async def streaming_chat_response_handler(response, ctx): request, value, { - "chat_id": metadata.get( - "chat_id", None - ), - "message_id": metadata.get( - "message_id", None - ), + 'chat_id': metadata.get('chat_id', None), + 'message_id': metadata.get('message_id', None), }, user, ) - content = f"{content}{value}" + content = f'{content}{value}' # Check if we're inside a tag-based block # (reasoning, code_interpreter, or solution). @@ -4002,122 +3712,93 @@ async def streaming_chat_response_handler(response, ctx): # start tag on every chunk and fragments the # output. last_item = output[-1] if output else None - last_item_type = ( - last_item.get("type", "") - if last_item - else "" - ) + last_item_type = last_item.get('type', '') if last_item else '' inside_tag_block = ( last_item is not None - and last_item.get("status") == "in_progress" - and last_item.get("attributes", {}).get( - "type" - ) - != "reasoning_content" + and last_item.get('status') == 'in_progress' + and last_item.get('attributes', {}).get('type') != 'reasoning_content' and ( - last_item_type == "reasoning" - or last_item_type - == "open_webui:code_interpreter" + last_item_type == 'reasoning' + or last_item_type == 'open_webui:code_interpreter' or ( - last_item_type == "message" - and last_item.get("_tag_type") - is not None + last_item_type == 'message' + and last_item.get('_tag_type') is not None ) ) ) if inside_tag_block: # Append to the existing tag-based item - if ( - last_item_type - == "open_webui:code_interpreter" - ): - last_item["code"] = ( - last_item.get("code", "") + value - ) - elif last_item_type == "reasoning": - parts = last_item.get("content", []) - if ( - parts - and parts[-1].get("type") - == "output_text" - ): - parts[-1]["text"] += value + if last_item_type == 'open_webui:code_interpreter': + last_item['code'] = last_item.get('code', '') + value + elif last_item_type == 'reasoning': + parts = last_item.get('content', []) + if parts and parts[-1].get('type') == 'output_text': + parts[-1]['text'] += value else: - last_item["content"] = [ + last_item['content'] = [ { - "type": "output_text", - "text": value, + 'type': 'output_text', + 'text': value, } ] else: # solution or other _tag_type message - msg_parts = last_item.get("content", []) - if ( - msg_parts - and msg_parts[-1].get("type") - == "output_text" - ): - msg_parts[-1]["text"] += value + msg_parts = last_item.get('content', []) + if msg_parts and msg_parts[-1].get('type') == 'output_text': + msg_parts[-1]['text'] += value else: - last_item["content"] = [ + last_item['content'] = [ { - "type": "output_text", - "text": value, + 'type': 'output_text', + 'text': value, } ] else: - if ( - not output - or output[-1].get("type") != "message" - ): + if not output or output[-1].get('type') != 'message': output.append( { - "type": "message", - "id": output_id("msg"), - "status": "in_progress", - "role": "assistant", - "content": [ + 'type': 'message', + 'id': output_id('msg'), + 'status': 'in_progress', + 'role': 'assistant', + 'content': [ { - "type": "output_text", - "text": "", + 'type': 'output_text', + 'text': '', } ], } ) # Append value to last message item's text - msg_parts = output[-1].get("content", []) - if ( - msg_parts - and msg_parts[-1].get("type") - == "output_text" - ): - msg_parts[-1]["text"] += value + msg_parts = output[-1].get('content', []) + if msg_parts and msg_parts[-1].get('type') == 'output_text': + msg_parts[-1]['text'] += value else: - output[-1]["content"] = [ + output[-1]['content'] = [ { - "type": "output_text", - "text": value, + 'type': 'output_text', + 'text': value, } ] if DETECT_REASONING_TAGS: output, _ = tag_output_handler( - "reasoning", + 'reasoning', reasoning_tags, output, ) output, _ = tag_output_handler( - "solution", + 'solution', DEFAULT_SOLUTION_TAGS, output, ) if DETECT_CODE_INTERPRETER: output, end = tag_output_handler( - "code_interpreter", + 'code_interpreter', DEFAULT_CODE_INTERPRETER_TAGS, output, ) @@ -4128,16 +3809,16 @@ async def streaming_chat_response_handler(response, ctx): if ENABLE_REALTIME_CHAT_SAVE: # Save message in the database Chats.upsert_message_to_chat_by_id_and_message_id( - metadata["chat_id"], - metadata["message_id"], + metadata['chat_id'], + metadata['message_id'], { - "content": serialize_output(output), - "output": output, + 'content': serialize_output(output), + 'output': output, }, ) else: data = { - "content": serialize_output(output), + 'content': serialize_output(output), } if delta: @@ -4148,51 +3829,48 @@ async def streaming_chat_response_handler(response, ctx): else: await event_emitter( { - "type": "chat:completion", - "data": data, + 'type': 'chat:completion', + 'data': data, } ) except Exception as e: - done = "data: [DONE]" in line + done = 'data: [DONE]' in line if done: pass else: - log.debug(f"Error: {e}") + log.debug(f'Error: {e}') continue await flush_pending_delta_data() if output: # Clean up the last message item - if output[-1].get("type") == "message": - parts = output[-1].get("content", []) - if parts and parts[-1].get("type") == "output_text": - parts[-1]["text"] = parts[-1]["text"].strip() + if output[-1].get('type') == 'message': + parts = output[-1].get('content', []) + if parts and parts[-1].get('type') == 'output_text': + parts[-1]['text'] = parts[-1]['text'].strip() - if not parts[-1]["text"]: + if not parts[-1]['text']: output.pop() if not output: output.append( { - "type": "message", - "id": output_id("msg"), - "status": "in_progress", - "role": "assistant", - "content": [ - {"type": "output_text", "text": ""} - ], + 'type': 'message', + 'id': output_id('msg'), + 'status': 'in_progress', + 'role': 'assistant', + 'content': [{'type': 'output_text', 'text': ''}], } ) - if output[-1].get("type") == "reasoning": + if output[-1].get('type') == 'reasoning': reasoning_item = output[-1] - if reasoning_item.get("ended_at") is None: - reasoning_item["ended_at"] = time.time() - reasoning_item["duration"] = int( - reasoning_item["ended_at"] - - reasoning_item["started_at"] + if reasoning_item.get('ended_at') is None: + reasoning_item['ended_at'] = time.time() + reasoning_item['duration'] = int( + reasoning_item['ended_at'] - reasoning_item['started_at'] ) - reasoning_item["status"] = "completed" + reasoning_item['status'] = 'completed' if response_tool_calls: tool_calls.append(_split_tool_calls(response_tool_calls)) @@ -4205,69 +3883,61 @@ async def streaming_chat_response_handler(response, ctx): tool_call_retries = 0 tool_call_sources = [] # Track citation sources from tool results all_tool_call_sources = [] # Accumulated sources across all iterations - user_message = get_last_user_message(form_data["messages"]) + user_message = get_last_user_message(form_data['messages']) # Check if citations are enabled for this model - citations_enabled = ( - model.get("info", {}).get("meta", {}).get("capabilities") or {} - ).get("citations", True) + citations_enabled = (model.get('info', {}).get('meta', {}).get('capabilities') or {}).get( + 'citations', True + ) # Use the pre-RAG system content captured before the # initial file-source injection in process_chat_payload. # This ensures restore truly undoes the RAG template. - original_system_content = metadata.get("system_prompt") + original_system_content = metadata.get('system_prompt') if original_system_content is None: - original_system_message = get_system_message(form_data["messages"]) + original_system_message = get_system_message(form_data['messages']) original_system_content = ( - get_content_from_message(original_system_message) - if original_system_message - else None + get_content_from_message(original_system_message) if original_system_message else None ) - while ( - len(tool_calls) > 0 - and tool_call_retries < CHAT_RESPONSE_MAX_TOOL_CALL_RETRIES - ): - + while len(tool_calls) > 0 and tool_call_retries < CHAT_RESPONSE_MAX_TOOL_CALL_RETRIES: tool_call_retries += 1 response_tool_calls = tool_calls.pop(0) # Append function_call items for each tool call for tc in response_tool_calls: - call_id = tc.get("id", "") - func = tc.get("function", {}) + call_id = tc.get('id', '') + func = tc.get('function', {}) output.append( { - "type": "function_call", - "id": call_id or output_id("fc"), - "call_id": call_id, - "name": func.get("name", ""), - "arguments": func.get("arguments", "{}"), - "status": "in_progress", + 'type': 'function_call', + 'id': call_id or output_id('fc'), + 'call_id': call_id, + 'name': func.get('name', ''), + 'arguments': func.get('arguments', '{}'), + 'status': 'in_progress', } ) await event_emitter( { - "type": "chat:completion", - "data": { - "content": serialize_output(output), - "output": output, + 'type': 'chat:completion', + 'data': { + 'content': serialize_output(output), + 'output': output, }, } ) - tools = metadata.get("tools", {}) + tools = metadata.get('tools', {}) results = [] for tool_call in response_tool_calls: - tool_call_id = tool_call.get("id", "") - tool_function_name = tool_call.get("function", {}).get( - "name", "" - ) - tool_args = tool_call.get("function", {}).get("arguments", "{}") + tool_call_id = tool_call.get('id', '') + tool_function_name = tool_call.get('function', {}).get('name', '') + tool_args = tool_call.get('function', {}).get('arguments', '{}') tool_function_params = {} if tool_args and tool_args.strip(): @@ -4280,24 +3950,18 @@ async def streaming_chat_response_handler(response, ctx): try: tool_function_params = json.loads(tool_args) except Exception as e: - log.error( - f"Error parsing tool call arguments: {tool_args}" - ) + log.error(f'Error parsing tool call arguments: {tool_args}') results.append( { - "tool_call_id": tool_call_id, - "content": f"Error: Tool call arguments could not be parsed. The model generated malformed or incomplete JSON for `{tool_function_name}`. Please try again.", + 'tool_call_id': tool_call_id, + 'content': f'Error: Tool call arguments could not be parsed. The model generated malformed or incomplete JSON for `{tool_function_name}`. Please try again.', } ) continue # Ensure arguments are valid JSON for downstream LLM integrations - log.debug( - f"Parsed args from {tool_args} to {tool_function_params}" - ) - tool_call.setdefault("function", {})["arguments"] = json.dumps( - tool_function_params - ) + log.debug(f'Parsed args from {tool_args} to {tool_function_params}') + tool_call.setdefault('function', {})['arguments'] = json.dumps(tool_function_params) tool_result = None tool = None @@ -4306,68 +3970,54 @@ async def streaming_chat_response_handler(response, ctx): if tool_function_name in tools: tool = tools[tool_function_name] - spec = tool.get("spec", {}) + spec = tool.get('spec', {}) - tool_type = tool.get("type", "") - direct_tool = tool.get("direct", False) + tool_type = tool.get('type', '') + direct_tool = tool.get('direct', False) try: - allowed_params = ( - spec.get("parameters", {}) - .get("properties", {}) - .keys() - ) + allowed_params = spec.get('parameters', {}).get('properties', {}).keys() tool_function_params = { - k: v - for k, v in tool_function_params.items() - if k in allowed_params + k: v for k, v in tool_function_params.items() if k in allowed_params } if direct_tool: tool_result = await event_caller( { - "type": "execute:tool", - "data": { - "id": str(uuid4()), - "name": tool_function_name, - "params": tool_function_params, - "server": tool.get("server", {}), - "session_id": metadata.get( - "session_id", None - ), + 'type': 'execute:tool', + 'data': { + 'id': str(uuid4()), + 'name': tool_function_name, + 'params': tool_function_params, + 'server': tool.get('server', {}), + 'session_id': metadata.get('session_id', None), }, } ) else: tool_function = get_updated_tool_function( - function=tool["callable"], + function=tool['callable'], extra_params={ - "__messages__": form_data.get( - "messages", [] - ), - "__files__": metadata.get("files", []), + '__messages__': form_data.get('messages', []), + '__files__': metadata.get('files', []), }, ) - tool_result = await tool_function( - **tool_function_params - ) + tool_result = await tool_function(**tool_function_params) except Exception as e: tool_result = str(e) - tool_result, tool_result_files, tool_result_embeds = ( - process_tool_result( - request, - tool_function_name, - tool_result, - tool_type, - direct_tool, - metadata, - user, - ) + tool_result, tool_result_files, tool_result_embeds = process_tool_result( + request, + tool_function_name, + tool_result, + tool_type, + direct_tool, + metadata, + user, ) await terminal_event_handler( @@ -4382,10 +4032,10 @@ async def streaming_chat_response_handler(response, ctx): citations_enabled and tool_function_name in [ - "search_web", - "fetch_url", - "view_knowledge_file", - "query_knowledge_files", + 'search_web', + 'fetch_url', + 'view_knowledge_file', + 'query_knowledge_files', ] and tool_result ): @@ -4394,86 +4044,65 @@ async def streaming_chat_response_handler(response, ctx): tool_name=tool_function_name, tool_params=tool_function_params, tool_result=tool_result, - tool_id=tool.get("tool_id", "") if tool else "", + tool_id=tool.get('tool_id', '') if tool else '', ) tool_call_sources.extend(citation_sources) except Exception as e: - log.exception(f"Error extracting citation source: {e}") + log.exception(f'Error extracting citation source: {e}') results.append( { - "tool_call_id": tool_call_id, - "content": str(tool_result) if tool_result else "", - **( - {"files": tool_result_files} - if tool_result_files - else {} - ), - **( - {"embeds": tool_result_embeds} - if tool_result_embeds - else {} - ), + 'tool_call_id': tool_call_id, + 'content': str(tool_result) if tool_result else '', + **({'files': tool_result_files} if tool_result_files else {}), + **({'embeds': tool_result_embeds} if tool_result_embeds else {}), } ) # Update function_call statuses and append function_call_output items for tc in response_tool_calls: - call_id = tc.get("id", "") + call_id = tc.get('id', '') # Mark function_call as completed for item in output: - if ( - item.get("type") == "function_call" - and item.get("call_id") == call_id - ): - item["status"] = "completed" + if item.get('type') == 'function_call' and item.get('call_id') == call_id: + item['status'] = 'completed' # Update arguments with parsed/sanitized version - item["arguments"] = tc.get("function", {}).get( - "arguments", "{}" - ) + item['arguments'] = tc.get('function', {}).get('arguments', '{}') break for result in results: output.append( { - "type": "function_call_output", - "id": output_id("fco"), - "call_id": result.get("tool_call_id", ""), - "output": [ + 'type': 'function_call_output', + 'id': output_id('fco'), + 'call_id': result.get('tool_call_id', ''), + 'output': [ { - "type": "input_text", - "text": result.get("content", ""), + 'type': 'input_text', + 'text': result.get('content', ''), } ], - "status": "completed", - **( - {"files": result.get("files")} - if result.get("files") - else {} - ), - **( - {"embeds": result.get("embeds")} - if result.get("embeds") - else {} - ), + 'status': 'completed', + **({'files': result.get('files')} if result.get('files') else {}), + **({'embeds': result.get('embeds')} if result.get('embeds') else {}), } ) # Append a new empty message item for the next response output.append( { - "type": "message", - "id": output_id("msg"), - "status": "in_progress", - "role": "assistant", - "content": [{"type": "output_text", "text": ""}], + 'type': 'message', + 'id': output_id('msg'), + 'status': 'in_progress', + 'role': 'assistant', + 'content': [{'type': 'output_text', 'text': ''}], } ) # Emit citation sources to the frontend for display if citations_enabled: for source in tool_call_sources: - await event_emitter({"type": "source", "data": source}) + await event_emitter({'type': 'source', 'data': source}) # Apply tool source context to messages for the model. # Restoring to pre-RAG original prevents duplicating @@ -4482,23 +4111,21 @@ async def streaming_chat_response_handler(response, ctx): if all_tool_call_sources and user_message: # Restore pre-RAG message state before re-applying # to prevent RAG template duplication. - original_user_message = ( - metadata.get("user_prompt") or user_message - ) + original_user_message = metadata.get('user_prompt') or user_message set_last_user_message_content( original_user_message, - form_data["messages"], + form_data['messages'], ) replace_system_message_content( - original_system_content or "", - form_data["messages"], + original_system_content or '', + form_data['messages'], ) # Build context: file sources with content, # tool sources as citation markers only. source_ids = {} source_context = get_source_context( - metadata.get("sources", []), source_ids + metadata.get('sources', []), source_ids ) + get_source_context( all_tool_call_sources, source_ids, @@ -4512,27 +4139,25 @@ async def streaming_chat_response_handler(response, ctx): user_message, ) if RAG_SYSTEM_CONTEXT: - form_data["messages"] = ( - add_or_update_system_message( - rag_content, - form_data["messages"], - append=True, - ) + form_data['messages'] = add_or_update_system_message( + rag_content, + form_data['messages'], + append=True, ) else: - form_data["messages"] = add_or_update_user_message( + form_data['messages'] = add_or_update_user_message( rag_content, - form_data["messages"], + form_data['messages'], append=False, ) tool_call_sources.clear() await event_emitter( { - "type": "chat:completion", - "data": { - "content": serialize_output(output), - "output": output, + 'type': 'chat:completion', + 'data': { + 'content': serialize_output(output), + 'output': output, }, } ) @@ -4540,10 +4165,10 @@ async def streaming_chat_response_handler(response, ctx): try: new_form_data = { **form_data, - "model": model_id, - "stream": True, - "messages": [ - *form_data["messages"], + 'model': model_id, + 'stream': True, + 'messages': [ + *form_data['messages'], *convert_output_to_messages(output, raw=True), ], } @@ -4567,30 +4192,25 @@ async def streaming_chat_response_handler(response, ctx): MAX_RETRIES = 5 retries = 0 - while ( - output - and output[-1].get("type") == "open_webui:code_interpreter" - and retries < MAX_RETRIES - ): - + while output and output[-1].get('type') == 'open_webui:code_interpreter' and retries < MAX_RETRIES: await event_emitter( { - "type": "chat:completion", - "data": { - "content": serialize_output(output), - "output": output, + 'type': 'chat:completion', + 'data': { + 'content': serialize_output(output), + 'output': output, }, } ) retries += 1 - log.debug(f"Attempt count: {retries}") + log.debug(f'Attempt count: {retries}') ci_item = output[-1] - ci_output = "" + ci_output = '' try: - if ci_item.get("attributes", {}).get("type") == "code": - code = ci_item.get("code", "") + if ci_item.get('attributes', {}).get('type') == 'code': + code = ci_item.get('code', '') # Sanitize code (strips ANSI codes and markdown fences) code = sanitize_code(code) @@ -4612,61 +4232,48 @@ async def streaming_chat_response_handler(response, ctx): builtins.__import__ = restricted_import """) - code = blocking_code + "\n" + code + code = blocking_code + '\n' + code - if ( - request.app.state.config.CODE_INTERPRETER_ENGINE - == "pyodide" - ): + if request.app.state.config.CODE_INTERPRETER_ENGINE == 'pyodide': ci_output = await event_caller( { - "type": "execute:python", - "data": { - "id": str(uuid4()), - "code": code, - "session_id": metadata.get( - "session_id", None - ), - "files": metadata.get("files", []), + 'type': 'execute:python', + 'data': { + 'id': str(uuid4()), + 'code': code, + 'session_id': metadata.get('session_id', None), + 'files': metadata.get('files', []), }, } ) - elif ( - request.app.state.config.CODE_INTERPRETER_ENGINE - == "jupyter" - ): + elif request.app.state.config.CODE_INTERPRETER_ENGINE == 'jupyter': ci_output = await execute_code_jupyter( request.app.state.config.CODE_INTERPRETER_JUPYTER_URL, code, ( request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH_TOKEN - if request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH - == "token" + if request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH == 'token' else None ), ( request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD - if request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH - == "password" + if request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH == 'password' else None ), request.app.state.config.CODE_INTERPRETER_JUPYTER_TIMEOUT, ) else: - ci_output = { - "stdout": "Code interpreter engine not configured." - } + ci_output = {'stdout': 'Code interpreter engine not configured.'} - log.debug(f"Code interpreter output: {ci_output}") + log.debug(f'Code interpreter output: {ci_output}') if isinstance(ci_output, dict): - stdout = ci_output.get("stdout", "") + stdout = ci_output.get('stdout', '') if isinstance(stdout, str): - stdoutLines = stdout.split("\n") + stdoutLines = stdout.split('\n') for idx, line in enumerate(stdoutLines): - - if "data:image/png;base64" in line: + if 'data:image/png;base64' in line: image_url = get_image_url_from_base64( request, line, @@ -4674,50 +4281,46 @@ async def streaming_chat_response_handler(response, ctx): user, ) if image_url: - stdoutLines[idx] = ( - f"![Output Image]({image_url})" - ) + stdoutLines[idx] = f'![Output Image]({image_url})' - ci_output["stdout"] = "\n".join(stdoutLines) + ci_output['stdout'] = '\n'.join(stdoutLines) - result = ci_output.get("result", "") + result = ci_output.get('result', '') if isinstance(result, str): - resultLines = result.split("\n") + resultLines = result.split('\n') for idx, line in enumerate(resultLines): - if "data:image/png;base64" in line: + if 'data:image/png;base64' in line: image_url = get_image_url_from_base64( request, line, metadata, user, ) - resultLines[idx] = ( - f"![Output Image]({image_url})" - ) - ci_output["result"] = "\n".join(resultLines) + resultLines[idx] = f'![Output Image]({image_url})' + ci_output['result'] = '\n'.join(resultLines) except Exception as e: ci_output = str(e) - ci_item["output"] = ci_output - ci_item["status"] = "completed" + ci_item['output'] = ci_output + ci_item['status'] = 'completed' output.append( { - "type": "message", - "id": output_id("msg"), - "status": "in_progress", - "role": "assistant", - "content": [{"type": "output_text", "text": ""}], + 'type': 'message', + 'id': output_id('msg'), + 'status': 'in_progress', + 'role': 'assistant', + 'content': [{'type': 'output_text', 'text': ''}], } ) await event_emitter( { - "type": "chat:completion", - "data": { - "content": serialize_output(output), - "output": output, + 'type': 'chat:completion', + 'data': { + 'content': serialize_output(output), + 'output': output, }, } ) @@ -4725,10 +4328,10 @@ async def streaming_chat_response_handler(response, ctx): try: new_form_data = { **form_data, - "model": model_id, - "stream": True, - "messages": [ - *form_data["messages"], + 'model': model_id, + 'stream': True, + 'messages': [ + *form_data['messages'], *convert_output_to_messages(output, raw=True), ], } @@ -4750,33 +4353,33 @@ async def streaming_chat_response_handler(response, ctx): # Mark all in-progress items as completed for item in output: - if item.get("status") == "in_progress": - item["status"] = "completed" + if item.get('status') == 'in_progress': + item['status'] = 'completed' - title = Chats.get_chat_title_by_id(metadata["chat_id"]) + title = Chats.get_chat_title_by_id(metadata['chat_id']) data = { - "done": True, - "content": serialize_output(output), - "output": output, - "title": title, + 'done': True, + 'content': serialize_output(output), + 'output': output, + 'title': title, } if not ENABLE_REALTIME_CHAT_SAVE: # Save message in the database Chats.upsert_message_to_chat_by_id_and_message_id( - metadata["chat_id"], - metadata["message_id"], + metadata['chat_id'], + metadata['message_id'], { - "content": serialize_output(output), - "output": output, - **({"usage": usage} if usage else {}), + 'content': serialize_output(output), + 'output': output, + **({'usage': usage} if usage else {}), }, ) elif usage: Chats.upsert_message_to_chat_by_id_and_message_id( - metadata["chat_id"], - metadata["message_id"], - {"usage": usage}, + metadata['chat_id'], + metadata['message_id'], + {'usage': usage}, ) # Send a webhook notification if the user is not active @@ -4786,35 +4389,35 @@ async def streaming_chat_response_handler(response, ctx): await post_webhook( request.app.state.WEBUI_NAME, webhook_url, - f"{title} - {request.app.state.config.WEBUI_URL}/c/{metadata['chat_id']}\n\n{content}", + f'{title} - {request.app.state.config.WEBUI_URL}/c/{metadata["chat_id"]}\n\n{content}', { - "action": "chat", - "message": content, - "title": title, - "url": f"{request.app.state.config.WEBUI_URL}/c/{metadata['chat_id']}", + 'action': 'chat', + 'message': content, + 'title': title, + 'url': f'{request.app.state.config.WEBUI_URL}/c/{metadata["chat_id"]}', }, ) await event_emitter( { - "type": "chat:completion", - "data": data, + 'type': 'chat:completion', + 'data': data, } ) await background_tasks_handler(ctx) except asyncio.CancelledError: - log.warning("Task was cancelled!") - await event_emitter({"type": "chat:tasks:cancel"}) + log.warning('Task was cancelled!') + await event_emitter({'type': 'chat:tasks:cancel'}) if not ENABLE_REALTIME_CHAT_SAVE: # Save message in the database Chats.upsert_message_to_chat_by_id_and_message_id( - metadata["chat_id"], - metadata["message_id"], + metadata['chat_id'], + metadata['message_id'], { - "content": serialize_output(output), - "output": output, + 'content': serialize_output(output), + 'output': output, }, ) @@ -4827,13 +4430,13 @@ async def streaming_chat_response_handler(response, ctx): # Fallback to the original response async def stream_wrapper(original_generator, events): def wrap_item(item): - return f"data: {item}\n\n" + return f'data: {item}\n\n' for event in events: event, _ = await process_filter_functions( request=request, filter_functions=filter_functions, - filter_type="stream", + filter_type='stream', form_data=event, extra_params=extra_params, ) @@ -4845,7 +4448,7 @@ async def streaming_chat_response_handler(response, ctx): data, _ = await process_filter_functions( request=request, filter_functions=filter_functions, - filter_type="stream", + filter_type='stream', form_data=data, extra_params=extra_params, ) @@ -4867,8 +4470,8 @@ async def process_chat_response(response, ctx): # Non standard response if not any( - content_type in response.headers["Content-Type"] - for content_type in ["text/event-stream", "application/x-ndjson"] + content_type in response.headers['Content-Type'] + for content_type in ['text/event-stream', 'application/x-ndjson'] ): return response diff --git a/backend/open_webui/utils/misc.py b/backend/open_webui/utils/misc.py index 809a982e6b..dec97abd25 100644 --- a/backend/open_webui/utils/misc.py +++ b/backend/open_webui/utils/misc.py @@ -33,7 +33,7 @@ def get_allow_block_lists(filter_list): if filter_list: for d in filter_list: - if d.startswith("!"): + if d.startswith('!'): # Domains starting with "!" → blocked block_list.append(d[1:].strip()) else: @@ -43,9 +43,7 @@ def get_allow_block_lists(filter_list): return allow_list, block_list -def is_string_allowed( - string: Union[str, Sequence[str]], filter_list: Optional[list[str]] = None -) -> bool: +def is_string_allowed(string: Union[str, Sequence[str]], filter_list: Optional[list[str]] = None) -> bool: """ Checks if a string is allowed based on the provided filter list. :param string: The string or sequence of strings to check (e.g., domain or hostname). @@ -94,7 +92,7 @@ def get_message_list(messages_map, message_id): visited_message_ids = set() while current_message: - message_id = current_message.get("id") + message_id = current_message.get('id') if message_id in visited_message_ids: # Cycle detected, break to prevent infinite loop break @@ -103,7 +101,7 @@ def get_message_list(messages_map, message_id): visited_message_ids.add(message_id) message_list.append(current_message) - parent_id = current_message.get("parentId") # Use .get() for safety + parent_id = current_message.get('parentId') # Use .get() for safety current_message = messages_map.get(parent_id) if parent_id else None message_list.reverse() @@ -111,28 +109,23 @@ def get_message_list(messages_map, message_id): def get_messages_content(messages: list[dict]) -> str: - return "\n".join( - [ - f"{message['role'].upper()}: {get_content_from_message(message)}" - for message in messages - ] - ) + return '\n'.join([f'{message["role"].upper()}: {get_content_from_message(message)}' for message in messages]) def get_last_user_message_item(messages: list[dict]) -> Optional[dict]: for message in reversed(messages): - if message["role"] == "user": + if message['role'] == 'user': return message return None def get_content_from_message(message: dict) -> Optional[str]: - if isinstance(message.get("content"), list): - for item in message["content"]: - if item["type"] == "text": - return item["text"] + if isinstance(message.get('content'), list): + for item in message['content']: + if item['type'] == 'text': + return item['text'] else: - return message.get("content") + return message.get('content') return None @@ -161,111 +154,101 @@ def convert_output_to_messages(output: list, raw: bool = False) -> list[dict]: if pending_content or pending_tool_calls: messages.append( { - "role": "assistant", - "content": "\n".join(pending_content) if pending_content else "", - **( - {"tool_calls": pending_tool_calls} if pending_tool_calls else {} - ), + 'role': 'assistant', + 'content': '\n'.join(pending_content) if pending_content else '', + **({'tool_calls': pending_tool_calls} if pending_tool_calls else {}), } ) pending_content = [] pending_tool_calls = [] for item in output: - item_type = item.get("type", "") + item_type = item.get('type', '') - if item_type == "message": + if item_type == 'message': # Extract text from output_text content parts - content_parts = item.get("content", []) - text = "" + content_parts = item.get('content', []) + text = '' for part in content_parts: - if part.get("type") == "output_text": - text += part.get("text", "") + if part.get('type') == 'output_text': + text += part.get('text', '') if text: pending_content.append(text) - elif item_type == "function_call": + elif item_type == 'function_call': # Collect tool calls to batch into assistant message - arguments = item.get("arguments", "{}") + arguments = item.get('arguments', '{}') # Ensure arguments is always a JSON string if not isinstance(arguments, str): arguments = json.dumps(arguments) pending_tool_calls.append( { - "id": item.get("call_id", ""), - "type": "function", - "function": { - "name": item.get("name", ""), - "arguments": arguments, + 'id': item.get('call_id', ''), + 'type': 'function', + 'function': { + 'name': item.get('name', ''), + 'arguments': arguments, }, } ) - elif item_type == "function_call_output": + elif item_type == 'function_call_output': # Flush any pending content/tool_calls before adding tool result flush_pending() # Extract text from output content parts - output_parts = item.get("output", []) - content = "" + output_parts = item.get('output', []) + content = '' for part in output_parts: - if part.get("type") == "input_text": - output_text = part.get("text", "") - content += ( - str(output_text) - if not isinstance(output_text, str) - else output_text - ) + if part.get('type') == 'input_text': + output_text = part.get('text', '') + content += str(output_text) if not isinstance(output_text, str) else output_text messages.append( { - "role": "tool", - "tool_call_id": item.get("call_id", ""), - "content": content, + 'role': 'tool', + 'tool_call_id': item.get('call_id', ''), + 'content': content, } ) - elif item_type == "reasoning": + elif item_type == 'reasoning': if raw: # Include reasoning with original tags for LLM re-processing - reasoning_text = "" - source_list = item.get("summary", []) or item.get("content", []) + reasoning_text = '' + source_list = item.get('summary', []) or item.get('content', []) for part in source_list: - if part.get("type") == "output_text": - reasoning_text += part.get("text", "") - elif "text" in part: - reasoning_text += part.get("text", "") + if part.get('type') == 'output_text': + reasoning_text += part.get('text', '') + elif 'text' in part: + reasoning_text += part.get('text', '') if reasoning_text: - start_tag = item.get("start_tag", "") - end_tag = item.get("end_tag", "") - pending_content.append(f"{start_tag}{reasoning_text}{end_tag}") + start_tag = item.get('start_tag', '') + end_tag = item.get('end_tag', '') + pending_content.append(f'{start_tag}{reasoning_text}{end_tag}') # else: skip reasoning blocks for normal LLM messages - elif item_type == "open_webui:code_interpreter": + elif item_type == 'open_webui:code_interpreter': # Always include code interpreter content so the LLM knows # the code was already executed and doesn't retry. - code = item.get("code", "") - code_output = item.get("output", "") + code = item.get('code', '') + code_output = item.get('output', '') if code: - pending_content.append( - f"\n{code}\n" - ) + pending_content.append(f'\n{code}\n') if code_output: if isinstance(code_output, dict): - stdout = code_output.get("stdout", "") - result = code_output.get("result", "") + stdout = code_output.get('stdout', '') + result = code_output.get('result', '') output_text = stdout or result else: output_text = str(code_output) if output_text: - pending_content.append( - f"\n{output_text}\n" - ) + pending_content.append(f'\n{output_text}\n') - elif item_type.startswith("open_webui:"): + elif item_type.startswith('open_webui:'): # Skip other extension types pass @@ -288,41 +271,41 @@ def set_last_user_message_content(content: str, messages: list[dict]) -> list[di Handles both plain-string and list-of-parts content formats. """ for message in reversed(messages): - if message.get("role") == "user": - if isinstance(message.get("content"), list): - for item in message["content"]: - if item.get("type") == "text": - item["text"] = content + if message.get('role') == 'user': + if isinstance(message.get('content'), list): + for item in message['content']: + if item.get('type') == 'text': + item['text'] = content break else: - message["content"] = content + message['content'] = content break return messages def get_last_assistant_message_item(messages: list[dict]) -> Optional[dict]: for message in reversed(messages): - if message["role"] == "assistant": + if message['role'] == 'assistant': return message return None def get_last_assistant_message(messages: list[dict]) -> Optional[str]: for message in reversed(messages): - if message["role"] == "assistant": + if message['role'] == 'assistant': return get_content_from_message(message) return None def get_system_message(messages: list[dict]) -> Optional[dict]: for message in messages: - if message["role"] == "system": + if message['role'] == 'system': return message return None def remove_system_message(messages: list[dict]) -> list[dict]: - return [message for message in messages if message["role"] != "system"] + return [message for message in messages if message['role'] != 'system'] def pop_system_message(messages: list[dict]) -> tuple[Optional[dict], list[dict]]: @@ -330,32 +313,30 @@ def pop_system_message(messages: list[dict]) -> tuple[Optional[dict], list[dict] def update_message_content(message: dict, content: str, append: bool = True) -> dict: - if isinstance(message["content"], list): - for item in message["content"]: - if item["type"] == "text": + if isinstance(message['content'], list): + for item in message['content']: + if item['type'] == 'text': if append: - item["text"] = f"{item['text']}\n{content}" + item['text'] = f'{item["text"]}\n{content}' else: - item["text"] = f"{content}\n{item['text']}" + item['text'] = f'{content}\n{item["text"]}' else: if append: - message["content"] = f"{message['content']}\n{content}" + message['content'] = f'{message["content"]}\n{content}' else: - message["content"] = f"{content}\n{message['content']}" + message['content'] = f'{content}\n{message["content"]}' return message def replace_system_message_content(content: str, messages: list[dict]) -> dict: for message in messages: - if message["role"] == "system": - message["content"] = content + if message['role'] == 'system': + message['content'] = content break return messages -def add_or_update_system_message( - content: str, messages: list[dict], append: bool = False -): +def add_or_update_system_message(content: str, messages: list[dict], append: bool = False): """ Adds a new system message at the beginning of the messages list or updates the existing system message at the beginning. @@ -365,11 +346,11 @@ def add_or_update_system_message( :return: The updated list of message dictionaries. """ - if messages and messages[0].get("role") == "system": + if messages and messages[0].get('role') == 'system': messages[0] = update_message_content(messages[0], content, append) else: # Insert at the beginning - messages.insert(0, {"role": "system", "content": content}) + messages.insert(0, {'role': 'system', 'content': content}) return messages @@ -384,20 +365,18 @@ def add_or_update_user_message(content: str, messages: list[dict], append: bool :return: The updated list of message dictionaries. """ - if messages and messages[-1].get("role") == "user": + if messages and messages[-1].get('role') == 'user': messages[-1] = update_message_content(messages[-1], content, append) else: # Insert at the end - messages.append({"role": "user", "content": content}) + messages.append({'role': 'user', 'content': content}) return messages -def prepend_to_first_user_message_content( - content: str, messages: list[dict] -) -> list[dict]: +def prepend_to_first_user_message_content(content: str, messages: list[dict]) -> list[dict]: for message in messages: - if message["role"] == "user": + if message['role'] == 'user': message = update_message_content(message, content, append=False) break return messages @@ -413,21 +392,21 @@ def append_or_update_assistant_message(content: str, messages: list[dict]): :return: The updated list of message dictionaries. """ - if messages and messages[-1].get("role") == "assistant": - messages[-1]["content"] = f"{messages[-1]['content']}\n{content}" + if messages and messages[-1].get('role') == 'assistant': + messages[-1]['content'] = f'{messages[-1]["content"]}\n{content}' else: # Insert at the end - messages.append({"role": "assistant", "content": content}) + messages.append({'role': 'assistant', 'content': content}) return messages def openai_chat_message_template(model: str): return { - "id": f"{model}-{str(uuid.uuid4())}", - "created": int(time.time()), - "model": model, - "choices": [{"index": 0, "logprobs": None, "finish_reason": None}], + 'id': f'{model}-{str(uuid.uuid4())}', + 'created': int(time.time()), + 'model': model, + 'choices': [{'index': 0, 'logprobs': None, 'finish_reason': None}], } @@ -439,25 +418,25 @@ def openai_chat_chunk_message_template( usage: Optional[dict] = None, ) -> dict: template = openai_chat_message_template(model) - template["object"] = "chat.completion.chunk" + template['object'] = 'chat.completion.chunk' - template["choices"][0]["index"] = 0 - template["choices"][0]["delta"] = {} + template['choices'][0]['index'] = 0 + template['choices'][0]['delta'] = {} if content: - template["choices"][0]["delta"]["content"] = content + template['choices'][0]['delta']['content'] = content if reasoning_content: - template["choices"][0]["delta"]["reasoning_content"] = reasoning_content + template['choices'][0]['delta']['reasoning_content'] = reasoning_content if tool_calls: - template["choices"][0]["delta"]["tool_calls"] = tool_calls + template['choices'][0]['delta']['tool_calls'] = tool_calls if not content and not reasoning_content and not tool_calls: - template["choices"][0]["finish_reason"] = "stop" + template['choices'][0]['finish_reason'] = 'stop' if usage: - template["usage"] = usage + template['usage'] = usage return template @@ -469,19 +448,19 @@ def openai_chat_completion_message_template( usage: Optional[dict] = None, ) -> dict: template = openai_chat_message_template(model) - template["object"] = "chat.completion" + template['object'] = 'chat.completion' if message is not None: - template["choices"][0]["message"] = { - "role": "assistant", - "content": message, - **({"reasoning_content": reasoning_content} if reasoning_content else {}), - **({"tool_calls": tool_calls} if tool_calls else {}), + template['choices'][0]['message'] = { + 'role': 'assistant', + 'content': message, + **({'reasoning_content': reasoning_content} if reasoning_content else {}), + **({'tool_calls': tool_calls} if tool_calls else {}), } - template["choices"][0]["finish_reason"] = "tool_calls" if tool_calls else "stop" + template['choices'][0]['finish_reason'] = 'tool_calls' if tool_calls else 'stop' if usage: - template["usage"] = usage + template['usage'] = usage return template @@ -496,13 +475,13 @@ def get_gravatar_url(email): hash_hex = hash_object.hexdigest() # Grab the actual image URL - return f"https://www.gravatar.com/avatar/{hash_hex}?d=mp" + return f'https://www.gravatar.com/avatar/{hash_hex}?d=mp' def calculate_sha256(file_path, chunk_size): # Compute SHA-256 hash of a file efficiently in chunks sha256 = hashlib.sha256() - with open(file_path, "rb") as f: + with open(file_path, 'rb') as f: while chunk := f.read(chunk_size): sha256.update(chunk) return sha256.hexdigest() @@ -512,17 +491,17 @@ def calculate_sha256_string(string): # Create a new SHA-256 hash object sha256_hash = hashlib.sha256() # Update the hash object with the bytes of the input string - sha256_hash.update(string.encode("utf-8")) + sha256_hash.update(string.encode('utf-8')) # Get the hexadecimal representation of the hash hashed_string = sha256_hash.hexdigest() return hashed_string def validate_email_format(email: str) -> bool: - if email.endswith("@localhost"): + if email.endswith('@localhost'): return True - return bool(re.match(r"[^@]+@[^@]+\.[^@]+", email)) + return bool(re.match(r'[^@]+@[^@]+\.[^@]+', email)) def sanitize_filename(file_name): @@ -530,10 +509,10 @@ def sanitize_filename(file_name): lower_case_file_name = file_name.lower() # Remove special characters using regular expression - sanitized_file_name = re.sub(r"[^\w\s]", "", lower_case_file_name) + sanitized_file_name = re.sub(r'[^\w\s]', '', lower_case_file_name) # Replace spaces with dashes - final_file_name = re.sub(r"\s+", "-", sanitized_file_name) + final_file_name = re.sub(r'\s+', '-', sanitized_file_name) return final_file_name @@ -543,13 +522,11 @@ def sanitize_text_for_db(text: str) -> str: if not isinstance(text, str): return text # Remove null bytes - text = text.replace("\x00", "").replace("\u0000", "") + text = text.replace('\x00', '').replace('\u0000', '') # Remove invalid UTF-8 surrogate characters that can cause encoding errors # This handles cases where binary data or encoding issues introduced surrogates try: - text = text.encode("utf-8", errors="surrogatepass").decode( - "utf-8", errors="ignore" - ) + text = text.encode('utf-8', errors='surrogatepass').decode('utf-8', errors='ignore') except (UnicodeEncodeError, UnicodeDecodeError): pass return text @@ -582,15 +559,9 @@ def sanitize_metadata(metadata: dict) -> dict: if isinstance(obj, (str, int, float, bool, type(None))): return obj if isinstance(obj, dict): - return { - k: _sanitize(v) - for k, v in obj.items() - if not callable(v) and _is_serializable(v) - } + return {k: _sanitize(v) for k, v in obj.items() if not callable(v) and _is_serializable(v)} if isinstance(obj, list): - return [ - _sanitize(v) for v in obj if not callable(v) and _is_serializable(v) - ] + return [_sanitize(v) for v in obj if not callable(v) and _is_serializable(v)] if callable(obj): return None # Last resort: try to see if it's serializable @@ -622,8 +593,8 @@ def extract_folders_after_data_docs(path): # Find the index of '/data/docs' in the path try: - index_data_docs = parts.index("data") + 1 - index_docs = parts.index("docs", index_data_docs) + 1 + index_data_docs = parts.index('data') + 1 + index_docs = parts.index('docs', index_data_docs) + 1 except ValueError: return [] @@ -632,37 +603,37 @@ def extract_folders_after_data_docs(path): folders = parts[index_docs:-1] for idx, _ in enumerate(folders): - tags.append("/".join(folders[: idx + 1])) + tags.append('/'.join(folders[: idx + 1])) return tags def parse_duration(duration: str) -> Optional[timedelta]: - if duration == "-1" or duration == "0": + if duration == '-1' or duration == '0': return None # Regular expression to find number and unit pairs - pattern = r"(-?\d+(\.\d+)?)(ms|s|m|h|d|w)" + pattern = r'(-?\d+(\.\d+)?)(ms|s|m|h|d|w)' matches = re.findall(pattern, duration) if not matches: - raise ValueError("Invalid duration string") + raise ValueError('Invalid duration string') total_duration = timedelta() for number, _, unit in matches: number = float(number) - if unit == "ms": + if unit == 'ms': total_duration += timedelta(milliseconds=number) - elif unit == "s": + elif unit == 's': total_duration += timedelta(seconds=number) - elif unit == "m": + elif unit == 'm': total_duration += timedelta(minutes=number) - elif unit == "h": + elif unit == 'h': total_duration += timedelta(hours=number) - elif unit == "d": + elif unit == 'd': total_duration += timedelta(days=number) - elif unit == "w": + elif unit == 'w': total_duration += timedelta(weeks=number) return total_duration @@ -670,52 +641,48 @@ def parse_duration(duration: str) -> Optional[timedelta]: def parse_ollama_modelfile(model_text): parameters_meta = { - "mirostat": int, - "mirostat_eta": float, - "mirostat_tau": float, - "num_ctx": int, - "repeat_last_n": int, - "repeat_penalty": float, - "temperature": float, - "seed": int, - "tfs_z": float, - "num_predict": int, - "top_k": int, - "top_p": float, - "num_keep": int, - "presence_penalty": float, - "frequency_penalty": float, - "num_batch": int, - "num_gpu": int, - "use_mmap": bool, - "use_mlock": bool, - "num_thread": int, + 'mirostat': int, + 'mirostat_eta': float, + 'mirostat_tau': float, + 'num_ctx': int, + 'repeat_last_n': int, + 'repeat_penalty': float, + 'temperature': float, + 'seed': int, + 'tfs_z': float, + 'num_predict': int, + 'top_k': int, + 'top_p': float, + 'num_keep': int, + 'presence_penalty': float, + 'frequency_penalty': float, + 'num_batch': int, + 'num_gpu': int, + 'use_mmap': bool, + 'use_mlock': bool, + 'num_thread': int, } - data = {"base_model_id": None, "params": {}} + data = {'base_model_id': None, 'params': {}} # Parse base model - base_model_match = re.search( - r"^FROM\s+(\w+)", model_text, re.MULTILINE | re.IGNORECASE - ) + base_model_match = re.search(r'^FROM\s+(\w+)', model_text, re.MULTILINE | re.IGNORECASE) if base_model_match: - data["base_model_id"] = base_model_match.group(1) + data['base_model_id'] = base_model_match.group(1) # Parse template - template_match = re.search( - r'TEMPLATE\s+"""(.+?)"""', model_text, re.DOTALL | re.IGNORECASE - ) + template_match = re.search(r'TEMPLATE\s+"""(.+?)"""', model_text, re.DOTALL | re.IGNORECASE) if template_match: - data["params"] = {"template": template_match.group(1).strip()} + data['params'] = {'template': template_match.group(1).strip()} # Parse stops stops = re.findall(r'PARAMETER stop "(.*?)"', model_text, re.IGNORECASE) if stops: - data["params"]["stop"] = stops + data['params']['stop'] = stops # Parse other parameters from the provided list for param, param_type in parameters_meta.items(): - param_match = re.search(rf"PARAMETER {param} (.+)", model_text, re.IGNORECASE) + param_match = re.search(rf'PARAMETER {param} (.+)', model_text, re.IGNORECASE) if param_match: value = param_match.group(1) @@ -725,39 +692,35 @@ def parse_ollama_modelfile(model_text): elif param_type is float: value = float(value) elif param_type is bool: - value = value.lower() == "true" + value = value.lower() == 'true' except Exception as e: - log.exception(f"Failed to parse parameter {param}: {e}") + log.exception(f'Failed to parse parameter {param}: {e}') continue - data["params"][param] = value + data['params'][param] = value # Parse adapter - adapter_match = re.search(r"ADAPTER (.+)", model_text, re.IGNORECASE) + adapter_match = re.search(r'ADAPTER (.+)', model_text, re.IGNORECASE) if adapter_match: - data["params"]["adapter"] = adapter_match.group(1) + data['params']['adapter'] = adapter_match.group(1) # Parse system description - system_desc_match = re.search( - r'SYSTEM\s+"""(.+?)"""', model_text, re.DOTALL | re.IGNORECASE - ) - system_desc_match_single = re.search( - r"SYSTEM\s+([^\n]+)", model_text, re.IGNORECASE - ) + system_desc_match = re.search(r'SYSTEM\s+"""(.+?)"""', model_text, re.DOTALL | re.IGNORECASE) + system_desc_match_single = re.search(r'SYSTEM\s+([^\n]+)', model_text, re.IGNORECASE) if system_desc_match: - data["params"]["system"] = system_desc_match.group(1).strip() + data['params']['system'] = system_desc_match.group(1).strip() elif system_desc_match_single: - data["params"]["system"] = system_desc_match_single.group(1).strip() + data['params']['system'] = system_desc_match_single.group(1).strip() # Parse messages messages = [] - message_matches = re.findall(r"MESSAGE (\w+) (.+)", model_text, re.IGNORECASE) + message_matches = re.findall(r'MESSAGE (\w+) (.+)', model_text, re.IGNORECASE) for role, content in message_matches: - messages.append({"role": role, "content": content}) + messages.append({'role': role, 'content': content}) if messages: - data["params"]["messages"] = messages + data['params']['messages'] = messages return data @@ -769,10 +732,10 @@ def convert_logit_bias_input_to_json(logit_bias_input) -> Optional[str]: if isinstance(logit_bias_input, dict): return json.dumps(logit_bias_input) - logit_bias_pairs = logit_bias_input.split(",") + logit_bias_pairs = logit_bias_input.split(',') logit_bias_json = {} for pair in logit_bias_pairs: - token, bias = pair.split(":") + token, bias = pair.split(':') token = str(token.strip()) bias = int(bias.strip()) bias = 100 if bias > 100 else -100 if bias < -100 else bias @@ -834,13 +797,13 @@ def strict_match_mime_type(supported: list[str] | str, header: str) -> Optional[ try: if isinstance(supported, str): - supported = supported.split(",") + supported = supported.split(',') - supported = [s for s in supported if s.strip() and "/" in s] + supported = [s for s in supported if s.strip() and '/' in s] if len(supported) == 0: # Default to common types if none are specified - supported = ["audio/*", "video/webm"] + supported = ['audio/*', 'video/webm'] match = mimeparse.best_match(supported, header) if not match: @@ -854,15 +817,13 @@ def strict_match_mime_type(supported: list[str] | str, header: str) -> Optional[ return match except Exception as e: - log.exception(f"Failed to match mime type {header}: {e}") + log.exception(f'Failed to match mime type {header}: {e}') return None def extract_urls(text: str) -> list[str]: # Regex pattern to match URLs - url_pattern = re.compile( - r"(https?://[^\s]+)", re.IGNORECASE - ) # Matches http and https URLs + url_pattern = re.compile(r'(https?://[^\s]+)', re.IGNORECASE) # Matches http and https URLs return url_pattern.findall(text) @@ -882,9 +843,7 @@ async def stream_wrapper(response, session, content_handler=None): This is more reliable than BackgroundTask which may not run if client disconnects. """ try: - stream = ( - content_handler(response.content) if content_handler else response.content - ) + stream = content_handler(response.content) if content_handler else response.content async for chunk in stream: yield chunk finally: @@ -906,7 +865,7 @@ def stream_chunks_handler(stream: aiohttp.StreamReader): return stream async def yield_safe_stream_chunks(): - buffer = b"" + buffer = b'' skip_mode = False async for data, _ in stream.iter_chunks(): @@ -915,9 +874,9 @@ def stream_chunks_handler(stream: aiohttp.StreamReader): # In skip_mode, if buffer already exceeds the limit, clear it (it's part of an oversized line) if skip_mode and len(buffer) > max_buffer_size: - buffer = b"" + buffer = b'' - lines = (buffer + data).split(b"\n") + lines = (buffer + data).split(b'\n') # Process complete lines (except the last possibly incomplete fragment) for i in range(len(lines) - 1): @@ -929,18 +888,18 @@ def stream_chunks_handler(stream: aiohttp.StreamReader): skip_mode = False yield line else: - yield b"data: {}" - yield b"\n" + yield b'data: {}' + yield b'\n' else: # Normal mode: check if line exceeds limit if len(line) > max_buffer_size: skip_mode = True - yield b"data: {}" - yield b"\n" - log.info(f"Skip mode triggered, line size: {len(line)}") + yield b'data: {}' + yield b'\n' + log.info(f'Skip mode triggered, line size: {len(line)}') else: yield line - yield b"\n" + yield b'\n' # Save the last incomplete fragment buffer = lines[-1] @@ -948,13 +907,13 @@ def stream_chunks_handler(stream: aiohttp.StreamReader): # Check if buffer exceeds limit if not skip_mode and len(buffer) > max_buffer_size: skip_mode = True - log.info(f"Skip mode triggered, buffer size: {len(buffer)}") + log.info(f'Skip mode triggered, buffer size: {len(buffer)}') # Clear oversized buffer to prevent unlimited growth - buffer = b"" + buffer = b'' # Process remaining buffer data if buffer and not skip_mode: yield buffer - yield b"\n" + yield b'\n' return yield_safe_stream_chunks() diff --git a/backend/open_webui/utils/models.py b/backend/open_webui/utils/models.py index 108cd0b4b0..e579a3e3e7 100644 --- a/backend/open_webui/utils/models.py +++ b/backend/open_webui/utils/models.py @@ -41,22 +41,22 @@ async def fetch_ollama_models(request: Request, user: UserModel = None): raw_ollama_models = await ollama.get_all_models(request, user=user) return [ { - "id": model["model"], - "name": model["name"], - "object": "model", - "created": int(time.time()), - "owned_by": "ollama", - "ollama": model, - "connection_type": model.get("connection_type", "local"), - "tags": model.get("tags", []), + 'id': model['model'], + 'name': model['name'], + 'object': 'model', + 'created': int(time.time()), + 'owned_by': 'ollama', + 'ollama': model, + 'connection_type': model.get('connection_type', 'local'), + 'tags': model.get('tags', []), } - for model in raw_ollama_models["models"] + for model in raw_ollama_models['models'] ] async def fetch_openai_models(request: Request, user: UserModel = None): openai_response = await openai.get_all_models(request, user=user) - return openai_response["data"] + return openai_response['data'] async def get_all_base_models(request: Request, user: UserModel = None): @@ -72,9 +72,7 @@ async def get_all_base_models(request: Request, user: UserModel = None): ) function_task = get_function_models(request) - openai_models, ollama_models, function_models = await asyncio.gather( - openai_task, ollama_task, function_task - ) + openai_models, ollama_models, function_models = await asyncio.gather(openai_task, ollama_task, function_task) return function_models + openai_models + ollama_models @@ -103,15 +101,15 @@ async def get_all_models(request, refresh: bool = False, user: UserModel = None) if len(request.app.state.config.EVALUATION_ARENA_MODELS) > 0: arena_models = [ { - "id": model["id"], - "name": model["name"], - "info": { - "meta": model["meta"], + 'id': model['id'], + 'name': model['name'], + 'info': { + 'meta': model['meta'], }, - "object": "model", - "created": int(time.time()), - "owned_by": "arena", - "arena": True, + 'object': 'model', + 'created': int(time.time()), + 'owned_by': 'arena', + 'arena': True, } for model in request.app.state.config.EVALUATION_ARENA_MODELS ] @@ -119,45 +117,35 @@ async def get_all_models(request, refresh: bool = False, user: UserModel = None) # Add default arena model arena_models = [ { - "id": DEFAULT_ARENA_MODEL["id"], - "name": DEFAULT_ARENA_MODEL["name"], - "info": { - "meta": DEFAULT_ARENA_MODEL["meta"], + 'id': DEFAULT_ARENA_MODEL['id'], + 'name': DEFAULT_ARENA_MODEL['name'], + 'info': { + 'meta': DEFAULT_ARENA_MODEL['meta'], }, - "object": "model", - "created": int(time.time()), - "owned_by": "arena", - "arena": True, + 'object': 'model', + 'created': int(time.time()), + 'owned_by': 'arena', + 'arena': True, } ] models = models + arena_models - global_action_ids = [ - function.id for function in Functions.get_global_action_functions() - ] - enabled_action_ids = [ - function.id - for function in Functions.get_functions_by_type("action", active_only=True) - ] + global_action_ids = [function.id for function in Functions.get_global_action_functions()] + enabled_action_ids = [function.id for function in Functions.get_functions_by_type('action', active_only=True)] - global_filter_ids = [ - function.id for function in Functions.get_global_filter_functions() - ] - enabled_filter_ids = [ - function.id - for function in Functions.get_functions_by_type("filter", active_only=True) - ] + global_filter_ids = [function.id for function in Functions.get_global_filter_functions()] + enabled_filter_ids = [function.id for function in Functions.get_functions_by_type('filter', active_only=True)] custom_models = Models.get_all_models() # Single O(1) lookup: Ollama base names first, then exact IDs (exact wins). base_model_lookup = {} for model in models: - if model.get("owned_by") == "ollama": - base_model_lookup.setdefault(model["id"].split(":")[0], model) - base_model_lookup[model["id"]] = model + if model.get('owned_by') == 'ollama': + base_model_lookup.setdefault(model['id'].split(':')[0], model) + base_model_lookup[model['id']] = model - existing_ids = {m["id"] for m in models} + existing_ids = {m['id'] for m in models} for custom_model in custom_models: if custom_model.base_model_id is None: @@ -166,26 +154,22 @@ async def get_all_models(request, refresh: bool = False, user: UserModel = None) if model: if custom_model.is_active: - model["name"] = custom_model.name - model["info"] = custom_model.model_dump() + model['name'] = custom_model.name + model['info'] = custom_model.model_dump() action_ids = [] filter_ids = [] - if "info" in model: - if "meta" in model["info"]: - action_ids.extend( - model["info"]["meta"].get("actionIds", []) - ) - filter_ids.extend( - model["info"]["meta"].get("filterIds", []) - ) + if 'info' in model: + if 'meta' in model['info']: + action_ids.extend(model['info']['meta'].get('actionIds', [])) + filter_ids.extend(model['info']['meta'].get('filterIds', [])) - if "params" in model["info"]: - del model["info"]["params"] + if 'params' in model['info']: + del model['info']['params'] - model["action_ids"] = action_ids - model["filter_ids"] = filter_ids + model['action_ids'] = action_ids + model['filter_ids'] = filter_ids else: models.remove(model) @@ -193,38 +177,36 @@ async def get_all_models(request, refresh: bool = False, user: UserModel = None) if custom_model.id in existing_ids: continue - owned_by = "openai" + owned_by = 'openai' connection_type = None pipe = None base_model = base_model_lookup.get(custom_model.base_model_id) if base_model is None: - base_model = base_model_lookup.get( - custom_model.base_model_id.split(":")[0] - ) + base_model = base_model_lookup.get(custom_model.base_model_id.split(':')[0]) if base_model: - owned_by = base_model.get("owned_by", "unknown") - if "pipe" in base_model: - pipe = base_model["pipe"] - connection_type = base_model.get("connection_type", None) + owned_by = base_model.get('owned_by', 'unknown') + if 'pipe' in base_model: + pipe = base_model['pipe'] + connection_type = base_model.get('connection_type', None) model = { - "id": f"{custom_model.id}", - "name": custom_model.name, - "object": "model", - "created": custom_model.created_at, - "owned_by": owned_by, - "connection_type": connection_type, - "preset": True, - **({"pipe": pipe} if pipe is not None else {}), + 'id': f'{custom_model.id}', + 'name': custom_model.name, + 'object': 'model', + 'created': custom_model.created_at, + 'owned_by': owned_by, + 'connection_type': connection_type, + 'preset': True, + **({'pipe': pipe} if pipe is not None else {}), } info = custom_model.model_dump() - if "params" in info: + if 'params' in info: # Remove params to avoid exposing sensitive info - del info["params"] + del info['params'] - model["info"] = info + model['info'] = info action_ids = [] filter_ids = [] @@ -232,32 +214,32 @@ async def get_all_models(request, refresh: bool = False, user: UserModel = None) if custom_model.meta: meta = custom_model.meta.model_dump() - if "actionIds" in meta: - action_ids.extend(meta["actionIds"]) + if 'actionIds' in meta: + action_ids.extend(meta['actionIds']) - if "filterIds" in meta: - filter_ids.extend(meta["filterIds"]) + if 'filterIds' in meta: + filter_ids.extend(meta['filterIds']) - model["action_ids"] = action_ids - model["filter_ids"] = filter_ids + model['action_ids'] = action_ids + model['filter_ids'] = filter_ids models.append(model) # Process action_ids to get the actions def get_action_items_from_module(function, module): actions = [] - if hasattr(module, "actions"): + if hasattr(module, 'actions'): actions = module.actions return [ { - "id": f"{function.id}.{action['id']}", - "name": action.get("name", f"{function.name} ({action['id']})"), - "description": function.meta.description, - "icon": action.get( - "icon_url", - function.meta.manifest.get("icon_url", None) - or getattr(module, "icon_url", None) - or getattr(module, "icon", None), + 'id': f'{function.id}.{action["id"]}', + 'name': action.get('name', f'{function.name} ({action["id"]})'), + 'description': function.meta.description, + 'icon': action.get( + 'icon_url', + function.meta.manifest.get('icon_url', None) + or getattr(module, 'icon_url', None) + or getattr(module, 'icon', None), ), } for action in actions @@ -265,12 +247,12 @@ async def get_all_models(request, refresh: bool = False, user: UserModel = None) else: return [ { - "id": function.id, - "name": function.name, - "description": function.meta.description, - "icon": function.meta.manifest.get("icon_url", None) - or getattr(module, "icon_url", None) - or getattr(module, "icon", None), + 'id': function.id, + 'name': function.name, + 'description': function.meta.description, + 'icon': function.meta.manifest.get('icon_url', None) + or getattr(module, 'icon_url', None) + or getattr(module, 'icon', None), } ] @@ -278,27 +260,25 @@ async def get_all_models(request, refresh: bool = False, user: UserModel = None) def get_filter_items_from_module(function, module): return [ { - "id": function.id, - "name": function.name, - "description": function.meta.description, - "icon": function.meta.manifest.get("icon_url", None) - or getattr(module, "icon_url", None) - or getattr(module, "icon", None), - "has_user_valves": hasattr(module, "UserValves"), + 'id': function.id, + 'name': function.name, + 'description': function.meta.description, + 'icon': function.meta.manifest.get('icon_url', None) + or getattr(module, 'icon_url', None) + or getattr(module, 'icon', None), + 'has_user_valves': hasattr(module, 'UserValves'), } ] # Batch-prefetch all needed function records to avoid N+1 queries all_function_ids = set() for model in models: - all_function_ids.update(model.get("action_ids", [])) - all_function_ids.update(model.get("filter_ids", [])) + all_function_ids.update(model.get('action_ids', [])) + all_function_ids.update(model.get('filter_ids', [])) all_function_ids.update(global_action_ids) all_function_ids.update(global_filter_ids) - functions_by_id = { - f.id: f for f in Functions.get_functions_by_ids(list(all_function_ids)) - } + functions_by_id = {f.id: f for f in Functions.get_functions_by_ids(list(all_function_ids))} # Pre-warm the function module cache once per unique function ID. # This ensures each function's DB freshness check runs exactly once, @@ -307,28 +287,26 @@ async def get_all_models(request, refresh: bool = False, user: UserModel = None) try: get_function_module_from_cache(request, function_id) except Exception as e: - log.info(f"Failed to load function module for {function_id}: {e}") + log.info(f'Failed to load function module for {function_id}: {e}') # Apply global model defaults to all models # Per-model overrides take precedence over global defaults - default_metadata = ( - getattr(request.app.state.config, "DEFAULT_MODEL_METADATA", None) or {} - ) + default_metadata = getattr(request.app.state.config, 'DEFAULT_MODEL_METADATA', None) or {} if default_metadata: for model in models: - info = model.get("info") + info = model.get('info') if info is None: - model["info"] = {"meta": copy.deepcopy(default_metadata)} + model['info'] = {'meta': copy.deepcopy(default_metadata)} continue - meta = info.setdefault("meta", {}) + meta = info.setdefault('meta', {}) for key, value in default_metadata.items(): - if key == "capabilities": + if key == 'capabilities': # Merge capabilities: defaults as base, per-model overrides win - existing = meta.get("capabilities") or {} - meta["capabilities"] = {**value, **existing} + existing = meta.get('capabilities') or {} + meta['capabilities'] = {**value, **existing} elif meta.get(key) is None: meta[key] = copy.deepcopy(value) @@ -339,10 +317,10 @@ async def get_all_models(request, refresh: bool = False, user: UserModel = None) def get_action_priority(action_id): try: function_module = request.app.state.FUNCTIONS.get(action_id) - if function_module and hasattr(function_module, "Valves"): + if function_module and hasattr(function_module, 'Valves'): valves_db = all_function_valves.get(action_id) valves = function_module.Valves(**(valves_db if valves_db else {})) - return getattr(valves, "priority", 0) + return getattr(valves, 'priority', 0) except Exception: pass return 0 @@ -350,51 +328,47 @@ async def get_all_models(request, refresh: bool = False, user: UserModel = None) for model in models: action_ids = [ action_id - for action_id in list(set(model.pop("action_ids", []) + global_action_ids)) + for action_id in list(set(model.pop('action_ids', []) + global_action_ids)) if action_id in enabled_action_ids ] action_ids.sort(key=lambda aid: (get_action_priority(aid), aid)) filter_ids = [ filter_id - for filter_id in list(set(model.pop("filter_ids", []) + global_filter_ids)) + for filter_id in list(set(model.pop('filter_ids', []) + global_filter_ids)) if filter_id in enabled_filter_ids ] - model["actions"] = [] + model['actions'] = [] for action_id in action_ids: action_function = functions_by_id.get(action_id) if action_function is None: - log.info(f"Action not found: {action_id}") + log.info(f'Action not found: {action_id}') continue function_module = request.app.state.FUNCTIONS.get(action_id) if function_module is None: - log.info(f"Failed to load action module: {action_id}") + log.info(f'Failed to load action module: {action_id}') continue - model["actions"].extend( - get_action_items_from_module(action_function, function_module) - ) + model['actions'].extend(get_action_items_from_module(action_function, function_module)) - model["filters"] = [] + model['filters'] = [] for filter_id in filter_ids: filter_function = functions_by_id.get(filter_id) if filter_function is None: - log.info(f"Filter not found: {filter_id}") + log.info(f'Filter not found: {filter_id}') continue function_module = request.app.state.FUNCTIONS.get(filter_id) if function_module is None: - log.info(f"Failed to load filter module: {filter_id}") + log.info(f'Failed to load filter module: {filter_id}') continue - if getattr(function_module, "toggle", None): - model["filters"].extend( - get_filter_items_from_module(filter_function, function_module) - ) + if getattr(function_module, 'toggle', None): + model['filters'].extend(get_filter_items_from_module(filter_function, function_module)) - log.debug(f"get_all_models() returned {len(models)} models") + log.debug(f'get_all_models() returned {len(models)} models') - models_dict = {model["id"]: model for model in models} + models_dict = {model['id']: model for model in models} if isinstance(request.app.state.MODELS, RedisDict): request.app.state.MODELS.set(models_dict) else: @@ -404,81 +378,78 @@ async def get_all_models(request, refresh: bool = False, user: UserModel = None) def check_model_access(user, model, db=None): - if model.get("arena"): - meta = model.get("info", {}).get("meta", {}) - access_grants = meta.get("access_grants", []) + if model.get('arena'): + meta = model.get('info', {}).get('meta', {}) + access_grants = meta.get('access_grants', []) if not has_access( user.id, - permission="read", + permission='read', access_grants=access_grants, db=db, ): - raise Exception("Model not found") + raise Exception('Model not found') else: - model_info = Models.get_model_by_id(model.get("id"), db=db) + model_info = Models.get_model_by_id(model.get('id'), db=db) if not model_info: - raise Exception("Model not found") + raise Exception('Model not found') elif not ( user.id == model_info.user_id or AccessGrants.has_access( user_id=user.id, - resource_type="model", + resource_type='model', resource_id=model_info.id, - permission="read", + permission='read', db=db, ) ): - raise Exception("Model not found") + raise Exception('Model not found') def get_filtered_models(models, user, db=None): # Filter out models that the user does not have access to if ( - user.role == "user" - or (user.role == "admin" and not BYPASS_ADMIN_ACCESS_CONTROL) + user.role == 'user' or (user.role == 'admin' and not BYPASS_ADMIN_ACCESS_CONTROL) ) and not BYPASS_MODEL_ACCESS_CONTROL: model_infos = {} for model in models: - if model.get("arena"): + if model.get('arena'): continue - info = model.get("info") + info = model.get('info') if info: - model_infos[model["id"]] = info + model_infos[model['id']] = info - user_group_ids = { - group.id for group in Groups.get_groups_by_member_id(user.id, db=db) - } + user_group_ids = {group.id for group in Groups.get_groups_by_member_id(user.id, db=db)} # Batch-fetch accessible resource IDs in a single query instead of N has_access calls accessible_model_ids = AccessGrants.get_accessible_resource_ids( user_id=user.id, - resource_type="model", + resource_type='model', resource_ids=list(model_infos.keys()), - permission="read", + permission='read', user_group_ids=user_group_ids, db=db, ) filtered_models = [] for model in models: - if model.get("arena"): - meta = model.get("info", {}).get("meta", {}) - access_grants = meta.get("access_grants", []) + if model.get('arena'): + meta = model.get('info', {}).get('meta', {}) + access_grants = meta.get('access_grants', []) if has_access( user.id, - permission="read", + permission='read', access_grants=access_grants, user_group_ids=user_group_ids, ): filtered_models.append(model) continue - model_info = model_infos.get(model["id"]) + model_info = model_infos.get(model['id']) if model_info: if ( - (user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL) - or user.id == model_info.get("user_id") - or model["id"] in accessible_model_ids + (user.role == 'admin' and BYPASS_ADMIN_ACCESS_CONTROL) + or user.id == model_info.get('user_id') + or model['id'] in accessible_model_ids ): filtered_models.append(model) diff --git a/backend/open_webui/utils/oauth.py b/backend/open_webui/utils/oauth.py index 636ad957d7..202dd42d4a 100644 --- a/backend/open_webui/utils/oauth.py +++ b/backend/open_webui/utils/oauth.py @@ -89,9 +89,7 @@ from authlib.oauth2.rfc6749.errors import OAuth2Error class OAuthClientMetadata(MCPOAuthClientMetadata): - token_endpoint_auth_method: Literal[ - "none", "client_secret_basic", "client_secret_post" - ] = "client_secret_post" + token_endpoint_auth_method: Literal['none', 'client_secret_basic', 'client_secret_post'] = 'client_secret_post' pass @@ -114,9 +112,7 @@ log = logging.getLogger(__name__) auth_manager_config = AppConfig() auth_manager_config.DEFAULT_USER_ROLE = DEFAULT_USER_ROLE auth_manager_config.ENABLE_OAUTH_SIGNUP = ENABLE_OAUTH_SIGNUP -auth_manager_config.OAUTH_REFRESH_TOKEN_INCLUDE_SCOPE = ( - OAUTH_REFRESH_TOKEN_INCLUDE_SCOPE -) +auth_manager_config.OAUTH_REFRESH_TOKEN_INCLUDE_SCOPE = OAUTH_REFRESH_TOKEN_INCLUDE_SCOPE auth_manager_config.OAUTH_MERGE_ACCOUNTS_BY_EMAIL = OAUTH_MERGE_ACCOUNTS_BY_EMAIL auth_manager_config.ENABLE_OAUTH_ROLE_MANAGEMENT = ENABLE_OAUTH_ROLE_MANAGEMENT auth_manager_config.ENABLE_OAUTH_GROUP_MANAGEMENT = ENABLE_OAUTH_GROUP_MANAGEMENT @@ -151,7 +147,7 @@ else: try: FERNET = Fernet(OAUTH_CLIENT_INFO_ENCRYPTION_KEY) except Exception as e: - log.error(f"Error initializing Fernet with provided key: {e}") + log.error(f'Error initializing Fernet with provided key: {e}') raise @@ -162,7 +158,7 @@ def encrypt_data(data) -> str: encrypted = FERNET.encrypt(data_json.encode()).decode() return encrypted except Exception as e: - log.error(f"Error encrypting data: {e}") + log.error(f'Error encrypting data: {e}') raise @@ -172,7 +168,7 @@ def decrypt_data(data: str): decrypted = FERNET.decrypt(data.encode()).decode() return json.loads(decrypted) except Exception as e: - log.error(f"Error decrypting data: {e}") + log.error(f'Error decrypting data: {e}') raise @@ -183,28 +179,28 @@ def _build_oauth_callback_error_message(e: Exception) -> str: """ if isinstance(e, OAuth2Error): parts = [p for p in [e.error, e.description] if p] - detail = " - ".join(parts) + detail = ' - '.join(parts) elif isinstance(e, HTTPException): detail = e.detail if isinstance(e.detail, str) else str(e.detail) elif isinstance(e, aiohttp.ClientResponseError): - detail = f"Upstream provider returned {e.status}: {e.message}" + detail = f'Upstream provider returned {e.status}: {e.message}' elif isinstance(e, aiohttp.ClientError): detail = str(e) elif isinstance(e, KeyError): missing = str(e).strip("'") - if missing.lower() == "state": - detail = "Missing state parameter in callback (session may have expired)" + if missing.lower() == 'state': + detail = 'Missing state parameter in callback (session may have expired)' else: detail = f"Missing expected key '{missing}' in OAuth response" else: detail = str(e) - detail = detail.replace("\n", " ").strip() + detail = detail.replace('\n', ' ').strip() if not detail: detail = e.__class__.__name__ - message = f"OAuth callback failed: {detail}" - return message[:197] + "..." if len(message) > 200 else message + message = f'OAuth callback failed: {detail}' + return message[:197] + '...' if len(message) > 200 else message def is_in_blocked_groups(group_name: str, groups: list) -> bool: @@ -231,10 +227,7 @@ def is_in_blocked_groups(group_name: str, groups: list) -> bool: return True # Try as regex pattern first if it contains regex-specific characters - if any( - char in group_pattern - for char in ["^", "$", "[", "]", "(", ")", "{", "}", "+", "\\", "|"] - ): + if any(char in group_pattern for char in ['^', '$', '[', ']', '(', ')', '{', '}', '+', '\\', '|']): try: # Use the original pattern as-is for regex matching if re.search(group_pattern, group_name): @@ -244,7 +237,7 @@ def is_in_blocked_groups(group_name: str, groups: list) -> bool: pass # Shell-style wildcard match (supports * and ?) - if "*" in group_pattern or "?" in group_pattern: + if '*' in group_pattern or '?' in group_pattern: if fnmatch.fnmatch(group_name, group_pattern): return True @@ -253,7 +246,7 @@ def is_in_blocked_groups(group_name: str, groups: list) -> bool: def get_parsed_and_base_url(server_url) -> tuple[urllib.parse.ParseResult, str]: parsed = urllib.parse.urlparse(server_url) - base_url = f"{parsed.scheme}://{parsed.netloc}" + base_url = f'{parsed.scheme}://{parsed.netloc}' return parsed, base_url @@ -267,20 +260,18 @@ async def get_authorization_server_discovery_urls(server_url: str) -> list[str]: async with aiohttp.ClientSession(trust_env=True) as session: async with session.post( server_url, - json={"jsonrpc": "2.0", "method": "initialize", "params": {}, "id": 1}, - headers={"Content-Type": "application/json"}, + json={'jsonrpc': '2.0', 'method': 'initialize', 'params': {}, 'id': 1}, + headers={'Content-Type': 'application/json'}, ssl=AIOHTTP_CLIENT_SESSION_SSL, ) as response: if response.status == 401: match = re.search( r'resource_metadata="([^"]+)"', - response.headers.get("WWW-Authenticate", ""), + response.headers.get('WWW-Authenticate', ''), ) if match: resource_metadata_url = match.group(1) - log.debug( - f"Found resource_metadata URL: {resource_metadata_url}" - ) + log.debug(f'Found resource_metadata URL: {resource_metadata_url}') # Step 2: Fetch Protected Resource metadata async with session.get( @@ -290,24 +281,20 @@ async def get_authorization_server_discovery_urls(server_url: str) -> list[str]: resource_metadata = await resource_response.json() # Step 3: Extract authorization_servers - servers = resource_metadata.get( - "authorization_servers", [] - ) + servers = resource_metadata.get('authorization_servers', []) if servers: authorization_servers = servers - log.debug( - f"Discovered authorization servers: {servers}" - ) + log.debug(f'Discovered authorization servers: {servers}') except Exception as e: - log.debug(f"MCP Protected Resource discovery failed: {e}") + log.debug(f'MCP Protected Resource discovery failed: {e}') discovery_urls = [] for auth_server in authorization_servers: - auth_server = auth_server.rstrip("/") + auth_server = auth_server.rstrip('/') discovery_urls.extend( [ - f"{auth_server}/.well-known/oauth-authorization-server", - f"{auth_server}/.well-known/openid-configuration", + f'{auth_server}/.well-known/oauth-authorization-server', + f'{auth_server}/.well-known/openid-configuration', ] ) @@ -318,28 +305,24 @@ async def get_discovery_urls(server_url) -> list[str]: urls = await get_authorization_server_discovery_urls(server_url) parsed, base_url = get_parsed_and_base_url(server_url) - if parsed.path and parsed.path != "/": + if parsed.path and parsed.path != '/': # Generate discovery URLs based on https://modelcontextprotocol.io/specification/draft/basic/authorization#authorization-server-metadata-discovery - tenant = parsed.path.rstrip("/") + tenant = parsed.path.rstrip('/') urls.extend( [ urllib.parse.urljoin( base_url, - f"/.well-known/oauth-authorization-server{tenant}", - ), - urllib.parse.urljoin( - base_url, f"/.well-known/openid-configuration{tenant}" - ), - urllib.parse.urljoin( - base_url, f"{tenant}/.well-known/openid-configuration" + f'/.well-known/oauth-authorization-server{tenant}', ), + urllib.parse.urljoin(base_url, f'/.well-known/openid-configuration{tenant}'), + urllib.parse.urljoin(base_url, f'{tenant}/.well-known/openid-configuration'), ] ) urls.extend( [ - urllib.parse.urljoin(base_url, "/.well-known/oauth-authorization-server"), - urllib.parse.urljoin(base_url, "/.well-known/openid-configuration"), + urllib.parse.urljoin(base_url, '/.well-known/oauth-authorization-server'), + urllib.parse.urljoin(base_url, '/.well-known/openid-configuration'), ] ) @@ -358,24 +341,20 @@ async def get_oauth_client_info_with_dynamic_client_registration( oauth_server_metadata = None oauth_server_metadata_url = None - redirect_base_url = ( - str(request.app.state.config.WEBUI_URL or request.base_url) - ).rstrip("/") + redirect_base_url = (str(request.app.state.config.WEBUI_URL or request.base_url)).rstrip('/') oauth_client_metadata = OAuthClientMetadata( - client_name="Open WebUI", - redirect_uris=[f"{redirect_base_url}/oauth/clients/{client_id}/callback"], - grant_types=["authorization_code", "refresh_token"], - response_types=["code"], + client_name='Open WebUI', + redirect_uris=[f'{redirect_base_url}/oauth/clients/{client_id}/callback'], + grant_types=['authorization_code', 'refresh_token'], + response_types=['code'], ) # Attempt to fetch OAuth server metadata to get registration endpoint & scopes discovery_urls = await get_discovery_urls(oauth_server_url) for url in discovery_urls: async with aiohttp.ClientSession(trust_env=True) as session: - async with session.get( - url, ssl=AIOHTTP_CLIENT_SESSION_SSL - ) as oauth_server_metadata_response: + async with session.get(url, ssl=AIOHTTP_CLIENT_SESSION_SSL) as oauth_server_metadata_response: if oauth_server_metadata_response.status == 200: try: oauth_server_metadata = OAuthMetadata.model_validate( @@ -386,9 +365,7 @@ async def get_oauth_client_info_with_dynamic_client_registration( oauth_client_metadata.scope is None and oauth_server_metadata.scopes_supported is not None ): - oauth_client_metadata.scope = " ".join( - oauth_server_metadata.scopes_supported - ) + oauth_client_metadata.scope = ' '.join(oauth_server_metadata.scopes_supported) if ( oauth_server_metadata.token_endpoint_auth_methods_supported @@ -396,13 +373,13 @@ async def get_oauth_client_info_with_dynamic_client_registration( not in oauth_server_metadata.token_endpoint_auth_methods_supported ): # Pick the first supported method from the server - oauth_client_metadata.token_endpoint_auth_method = oauth_server_metadata.token_endpoint_auth_methods_supported[ - 0 - ] + oauth_client_metadata.token_endpoint_auth_method = ( + oauth_server_metadata.token_endpoint_auth_methods_supported[0] + ) break except Exception as e: - log.error(f"Error parsing OAuth metadata from {url}: {e}") + log.error(f'Error parsing OAuth metadata from {url}: {e}') continue registration_url = None @@ -410,11 +387,11 @@ async def get_oauth_client_info_with_dynamic_client_registration( registration_url = str(oauth_server_metadata.registration_endpoint) else: _, base_url = get_parsed_and_base_url(oauth_server_url) - registration_url = urllib.parse.urljoin(base_url, "/register") + registration_url = urllib.parse.urljoin(base_url, '/register') registration_data = oauth_client_metadata.model_dump( exclude_none=True, - mode="json", + mode='json', by_alias=True, ) @@ -424,25 +401,22 @@ async def get_oauth_client_info_with_dynamic_client_registration( registration_url, json=registration_data, ssl=AIOHTTP_CLIENT_SESSION_SSL ) as oauth_client_registration_response: try: - registration_response_json = ( - await oauth_client_registration_response.json() - ) + registration_response_json = await oauth_client_registration_response.json() # The mcp package requires optional unset values to be None. If an empty string is passed, it gets validated and fails. # This replaces all empty strings with None. registration_response_json = { - k: (None if v == "" else v) - for k, v in registration_response_json.items() + k: (None if v == '' else v) for k, v in registration_response_json.items() } oauth_client_info = OAuthClientInformationFull.model_validate( { **registration_response_json, - **{"issuer": oauth_server_metadata_url}, - **{"server_metadata": oauth_server_metadata}, + **{'issuer': oauth_server_metadata_url}, + **{'server_metadata': oauth_server_metadata}, } ) log.info( - f"Dynamic client registration successful at {registration_url}, client_id: {oauth_client_info.client_id}" + f'Dynamic client registration successful at {registration_url}, client_id: {oauth_client_info.client_id}' ) return oauth_client_info except Exception as e: @@ -450,20 +424,20 @@ async def get_oauth_client_info_with_dynamic_client_registration( try: error_text = await oauth_client_registration_response.text() log.error( - f"Dynamic client registration failed at {registration_url}: {oauth_client_registration_response.status} - {error_text}" + f'Dynamic client registration failed at {registration_url}: {oauth_client_registration_response.status} - {error_text}' ) except Exception as e: pass - log.error(f"Error parsing client registration response: {e}") + log.error(f'Error parsing client registration response: {e}') raise Exception( - f"Dynamic client registration failed: {error_text}" + f'Dynamic client registration failed: {error_text}' if error_text - else "Error parsing client registration response" + else 'Error parsing client registration response' ) - raise Exception("Dynamic client registration failed") + raise Exception('Dynamic client registration failed') except Exception as e: - log.error(f"Exception during dynamic client registration: {e}") + log.error(f'Exception during dynamic client registration: {e}') raise e @@ -475,45 +449,33 @@ class OAuthClientManager: def add_client(self, client_id, oauth_client_info: OAuthClientInformationFull): kwargs = { - "name": client_id, - "client_id": oauth_client_info.client_id, - "client_secret": oauth_client_info.client_secret, - "client_kwargs": { + 'name': client_id, + 'client_id': oauth_client_info.client_id, + 'client_secret': oauth_client_info.client_secret, + 'client_kwargs': { + **({'scope': oauth_client_info.scope} if oauth_client_info.scope else {}), **( - {"scope": oauth_client_info.scope} - if oauth_client_info.scope - else {} - ), - **( - { - "token_endpoint_auth_method": oauth_client_info.token_endpoint_auth_method - } + {'token_endpoint_auth_method': oauth_client_info.token_endpoint_auth_method} if oauth_client_info.token_endpoint_auth_method else {} ), }, - "server_metadata_url": ( - oauth_client_info.issuer if oauth_client_info.issuer else None - ), + 'server_metadata_url': (oauth_client_info.issuer if oauth_client_info.issuer else None), } - if ( - oauth_client_info.server_metadata - and oauth_client_info.server_metadata.code_challenge_methods_supported - ): + if oauth_client_info.server_metadata and oauth_client_info.server_metadata.code_challenge_methods_supported: if ( isinstance( oauth_client_info.server_metadata.code_challenge_methods_supported, list, ) - and "S256" - in oauth_client_info.server_metadata.code_challenge_methods_supported + and 'S256' in oauth_client_info.server_metadata.code_challenge_methods_supported ): - kwargs["code_challenge_method"] = "S256" + kwargs['code_challenge_method'] = 'S256' self.clients[client_id] = { - "client": self.oauth.register(**kwargs), - "client_info": oauth_client_info, + 'client': self.oauth.register(**kwargs), + 'client_info': oauth_client_info, } return self.clients[client_id] @@ -523,40 +485,36 @@ class OAuthClientManager: config if it hasn't been registered on this node yet. """ if client_id in self.clients: - return self.clients[client_id]["client"] + return self.clients[client_id]['client'] try: - connections = getattr(self.app.state.config, "TOOL_SERVER_CONNECTIONS", []) + connections = getattr(self.app.state.config, 'TOOL_SERVER_CONNECTIONS', []) except Exception: connections = [] for connection in connections or []: - if connection.get("type", "openapi") != "mcp": + if connection.get('type', 'openapi') != 'mcp': continue - if connection.get("auth_type", "none") != "oauth_2.1": + if connection.get('auth_type', 'none') != 'oauth_2.1': continue - server_id = connection.get("info", {}).get("id") + server_id = connection.get('info', {}).get('id') if not server_id: continue - expected_client_id = f"mcp:{server_id}" + expected_client_id = f'mcp:{server_id}' if client_id != expected_client_id: continue - oauth_client_info = connection.get("info", {}).get("oauth_client_info", "") + oauth_client_info = connection.get('info', {}).get('oauth_client_info', '') if not oauth_client_info: continue try: oauth_client_info = decrypt_data(oauth_client_info) - return self.add_client( - expected_client_id, OAuthClientInformationFull(**oauth_client_info) - )["client"] + return self.add_client(expected_client_id, OAuthClientInformationFull(**oauth_client_info))['client'] except Exception as e: - log.error( - f"Failed to lazily add OAuth client {expected_client_id} from config: {e}" - ) + log.error(f'Failed to lazily add OAuth client {expected_client_id} from config: {e}') continue return None @@ -564,24 +522,22 @@ class OAuthClientManager: def remove_client(self, client_id): if client_id in self.clients: del self.clients[client_id] - log.info(f"Removed OAuth client {client_id}") + log.info(f'Removed OAuth client {client_id}') - if hasattr(self.oauth, "_clients"): + if hasattr(self.oauth, '_clients'): if client_id in self.oauth._clients: self.oauth._clients.pop(client_id, None) - if hasattr(self.oauth, "_registry"): + if hasattr(self.oauth, '_registry'): if client_id in self.oauth._registry: self.oauth._registry.pop(client_id, None) return True - async def _preflight_authorization_url( - self, client, client_info: OAuthClientInformationFull - ) -> bool: + async def _preflight_authorization_url(self, client, client_info: OAuthClientInformationFull) -> bool: # TODO: Replace this logic with a more robust OAuth client registration validation # Only perform preflight checks for Starlette OAuth clients - if not hasattr(client, "create_authorization_url"): + if not hasattr(client, 'create_authorization_url'): return True redirect_uri = None @@ -590,13 +546,13 @@ class OAuthClientManager: try: auth_data = await client.create_authorization_url(redirect_uri=redirect_uri) - authorization_url = auth_data.get("url") + authorization_url = auth_data.get('url') if not authorization_url: return True except Exception as e: log.debug( - f"Skipping OAuth preflight for client {client_info.client_id}: {e}", + f'Skipping OAuth preflight for client {client_info.client_id}: {e}', ) return True @@ -612,34 +568,29 @@ class OAuthClientManager: response_text = await resp.text() error = None - error_description = "" + error_description = '' - content_type = resp.headers.get("content-type", "") - if "application/json" in content_type: + content_type = resp.headers.get('content-type', '') + if 'application/json' in content_type: try: payload = json.loads(response_text) - error = payload.get("error") - error_description = payload.get("error_description", "") + error = payload.get('error') + error_description = payload.get('error_description', '') except Exception: pass else: error_description = response_text - error_message = f"{error or ''} {error_description or ''}".lower() + error_message = f'{error or ""} {error_description or ""}'.lower() - if any( - keyword in error_message - for keyword in ("invalid_client", "invalid client", "client id") - ): + if any(keyword in error_message for keyword in ('invalid_client', 'invalid client', 'client id')): log.warning( - f"OAuth client preflight detected invalid registration for {client_info.client_id}: {error} {error_description}" + f'OAuth client preflight detected invalid registration for {client_info.client_id}: {error} {error_description}' ) return False except Exception as e: - log.debug( - f"Skipping OAuth preflight network check for client {client_info.client_id}: {e}" - ) + log.debug(f'Skipping OAuth preflight network check for client {client_info.client_id}: {e}') return True @@ -648,29 +599,23 @@ class OAuthClientManager: self.ensure_client_from_config(client_id) client = self.clients.get(client_id) - return client["client"] if client else None + return client['client'] if client else None def get_client_info(self, client_id): if client_id not in self.clients: self.ensure_client_from_config(client_id) client = self.clients.get(client_id) - return client["client_info"] if client else None + return client['client_info'] if client else None def get_server_metadata_url(self, client_id): client = self.get_client(client_id) if not client: return None - return ( - client._server_metadata_url - if hasattr(client, "_server_metadata_url") - else None - ) + return client._server_metadata_url if hasattr(client, '_server_metadata_url') else None - async def get_oauth_token( - self, user_id: str, client_id: str, force_refresh: bool = False - ): + async def get_oauth_token(self, user_id: str, client_id: str, force_refresh: bool = False): """ Get a valid OAuth token for the user, automatically refreshing if needed. @@ -684,34 +629,26 @@ class OAuthClientManager: """ try: # Get the OAuth session - session = OAuthSessions.get_session_by_provider_and_user_id( - client_id, user_id - ) + session = OAuthSessions.get_session_by_provider_and_user_id(client_id, user_id) if not session: - log.warning( - f"No OAuth session found for user {user_id}, client_id {client_id}" - ) + log.warning(f'No OAuth session found for user {user_id}, client_id {client_id}') return None - if force_refresh or datetime.now() + timedelta( - minutes=5 - ) >= datetime.fromtimestamp(session.expires_at): - log.debug( - f"Token refresh needed for user {user_id}, client_id {session.provider}" - ) + if force_refresh or datetime.now() + timedelta(minutes=5) >= datetime.fromtimestamp(session.expires_at): + log.debug(f'Token refresh needed for user {user_id}, client_id {session.provider}') refreshed_token = await self._refresh_token(session) if refreshed_token: return refreshed_token else: log.warning( - f"Token refresh failed for user {user_id}, client_id {session.provider}, deleting session {session.id}" + f'Token refresh failed for user {user_id}, client_id {session.provider}, deleting session {session.id}' ) OAuthSessions.delete_session_by_id(session.id) return None return session.token except Exception as e: - log.error(f"Error getting OAuth token for user {user_id}: {e}") + log.error(f'Error getting OAuth token for user {user_id}: {e}') return None async def _refresh_token(self, session) -> dict: @@ -730,17 +667,15 @@ class OAuthClientManager: if refreshed_token: # Update the session with new token data - session = OAuthSessions.update_session_by_id( - session.id, refreshed_token - ) - log.info(f"Successfully refreshed token for session {session.id}") + session = OAuthSessions.update_session_by_id(session.id, refreshed_token) + log.info(f'Successfully refreshed token for session {session.id}') return session.token else: - log.error(f"Failed to refresh token for session {session.id}") + log.error(f'Failed to refresh token for session {session.id}') return None except Exception as e: - log.error(f"Error refreshing token for session {session.id}: {e}") + log.error(f'Error refreshing token for session {session.id}: {e}') return None async def _perform_token_refresh(self, session) -> dict: @@ -756,92 +691,78 @@ class OAuthClientManager: client_id = session.provider token_data = session.token - if not token_data.get("refresh_token"): - log.warning(f"No refresh token available for session {session.id}") + if not token_data.get('refresh_token'): + log.warning(f'No refresh token available for session {session.id}') return None try: client = self.get_client(client_id) if not client: - log.error(f"No OAuth client found for provider {client_id}") + log.error(f'No OAuth client found for provider {client_id}') return None token_endpoint = None async with aiohttp.ClientSession(trust_env=True) as session_http: - async with session_http.get( - self.get_server_metadata_url(client_id) - ) as r: + async with session_http.get(self.get_server_metadata_url(client_id)) as r: if r.status == 200: openid_data = await r.json() - token_endpoint = openid_data.get("token_endpoint") + token_endpoint = openid_data.get('token_endpoint') else: - log.error( - f"Failed to fetch OpenID configuration for client_id {client_id}" - ) + log.error(f'Failed to fetch OpenID configuration for client_id {client_id}') if not token_endpoint: - log.error(f"No token endpoint found for client_id {client_id}") + log.error(f'No token endpoint found for client_id {client_id}') return None # Prepare refresh request refresh_data = { - "grant_type": "refresh_token", - "refresh_token": token_data["refresh_token"], - "client_id": client.client_id, + 'grant_type': 'refresh_token', + 'refresh_token': token_data['refresh_token'], + 'client_id': client.client_id, } - if hasattr(client, "client_secret") and client.client_secret: - refresh_data["client_secret"] = client.client_secret + if hasattr(client, 'client_secret') and client.client_secret: + refresh_data['client_secret'] = client.client_secret # Add scope if available in client kwargs (some providers require it on refresh) if ( - hasattr(client, "client_kwargs") - and client.client_kwargs.get("scope") - and getattr( - self.app.state.config, "OAUTH_REFRESH_TOKEN_INCLUDE_SCOPE", False - ) + hasattr(client, 'client_kwargs') + and client.client_kwargs.get('scope') + and getattr(self.app.state.config, 'OAUTH_REFRESH_TOKEN_INCLUDE_SCOPE', False) ): - refresh_data["scope"] = client.client_kwargs["scope"] + refresh_data['scope'] = client.client_kwargs['scope'] # Make refresh request async with aiohttp.ClientSession(trust_env=True) as session_http: async with session_http.post( token_endpoint, data=refresh_data, - headers={"Content-Type": "application/x-www-form-urlencoded"}, + headers={'Content-Type': 'application/x-www-form-urlencoded'}, ssl=AIOHTTP_CLIENT_SESSION_SSL, ) as r: if r.status == 200: new_token_data = await r.json() # Merge with existing token data (preserve refresh_token if not provided) - if "refresh_token" not in new_token_data: - new_token_data["refresh_token"] = token_data[ - "refresh_token" - ] + if 'refresh_token' not in new_token_data: + new_token_data['refresh_token'] = token_data['refresh_token'] # Add timestamp for tracking - new_token_data["issued_at"] = datetime.now().timestamp() + new_token_data['issued_at'] = datetime.now().timestamp() # Calculate expires_at if we have expires_in - if ( - "expires_in" in new_token_data - and "expires_at" not in new_token_data - ): - new_token_data["expires_at"] = int( - datetime.now().timestamp() - + new_token_data["expires_in"] + if 'expires_in' in new_token_data and 'expires_at' not in new_token_data: + new_token_data['expires_at'] = int( + datetime.now().timestamp() + new_token_data['expires_in'] ) - log.debug(f"Token refresh successful for client_id {client_id}") + log.debug(f'Token refresh successful for client_id {client_id}') return new_token_data else: error_text = await r.text() - log.error( - f"Token refresh failed for client_id {client_id}: {r.status} - {error_text}" - ) + log.error(f'Token refresh failed for client_id {client_id}: {r.status} - {error_text}') return None except Exception as e: - log.error(f"Exception during token refresh for client_id {client_id}: {e}") + log.error(f'Exception during token refresh for client_id {client_id}: {e}') return None async def handle_authorize(self, request, client_id: str) -> RedirectResponse: @@ -855,9 +776,7 @@ class OAuthClientManager: if client_info is None: raise HTTPException(404) - redirect_uri = ( - client_info.redirect_uris[0] if client_info.redirect_uris else None - ) + redirect_uri = client_info.redirect_uris[0] if client_info.redirect_uris else None redirect_uri_str = str(redirect_uri) if redirect_uri else None return await client.authorize_redirect(request, redirect_uri_str) @@ -878,24 +797,20 @@ class OAuthClientManager: # Validate that we received a proper token response # If token exchange failed (e.g., 401), we may get an error response instead - if token and not token.get("access_token"): - error_desc = token.get( - "error_description", token.get("error", "Unknown error") - ) - error_message = f"Token exchange failed: {error_desc}" - log.error(f"Invalid token response for client_id {client_id}: {token}") + if token and not token.get('access_token'): + error_desc = token.get('error_description', token.get('error', 'Unknown error')) + error_message = f'Token exchange failed: {error_desc}' + log.error(f'Invalid token response for client_id {client_id}: {token}') token = None if token: try: # Add timestamp for tracking - token["issued_at"] = datetime.now().timestamp() + token['issued_at'] = datetime.now().timestamp() # Calculate expires_at if we have expires_in - if "expires_in" in token and "expires_at" not in token: - token["expires_at"] = ( - datetime.now().timestamp() + token["expires_in"] - ) + if 'expires_in' in token and 'expires_at' not in token: + token['expires_at'] = datetime.now().timestamp() + token['expires_in'] # Clean up any existing sessions for this user/client_id first sessions = OAuthSessions.get_sessions_by_user_id(user_id) @@ -908,35 +823,29 @@ class OAuthClientManager: provider=client_id, token=token, ) - log.info( - f"Stored OAuth session server-side for user {user_id}, client_id {client_id}" - ) + log.info(f'Stored OAuth session server-side for user {user_id}, client_id {client_id}') except Exception as e: - error_message = "Failed to store OAuth session server-side" - log.error(f"Failed to store OAuth session server-side: {e}") + error_message = 'Failed to store OAuth session server-side' + log.error(f'Failed to store OAuth session server-side: {e}') else: if not error_message: - error_message = "Failed to obtain OAuth token" + error_message = 'Failed to obtain OAuth token' log.warning(error_message) except Exception as e: error_message = _build_oauth_callback_error_message(e) log.warning( - "OAuth callback error for user_id=%s client_id=%s: %s", + 'OAuth callback error for user_id=%s client_id=%s: %s', user_id, client_id, error_message, exc_info=True, ) - redirect_url = ( - str(request.app.state.config.WEBUI_URL or request.base_url) - ).rstrip("/") + redirect_url = (str(request.app.state.config.WEBUI_URL or request.base_url)).rstrip('/') if error_message: log.debug(error_message) - redirect_url = ( - f"{redirect_url}/?error={urllib.parse.quote_plus(error_message)}" - ) + redirect_url = f'{redirect_url}/?error={urllib.parse.quote_plus(error_message)}' return RedirectResponse(url=redirect_url, headers=response.headers) response = RedirectResponse(url=redirect_url, headers=response.headers) @@ -951,11 +860,11 @@ class OAuthManager: self._clients = {} for name, provider_config in OAUTH_PROVIDERS.items(): - if "register" not in provider_config: - log.error(f"OAuth provider {name} missing register function") + if 'register' not in provider_config: + log.error(f'OAuth provider {name} missing register function') continue - client = provider_config["register"](self.oauth) + client = provider_config['register'](self.oauth) self._clients[name] = client def get_client(self, provider_name): @@ -966,16 +875,10 @@ class OAuthManager: def get_server_metadata_url(self, provider_name): if provider_name in self._clients: client = self._clients[provider_name] - return ( - client._server_metadata_url - if hasattr(client, "_server_metadata_url") - else None - ) + return client._server_metadata_url if hasattr(client, '_server_metadata_url') else None return None - async def get_oauth_token( - self, user_id: str, session_id: str, force_refresh: bool = False - ): + async def get_oauth_token(self, user_id: str, session_id: str, force_refresh: bool = False): """ Get a valid OAuth token for the user, automatically refreshing if needed. @@ -991,23 +894,17 @@ class OAuthManager: # Get the OAuth session session = OAuthSessions.get_session_by_id_and_user_id(session_id, user_id) if not session: - log.warning( - f"No OAuth session found for user {user_id}, session {session_id}" - ) + log.warning(f'No OAuth session found for user {user_id}, session {session_id}') return None - if force_refresh or datetime.now() + timedelta( - minutes=5 - ) >= datetime.fromtimestamp(session.expires_at): - log.debug( - f"Token refresh needed for user {user_id}, provider {session.provider}" - ) + if force_refresh or datetime.now() + timedelta(minutes=5) >= datetime.fromtimestamp(session.expires_at): + log.debug(f'Token refresh needed for user {user_id}, provider {session.provider}') refreshed_token = await self._refresh_token(session) if refreshed_token: return refreshed_token else: log.warning( - f"Token refresh failed for user {user_id}, provider {session.provider}, deleting session {session.id}" + f'Token refresh failed for user {user_id}, provider {session.provider}, deleting session {session.id}' ) OAuthSessions.delete_session_by_id(session.id) @@ -1015,7 +912,7 @@ class OAuthManager: return session.token except Exception as e: - log.error(f"Error getting OAuth token for user {user_id}: {e}") + log.error(f'Error getting OAuth token for user {user_id}: {e}') return None async def _refresh_token(self, session) -> dict: @@ -1034,17 +931,15 @@ class OAuthManager: if refreshed_token: # Update the session with new token data - session = OAuthSessions.update_session_by_id( - session.id, refreshed_token - ) - log.info(f"Successfully refreshed token for session {session.id}") + session = OAuthSessions.update_session_by_id(session.id, refreshed_token) + log.info(f'Successfully refreshed token for session {session.id}') return session.token else: - log.error(f"Failed to refresh token for session {session.id}") + log.error(f'Failed to refresh token for session {session.id}') return None except Exception as e: - log.error(f"Error refreshing token for session {session.id}: {e}") + log.error(f'Error refreshing token for session {session.id}: {e}') return None async def _perform_token_refresh(self, session) -> dict: @@ -1060,14 +955,14 @@ class OAuthManager: provider = session.provider token_data = session.token - if not token_data.get("refresh_token"): - log.warning(f"No refresh token available for session {session.id}") + if not token_data.get('refresh_token'): + log.warning(f'No refresh token available for session {session.id}') return None try: client = self.get_client(provider) if not client: - log.error(f"No OAuth client found for provider {provider}") + log.error(f'No OAuth client found for provider {provider}') return None server_metadata_url = self.get_server_metadata_url(provider) @@ -1076,89 +971,79 @@ class OAuthManager: async with session_http.get(server_metadata_url) as r: if r.status == 200: openid_data = await r.json() - token_endpoint = openid_data.get("token_endpoint") + token_endpoint = openid_data.get('token_endpoint') else: - log.error( - f"Failed to fetch OpenID configuration for provider {provider}" - ) + log.error(f'Failed to fetch OpenID configuration for provider {provider}') if not token_endpoint: - log.error(f"No token endpoint found for provider {provider}") + log.error(f'No token endpoint found for provider {provider}') return None # Prepare refresh request refresh_data = { - "grant_type": "refresh_token", - "refresh_token": token_data["refresh_token"], - "client_id": client.client_id, + 'grant_type': 'refresh_token', + 'refresh_token': token_data['refresh_token'], + 'client_id': client.client_id, } # Add client_secret if available (some providers require it) - if hasattr(client, "client_secret") and client.client_secret: - refresh_data["client_secret"] = client.client_secret + if hasattr(client, 'client_secret') and client.client_secret: + refresh_data['client_secret'] = client.client_secret # Add scope if available in client kwargs (some providers require it on refresh) if ( - hasattr(client, "client_kwargs") - and client.client_kwargs.get("scope") + hasattr(client, 'client_kwargs') + and client.client_kwargs.get('scope') and auth_manager_config.OAUTH_REFRESH_TOKEN_INCLUDE_SCOPE ): - refresh_data["scope"] = client.client_kwargs["scope"] + refresh_data['scope'] = client.client_kwargs['scope'] # Make refresh request async with aiohttp.ClientSession(trust_env=True) as session_http: async with session_http.post( token_endpoint, data=refresh_data, - headers={"Content-Type": "application/x-www-form-urlencoded"}, + headers={'Content-Type': 'application/x-www-form-urlencoded'}, ssl=AIOHTTP_CLIENT_SESSION_SSL, ) as r: if r.status == 200: new_token_data = await r.json() # Merge with existing token data (preserve refresh_token if not provided) - if "refresh_token" not in new_token_data: - new_token_data["refresh_token"] = token_data[ - "refresh_token" - ] + if 'refresh_token' not in new_token_data: + new_token_data['refresh_token'] = token_data['refresh_token'] # Add timestamp for tracking - new_token_data["issued_at"] = datetime.now().timestamp() + new_token_data['issued_at'] = datetime.now().timestamp() # Calculate expires_at if we have expires_in - if ( - "expires_in" in new_token_data - and "expires_at" not in new_token_data - ): - new_token_data["expires_at"] = int( - datetime.now().timestamp() - + new_token_data["expires_in"] + if 'expires_in' in new_token_data and 'expires_at' not in new_token_data: + new_token_data['expires_at'] = int( + datetime.now().timestamp() + new_token_data['expires_in'] ) - log.debug(f"Token refresh successful for provider {provider}") + log.debug(f'Token refresh successful for provider {provider}') return new_token_data else: error_text = await r.text() - log.error( - f"Token refresh failed for provider {provider}: {r.status} - {error_text}" - ) + log.error(f'Token refresh failed for provider {provider}: {r.status} - {error_text}') return None except Exception as e: - log.error(f"Exception during token refresh for provider {provider}: {e}") + log.error(f'Exception during token refresh for provider {provider}: {e}') return None def get_user_role(self, user, user_data): user_count = Users.get_num_users() if user and user_count == 1: # If the user is the only user, assign the role "admin" - actually repairs role for single user on login - log.debug("Assigning the only user the admin role") - return "admin" + log.debug('Assigning the only user the admin role') + return 'admin' if not user and user_count == 0: # If there are no users, assign the role "admin", as the first user will be an admin - log.debug("Assigning the first user the admin role") - return "admin" + log.debug('Assigning the first user the admin role') + return 'admin' if auth_manager_config.ENABLE_OAUTH_ROLE_MANAGEMENT: - log.debug("Running OAUTH Role management") + log.debug('Running OAUTH Role management') oauth_claim = auth_manager_config.OAUTH_ROLES_CLAIM oauth_allowed_roles = auth_manager_config.OAUTH_ALLOWED_ROLES oauth_admin_roles = auth_manager_config.OAUTH_ADMIN_ROLES @@ -1169,7 +1054,7 @@ class OAuthManager: # Next block extracts the roles from the user data, accepting nested claims of any depth if oauth_claim and oauth_allowed_roles and oauth_admin_roles: claim_data = user_data - nested_claims = oauth_claim.split(".") + nested_claims = oauth_claim.split('.') for nested_claim in nested_claims: claim_data = claim_data.get(nested_claim, {}) @@ -1190,10 +1075,10 @@ class OAuthManager: elif isinstance(claim_data, int): oauth_roles = [str(claim_data)] - log.debug(f"Oauth Roles claim: {oauth_claim}") - log.debug(f"User roles from oauth: {oauth_roles}") - log.debug(f"Accepted user roles: {oauth_allowed_roles}") - log.debug(f"Accepted admin roles: {oauth_admin_roles}") + log.debug(f'Oauth Roles claim: {oauth_claim}') + log.debug(f'User roles from oauth: {oauth_roles}') + log.debug(f'Accepted user roles: {oauth_allowed_roles}') + log.debug(f'Accepted admin roles: {oauth_admin_roles}') # If any roles are found, check if they match the allowed or admin roles if oauth_roles: @@ -1201,14 +1086,14 @@ class OAuthManager: for allowed_role in oauth_allowed_roles: # If the user has any of the allowed roles, assign the role "user" if allowed_role in oauth_roles: - log.debug("Assigned user the user role") - role = "user" + log.debug('Assigned user the user role') + role = 'user' break for admin_role in oauth_admin_roles: # If the user has any of the admin roles, assign the role "admin" if admin_role in oauth_roles: - log.debug("Assigned user the admin role") - role = "admin" + log.debug('Assigned user the admin role') + role = 'admin' break else: if not user: @@ -1221,20 +1106,20 @@ class OAuthManager: return role def update_user_groups(self, user, user_data, default_permissions, db=None): - log.debug("Running OAUTH Group management") + log.debug('Running OAUTH Group management') oauth_claim = auth_manager_config.OAUTH_GROUPS_CLAIM try: blocked_groups = json.loads(auth_manager_config.OAUTH_BLOCKED_GROUPS) except Exception as e: - log.exception(f"Error loading OAUTH_BLOCKED_GROUPS: {e}") + log.exception(f'Error loading OAUTH_BLOCKED_GROUPS: {e}') blocked_groups = [] user_oauth_groups = [] # Nested claim search for groups claim if oauth_claim: claim_data = user_data - nested_claims = oauth_claim.split(".") + nested_claims = oauth_claim.split('.') for nested_claim in nested_claims: claim_data = claim_data.get(nested_claim, {}) @@ -1249,41 +1134,31 @@ class OAuthManager: else: user_oauth_groups = [] - user_current_groups: list[GroupModel] = Groups.get_groups_by_member_id( - user.id, db=db - ) + user_current_groups: list[GroupModel] = Groups.get_groups_by_member_id(user.id, db=db) all_available_groups: list[GroupModel] = Groups.get_all_groups(db=db) # Create groups if they don't exist and creation is enabled if auth_manager_config.ENABLE_OAUTH_GROUP_CREATION: - log.debug("Checking for missing groups to create...") + log.debug('Checking for missing groups to create...') all_group_names = {g.name for g in all_available_groups} groups_created = False # Determine creator ID: Prefer admin, fallback to current user if no admin exists admin_user = Users.get_super_admin_user() creator_id = admin_user.id if admin_user else user.id - log.debug(f"Using creator ID {creator_id} for potential group creation.") + log.debug(f'Using creator ID {creator_id} for potential group creation.') for group_name in user_oauth_groups: if group_name not in all_group_names: - log.info( - f"Group '{group_name}' not found via OAuth claim. Creating group..." - ) + log.info(f"Group '{group_name}' not found via OAuth claim. Creating group...") try: new_group_form = GroupForm( name=group_name, description=f"Group '{group_name}' created automatically via OAuth.", permissions=default_permissions, # Use default permissions from function args - data={ - "config": { - "share": auth_manager_config.OAUTH_GROUP_DEFAULT_SHARE - } - }, + data={'config': {'share': auth_manager_config.OAUTH_GROUP_DEFAULT_SHARE}}, ) # Use determined creator ID (admin or fallback to current user) - created_group = Groups.insert_new_group( - creator_id, new_group_form, db=db - ) + created_group = Groups.insert_new_group(creator_id, new_group_form, db=db) if created_group: log.info( f"Successfully created group '{group_name}' with ID {created_group.id} using creator ID {creator_id}" @@ -1292,23 +1167,19 @@ class OAuthManager: # Add to local set to prevent duplicate creation attempts in this run all_group_names.add(group_name) else: - log.error( - f"Failed to create group '{group_name}' via OAuth." - ) + log.error(f"Failed to create group '{group_name}' via OAuth.") except Exception as e: log.error(f"Error creating group '{group_name}' via OAuth: {e}") # Refresh the list of all available groups if any were created if groups_created: all_available_groups = Groups.get_all_groups(db=db) - log.debug("Refreshed list of all available groups after creation.") + log.debug('Refreshed list of all available groups after creation.') - log.debug(f"Oauth Groups claim: {oauth_claim}") - log.debug(f"User oauth groups: {user_oauth_groups}") + log.debug(f'Oauth Groups claim: {oauth_claim}') + log.debug(f'User oauth groups: {user_oauth_groups}') log.debug(f"User's current groups: {[g.name for g in user_current_groups]}") - log.debug( - f"All groups available in OpenWebUI: {[g.name for g in all_available_groups]}" - ) + log.debug(f'All groups available in OpenWebUI: {[g.name for g in all_available_groups]}') # Remove groups that user is no longer a part of for group_model in user_current_groups: @@ -1318,9 +1189,7 @@ class OAuthManager: and not is_in_blocked_groups(group_model.name, blocked_groups) ): # Remove group from user - log.debug( - f"Removing user from group {group_model.name} as it is no longer in their oauth groups" - ) + log.debug(f'Removing user from group {group_model.name} as it is no longer in their oauth groups') Groups.remove_users_from_group(group_model.id, [user.id], db=db) # In case a group is created, but perms are never assigned to the group by hitting "save" @@ -1348,9 +1217,7 @@ class OAuthManager: and not is_in_blocked_groups(group_model.name, blocked_groups) ): # Add user to group - log.debug( - f"Adding user to group {group_model.name} as it was found in their oauth groups" - ) + log.debug(f'Adding user to group {group_model.name} as it was found in their oauth groups') Groups.add_users_to_group(group_model.id, [user.id], db=db) @@ -1370,9 +1237,7 @@ class OAuthManager: db=db, ) - async def _process_picture_url( - self, picture_url: str, access_token: str = None - ) -> str: + async def _process_picture_url(self, picture_url: str, access_token: str = None) -> str: """Process a picture URL and return a base64 encoded data URL. Args: @@ -1383,44 +1248,36 @@ class OAuthManager: A data URL containing the base64 encoded picture, or "/user.png" if processing fails """ if not picture_url: - return "/user.png" + return '/user.png' try: get_kwargs = {} if access_token: - get_kwargs["headers"] = { - "Authorization": f"Bearer {access_token}", + get_kwargs['headers'] = { + 'Authorization': f'Bearer {access_token}', } async with aiohttp.ClientSession(trust_env=True) as session: - async with session.get( - picture_url, **get_kwargs, ssl=AIOHTTP_CLIENT_SESSION_SSL - ) as resp: + async with session.get(picture_url, **get_kwargs, ssl=AIOHTTP_CLIENT_SESSION_SSL) as resp: if resp.ok: picture = await resp.read() - base64_encoded_picture = base64.b64encode(picture).decode( - "utf-8" - ) + base64_encoded_picture = base64.b64encode(picture).decode('utf-8') guessed_mime_type = mimetypes.guess_type(picture_url)[0] if guessed_mime_type is None: - guessed_mime_type = "image/jpeg" - return ( - f"data:{guessed_mime_type};base64,{base64_encoded_picture}" - ) + guessed_mime_type = 'image/jpeg' + return f'data:{guessed_mime_type};base64,{base64_encoded_picture}' else: - log.warning( - f"Failed to fetch profile picture from {picture_url}" - ) - return "/user.png" + log.warning(f'Failed to fetch profile picture from {picture_url}') + return '/user.png' except Exception as e: log.error(f"Error processing profile picture '{picture_url}': {e}") - return "/user.png" + return '/user.png' async def handle_login(self, request, provider): if provider not in OAUTH_PROVIDERS: raise HTTPException(404) # If the provider has a custom redirect URL, use that, otherwise automatically generate one - redirect_uri = OAUTH_PROVIDERS[provider].get("redirect_uri") or request.url_for( - "oauth_login_callback", provider=provider + redirect_uri = OAUTH_PROVIDERS[provider].get('redirect_uri') or request.url_for( + 'oauth_login_callback', provider=provider ) client = self.get_client(provider) if client is None: @@ -1428,7 +1285,7 @@ class OAuthManager: kwargs = {} if auth_manager_config.OAUTH_AUDIENCE: - kwargs["audience"] = auth_manager_config.OAUTH_AUDIENCE + kwargs['audience'] = auth_manager_config.OAUTH_AUDIENCE return await client.authorize_redirect(request, redirect_uri, **kwargs) @@ -1443,18 +1300,15 @@ class OAuthManager: auth_params = {} if client: - if ( - hasattr(client, "client_id") - and OAUTH_ACCESS_TOKEN_REQUEST_INCLUDE_CLIENT_ID - ): - auth_params["client_id"] = client.client_id + if hasattr(client, 'client_id') and OAUTH_ACCESS_TOKEN_REQUEST_INCLUDE_CLIENT_ID: + auth_params['client_id'] = client.client_id try: token = await client.authorize_access_token(request, **auth_params) except Exception as e: detailed_error = _build_oauth_callback_error_message(e) log.warning( - "OAuth callback error during authorize_access_token for provider %s: %s", + 'OAuth callback error during authorize_access_token for provider %s: %s', provider, detailed_error, exc_info=True, @@ -1462,21 +1316,17 @@ class OAuthManager: raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED) # Try to get userinfo from the token first, some providers include it there - user_data: UserInfo = token.get("userinfo") + user_data: UserInfo = token.get('userinfo') if ( (not user_data) or (auth_manager_config.OAUTH_EMAIL_CLAIM not in user_data) or (auth_manager_config.OAUTH_USERNAME_CLAIM not in user_data) ): user_data: UserInfo = await client.userinfo(token=token) - if ( - provider == "feishu" - and isinstance(user_data, dict) - and "data" in user_data - ): - user_data = user_data["data"] + if provider == 'feishu' and isinstance(user_data, dict) and 'data' in user_data: + user_data = user_data['data'] if not user_data: - log.warning(f"OAuth callback failed, user data is missing: {token}") + log.warning(f'OAuth callback failed, user data is missing: {token}') raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED) # Extract the "sub" claim, using custom claim if configured @@ -1484,29 +1334,29 @@ class OAuthManager: sub = user_data.get(auth_manager_config.OAUTH_SUB_CLAIM) else: # Fallback to the default sub claim if not configured - sub = user_data.get(OAUTH_PROVIDERS[provider].get("sub_claim", "sub")) + sub = user_data.get(OAUTH_PROVIDERS[provider].get('sub_claim', 'sub')) if not sub: - log.warning(f"OAuth callback failed, sub is missing: {user_data}") + log.warning(f'OAuth callback failed, sub is missing: {user_data}') raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED) oauth_data = {} oauth_data[provider] = { - "sub": sub, + 'sub': sub, } # Email extraction email_claim = auth_manager_config.OAUTH_EMAIL_CLAIM - email = user_data.get(email_claim, "") + email = user_data.get(email_claim, '') # We currently mandate that email addresses are provided if not email: # If the provider is GitHub,and public email is not provided, we can use the access token to fetch the user's email - if provider == "github": + if provider == 'github': try: - access_token = token.get("access_token") - headers = {"Authorization": f"Bearer {access_token}"} + access_token = token.get('access_token') + headers = {'Authorization': f'Bearer {access_token}'} async with aiohttp.ClientSession(trust_env=True) as session: async with session.get( - "https://api.github.com/user/emails", + 'https://api.github.com/user/emails', headers=headers, ssl=AIOHTTP_CLIENT_SESSION_SSL, ) as resp: @@ -1514,46 +1364,33 @@ class OAuthManager: emails = await resp.json() # use the primary email as the user's email primary_email = next( - ( - e["email"] - for e in emails - if e.get("primary") - ), + (e['email'] for e in emails if e.get('primary')), None, ) if primary_email: email = primary_email else: - log.warning( - "No primary email found in GitHub response" - ) - raise HTTPException( - 400, detail=ERROR_MESSAGES.INVALID_CRED - ) + log.warning('No primary email found in GitHub response') + raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED) else: - log.warning("Failed to fetch GitHub email") - raise HTTPException( - 400, detail=ERROR_MESSAGES.INVALID_CRED - ) + log.warning('Failed to fetch GitHub email') + raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED) except Exception as e: - log.warning(f"Error fetching GitHub email: {e}") + log.warning(f'Error fetching GitHub email: {e}') raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED) elif ENABLE_OAUTH_EMAIL_FALLBACK: - email = f"{provider}@{sub}.local" + email = f'{provider}@{sub}.local' else: - log.warning(f"OAuth callback failed, email is missing: {user_data}") + log.warning(f'OAuth callback failed, email is missing: {user_data}') raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED) email = email.lower() # If allowed domains are configured, check if the email domain is in the list if ( - "*" not in auth_manager_config.OAUTH_ALLOWED_DOMAINS - and email.split("@")[-1] - not in auth_manager_config.OAUTH_ALLOWED_DOMAINS + '*' not in auth_manager_config.OAUTH_ALLOWED_DOMAINS + and email.split('@')[-1] not in auth_manager_config.OAUTH_ALLOWED_DOMAINS ): - log.warning( - f"OAuth callback failed, e-mail domain is not in the list of allowed domains: {user_data}" - ) + log.warning(f'OAuth callback failed, e-mail domain is not in the list of allowed domains: {user_data}') raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED) # Check if the user exists @@ -1580,9 +1417,9 @@ class OAuthManager: if username_claim: new_name = user_data.get(username_claim) if new_name and new_name != user.name: - Users.update_user_by_id(user.id, {"name": new_name}, db=db) + Users.update_user_by_id(user.id, {'name': new_name}, db=db) user.name = new_name - log.debug(f"Updated name for user {user.email}") + log.debug(f'Updated name for user {user.email}') if auth_manager_config.OAUTH_UPDATE_EMAIL_ON_LOGIN: email_claim = auth_manager_config.OAUTH_EMAIL_CLAIM @@ -1592,14 +1429,12 @@ class OAuthManager: existing_user = Users.get_user_by_email(new_email, db=db) if existing_user: log.error( - f"Cannot update email to {new_email} for user {user.id} because it is already taken." + f'Cannot update email to {new_email} for user {user.id} because it is already taken.' ) else: - Auths.update_email_by_id( - user.id, new_email.lower(), db=db - ) + Auths.update_email_by_id(user.id, new_email.lower(), db=db) user.email = new_email.lower() - log.debug(f"Updated email for user {user.id}") + log.debug(f'Updated email for user {user.id}') # Update profile picture if enabled and different from current if auth_manager_config.OAUTH_UPDATE_PICTURE_ON_LOGIN: @@ -1607,16 +1442,14 @@ class OAuthManager: if picture_claim: new_picture_url = user_data.get( picture_claim, - OAUTH_PROVIDERS[provider].get("picture_url", ""), + OAUTH_PROVIDERS[provider].get('picture_url', ''), ) processed_picture_url = await self._process_picture_url( - new_picture_url, token.get("access_token") + new_picture_url, token.get('access_token') ) if processed_picture_url != user.profile_image_url: - Users.update_user_profile_image_url_by_id( - user.id, processed_picture_url, db=db - ) - log.debug(f"Updated profile picture for user {user.email}") + Users.update_user_profile_image_url_by_id(user.id, processed_picture_url, db=db) + log.debug(f'Updated profile picture for user {user.email}') else: # If the user does not exist, check if signups are enabled if auth_manager_config.ENABLE_OAUTH_SIGNUP: @@ -1629,25 +1462,21 @@ class OAuthManager: if picture_claim: picture_url = user_data.get( picture_claim, - OAUTH_PROVIDERS[provider].get("picture_url", ""), - ) - picture_url = await self._process_picture_url( - picture_url, token.get("access_token") + OAUTH_PROVIDERS[provider].get('picture_url', ''), ) + picture_url = await self._process_picture_url(picture_url, token.get('access_token')) else: - picture_url = "/user.png" + picture_url = '/user.png' username_claim = auth_manager_config.OAUTH_USERNAME_CLAIM name = user_data.get(username_claim) if not name: - log.warning("Username claim is missing, using email as name") + log.warning('Username claim is missing, using email as name') name = email user = Auths.insert_new_auth( email=email, - password=get_password_hash( - str(uuid.uuid4()) - ), # Random password, not used + password=get_password_hash(str(uuid.uuid4())), # Random password, not used name=name, profile_image_url=picture_url, role=self.get_user_role(None, user_data), @@ -1661,15 +1490,13 @@ class OAuthManager: auth_manager_config.WEBHOOK_URL, WEBHOOK_MESSAGES.USER_SIGNUP(user.name), { - "action": "signup", - "message": WEBHOOK_MESSAGES.USER_SIGNUP(user.name), - "user": user.model_dump_json(exclude_none=True), + 'action': 'signup', + 'message': WEBHOOK_MESSAGES.USER_SIGNUP(user.name), + 'user': user.model_dump_json(exclude_none=True), }, ) - apply_default_group_assignment( - request.app.state.config.DEFAULT_GROUP_ID, user.id, db=db - ) + apply_default_group_assignment(request.app.state.config.DEFAULT_GROUP_ID, user.id, db=db) else: raise HTTPException( @@ -1678,13 +1505,10 @@ class OAuthManager: ) jwt_token = create_token( - data={"id": user.id}, + data={'id': user.id}, expires_delta=parse_duration(auth_manager_config.JWT_EXPIRES_IN), ) - if ( - auth_manager_config.ENABLE_OAUTH_GROUP_MANAGEMENT - and user.role != "admin" - ): + if auth_manager_config.ENABLE_OAUTH_GROUP_MANAGEMENT and user.role != 'admin': self.update_user_groups( user=user, user_data=user_data, @@ -1693,22 +1517,18 @@ class OAuthManager: ) except Exception as e: - log.error(f"Error during OAuth process: {e}") + log.error(f'Error during OAuth process: {e}') error_message = ( e.detail if isinstance(e, HTTPException) and e.detail - else ERROR_MESSAGES.DEFAULT("Error during OAuth process") + else ERROR_MESSAGES.DEFAULT('Error during OAuth process') ) - redirect_base_url = ( - str(request.app.state.config.WEBUI_URL or request.base_url) - ).rstrip("/") - redirect_url = f"{redirect_base_url}/auth" + redirect_base_url = (str(request.app.state.config.WEBUI_URL or request.base_url)).rstrip('/') + redirect_url = f'{redirect_base_url}/auth' if error_message: - redirect_url = ( - f"{redirect_url}?error={urllib.parse.quote_plus(error_message)}" - ) + redirect_url = f'{redirect_url}?error={urllib.parse.quote_plus(error_message)}' return RedirectResponse(url=redirect_url, headers=response.headers) response = RedirectResponse(url=redirect_url, headers=response.headers) @@ -1716,7 +1536,7 @@ class OAuthManager: # Set the cookie token # Redirect back to the frontend with the JWT token response.set_cookie( - key="token", + key='token', value=jwt_token, httponly=False, # Required for frontend access samesite=WEBUI_AUTH_COOKIE_SAME_SITE, @@ -1726,8 +1546,8 @@ class OAuthManager: # Legacy cookies for compatibility with older frontend versions if ENABLE_OAUTH_ID_TOKEN_COOKIE: response.set_cookie( - key="oauth_id_token", - value=token.get("id_token"), + key='oauth_id_token', + value=token.get('id_token'), httponly=True, samesite=WEBUI_AUTH_COOKIE_SAME_SITE, secure=WEBUI_AUTH_COOKIE_SECURE, @@ -1735,11 +1555,11 @@ class OAuthManager: try: # Add timestamp for tracking - token["issued_at"] = datetime.now().timestamp() + token['issued_at'] = datetime.now().timestamp() # Calculate expires_at if we have expires_in - if "expires_in" in token and "expires_at" not in token: - token["expires_at"] = datetime.now().timestamp() + token["expires_in"] + if 'expires_in' in token and 'expires_at' not in token: + token['expires_at'] = datetime.now().timestamp() + token['expires_in'] # Enforce max concurrent sessions per user/provider to prevent # unbounded growth while allowing multi-device usage @@ -1763,21 +1583,17 @@ class OAuthManager: if session: response.set_cookie( - key="oauth_session_id", + key='oauth_session_id', value=session.id, httponly=True, samesite=WEBUI_AUTH_COOKIE_SAME_SITE, secure=WEBUI_AUTH_COOKIE_SECURE, ) - log.info( - f"Stored OAuth session server-side for user {user.id}, provider {provider}" - ) + log.info(f'Stored OAuth session server-side for user {user.id}, provider {provider}') else: - log.warning( - f"Failed to create OAuth session for user {user.id}, provider {provider}" - ) + log.warning(f'Failed to create OAuth session for user {user.id}, provider {provider}') except Exception as e: - log.error(f"Failed to store OAuth session server-side: {e}") + log.error(f'Failed to store OAuth session server-side: {e}') return response diff --git a/backend/open_webui/utils/payload.py b/backend/open_webui/utils/payload.py index 168ec893b2..21828d93f1 100644 --- a/backend/open_webui/utils/payload.py +++ b/backend/open_webui/utils/payload.py @@ -23,7 +23,7 @@ def apply_system_prompt_to_body( # Metadata (WebUI Usage) if metadata: - variables = metadata.get("variables", {}) + variables = metadata.get('variables', {}) if variables: system = prompt_variables_template(system, variables) @@ -31,21 +31,15 @@ def apply_system_prompt_to_body( system = prompt_template(system, user) if replace: - form_data["messages"] = replace_system_message_content( - system, form_data.get("messages", []) - ) + form_data['messages'] = replace_system_message_content(system, form_data.get('messages', [])) else: - form_data["messages"] = add_or_update_system_message( - system, form_data.get("messages", []) - ) + form_data['messages'] = add_or_update_system_message(system, form_data.get('messages', [])) return form_data # inplace function: form_data is modified -def apply_model_params_to_body( - params: dict, form_data: dict, mappings: dict[str, Callable] -) -> dict: +def apply_model_params_to_body(params: dict, form_data: dict, mappings: dict[str, Callable]) -> dict: if not params: return form_data @@ -72,11 +66,11 @@ def remove_open_webui_params(params: dict) -> dict: dict: The modified dictionary with OpenWebUI parameters removed. """ open_webui_params = { - "stream_response": bool, - "stream_delta_chunk_size": int, - "function_calling": str, - "reasoning_tags": list, - "system": str, + 'stream_response': bool, + 'stream_delta_chunk_size': int, + 'function_calling': str, + 'reasoning_tags': list, + 'system': str, } for key in list(params.keys()): @@ -90,7 +84,7 @@ def remove_open_webui_params(params: dict) -> dict: def apply_model_params_to_body_openai(params: dict, form_data: dict) -> dict: params = remove_open_webui_params(params) - custom_params = params.pop("custom_params", {}) + custom_params = params.pop('custom_params', {}) if custom_params: # Attempt to parse custom_params if they are strings for key, value in custom_params.items(): @@ -106,17 +100,17 @@ def apply_model_params_to_body_openai(params: dict, form_data: dict) -> dict: params = deep_update(params, custom_params) mappings = { - "temperature": float, - "top_p": float, - "min_p": float, - "max_tokens": int, - "frequency_penalty": float, - "presence_penalty": float, - "reasoning_effort": str, - "seed": lambda x: x, - "stop": lambda x: [bytes(s, "utf-8").decode("unicode_escape") for s in x], - "logit_bias": lambda x: x, - "response_format": dict, + 'temperature': float, + 'top_p': float, + 'min_p': float, + 'max_tokens': int, + 'frequency_penalty': float, + 'presence_penalty': float, + 'reasoning_effort': str, + 'seed': lambda x: x, + 'stop': lambda x: [bytes(s, 'utf-8').decode('unicode_escape') for s in x], + 'logit_bias': lambda x: x, + 'response_format': dict, } return apply_model_params_to_body(params, form_data, mappings) @@ -124,7 +118,7 @@ def apply_model_params_to_body_openai(params: dict, form_data: dict) -> dict: def apply_model_params_to_body_ollama(params: dict, form_data: dict) -> dict: params = remove_open_webui_params(params) - custom_params = params.pop("custom_params", {}) + custom_params = params.pop('custom_params', {}) if custom_params: # Attempt to parse custom_params if they are strings for key, value in custom_params.items(): @@ -141,7 +135,7 @@ def apply_model_params_to_body_ollama(params: dict, form_data: dict) -> dict: # Convert OpenAI parameter names to Ollama parameter names if needed. name_differences = { - "max_tokens": "num_predict", + 'max_tokens': 'num_predict', } for key, value in name_differences.items(): @@ -152,27 +146,27 @@ def apply_model_params_to_body_ollama(params: dict, form_data: dict) -> dict: # See https://github.com/ollama/ollama/blob/main/docs/api.md#request-8 mappings = { - "temperature": float, - "top_p": float, - "seed": lambda x: x, - "mirostat": int, - "mirostat_eta": float, - "mirostat_tau": float, - "num_ctx": int, - "num_batch": int, - "num_keep": int, - "num_predict": int, - "repeat_last_n": int, - "top_k": int, - "min_p": float, - "repeat_penalty": float, - "presence_penalty": float, - "frequency_penalty": float, - "stop": lambda x: [bytes(s, "utf-8").decode("unicode_escape") for s in x], - "num_gpu": int, - "use_mmap": bool, - "use_mlock": bool, - "num_thread": int, + 'temperature': float, + 'top_p': float, + 'seed': lambda x: x, + 'mirostat': int, + 'mirostat_eta': float, + 'mirostat_tau': float, + 'num_ctx': int, + 'num_batch': int, + 'num_keep': int, + 'num_predict': int, + 'repeat_last_n': int, + 'top_k': int, + 'min_p': float, + 'repeat_penalty': float, + 'presence_penalty': float, + 'frequency_penalty': float, + 'stop': lambda x: [bytes(s, 'utf-8').decode('unicode_escape') for s in x], + 'num_gpu': int, + 'use_mmap': bool, + 'use_mlock': bool, + 'num_thread': int, } def parse_json(value: str) -> dict: @@ -185,9 +179,9 @@ def apply_model_params_to_body_ollama(params: dict, form_data: dict) -> dict: return value ollama_root_params = { - "format": lambda x: parse_json(x), - "keep_alive": lambda x: parse_json(x), - "think": lambda x: x, + 'format': lambda x: parse_json(x), + 'keep_alive': lambda x: parse_json(x), + 'think': lambda x: x, } for key, value in ollama_root_params.items(): @@ -197,9 +191,7 @@ def apply_model_params_to_body_ollama(params: dict, form_data: dict) -> dict: del params[key] # Unlike OpenAI, Ollama does not support params directly in the body - form_data["options"] = apply_model_params_to_body( - params, (form_data.get("options", {}) or {}), mappings - ) + form_data['options'] = apply_model_params_to_body(params, (form_data.get('options', {}) or {}), mappings) return form_data @@ -208,68 +200,66 @@ def convert_messages_openai_to_ollama(messages: list[dict]) -> list[dict]: for message in messages: # Initialize the new message structure with the role - new_message = {"role": message["role"]} + new_message = {'role': message['role']} - content = message.get("content", []) - tool_calls = message.get("tool_calls", None) - tool_call_id = message.get("tool_call_id", None) + content = message.get('content', []) + tool_calls = message.get('tool_calls', None) + tool_call_id = message.get('tool_call_id', None) # Check if the content is a string (just a simple message) if isinstance(content, str) and not tool_calls: # If the content is a string, it's pure text - new_message["content"] = content + new_message['content'] = content # If message is a tool call, add the tool call id to the message if tool_call_id: - new_message["tool_call_id"] = tool_call_id + new_message['tool_call_id'] = tool_call_id elif tool_calls: # If tool calls are present, add them to the message ollama_tool_calls = [] for tool_call in tool_calls: ollama_tool_call = { - "index": tool_call.get("index", 0), - "id": tool_call.get("id", None), - "function": { - "name": tool_call.get("function", {}).get("name", ""), - "arguments": json.loads( - tool_call.get("function", {}).get("arguments", {}) - ), + 'index': tool_call.get('index', 0), + 'id': tool_call.get('id', None), + 'function': { + 'name': tool_call.get('function', {}).get('name', ''), + 'arguments': json.loads(tool_call.get('function', {}).get('arguments', {})), }, } ollama_tool_calls.append(ollama_tool_call) - new_message["tool_calls"] = ollama_tool_calls + new_message['tool_calls'] = ollama_tool_calls # Put the content to empty string (Ollama requires an empty string for tool calls) - new_message["content"] = "" + new_message['content'] = '' else: # Otherwise, assume the content is a list of dicts, e.g., text followed by an image URL - content_text = "" + content_text = '' images = [] # Iterate through the list of content items for item in content: # Check if it's a text type - if item.get("type") == "text": - content_text += item.get("text", "") + if item.get('type') == 'text': + content_text += item.get('text', '') # Check if it's an image URL type - elif item.get("type") == "image_url": - img_url = item.get("image_url", {}).get("url", "") + elif item.get('type') == 'image_url': + img_url = item.get('image_url', {}).get('url', '') if img_url: # If the image url starts with data:, it's a base64 image and should be trimmed - if img_url.startswith("data:"): - img_url = img_url.split(",")[-1] + if img_url.startswith('data:'): + img_url = img_url.split(',')[-1] images.append(img_url) # Add content text (if any) if content_text: - new_message["content"] = content_text.strip() + new_message['content'] = content_text.strip() # Add images (if any) if images: - new_message["images"] = images + new_message['images'] = images # Append the new formatted message to the result ollama_messages.append(new_message) @@ -288,31 +278,27 @@ def convert_payload_openai_to_ollama(openai_payload: dict) -> dict: dict: A modified payload compatible with the Ollama API. """ # Shallow copy metadata separately (may contain non-picklable objects) - metadata = openai_payload.get("metadata") - openai_payload = copy.deepcopy( - {k: v for k, v in openai_payload.items() if k != "metadata"} - ) + metadata = openai_payload.get('metadata') + openai_payload = copy.deepcopy({k: v for k, v in openai_payload.items() if k != 'metadata'}) if metadata is not None: - openai_payload["metadata"] = dict(metadata) + openai_payload['metadata'] = dict(metadata) ollama_payload = {} # Mapping basic model and message details - ollama_payload["model"] = openai_payload.get("model") - ollama_payload["messages"] = convert_messages_openai_to_ollama( - openai_payload.get("messages") - ) - ollama_payload["stream"] = openai_payload.get("stream", False) - if "tools" in openai_payload: - ollama_payload["tools"] = openai_payload["tools"] + ollama_payload['model'] = openai_payload.get('model') + ollama_payload['messages'] = convert_messages_openai_to_ollama(openai_payload.get('messages')) + ollama_payload['stream'] = openai_payload.get('stream', False) + if 'tools' in openai_payload: + ollama_payload['tools'] = openai_payload['tools'] - if "max_tokens" in openai_payload: - ollama_payload["num_predict"] = openai_payload["max_tokens"] - del openai_payload["max_tokens"] + if 'max_tokens' in openai_payload: + ollama_payload['num_predict'] = openai_payload['max_tokens'] + del openai_payload['max_tokens'] # If there are advanced parameters in the payload, format them in Ollama's options field - if openai_payload.get("options"): - ollama_payload["options"] = openai_payload["options"] - ollama_options = openai_payload["options"] + if openai_payload.get('options'): + ollama_payload['options'] = openai_payload['options'] + ollama_options = openai_payload['options'] def parse_json(value: str) -> dict: """ @@ -324,9 +310,9 @@ def convert_payload_openai_to_ollama(openai_payload: dict) -> dict: return value ollama_root_params = { - "format": lambda x: parse_json(x), - "keep_alive": lambda x: parse_json(x), - "think": lambda x: x, + 'format': lambda x: parse_json(x), + 'keep_alive': lambda x: parse_json(x), + 'think': lambda x: x, } # Ollama's options field can contain parameters that should be at the root level. @@ -337,35 +323,35 @@ def convert_payload_openai_to_ollama(openai_payload: dict) -> dict: del ollama_options[key] # Re-Mapping OpenAI's `max_tokens` -> Ollama's `num_predict` - if "max_tokens" in ollama_options: - ollama_options["num_predict"] = ollama_options["max_tokens"] - del ollama_options["max_tokens"] + if 'max_tokens' in ollama_options: + ollama_options['num_predict'] = ollama_options['max_tokens'] + del ollama_options['max_tokens'] # Ollama lacks a "system" prompt option. It has to be provided as a direct parameter, so we copy it down. # Comment: Not sure why this is needed, but we'll keep it for compatibility. - if "system" in ollama_options: - ollama_payload["system"] = ollama_options["system"] - del ollama_options["system"] + if 'system' in ollama_options: + ollama_payload['system'] = ollama_options['system'] + del ollama_options['system'] - ollama_payload["options"] = ollama_options + ollama_payload['options'] = ollama_options # If there is the "stop" parameter in the openai_payload, remap it to the ollama_payload.options - if "stop" in openai_payload: - ollama_options = ollama_payload.get("options", {}) - ollama_options["stop"] = openai_payload.get("stop") - ollama_payload["options"] = ollama_options + if 'stop' in openai_payload: + ollama_options = ollama_payload.get('options', {}) + ollama_options['stop'] = openai_payload.get('stop') + ollama_payload['options'] = ollama_options - if "metadata" in openai_payload: - ollama_payload["metadata"] = openai_payload["metadata"] + if 'metadata' in openai_payload: + ollama_payload['metadata'] = openai_payload['metadata'] - if "response_format" in openai_payload: - response_format = openai_payload["response_format"] - format_type = response_format.get("type", None) + if 'response_format' in openai_payload: + response_format = openai_payload['response_format'] + format_type = response_format.get('type', None) schema = response_format.get(format_type, None) if schema: - format = schema.get("schema", None) - ollama_payload["format"] = format + format = schema.get('schema', None) + ollama_payload['format'] = format return ollama_payload @@ -380,19 +366,19 @@ def convert_embedding_payload_openai_to_ollama(openai_payload: dict) -> dict: Returns: dict: A payload compatible with the Ollama API embeddings endpoint. """ - ollama_payload = {"model": openai_payload.get("model")} - input_value = openai_payload.get("input") + ollama_payload = {'model': openai_payload.get('model')} + input_value = openai_payload.get('input') # Ollama expects 'input' as a list, and 'prompt' as a single string. if isinstance(input_value, list): - ollama_payload["input"] = input_value - ollama_payload["prompt"] = "\n".join(str(x) for x in input_value) + ollama_payload['input'] = input_value + ollama_payload['prompt'] = '\n'.join(str(x) for x in input_value) else: - ollama_payload["input"] = [input_value] - ollama_payload["prompt"] = str(input_value) + ollama_payload['input'] = [input_value] + ollama_payload['prompt'] = str(input_value) # Optionally forward other fields if present - for optional_key in ("options", "truncate", "keep_alive"): + for optional_key in ('options', 'truncate', 'keep_alive'): if optional_key in openai_payload: ollama_payload[optional_key] = openai_payload[optional_key] @@ -411,14 +397,14 @@ def convert_embed_payload_openai_to_ollama(openai_payload: dict) -> dict: Returns: dict: A payload compatible with the Ollama /api/embed endpoint. """ - ollama_payload = {"model": openai_payload.get("model")} - input_value = openai_payload.get("input") + ollama_payload = {'model': openai_payload.get('model')} + input_value = openai_payload.get('input') # /api/embed accepts 'input' as a string or list of strings directly - ollama_payload["input"] = input_value + ollama_payload['input'] = input_value # Optionally forward other fields if present - for optional_key in ("truncate", "options", "keep_alive"): + for optional_key in ('truncate', 'options', 'keep_alive'): if optional_key in openai_payload: ollama_payload[optional_key] = openai_payload[optional_key] diff --git a/backend/open_webui/utils/pdf_generator.py b/backend/open_webui/utils/pdf_generator.py index c137b49da0..3db4297a21 100644 --- a/backend/open_webui/utils/pdf_generator.py +++ b/backend/open_webui/utils/pdf_generator.py @@ -29,32 +29,32 @@ class PDFGenerator: self.messages_html = None self.form_data = form_data - self.css = Path(STATIC_DIR / "assets" / "pdf-style.css").read_text() + self.css = Path(STATIC_DIR / 'assets' / 'pdf-style.css').read_text() def format_timestamp(self, timestamp: float) -> str: """Convert a UNIX timestamp to a formatted date string.""" try: date_time = datetime.fromtimestamp(timestamp) - return date_time.strftime("%Y-%m-%d, %H:%M:%S") + return date_time.strftime('%Y-%m-%d, %H:%M:%S') except (ValueError, TypeError) as e: # Log the error if necessary - return "" + return '' def _build_html_message(self, message: Dict[str, Any]) -> str: """Build HTML for a single message.""" - role = escape(message.get("role", "user")) - content = escape(message.get("content", "")) - timestamp = message.get("timestamp") + role = escape(message.get('role', 'user')) + content = escape(message.get('content', '')) + timestamp = message.get('timestamp') - model = escape(message.get("model") if role == "assistant" else "") + model = escape(message.get('model') if role == 'assistant' else '') - date_str = escape(self.format_timestamp(timestamp) if timestamp else "") + date_str = escape(self.format_timestamp(timestamp) if timestamp else '') # extends pymdownx extension to convert markdown to html. # - https://facelessuser.github.io/pymdown-extensions/usage_notes/ # html_content = markdown(content, extensions=["pymdownx.extra"]) - content = content.replace("\n", "
") + content = content.replace('\n', '
') html_message = f"""
@@ -106,32 +106,28 @@ class PDFGenerator: # When running using `pip install` the static directory is in the site packages. if not FONTS_DIR.exists(): - FONTS_DIR = Path(site.getsitepackages()[0]) / "static/fonts" + FONTS_DIR = Path(site.getsitepackages()[0]) / 'static/fonts' # When running using `pip install -e .` the static directory is in the site packages. # This path only works if `open-webui serve` is run from the root of this project. if not FONTS_DIR.exists(): - FONTS_DIR = Path(".") / "backend" / "static" / "fonts" + FONTS_DIR = Path('.') / 'backend' / 'static' / 'fonts' - pdf.add_font("NotoSans", "", f"{FONTS_DIR}/NotoSans-Regular.ttf") - pdf.add_font("NotoSans", "b", f"{FONTS_DIR}/NotoSans-Bold.ttf") - pdf.add_font("NotoSans", "i", f"{FONTS_DIR}/NotoSans-Italic.ttf") - pdf.add_font("NotoSansKR", "", f"{FONTS_DIR}/NotoSansKR-Regular.ttf") - pdf.add_font("NotoSansJP", "", f"{FONTS_DIR}/NotoSansJP-Regular.ttf") - pdf.add_font("NotoSansSC", "", f"{FONTS_DIR}/NotoSansSC-Regular.ttf") - pdf.add_font("Twemoji", "", f"{FONTS_DIR}/Twemoji.ttf") + pdf.add_font('NotoSans', '', f'{FONTS_DIR}/NotoSans-Regular.ttf') + pdf.add_font('NotoSans', 'b', f'{FONTS_DIR}/NotoSans-Bold.ttf') + pdf.add_font('NotoSans', 'i', f'{FONTS_DIR}/NotoSans-Italic.ttf') + pdf.add_font('NotoSansKR', '', f'{FONTS_DIR}/NotoSansKR-Regular.ttf') + pdf.add_font('NotoSansJP', '', f'{FONTS_DIR}/NotoSansJP-Regular.ttf') + pdf.add_font('NotoSansSC', '', f'{FONTS_DIR}/NotoSansSC-Regular.ttf') + pdf.add_font('Twemoji', '', f'{FONTS_DIR}/Twemoji.ttf') - pdf.set_font("NotoSans", size=12) - pdf.set_fallback_fonts( - ["NotoSansKR", "NotoSansJP", "NotoSansSC", "Twemoji"] - ) + pdf.set_font('NotoSans', size=12) + pdf.set_fallback_fonts(['NotoSansKR', 'NotoSansJP', 'NotoSansSC', 'Twemoji']) pdf.set_auto_page_break(auto=True, margin=15) # Build HTML messages - messages_html_list: List[str] = [ - self._build_html_message(msg) for msg in self.form_data.messages - ] - self.messages_html = "
" + "".join(messages_html_list) + "
" + messages_html_list: List[str] = [self._build_html_message(msg) for msg in self.form_data.messages] + self.messages_html = '
' + ''.join(messages_html_list) + '
' # Generate full HTML body self.html_body = self._generate_html_body() diff --git a/backend/open_webui/utils/plugin.py b/backend/open_webui/utils/plugin.py index a2d7b9ad11..6dae37e531 100644 --- a/backend/open_webui/utils/plugin.py +++ b/backend/open_webui/utils/plugin.py @@ -20,9 +20,7 @@ from open_webui.models.tools import Tools log = logging.getLogger(__name__) -def resolve_valves_schema_options( - valves_class: type, schema: dict, user: Any = None -) -> dict: +def resolve_valves_schema_options(valves_class: type, schema: dict, user: Any = None) -> dict: """ Resolve dynamic options in a Valves schema. @@ -66,16 +64,16 @@ def resolve_valves_schema_options( Returns: Modified schema dict with resolved options """ - if not schema or "properties" not in schema: + if not schema or 'properties' not in schema: return schema # Make a copy to avoid mutating the original schema = dict(schema) - schema["properties"] = dict(schema.get("properties", {})) + schema['properties'] = dict(schema.get('properties', {})) - for prop_name, prop_schema in list(schema["properties"].items()): + for prop_name, prop_schema in list(schema['properties'].items()): # Get the original field info from the Pydantic model - if not hasattr(valves_class, "model_fields"): + if not hasattr(valves_class, 'model_fields'): continue field_info = valves_class.model_fields.get(prop_name) @@ -87,11 +85,11 @@ def resolve_valves_schema_options( if not json_schema_extra or not isinstance(json_schema_extra, dict): continue - input_config = json_schema_extra.get("input") + input_config = json_schema_extra.get('input') if not input_config or not isinstance(input_config, dict): continue - options = input_config.get("options") + options = input_config.get('options') if options is None: continue @@ -105,9 +103,7 @@ def resolve_valves_schema_options( elif isinstance(options, str) and options: method = getattr(valves_class, options, None) if method is None or not callable(method): - log.warning( - f"options '{options}' not found or not callable on {valves_class.__name__}" - ) + log.warning(f"options '{options}' not found or not callable on {valves_class.__name__}") continue try: @@ -118,40 +114,32 @@ def resolve_valves_schema_options( # Prepare kwargs based on what the method accepts kwargs = {} - if "__user__" in params and user is not None: - kwargs["__user__"] = ( - user.model_dump() if hasattr(user, "model_dump") else user - ) - if "user" in params and user is not None: - kwargs["user"] = ( - user.model_dump() if hasattr(user, "model_dump") else user - ) + if '__user__' in params and user is not None: + kwargs['__user__'] = user.model_dump() if hasattr(user, 'model_dump') else user + if 'user' in params and user is not None: + kwargs['user'] = user.model_dump() if hasattr(user, 'model_dump') else user resolved_options = method(**kwargs) if kwargs else method() # Validate return type if not isinstance(resolved_options, list): - log.warning( - f"Method '{options}' did not return a list for {prop_name}" - ) + log.warning(f"Method '{options}' did not return a list for {prop_name}") continue except Exception as e: - log.warning(f"Failed to resolve options for {prop_name}: {e}") + log.warning(f'Failed to resolve options for {prop_name}: {e}') continue else: # Invalid options type - skip continue # Update the schema with resolved options - schema["properties"][prop_name] = dict(prop_schema) - if "input" not in schema["properties"][prop_name]: - schema["properties"][prop_name]["input"] = {"type": "select"} + schema['properties'][prop_name] = dict(prop_schema) + if 'input' not in schema['properties'][prop_name]: + schema['properties'][prop_name]['input'] = {'type': 'select'} else: - schema["properties"][prop_name]["input"] = dict( - schema["properties"][prop_name].get("input", {}) - ) - schema["properties"][prop_name]["input"]["options"] = resolved_options + schema['properties'][prop_name]['input'] = dict(schema['properties'][prop_name].get('input', {})) + schema['properties'][prop_name]['input']['options'] = resolved_options return schema @@ -163,7 +151,7 @@ def extract_frontmatter(content): frontmatter = {} frontmatter_started = False frontmatter_ended = False - frontmatter_pattern = re.compile(r"^\s*([a-z_]+):\s*(.*)\s*$", re.IGNORECASE) + frontmatter_pattern = re.compile(r'^\s*([a-z_]+):\s*(.*)\s*$', re.IGNORECASE) try: lines = content.splitlines() @@ -186,7 +174,7 @@ def extract_frontmatter(content): frontmatter[key.strip()] = value.strip() except Exception as e: - log.exception(f"Failed to extract frontmatter: {e}") + log.exception(f'Failed to extract frontmatter: {e}') return {} return frontmatter @@ -197,10 +185,10 @@ def replace_imports(content): Replace the import paths in the content. """ replacements = { - "from utils": "from open_webui.utils", - "from apps": "from open_webui.apps", - "from main": "from open_webui.main", - "from config": "from open_webui.config", + 'from utils': 'from open_webui.utils', + 'from apps': 'from open_webui.apps', + 'from main': 'from open_webui.main', + 'from config': 'from open_webui.config', } for old, new in replacements.items(): @@ -210,22 +198,21 @@ def replace_imports(content): def load_tool_module_by_id(tool_id, content=None): - if content is None: tool = Tools.get_tool_by_id(tool_id) if not tool: - raise Exception(f"Toolkit not found: {tool_id}") + raise Exception(f'Toolkit not found: {tool_id}') content = tool.content content = replace_imports(content) - Tools.update_tool_by_id(tool_id, {"content": content}) + Tools.update_tool_by_id(tool_id, {'content': content}) else: frontmatter = extract_frontmatter(content) # Install required packages found within the frontmatter - install_frontmatter_requirements(frontmatter.get("requirements", "")) + install_frontmatter_requirements(frontmatter.get('requirements', '')) - module_name = f"tool_{tool_id}" + module_name = f'tool_{tool_id}' module = types.ModuleType(module_name) sys.modules[module_name] = module @@ -234,22 +221,22 @@ def load_tool_module_by_id(tool_id, content=None): temp_file = tempfile.NamedTemporaryFile(delete=False) temp_file.close() try: - with open(temp_file.name, "w", encoding="utf-8") as f: + with open(temp_file.name, 'w', encoding='utf-8') as f: f.write(content) - module.__dict__["__file__"] = temp_file.name + module.__dict__['__file__'] = temp_file.name # Executing the modified content in the created module's namespace exec(content, module.__dict__) frontmatter = extract_frontmatter(content) - log.info(f"Loaded module: {module.__name__}") + log.info(f'Loaded module: {module.__name__}') # Create and return the object if the class 'Tools' is found in the module - if hasattr(module, "Tools"): + if hasattr(module, 'Tools'): return module.Tools(), frontmatter else: - raise Exception("No Tools class found in the module") + raise Exception('No Tools class found in the module') except Exception as e: - log.error(f"Error loading module: {tool_id}: {e}") + log.error(f'Error loading module: {tool_id}: {e}') del sys.modules[module_name] # Clean up raise e finally: @@ -260,16 +247,16 @@ def load_function_module_by_id(function_id: str, content: str | None = None): if content is None: function = Functions.get_function_by_id(function_id) if not function: - raise Exception(f"Function not found: {function_id}") + raise Exception(f'Function not found: {function_id}') content = function.content content = replace_imports(content) - Functions.update_function_by_id(function_id, {"content": content}) + Functions.update_function_by_id(function_id, {'content': content}) else: frontmatter = extract_frontmatter(content) - install_frontmatter_requirements(frontmatter.get("requirements", "")) + install_frontmatter_requirements(frontmatter.get('requirements', '')) - module_name = f"function_{function_id}" + module_name = f'function_{function_id}' module = types.ModuleType(module_name) sys.modules[module_name] = module @@ -278,30 +265,30 @@ def load_function_module_by_id(function_id: str, content: str | None = None): temp_file = tempfile.NamedTemporaryFile(delete=False) temp_file.close() try: - with open(temp_file.name, "w", encoding="utf-8") as f: + with open(temp_file.name, 'w', encoding='utf-8') as f: f.write(content) - module.__dict__["__file__"] = temp_file.name + module.__dict__['__file__'] = temp_file.name # Execute the modified content in the created module's namespace exec(content, module.__dict__) frontmatter = extract_frontmatter(content) - log.info(f"Loaded module: {module.__name__}") + log.info(f'Loaded module: {module.__name__}') # Create appropriate object based on available class type in the module - if hasattr(module, "Pipe"): - return module.Pipe(), "pipe", frontmatter - elif hasattr(module, "Filter"): - return module.Filter(), "filter", frontmatter - elif hasattr(module, "Action"): - return module.Action(), "action", frontmatter + if hasattr(module, 'Pipe'): + return module.Pipe(), 'pipe', frontmatter + elif hasattr(module, 'Filter'): + return module.Filter(), 'filter', frontmatter + elif hasattr(module, 'Action'): + return module.Action(), 'action', frontmatter else: - raise Exception("No Function class found in the module") + raise Exception('No Function class found in the module') except Exception as e: - log.error(f"Error loading module: {function_id}: {e}") + log.error(f'Error loading module: {function_id}: {e}') # Cleanup by removing the module in case of error del sys.modules[module_name] - Functions.update_function_by_id(function_id, {"is_active": False}) + Functions.update_function_by_id(function_id, {'is_active': False}) raise e finally: os.unlink(temp_file.name) @@ -312,35 +299,32 @@ def get_tool_module_from_cache(request, tool_id, load_from_db=True): # Always load from the database by default tool = Tools.get_tool_by_id(tool_id) if not tool: - raise Exception(f"Tool not found: {tool_id}") + raise Exception(f'Tool not found: {tool_id}') content = tool.content new_content = replace_imports(content) if new_content != content: content = new_content # Update the tool content in the database - Tools.update_tool_by_id(tool_id, {"content": content}) + Tools.update_tool_by_id(tool_id, {'content': content}) - if ( - hasattr(request.app.state, "TOOL_CONTENTS") - and tool_id in request.app.state.TOOL_CONTENTS - ) and ( - hasattr(request.app.state, "TOOLS") and tool_id in request.app.state.TOOLS + if (hasattr(request.app.state, 'TOOL_CONTENTS') and tool_id in request.app.state.TOOL_CONTENTS) and ( + hasattr(request.app.state, 'TOOLS') and tool_id in request.app.state.TOOLS ): if request.app.state.TOOL_CONTENTS[tool_id] == content: return request.app.state.TOOLS[tool_id], None tool_module, frontmatter = load_tool_module_by_id(tool_id, content) else: - if hasattr(request.app.state, "TOOLS") and tool_id in request.app.state.TOOLS: + if hasattr(request.app.state, 'TOOLS') and tool_id in request.app.state.TOOLS: return request.app.state.TOOLS[tool_id], None tool_module, frontmatter = load_tool_module_by_id(tool_id) - if not hasattr(request.app.state, "TOOLS"): + if not hasattr(request.app.state, 'TOOLS'): request.app.state.TOOLS = {} - if not hasattr(request.app.state, "TOOL_CONTENTS"): + if not hasattr(request.app.state, 'TOOL_CONTENTS'): request.app.state.TOOL_CONTENTS = {} request.app.state.TOOLS[tool_id] = tool_module @@ -357,46 +341,35 @@ def get_function_module_from_cache(request, function_id, load_from_db=True): function = Functions.get_function_by_id(function_id) if not function: - raise Exception(f"Function not found: {function_id}") + raise Exception(f'Function not found: {function_id}') content = function.content new_content = replace_imports(content) if new_content != content: content = new_content # Update the function content in the database - Functions.update_function_by_id(function_id, {"content": content}) + Functions.update_function_by_id(function_id, {'content': content}) if ( - hasattr(request.app.state, "FUNCTION_CONTENTS") - and function_id in request.app.state.FUNCTION_CONTENTS - ) and ( - hasattr(request.app.state, "FUNCTIONS") - and function_id in request.app.state.FUNCTIONS - ): + hasattr(request.app.state, 'FUNCTION_CONTENTS') and function_id in request.app.state.FUNCTION_CONTENTS + ) and (hasattr(request.app.state, 'FUNCTIONS') and function_id in request.app.state.FUNCTIONS): if request.app.state.FUNCTION_CONTENTS[function_id] == content: return request.app.state.FUNCTIONS[function_id], None, None - function_module, function_type, frontmatter = load_function_module_by_id( - function_id, content - ) + function_module, function_type, frontmatter = load_function_module_by_id(function_id, content) else: # Load from cache (e.g. "stream" hook) # This is useful for performance reasons - if ( - hasattr(request.app.state, "FUNCTIONS") - and function_id in request.app.state.FUNCTIONS - ): + if hasattr(request.app.state, 'FUNCTIONS') and function_id in request.app.state.FUNCTIONS: return request.app.state.FUNCTIONS[function_id], None, None - function_module, function_type, frontmatter = load_function_module_by_id( - function_id - ) + function_module, function_type, frontmatter = load_function_module_by_id(function_id) - if not hasattr(request.app.state, "FUNCTIONS"): + if not hasattr(request.app.state, 'FUNCTIONS'): request.app.state.FUNCTIONS = {} - if not hasattr(request.app.state, "FUNCTION_CONTENTS"): + if not hasattr(request.app.state, 'FUNCTION_CONTENTS'): request.app.state.FUNCTION_CONTENTS = {} request.app.state.FUNCTIONS[function_id] = function_module @@ -407,31 +380,26 @@ def get_function_module_from_cache(request, function_id, load_from_db=True): def install_frontmatter_requirements(requirements: str): if not ENABLE_PIP_INSTALL_FRONTMATTER_REQUIREMENTS: - log.info( - "ENABLE_PIP_INSTALL_FRONTMATTER_REQUIREMENTS is disabled, skipping installation of requirements." - ) + log.info('ENABLE_PIP_INSTALL_FRONTMATTER_REQUIREMENTS is disabled, skipping installation of requirements.') return if OFFLINE_MODE: - log.info("Offline mode enabled, skipping installation of requirements.") + log.info('Offline mode enabled, skipping installation of requirements.') return if requirements: try: - req_list = [req.strip() for req in requirements.split(",")] - log.info(f"Installing requirements: {' '.join(req_list)}") + req_list = [req.strip() for req in requirements.split(',')] + log.info(f'Installing requirements: {" ".join(req_list)}') subprocess.check_call( - [sys.executable, "-m", "pip", "install"] - + PIP_OPTIONS - + req_list - + PIP_PACKAGE_INDEX_OPTIONS + [sys.executable, '-m', 'pip', 'install'] + PIP_OPTIONS + req_list + PIP_PACKAGE_INDEX_OPTIONS ) except Exception as e: - log.error(f"Error installing packages: {' '.join(req_list)}") + log.error(f'Error installing packages: {" ".join(req_list)}') raise e else: - log.info("No requirements found in frontmatter.") + log.info('No requirements found in frontmatter.') def install_tool_and_function_dependencies(): @@ -445,19 +413,19 @@ def install_tool_and_function_dependencies(): function_list = Functions.get_functions(active_only=True) tool_list = Tools.get_tools() - all_dependencies = "" + all_dependencies = '' try: for function in function_list: frontmatter = extract_frontmatter(replace_imports(function.content)) - if dependencies := frontmatter.get("requirements"): - all_dependencies += f"{dependencies}, " + if dependencies := frontmatter.get('requirements'): + all_dependencies += f'{dependencies}, ' for tool in tool_list: # Only install requirements for admin tools - if tool.user and tool.user.role == "admin": + if tool.user and tool.user.role == 'admin': frontmatter = extract_frontmatter(replace_imports(tool.content)) - if dependencies := frontmatter.get("requirements"): - all_dependencies += f"{dependencies}, " + if dependencies := frontmatter.get('requirements'): + all_dependencies += f'{dependencies}, ' - install_frontmatter_requirements(all_dependencies.strip(", ")) + install_frontmatter_requirements(all_dependencies.strip(', ')) except Exception as e: - log.error(f"Error installing requirements: {e}") + log.error(f'Error installing requirements: {e}') diff --git a/backend/open_webui/utils/rate_limit.py b/backend/open_webui/utils/rate_limit.py index b657a937ab..93f3851d1f 100644 --- a/backend/open_webui/utils/rate_limit.py +++ b/backend/open_webui/utils/rate_limit.py @@ -35,7 +35,7 @@ class RateLimiter: self.enabled = enabled def _bucket_key(self, key: str, bucket_index: int) -> str: - return f"{REDIS_KEY_PREFIX}:ratelimit:{key.lower()}:{bucket_index}" + return f'{REDIS_KEY_PREFIX}:ratelimit:{key.lower()}:{bucket_index}' def _current_bucket(self) -> int: return int(time.time()) // self.bucket_size @@ -84,9 +84,7 @@ class RateLimiter: self.r.expire(bucket_key, self.window + self.bucket_size) # Collect buckets - buckets = [ - self._bucket_key(key, now_bucket - i) for i in range(self.num_buckets + 1) - ] + buckets = [self._bucket_key(key, now_bucket - i) for i in range(self.num_buckets + 1)] counts = self.r.mget(buckets) total = sum(int(c) for c in counts if c) @@ -95,9 +93,7 @@ class RateLimiter: def _get_count_redis(self, key: str) -> int: now_bucket = self._current_bucket() - buckets = [ - self._bucket_key(key, now_bucket - i) for i in range(self.num_buckets + 1) - ] + buckets = [self._bucket_key(key, now_bucket - i) for i in range(self.num_buckets + 1)] counts = self.r.mget(buckets) return sum(int(c) for c in counts if c) diff --git a/backend/open_webui/utils/redis.py b/backend/open_webui/utils/redis.py index fcc4879ba3..a4d9d5cba5 100644 --- a/backend/open_webui/utils/redis.py +++ b/backend/open_webui/utils/redis.py @@ -40,7 +40,7 @@ class SentinelRedisProxy: if not callable(orig_attr): return orig_attr - FACTORY_METHODS = {"pipeline", "pubsub", "monitor", "client", "transaction"} + FACTORY_METHODS = {'pipeline', 'pubsub', 'monitor', 'client', 'transaction'} if item in FACTORY_METHODS: return orig_attr @@ -61,7 +61,7 @@ class SentinelRedisProxy: ) as e: if i < REDIS_SENTINEL_MAX_RETRY_COUNT - 1: log.debug( - "Redis sentinel fail-over (%s). Retry %s/%s", + 'Redis sentinel fail-over (%s). Retry %s/%s', type(e).__name__, i + 1, REDIS_SENTINEL_MAX_RETRY_COUNT, @@ -70,7 +70,7 @@ class SentinelRedisProxy: time.sleep(REDIS_RECONNECT_DELAY / 1000) continue log.error( - "Redis operation failed after %s retries: %s", + 'Redis operation failed after %s retries: %s', REDIS_SENTINEL_MAX_RETRY_COUNT, e, ) @@ -94,7 +94,7 @@ class SentinelRedisProxy: ) as e: if i < REDIS_SENTINEL_MAX_RETRY_COUNT - 1: log.debug( - "Redis sentinel fail-over (%s). Retry %s/%s", + 'Redis sentinel fail-over (%s). Retry %s/%s', type(e).__name__, i + 1, REDIS_SENTINEL_MAX_RETRY_COUNT, @@ -103,7 +103,7 @@ class SentinelRedisProxy: await asyncio.sleep(REDIS_RECONNECT_DELAY / 1000) continue log.error( - "Redis operation failed after %s retries: %s", + 'Redis operation failed after %s retries: %s', REDIS_SENTINEL_MAX_RETRY_COUNT, e, ) @@ -124,7 +124,7 @@ class SentinelRedisProxy: ) as e: if i < REDIS_SENTINEL_MAX_RETRY_COUNT - 1: log.debug( - "Redis sentinel fail-over (%s). Retry %s/%s", + 'Redis sentinel fail-over (%s). Retry %s/%s', type(e).__name__, i + 1, REDIS_SENTINEL_MAX_RETRY_COUNT, @@ -133,7 +133,7 @@ class SentinelRedisProxy: time.sleep(REDIS_RECONNECT_DELAY / 1000) continue log.error( - "Redis operation failed after %s retries: %s", + 'Redis operation failed after %s retries: %s', REDIS_SENTINEL_MAX_RETRY_COUNT, e, ) @@ -144,15 +144,15 @@ class SentinelRedisProxy: def parse_redis_service_url(redis_url): parsed_url = urlparse(redis_url) - if parsed_url.scheme != "redis" and parsed_url.scheme != "rediss": + if parsed_url.scheme != 'redis' and parsed_url.scheme != 'rediss': raise ValueError("Invalid Redis URL scheme. Must be 'redis' or 'rediss'.") return { - "username": parsed_url.username or None, - "password": parsed_url.password or None, - "service": parsed_url.hostname or "mymaster", - "port": parsed_url.port or 6379, - "db": int(parsed_url.path.lstrip("/") or 0), + 'username': parsed_url.username or None, + 'password': parsed_url.password or None, + 'service': parsed_url.hostname or 'mymaster', + 'port': parsed_url.port or 6379, + 'db': int(parsed_url.path.lstrip('/') or 0), } @@ -160,14 +160,12 @@ def get_redis_client(async_mode=False): try: return get_redis_connection( redis_url=REDIS_URL, - redis_sentinels=get_sentinels_from_env( - REDIS_SENTINEL_HOSTS, REDIS_SENTINEL_PORT - ), + redis_sentinels=get_sentinels_from_env(REDIS_SENTINEL_HOSTS, REDIS_SENTINEL_PORT), redis_cluster=REDIS_CLUSTER, async_mode=async_mode, ) except Exception as e: - log.debug(f"Failed to get Redis client: {e}") + log.debug(f'Failed to get Redis client: {e}') return None @@ -178,7 +176,6 @@ def get_redis_connection( async_mode=False, decode_responses=True, ): - cache_key = ( redis_url, tuple(redis_sentinels) if redis_sentinels else (), @@ -199,24 +196,22 @@ def get_redis_connection( redis_config = parse_redis_service_url(redis_url) sentinel = redis.sentinel.Sentinel( redis_sentinels, - port=redis_config["port"], - db=redis_config["db"], - username=redis_config["username"], - password=redis_config["password"], + port=redis_config['port'], + db=redis_config['db'], + username=redis_config['username'], + password=redis_config['password'], decode_responses=decode_responses, socket_connect_timeout=REDIS_SOCKET_CONNECT_TIMEOUT, ) connection = SentinelRedisProxy( sentinel, - redis_config["service"], + redis_config['service'], async_mode=async_mode, ) elif redis_cluster: if not redis_url: - raise ValueError("Redis URL must be provided for cluster mode.") - return redis.cluster.RedisCluster.from_url( - redis_url, decode_responses=decode_responses - ) + raise ValueError('Redis URL must be provided for cluster mode.') + return redis.cluster.RedisCluster.from_url(redis_url, decode_responses=decode_responses) elif redis_url: connection = redis.from_url(redis_url, decode_responses=decode_responses) else: @@ -226,28 +221,24 @@ def get_redis_connection( redis_config = parse_redis_service_url(redis_url) sentinel = redis.sentinel.Sentinel( redis_sentinels, - port=redis_config["port"], - db=redis_config["db"], - username=redis_config["username"], - password=redis_config["password"], + port=redis_config['port'], + db=redis_config['db'], + username=redis_config['username'], + password=redis_config['password'], decode_responses=decode_responses, socket_connect_timeout=REDIS_SOCKET_CONNECT_TIMEOUT, ) connection = SentinelRedisProxy( sentinel, - redis_config["service"], + redis_config['service'], async_mode=async_mode, ) elif redis_cluster: if not redis_url: - raise ValueError("Redis URL must be provided for cluster mode.") - return redis.cluster.RedisCluster.from_url( - redis_url, decode_responses=decode_responses - ) + raise ValueError('Redis URL must be provided for cluster mode.') + return redis.cluster.RedisCluster.from_url(redis_url, decode_responses=decode_responses) elif redis_url: - connection = redis.Redis.from_url( - redis_url, decode_responses=decode_responses - ) + connection = redis.Redis.from_url(redis_url, decode_responses=decode_responses) _CONNECTION_CACHE[cache_key] = connection return connection @@ -255,7 +246,7 @@ def get_redis_connection( def get_sentinels_from_env(sentinel_hosts_env, sentinel_port_env): if sentinel_hosts_env: - sentinel_hosts = sentinel_hosts_env.split(",") + sentinel_hosts = sentinel_hosts_env.split(',') sentinel_port = int(sentinel_port_env) return [(host, sentinel_port) for host in sentinel_hosts] return [] @@ -263,12 +254,10 @@ def get_sentinels_from_env(sentinel_hosts_env, sentinel_port_env): def get_sentinel_url_from_env(redis_url, sentinel_hosts_env, sentinel_port_env): redis_config = parse_redis_service_url(redis_url) - username = redis_config["username"] or "" - password = redis_config["password"] or "" - auth_part = "" + username = redis_config['username'] or '' + password = redis_config['password'] or '' + auth_part = '' if username or password: - auth_part = f"{username}:{password}@" - hosts_part = ",".join( - f"{host}:{sentinel_port_env}" for host in sentinel_hosts_env.split(",") - ) - return f"redis+sentinel://{auth_part}{hosts_part}/{redis_config['db']}/{redis_config['service']}" + auth_part = f'{username}:{password}@' + hosts_part = ','.join(f'{host}:{sentinel_port_env}' for host in sentinel_hosts_env.split(',')) + return f'redis+sentinel://{auth_part}{hosts_part}/{redis_config["db"]}/{redis_config["service"]}' diff --git a/backend/open_webui/utils/response.py b/backend/open_webui/utils/response.py index 8785e374b0..ae911368a3 100644 --- a/backend/open_webui/utils/response.py +++ b/backend/open_webui/utils/response.py @@ -21,28 +21,28 @@ def normalize_usage(usage: dict) -> dict: # Map various field names to standard names input_tokens = ( - usage.get("input_tokens") # Already standard - or usage.get("prompt_tokens") # OpenAI - or usage.get("prompt_eval_count") # Ollama - or usage.get("prompt_n") # llama.cpp + usage.get('input_tokens') # Already standard + or usage.get('prompt_tokens') # OpenAI + or usage.get('prompt_eval_count') # Ollama + or usage.get('prompt_n') # llama.cpp or 0 ) output_tokens = ( - usage.get("output_tokens") # Already standard - or usage.get("completion_tokens") # OpenAI - or usage.get("eval_count") # Ollama - or usage.get("predicted_n") # llama.cpp + usage.get('output_tokens') # Already standard + or usage.get('completion_tokens') # OpenAI + or usage.get('eval_count') # Ollama + or usage.get('predicted_n') # llama.cpp or 0 ) - total_tokens = usage.get("total_tokens") or (input_tokens + output_tokens) + total_tokens = usage.get('total_tokens') or (input_tokens + output_tokens) # Add standardized fields to original data result = dict(usage) - result["input_tokens"] = int(input_tokens) - result["output_tokens"] = int(output_tokens) - result["total_tokens"] = int(total_tokens) + result['input_tokens'] = int(input_tokens) + result['output_tokens'] = int(output_tokens) + result['total_tokens'] = int(total_tokens) return result @@ -50,14 +50,14 @@ def normalize_usage(usage: dict) -> dict: def convert_ollama_tool_call_to_openai(tool_calls: list) -> list: openai_tool_calls = [] for tool_call in tool_calls: - function = tool_call.get("function", {}) + function = tool_call.get('function', {}) openai_tool_call = { - "index": tool_call.get("index", function.get("index", 0)), - "id": tool_call.get("id", f"call_{str(uuid4())}"), - "type": "function", - "function": { - "name": function.get("name", ""), - "arguments": json.dumps(function.get("arguments", {})), + 'index': tool_call.get('index', function.get('index', 0)), + 'id': tool_call.get('id', f'call_{str(uuid4())}'), + 'type': 'function', + 'function': { + 'name': function.get('name', ''), + 'arguments': json.dumps(function.get('arguments', {})), }, } openai_tool_calls.append(openai_tool_call) @@ -65,69 +65,57 @@ def convert_ollama_tool_call_to_openai(tool_calls: list) -> list: def convert_ollama_usage_to_openai(data: dict) -> dict: - input_tokens = int(data.get("prompt_eval_count", 0)) - output_tokens = int(data.get("eval_count", 0)) + input_tokens = int(data.get('prompt_eval_count', 0)) + output_tokens = int(data.get('eval_count', 0)) total_tokens = input_tokens + output_tokens return { # Standardized fields - "input_tokens": input_tokens, - "output_tokens": output_tokens, - "total_tokens": total_tokens, + 'input_tokens': input_tokens, + 'output_tokens': output_tokens, + 'total_tokens': total_tokens, # OpenAI-compatible fields (for backward compatibility) - "prompt_tokens": input_tokens, - "completion_tokens": output_tokens, + 'prompt_tokens': input_tokens, + 'completion_tokens': output_tokens, # Ollama-specific metrics - "response_token/s": ( + 'response_token/s': ( round( - ( - ( - data.get("eval_count", 0) - / ((data.get("eval_duration", 0) / 10_000_000)) - ) - * 100 - ), + ((data.get('eval_count', 0) / (data.get('eval_duration', 0) / 10_000_000)) * 100), 2, ) - if data.get("eval_duration", 0) > 0 - else "N/A" + if data.get('eval_duration', 0) > 0 + else 'N/A' ), - "prompt_token/s": ( + 'prompt_token/s': ( round( - ( - ( - data.get("prompt_eval_count", 0) - / ((data.get("prompt_eval_duration", 0) / 10_000_000)) - ) - * 100 - ), + ((data.get('prompt_eval_count', 0) / (data.get('prompt_eval_duration', 0) / 10_000_000)) * 100), 2, ) - if data.get("prompt_eval_duration", 0) > 0 - else "N/A" + if data.get('prompt_eval_duration', 0) > 0 + else 'N/A' ), - "total_duration": data.get("total_duration", 0), - "load_duration": data.get("load_duration", 0), - "prompt_eval_count": data.get("prompt_eval_count", 0), - "prompt_eval_duration": data.get("prompt_eval_duration", 0), - "eval_count": data.get("eval_count", 0), - "eval_duration": data.get("eval_duration", 0), - "approximate_total": (lambda s: f"{s // 3600}h{(s % 3600) // 60}m{s % 60}s")( - (data.get("total_duration", 0) or 0) // 1_000_000_000 + 'total_duration': data.get('total_duration', 0), + 'load_duration': data.get('load_duration', 0), + 'prompt_eval_count': data.get('prompt_eval_count', 0), + 'prompt_eval_duration': data.get('prompt_eval_duration', 0), + 'eval_count': data.get('eval_count', 0), + 'eval_duration': data.get('eval_duration', 0), + 'approximate_total': (lambda s: f'{s // 3600}h{(s % 3600) // 60}m{s % 60}s')( + (data.get('total_duration', 0) or 0) // 1_000_000_000 ), - "completion_tokens_details": { - "reasoning_tokens": 0, - "accepted_prediction_tokens": 0, - "rejected_prediction_tokens": 0, + 'completion_tokens_details': { + 'reasoning_tokens': 0, + 'accepted_prediction_tokens': 0, + 'rejected_prediction_tokens': 0, }, } def convert_response_ollama_to_openai(ollama_response: dict) -> dict: - model = ollama_response.get("model", "ollama") - message_content = ollama_response.get("message", {}).get("content", "") - reasoning_content = ollama_response.get("message", {}).get("thinking", None) - tool_calls = ollama_response.get("message", {}).get("tool_calls", None) + model = ollama_response.get('model', 'ollama') + message_content = ollama_response.get('message', {}).get('content', '') + reasoning_content = ollama_response.get('message', {}).get('thinking', None) + tool_calls = ollama_response.get('message', {}).get('tool_calls', None) openai_tool_calls = None if tool_calls: @@ -148,33 +136,31 @@ async def convert_streaming_response_ollama_to_openai(ollama_streaming_response) async for data in ollama_streaming_response.body_iterator: data = json.loads(data) - model = data.get("model", "ollama") - message_content = data.get("message", {}).get("content", None) - reasoning_content = data.get("message", {}).get("thinking", None) - tool_calls = data.get("message", {}).get("tool_calls", None) + model = data.get('model', 'ollama') + message_content = data.get('message', {}).get('content', None) + reasoning_content = data.get('message', {}).get('thinking', None) + tool_calls = data.get('message', {}).get('tool_calls', None) openai_tool_calls = None if tool_calls: openai_tool_calls = convert_ollama_tool_call_to_openai(tool_calls) has_tool_calls = True - done = data.get("done", False) + done = data.get('done', False) usage = None if done: usage = convert_ollama_usage_to_openai(data) - data = openai_chat_chunk_message_template( - model, message_content, reasoning_content, openai_tool_calls, usage - ) + data = openai_chat_chunk_message_template(model, message_content, reasoning_content, openai_tool_calls, usage) if done and has_tool_calls: - data["choices"][0]["finish_reason"] = "tool_calls" + data['choices'][0]['finish_reason'] = 'tool_calls' - line = f"data: {json.dumps(data)}\n\n" + line = f'data: {json.dumps(data)}\n\n' yield line - yield "data: [DONE]\n\n" + yield 'data: [DONE]\n\n' def convert_embedding_response_ollama_to_openai(response) -> dict: @@ -199,51 +185,47 @@ def convert_embedding_response_ollama_to_openai(response) -> dict: """ # Ollama batch-style output from /api/embed # Response format: {"embeddings": [[0.1, 0.2, ...], [0.3, 0.4, ...]], "model": "..."} - if isinstance(response, dict) and "embeddings" in response: + if isinstance(response, dict) and 'embeddings' in response: openai_data = [] - for i, emb in enumerate(response["embeddings"]): + for i, emb in enumerate(response['embeddings']): # /api/embed returns embeddings as plain float lists if isinstance(emb, list): openai_data.append( { - "object": "embedding", - "embedding": emb, - "index": i, + 'object': 'embedding', + 'embedding': emb, + 'index': i, } ) # Also handle dict format for robustness elif isinstance(emb, dict): openai_data.append( { - "object": "embedding", - "embedding": emb.get("embedding"), - "index": emb.get("index", i), + 'object': 'embedding', + 'embedding': emb.get('embedding'), + 'index': emb.get('index', i), } ) return { - "object": "list", - "data": openai_data, - "model": response.get("model"), + 'object': 'list', + 'data': openai_data, + 'model': response.get('model'), } # Ollama single output - elif isinstance(response, dict) and "embedding" in response: + elif isinstance(response, dict) and 'embedding' in response: return { - "object": "list", - "data": [ + 'object': 'list', + 'data': [ { - "object": "embedding", - "embedding": response["embedding"], - "index": 0, + 'object': 'embedding', + 'embedding': response['embedding'], + 'index': 0, } ], - "model": response.get("model"), + 'model': response.get('model'), } # Already OpenAI-compatible? - elif ( - isinstance(response, dict) - and "data" in response - and isinstance(response["data"], list) - ): + elif isinstance(response, dict) and 'data' in response and isinstance(response['data'], list): return response # Fallback: return as is if unrecognized diff --git a/backend/open_webui/utils/sanitize.py b/backend/open_webui/utils/sanitize.py index 258b6d78fb..7b65df375a 100644 --- a/backend/open_webui/utils/sanitize.py +++ b/backend/open_webui/utils/sanitize.py @@ -2,9 +2,7 @@ import re # ANSI escape code pattern - matches all common ANSI sequences # This includes color codes, cursor movement, and other terminal control sequences -ANSI_ESCAPE_PATTERN = re.compile( - r"\x1b\[[0-9;]*[A-Za-z]|\x1b\([AB]|\x1b[PX^_].*?\x1b\\|\x1b\].*?(?:\x07|\x1b\\)" -) +ANSI_ESCAPE_PATTERN = re.compile(r'\x1b\[[0-9;]*[A-Za-z]|\x1b\([AB]|\x1b[PX^_].*?\x1b\\|\x1b\].*?(?:\x07|\x1b\\)') def strip_ansi_codes(text: str) -> str: @@ -20,7 +18,7 @@ def strip_ansi_codes(text: str) -> str: - Reset codes: \x1b[0m, \x1b[39m - Cursor movement: \x1b[1A, \x1b[2J, etc. """ - return ANSI_ESCAPE_PATTERN.sub("", text) + return ANSI_ESCAPE_PATTERN.sub('', text) def strip_markdown_code_fences(code: str) -> str: @@ -37,9 +35,9 @@ def strip_markdown_code_fences(code: str) -> str: """ code = code.strip() # Remove opening fence (```python, ```py, ``` etc.) - code = re.sub(r"^```\w*\n?", "", code) + code = re.sub(r'^```\w*\n?', '', code) # Remove closing fence - code = re.sub(r"\n?```\s*$", "", code) + code = re.sub(r'\n?```\s*$', '', code) return code.strip() diff --git a/backend/open_webui/utils/security_headers.py b/backend/open_webui/utils/security_headers.py index 3b31c2c05c..33956688a1 100644 --- a/backend/open_webui/utils/security_headers.py +++ b/backend/open_webui/utils/security_headers.py @@ -39,16 +39,16 @@ def set_security_headers() -> Dict[str, str]: """ options = {} header_setters = { - "CACHE_CONTROL": set_cache_control, - "HSTS": set_hsts, - "PERMISSIONS_POLICY": set_permissions_policy, - "REFERRER_POLICY": set_referrer, - "XCONTENT_TYPE": set_xcontent_type, - "XDOWNLOAD_OPTIONS": set_xdownload_options, - "XFRAME_OPTIONS": set_xframe, - "XPERMITTED_CROSS_DOMAIN_POLICIES": set_xpermitted_cross_domain_policies, - "CONTENT_SECURITY_POLICY": set_content_security_policy, - "REPORTING_ENDPOINTS": set_reporting_endpoints, + 'CACHE_CONTROL': set_cache_control, + 'HSTS': set_hsts, + 'PERMISSIONS_POLICY': set_permissions_policy, + 'REFERRER_POLICY': set_referrer, + 'XCONTENT_TYPE': set_xcontent_type, + 'XDOWNLOAD_OPTIONS': set_xdownload_options, + 'XFRAME_OPTIONS': set_xframe, + 'XPERMITTED_CROSS_DOMAIN_POLICIES': set_xpermitted_cross_domain_policies, + 'CONTENT_SECURITY_POLICY': set_content_security_policy, + 'REPORTING_ENDPOINTS': set_reporting_endpoints, } for env_var, setter in header_setters.items(): @@ -63,78 +63,78 @@ def set_security_headers() -> Dict[str, str]: # Set HTTP Strict Transport Security(HSTS) response header def set_hsts(value: str): - pattern = r"^max-age=(\d+)(;includeSubDomains)?(;preload)?$" + pattern = r'^max-age=(\d+)(;includeSubDomains)?(;preload)?$' match = re.match(pattern, value, re.IGNORECASE) if not match: - value = "max-age=31536000;includeSubDomains" - return {"Strict-Transport-Security": value} + value = 'max-age=31536000;includeSubDomains' + return {'Strict-Transport-Security': value} # Set X-Frame-Options response header def set_xframe(value: str): - pattern = r"^(DENY|SAMEORIGIN)$" + pattern = r'^(DENY|SAMEORIGIN)$' match = re.match(pattern, value, re.IGNORECASE) if not match: - value = "DENY" - return {"X-Frame-Options": value} + value = 'DENY' + return {'X-Frame-Options': value} # Set Permissions-Policy response header def set_permissions_policy(value: str): - pattern = r"^(?:(accelerometer|autoplay|camera|clipboard-read|clipboard-write|fullscreen|geolocation|gyroscope|magnetometer|microphone|midi|payment|picture-in-picture|sync-xhr|usb|xr-spatial-tracking)=\((self)?\),?)*$" + pattern = r'^(?:(accelerometer|autoplay|camera|clipboard-read|clipboard-write|fullscreen|geolocation|gyroscope|magnetometer|microphone|midi|payment|picture-in-picture|sync-xhr|usb|xr-spatial-tracking)=\((self)?\),?)*$' match = re.match(pattern, value, re.IGNORECASE) if not match: - value = "none" - return {"Permissions-Policy": value} + value = 'none' + return {'Permissions-Policy': value} # Set Referrer-Policy response header def set_referrer(value: str): - pattern = r"^(no-referrer|no-referrer-when-downgrade|origin|origin-when-cross-origin|same-origin|strict-origin|strict-origin-when-cross-origin|unsafe-url)$" + pattern = r'^(no-referrer|no-referrer-when-downgrade|origin|origin-when-cross-origin|same-origin|strict-origin|strict-origin-when-cross-origin|unsafe-url)$' match = re.match(pattern, value, re.IGNORECASE) if not match: - value = "no-referrer" - return {"Referrer-Policy": value} + value = 'no-referrer' + return {'Referrer-Policy': value} # Set Cache-Control response header def set_cache_control(value: str): - pattern = r"^(public|private|no-cache|no-store|must-revalidate|proxy-revalidate|max-age=\d+|s-maxage=\d+|no-transform|immutable)(,\s*(public|private|no-cache|no-store|must-revalidate|proxy-revalidate|max-age=\d+|s-maxage=\d+|no-transform|immutable))*$" + pattern = r'^(public|private|no-cache|no-store|must-revalidate|proxy-revalidate|max-age=\d+|s-maxage=\d+|no-transform|immutable)(,\s*(public|private|no-cache|no-store|must-revalidate|proxy-revalidate|max-age=\d+|s-maxage=\d+|no-transform|immutable))*$' match = re.match(pattern, value, re.IGNORECASE) if not match: - value = "no-store, max-age=0" + value = 'no-store, max-age=0' - return {"Cache-Control": value} + return {'Cache-Control': value} # Set X-Download-Options response header def set_xdownload_options(value: str): - if value != "noopen": - value = "noopen" - return {"X-Download-Options": value} + if value != 'noopen': + value = 'noopen' + return {'X-Download-Options': value} # Set X-Content-Type-Options response header def set_xcontent_type(value: str): - if value != "nosniff": - value = "nosniff" - return {"X-Content-Type-Options": value} + if value != 'nosniff': + value = 'nosniff' + return {'X-Content-Type-Options': value} # Set X-Permitted-Cross-Domain-Policies response header def set_xpermitted_cross_domain_policies(value: str): - pattern = r"^(none|master-only|by-content-type|by-ftp-filename)$" + pattern = r'^(none|master-only|by-content-type|by-ftp-filename)$' match = re.match(pattern, value, re.IGNORECASE) if not match: - value = "none" - return {"X-Permitted-Cross-Domain-Policies": value} + value = 'none' + return {'X-Permitted-Cross-Domain-Policies': value} # Set Content-Security-Policy response header def set_content_security_policy(value: str): - return {"Content-Security-Policy": value} + return {'Content-Security-Policy': value} # Set Reporting-Endpoints response header def set_reporting_endpoints(value: str): - return {"Reporting-Endpoints": value} + return {'Reporting-Endpoints': value} diff --git a/backend/open_webui/utils/task.py b/backend/open_webui/utils/task.py index 0ea525c93e..c640e7f5f8 100644 --- a/backend/open_webui/utils/task.py +++ b/backend/open_webui/utils/task.py @@ -13,13 +13,11 @@ from open_webui.config import DEFAULT_RAG_TEMPLATE log = logging.getLogger(__name__) -def get_task_model_id( - default_model_id: str, task_model: str, task_model_external: str, models -) -> str: +def get_task_model_id(default_model_id: str, task_model: str, task_model_external: str, models) -> str: # Set the task model task_model_id = default_model_id # Check if the user has a custom task model and use that model - if models[task_model_id].get("connection_type") == "local": + if models[task_model_id].get('connection_type') == 'local': if task_model and task_model in models: task_model_id = task_model else: @@ -36,92 +34,70 @@ def prompt_variables_template(template: str, variables: dict[str, str]) -> str: def prompt_template(template: str, user: Optional[Any] = None) -> str: - USER_VARIABLES = {} if user: - if hasattr(user, "model_dump"): + if hasattr(user, 'model_dump'): user = user.model_dump() if isinstance(user, dict): - user_info = user.get("info", {}) or {} - birth_date = user.get("date_of_birth") + user_info = user.get('info', {}) or {} + birth_date = user.get('date_of_birth') age = None if birth_date: try: # If birth_date is str, convert to datetime if isinstance(birth_date, str): - birth_date = datetime.strptime(birth_date, "%Y-%m-%d") + birth_date = datetime.strptime(birth_date, '%Y-%m-%d') today = datetime.now() - age = ( - today.year - - birth_date.year - - ( - (today.month, today.day) - < (birth_date.month, birth_date.day) - ) - ) + age = today.year - birth_date.year - ((today.month, today.day) < (birth_date.month, birth_date.day)) except Exception as e: pass USER_VARIABLES = { - "name": str(user.get("name")), - "email": str(user.get("email")), - "location": str(user_info.get("location")), - "bio": str(user.get("bio")), - "gender": str(user.get("gender")), - "birth_date": str(birth_date), - "age": str(age), + 'name': str(user.get('name')), + 'email': str(user.get('email')), + 'location': str(user_info.get('location')), + 'bio': str(user.get('bio')), + 'gender': str(user.get('gender')), + 'birth_date': str(birth_date), + 'age': str(age), } # Get the current date current_date = datetime.now() # Format the date to YYYY-MM-DD - formatted_date = current_date.strftime("%Y-%m-%d") - formatted_time = current_date.strftime("%I:%M:%S %p") - formatted_weekday = current_date.strftime("%A") + formatted_date = current_date.strftime('%Y-%m-%d') + formatted_time = current_date.strftime('%I:%M:%S %p') + formatted_weekday = current_date.strftime('%A') - template = template.replace("{{CURRENT_DATE}}", formatted_date) - template = template.replace("{{CURRENT_TIME}}", formatted_time) - template = template.replace( - "{{CURRENT_DATETIME}}", f"{formatted_date} {formatted_time}" - ) - template = template.replace("{{CURRENT_WEEKDAY}}", formatted_weekday) + template = template.replace('{{CURRENT_DATE}}', formatted_date) + template = template.replace('{{CURRENT_TIME}}', formatted_time) + template = template.replace('{{CURRENT_DATETIME}}', f'{formatted_date} {formatted_time}') + template = template.replace('{{CURRENT_WEEKDAY}}', formatted_weekday) - template = template.replace("{{USER_NAME}}", USER_VARIABLES.get("name", "Unknown")) - template = template.replace( - "{{USER_EMAIL}}", USER_VARIABLES.get("email", "Unknown") - ) - template = template.replace("{{USER_BIO}}", USER_VARIABLES.get("bio", "Unknown")) - template = template.replace( - "{{USER_GENDER}}", USER_VARIABLES.get("gender", "Unknown") - ) - template = template.replace( - "{{USER_BIRTH_DATE}}", USER_VARIABLES.get("birth_date", "Unknown") - ) - template = template.replace( - "{{USER_AGE}}", str(USER_VARIABLES.get("age", "Unknown")) - ) - template = template.replace( - "{{USER_LOCATION}}", USER_VARIABLES.get("location", "Unknown") - ) + template = template.replace('{{USER_NAME}}', USER_VARIABLES.get('name', 'Unknown')) + template = template.replace('{{USER_EMAIL}}', USER_VARIABLES.get('email', 'Unknown')) + template = template.replace('{{USER_BIO}}', USER_VARIABLES.get('bio', 'Unknown')) + template = template.replace('{{USER_GENDER}}', USER_VARIABLES.get('gender', 'Unknown')) + template = template.replace('{{USER_BIRTH_DATE}}', USER_VARIABLES.get('birth_date', 'Unknown')) + template = template.replace('{{USER_AGE}}', str(USER_VARIABLES.get('age', 'Unknown'))) + template = template.replace('{{USER_LOCATION}}', USER_VARIABLES.get('location', 'Unknown')) return template def replace_prompt_variable(template: str, prompt: str) -> str: def replacement_function(match): - full_match = match.group( - 0 - ).lower() # Normalize to lowercase for consistent handling + full_match = match.group(0).lower() # Normalize to lowercase for consistent handling start_length = match.group(1) end_length = match.group(2) middle_length = match.group(3) - if full_match == "{{prompt}}": + if full_match == '{{prompt}}': return prompt elif start_length is not None: return prompt[: int(start_length)] @@ -133,16 +109,16 @@ def replace_prompt_variable(template: str, prompt: str) -> str: return prompt start = prompt[: math.ceil(middle_length / 2)] end = prompt[-math.floor(middle_length / 2) :] - return f"{start}...{end}" - return "" + return f'{start}...{end}' + return '' # Updated regex pattern to make it case-insensitive with the `(?i)` flag - pattern = r"(?i){{prompt}}|{{prompt:start:(\d+)}}|{{prompt:end:(\d+)}}|{{prompt:middletruncate:(\d+)}}" + pattern = r'(?i){{prompt}}|{{prompt:start:(\d+)}}|{{prompt:end:(\d+)}}|{{prompt:middletruncate:(\d+)}}' template = re.sub(pattern, replacement_function, template) return template -def truncate_content(content: str, max_chars: int, mode: str = "middletruncate") -> str: +def truncate_content(content: str, max_chars: int, mode: str = 'middletruncate') -> str: """Truncate a string to max_chars using the specified mode. Modes: @@ -153,13 +129,13 @@ def truncate_content(content: str, max_chars: int, mode: str = "middletruncate") if not content or len(content) <= max_chars: return content - if mode == "start": + if mode == 'start': return content[:max_chars] - elif mode == "end": + elif mode == 'end': return content[-max_chars:] else: # middletruncate half = max_chars // 2 - return f"{content[:half]}...{content[-(max_chars - half):]}" + return f'{content[:half]}...{content[-(max_chars - half) :]}' def apply_content_filter(messages: list[dict], filter_str: str) -> list[dict]: @@ -168,7 +144,7 @@ def apply_content_filter(messages: list[dict], filter_str: str) -> list[dict]: filter_str is like 'middletruncate:500', 'start:200', or 'end:200'. Returns a new list with truncated content (original messages are not mutated). """ - parts = filter_str.split(":") + parts = filter_str.split(':') if len(parts) != 2: return messages @@ -178,33 +154,29 @@ def apply_content_filter(messages: list[dict], filter_str: str) -> list[dict]: except ValueError: return messages - if mode not in ("middletruncate", "start", "end"): + if mode not in ('middletruncate', 'start', 'end'): return messages result = [] for msg in messages: new_msg = dict(msg) - if isinstance(new_msg.get("content"), str): - new_msg["content"] = truncate_content(new_msg["content"], max_chars, mode) - elif isinstance(new_msg.get("content"), list): + if isinstance(new_msg.get('content'), str): + new_msg['content'] = truncate_content(new_msg['content'], max_chars, mode) + elif isinstance(new_msg.get('content'), list): new_content = [] - for item in new_msg["content"]: - if isinstance(item, dict) and item.get("type") == "text": + for item in new_msg['content']: + if isinstance(item, dict) and item.get('type') == 'text': new_item = dict(item) - new_item["text"] = truncate_content( - item.get("text", ""), max_chars, mode - ) + new_item['text'] = truncate_content(item.get('text', ''), max_chars, mode) new_content.append(new_item) else: new_content.append(item) - new_msg["content"] = new_content + new_msg['content'] = new_content result.append(new_msg) return result -def replace_messages_variable( - template: str, messages: Optional[list[dict]] = None -) -> str: +def replace_messages_variable(template: str, messages: Optional[list[dict]] = None) -> str: def replacement_function(match): # Groups: (1) filter for bare MESSAGES # (2) START count, (3) filter for START @@ -220,7 +192,7 @@ def replace_messages_variable( # If messages is None, handle it as an empty list if messages is None: - return "" + return '' # Select messages based on the variant if start_length is not None: @@ -251,12 +223,12 @@ def replace_messages_variable( return get_messages_content(selected) template = re.sub( - r"(?:" - r"\{\{MESSAGES(?:\|(\w+:\d+))?\}\}" - r"|\{\{MESSAGES:START:(\d+)(?:\|(\w+:\d+))?\}\}" - r"|\{\{MESSAGES:END:(\d+)(?:\|(\w+:\d+))?\}\}" - r"|\{\{MESSAGES:MIDDLETRUNCATE:(\d+)(?:\|(\w+:\d+))?\}\}" - r")", + r'(?:' + r'\{\{MESSAGES(?:\|(\w+:\d+))?\}\}' + r'|\{\{MESSAGES:START:(\d+)(?:\|(\w+:\d+))?\}\}' + r'|\{\{MESSAGES:END:(\d+)(?:\|(\w+:\d+))?\}\}' + r'|\{\{MESSAGES:MIDDLETRUNCATE:(\d+)(?:\|(\w+:\d+))?\}\}' + r')', replacement_function, template, ) @@ -268,39 +240,37 @@ def replace_messages_variable( def rag_template(template: str, context: str, query: str): - if template.strip() == "": + if template.strip() == '': template = DEFAULT_RAG_TEMPLATE template = prompt_template(template) - if "[context]" not in template and "{{CONTEXT}}" not in template: - log.debug( - "WARNING: The RAG template does not contain the '[context]' or '{{CONTEXT}}' placeholder." - ) + if '[context]' not in template and '{{CONTEXT}}' not in template: + log.debug("WARNING: The RAG template does not contain the '[context]' or '{{CONTEXT}}' placeholder.") - if "" in context and "" in context: + if '' in context and '' in context: log.debug( - "WARNING: Potential prompt injection attack: the RAG " + 'WARNING: Potential prompt injection attack: the RAG ' "context contains '' and ''. This might be " - "nothing, or the user might be trying to hack something." + 'nothing, or the user might be trying to hack something.' ) query_placeholders = [] - if "[query]" in context: - query_placeholder = "{{QUERY" + str(uuid.uuid4()) + "}}" - template = template.replace("[query]", query_placeholder) - query_placeholders.append((query_placeholder, "[query]")) + if '[query]' in context: + query_placeholder = '{{QUERY' + str(uuid.uuid4()) + '}}' + template = template.replace('[query]', query_placeholder) + query_placeholders.append((query_placeholder, '[query]')) - if "{{QUERY}}" in context: - query_placeholder = "{{QUERY" + str(uuid.uuid4()) + "}}" - template = template.replace("{{QUERY}}", query_placeholder) - query_placeholders.append((query_placeholder, "{{QUERY}}")) + if '{{QUERY}}' in context: + query_placeholder = '{{QUERY' + str(uuid.uuid4()) + '}}' + template = template.replace('{{QUERY}}', query_placeholder) + query_placeholders.append((query_placeholder, '{{QUERY}}')) - template = template.replace("[context]", context) - template = template.replace("{{CONTEXT}}", context) + template = template.replace('[context]', context) + template = template.replace('{{CONTEXT}}', context) - template = template.replace("[query]", query) - template = template.replace("{{QUERY}}", query) + template = template.replace('[query]', query) + template = template.replace('{{QUERY}}', query) for query_placeholder, original_placeholder in query_placeholders: template = template.replace(query_placeholder, original_placeholder) @@ -308,10 +278,7 @@ def rag_template(template: str, context: str, query: str): return template -def title_generation_template( - template: str, messages: list[dict], user: Optional[Any] = None -) -> str: - +def title_generation_template(template: str, messages: list[dict], user: Optional[Any] = None) -> str: prompt = get_last_user_message(messages) template = replace_prompt_variable(template, prompt) template = replace_messages_variable(template, messages) @@ -321,9 +288,7 @@ def title_generation_template( return template -def follow_up_generation_template( - template: str, messages: list[dict], user: Optional[Any] = None -) -> str: +def follow_up_generation_template(template: str, messages: list[dict], user: Optional[Any] = None) -> str: prompt = get_last_user_message(messages) template = replace_prompt_variable(template, prompt) template = replace_messages_variable(template, messages) @@ -332,9 +297,7 @@ def follow_up_generation_template( return template -def tags_generation_template( - template: str, messages: list[dict], user: Optional[Any] = None -) -> str: +def tags_generation_template(template: str, messages: list[dict], user: Optional[Any] = None) -> str: prompt = get_last_user_message(messages) template = replace_prompt_variable(template, prompt) template = replace_messages_variable(template, messages) @@ -343,9 +306,7 @@ def tags_generation_template( return template -def image_prompt_generation_template( - template: str, messages: list[dict], user: Optional[Any] = None -) -> str: +def image_prompt_generation_template(template: str, messages: list[dict], user: Optional[Any] = None) -> str: prompt = get_last_user_message(messages) template = replace_prompt_variable(template, prompt) template = replace_messages_variable(template, messages) @@ -354,9 +315,7 @@ def image_prompt_generation_template( return template -def emoji_generation_template( - template: str, prompt: str, user: Optional[Any] = None -) -> str: +def emoji_generation_template(template: str, prompt: str, user: Optional[Any] = None) -> str: template = replace_prompt_variable(template, prompt) template = prompt_template(template, user) @@ -370,7 +329,7 @@ def autocomplete_generation_template( type: Optional[str] = None, user: Optional[Any] = None, ) -> str: - template = template.replace("{{TYPE}}", type if type else "") + template = template.replace('{{TYPE}}', type if type else '') template = replace_prompt_variable(template, prompt) template = replace_messages_variable(template, messages) @@ -378,9 +337,7 @@ def autocomplete_generation_template( return template -def query_generation_template( - template: str, messages: list[dict], user: Optional[Any] = None -) -> str: +def query_generation_template(template: str, messages: list[dict], user: Optional[Any] = None) -> str: prompt = get_last_user_message(messages) template = replace_prompt_variable(template, prompt) template = replace_messages_variable(template, messages) @@ -389,16 +346,14 @@ def query_generation_template( return template -def moa_response_generation_template( - template: str, prompt: str, responses: list[str] -) -> str: +def moa_response_generation_template(template: str, prompt: str, responses: list[str]) -> str: def replacement_function(match): full_match = match.group(0) start_length = match.group(1) end_length = match.group(2) middle_length = match.group(3) - if full_match == "{{prompt}}": + if full_match == '{{prompt}}': return prompt elif start_length is not None: return prompt[: int(start_length)] @@ -410,22 +365,22 @@ def moa_response_generation_template( return prompt start = prompt[: math.ceil(middle_length / 2)] end = prompt[-math.floor(middle_length / 2) :] - return f"{start}...{end}" - return "" + return f'{start}...{end}' + return '' template = re.sub( - r"{{prompt}}|{{prompt:start:(\d+)}}|{{prompt:end:(\d+)}}|{{prompt:middletruncate:(\d+)}}", + r'{{prompt}}|{{prompt:start:(\d+)}}|{{prompt:end:(\d+)}}|{{prompt:middletruncate:(\d+)}}', replacement_function, template, ) responses = [f'"""{response}"""' for response in responses] - responses = "\n\n".join(responses) + responses = '\n\n'.join(responses) - template = template.replace("{{responses}}", responses) + template = template.replace('{{responses}}', responses) return template def tools_function_calling_generation_template(template: str, tools_specs: str) -> str: - template = template.replace("{{TOOLS}}", tools_specs) + template = template.replace('{{TOOLS}}', tools_specs) return template diff --git a/backend/open_webui/utils/telemetry/constants.py b/backend/open_webui/utils/telemetry/constants.py index 6ef511f934..1f2102a86f 100644 --- a/backend/open_webui/utils/telemetry/constants.py +++ b/backend/open_webui/utils/telemetry/constants.py @@ -1,12 +1,12 @@ from opentelemetry.semconv.trace import SpanAttributes as _SpanAttributes # Span Tags -SPAN_DB_TYPE = "mysql" -SPAN_REDIS_TYPE = "redis" -SPAN_DURATION = "duration" -SPAN_SQL_STR = "sql" -SPAN_SQL_EXPLAIN = "explain" -SPAN_ERROR_TYPE = "error" +SPAN_DB_TYPE = 'mysql' +SPAN_REDIS_TYPE = 'redis' +SPAN_DURATION = 'duration' +SPAN_SQL_STR = 'sql' +SPAN_SQL_EXPLAIN = 'explain' +SPAN_ERROR_TYPE = 'error' class SpanAttributes(_SpanAttributes): @@ -14,13 +14,13 @@ class SpanAttributes(_SpanAttributes): Span Attributes """ - DB_INSTANCE = "db.instance" - DB_TYPE = "db.type" - DB_IP = "db.ip" - DB_PORT = "db.port" - ERROR_KIND = "error.kind" - ERROR_OBJECT = "error.object" - ERROR_MESSAGE = "error.message" - RESULT_CODE = "result.code" - RESULT_MESSAGE = "result.message" - RESULT_ERRORS = "result.errors" + DB_INSTANCE = 'db.instance' + DB_TYPE = 'db.type' + DB_IP = 'db.ip' + DB_PORT = 'db.port' + ERROR_KIND = 'error.kind' + ERROR_OBJECT = 'error.object' + ERROR_MESSAGE = 'error.message' + RESULT_CODE = 'result.code' + RESULT_MESSAGE = 'result.message' + RESULT_ERRORS = 'result.errors' diff --git a/backend/open_webui/utils/telemetry/instrumentors.py b/backend/open_webui/utils/telemetry/instrumentors.py index 25cd027d0e..394e7178d6 100644 --- a/backend/open_webui/utils/telemetry/instrumentors.py +++ b/backend/open_webui/utils/telemetry/instrumentors.py @@ -38,7 +38,7 @@ def requests_hook(span: Span, request: PreparedRequest): Http Request Hook """ - span.update_name(f"{request.method} {request.url}") + span.update_name(f'{request.method} {request.url}') span.set_attributes( attributes={ SpanAttributes.HTTP_URL: request.url, @@ -70,8 +70,8 @@ def redis_request_hook(span: Span, instance: Union[Redis | RedisCluster], args, # - redis.cluster.RedisCluster # Instead of checking the type, we check if the instance has a nodes_manager attribute. try: - db = "" - if hasattr(instance, "nodes_manager"): + db = '' + if hasattr(instance, 'nodes_manager'): default_node = instance.nodes_manager.default_node if not default_node: return @@ -79,17 +79,17 @@ def redis_request_hook(span: Span, instance: Union[Redis | RedisCluster], args, port = default_node.port else: connection_kwargs: dict = instance.connection_pool.connection_kwargs - host = connection_kwargs.get("host") - port = connection_kwargs.get("port") - db = connection_kwargs.get("db") + host = connection_kwargs.get('host') + port = connection_kwargs.get('port') + db = connection_kwargs.get('db') span.set_attributes( { - SpanAttributes.DB_INSTANCE: f"{host}/{db}", - SpanAttributes.DB_NAME: f"{host}/{db}", + SpanAttributes.DB_INSTANCE: f'{host}/{db}', + SpanAttributes.DB_NAME: f'{host}/{db}', SpanAttributes.DB_TYPE: SPAN_REDIS_TYPE, SpanAttributes.DB_PORT: port, SpanAttributes.DB_IP: host, - SpanAttributes.DB_STATEMENT: " ".join([str(i) for i in args]), + SpanAttributes.DB_STATEMENT: ' '.join([str(i) for i in args]), SpanAttributes.DB_OPERATION: str(args[0]), } ) @@ -102,7 +102,7 @@ def httpx_request_hook(span: Span, request: RequestInfo): HTTPX Request Hook """ - span.update_name(f"{request.method.decode()} {str(request.url)}") + span.update_name(f'{request.method.decode()} {str(request.url)}') span.set_attributes( attributes={ SpanAttributes.HTTP_URL: str(request.url), @@ -117,11 +117,7 @@ def httpx_response_hook(span: Span, request: RequestInfo, response: ResponseInfo """ span.set_attribute(SpanAttributes.HTTP_STATUS_CODE, response.status_code) - span.set_status( - StatusCode.ERROR - if response.status_code >= status.HTTP_400_BAD_REQUEST - else StatusCode.OK - ) + span.set_status(StatusCode.ERROR if response.status_code >= status.HTTP_400_BAD_REQUEST else StatusCode.OK) async def httpx_async_request_hook(span: Span, request: RequestInfo): @@ -132,9 +128,7 @@ async def httpx_async_request_hook(span: Span, request: RequestInfo): httpx_request_hook(span, request) -async def httpx_async_response_hook( - span: Span, request: RequestInfo, response: ResponseInfo -): +async def httpx_async_response_hook(span: Span, request: RequestInfo, response: ResponseInfo): """ Async Response Hook """ @@ -147,7 +141,7 @@ def aiohttp_request_hook(span: Span, request: TraceRequestStartParams): Aiohttp Request Hook """ - span.update_name(f"{request.method} {str(request.url)}") + span.update_name(f'{request.method} {str(request.url)}') span.set_attributes( attributes={ SpanAttributes.HTTP_URL: str(request.url), @@ -156,20 +150,14 @@ def aiohttp_request_hook(span: Span, request: TraceRequestStartParams): ) -def aiohttp_response_hook( - span: Span, response: Union[TraceRequestExceptionParams, TraceRequestEndParams] -): +def aiohttp_response_hook(span: Span, response: Union[TraceRequestExceptionParams, TraceRequestEndParams]): """ Aiohttp Response Hook """ if isinstance(response, TraceRequestEndParams): span.set_attribute(SpanAttributes.HTTP_STATUS_CODE, response.response.status) - span.set_status( - StatusCode.ERROR - if response.response.status >= status.HTTP_400_BAD_REQUEST - else StatusCode.OK - ) + span.set_status(StatusCode.ERROR if response.response.status >= status.HTTP_400_BAD_REQUEST else StatusCode.OK) elif isinstance(response, TraceRequestExceptionParams): span.set_status(StatusCode.ERROR) span.set_attribute(SpanAttributes.ERROR_MESSAGE, str(response.exception)) @@ -191,9 +179,7 @@ class Instrumentor(BaseInstrumentor): instrument_fastapi(app=self.app) SQLAlchemyInstrumentor().instrument(engine=self.db_engine) RedisInstrumentor().instrument(request_hook=redis_request_hook) - RequestsInstrumentor().instrument( - request_hook=requests_hook, response_hook=response_hook - ) + RequestsInstrumentor().instrument(request_hook=requests_hook, response_hook=response_hook) LoggingInstrumentor().instrument() HTTPXClientInstrumentor().instrument( request_hook=httpx_request_hook, @@ -208,7 +194,7 @@ class Instrumentor(BaseInstrumentor): SystemMetricsInstrumentor().instrument() def _uninstrument(self, **kwargs): - if getattr(self, "instrumentors", None) is None: + if getattr(self, 'instrumentors', None) is None: return for instrumentor in self.instrumentors: instrumentor.uninstrument() diff --git a/backend/open_webui/utils/telemetry/logs.py b/backend/open_webui/utils/telemetry/logs.py index 00d3e28c07..e501c99cea 100644 --- a/backend/open_webui/utils/telemetry/logs.py +++ b/backend/open_webui/utils/telemetry/logs.py @@ -24,12 +24,12 @@ from open_webui.env import ( def setup_logging(): headers = [] if OTEL_LOGS_BASIC_AUTH_USERNAME and OTEL_LOGS_BASIC_AUTH_PASSWORD: - auth_string = f"{OTEL_LOGS_BASIC_AUTH_USERNAME}:{OTEL_LOGS_BASIC_AUTH_PASSWORD}" + auth_string = f'{OTEL_LOGS_BASIC_AUTH_USERNAME}:{OTEL_LOGS_BASIC_AUTH_PASSWORD}' auth_header = b64encode(auth_string.encode()).decode() - headers = [("authorization", f"Basic {auth_header}")] + headers = [('authorization', f'Basic {auth_header}')] resource = Resource.create(attributes={SERVICE_NAME: OTEL_SERVICE_NAME}) - if OTEL_LOGS_OTLP_SPAN_EXPORTER == "http": + if OTEL_LOGS_OTLP_SPAN_EXPORTER == 'http': exporter = HttpOTLPLogExporter( endpoint=OTEL_LOGS_EXPORTER_OTLP_ENDPOINT, headers=headers, diff --git a/backend/open_webui/utils/telemetry/metrics.py b/backend/open_webui/utils/telemetry/metrics.py index 5ff1129b02..4c43de3342 100644 --- a/backend/open_webui/utils/telemetry/metrics.py +++ b/backend/open_webui/utils/telemetry/metrics.py @@ -53,19 +53,15 @@ def _build_meter_provider(resource: Resource) -> MeterProvider: """Return a configured MeterProvider.""" headers = [] if OTEL_METRICS_BASIC_AUTH_USERNAME and OTEL_METRICS_BASIC_AUTH_PASSWORD: - auth_string = ( - f"{OTEL_METRICS_BASIC_AUTH_USERNAME}:{OTEL_METRICS_BASIC_AUTH_PASSWORD}" - ) + auth_string = f'{OTEL_METRICS_BASIC_AUTH_USERNAME}:{OTEL_METRICS_BASIC_AUTH_PASSWORD}' auth_header = b64encode(auth_string.encode()).decode() - headers = [("authorization", f"Basic {auth_header}")] + headers = [('authorization', f'Basic {auth_header}')] # Periodic reader pushes metrics over OTLP/gRPC to collector - if OTEL_METRICS_OTLP_SPAN_EXPORTER == "http": + if OTEL_METRICS_OTLP_SPAN_EXPORTER == 'http': readers: List[PeriodicExportingMetricReader] = [ PeriodicExportingMetricReader( - OTLPHttpMetricExporter( - endpoint=OTEL_METRICS_EXPORTER_OTLP_ENDPOINT, headers=headers - ), + OTLPHttpMetricExporter(endpoint=OTEL_METRICS_EXPORTER_OTLP_ENDPOINT, headers=headers), export_interval_millis=OTEL_METRICS_EXPORT_INTERVAL_MILLIS, ) ] @@ -84,21 +80,21 @@ def _build_meter_provider(resource: Resource) -> MeterProvider: # Optional view to limit cardinality: drop user-agent etc. views: List[View] = [ View( - instrument_name="http.server.duration", - attribute_keys=["http.method", "http.route", "http.status_code"], + instrument_name='http.server.duration', + attribute_keys=['http.method', 'http.route', 'http.status_code'], ), View( - instrument_name="http.server.requests", - attribute_keys=["http.method", "http.route", "http.status_code"], + instrument_name='http.server.requests', + attribute_keys=['http.method', 'http.route', 'http.status_code'], ), View( - instrument_name="webui.users.total", + instrument_name='webui.users.total', ), View( - instrument_name="webui.users.active", + instrument_name='webui.users.active', ), View( - instrument_name="webui.users.active.today", + instrument_name='webui.users.active.today', ), ] @@ -118,14 +114,14 @@ def setup_metrics(app: FastAPI, resource: Resource) -> None: # Instruments request_counter = meter.create_counter( - name="http.server.requests", - description="Counts the total number of inbound HTTP requests.", - unit="1", + name='http.server.requests', + description='Counts the total number of inbound HTTP requests.', + unit='1', ) duration_histogram = meter.create_histogram( - name="http.server.duration", - description="Measures the duration of inbound HTTP requests.", - unit="ms", + name='http.server.duration', + description='Measures the duration of inbound HTTP requests.', + unit='ms', ) def observe_active_users( @@ -150,16 +146,16 @@ def setup_metrics(app: FastAPI, resource: Resource) -> None: ] meter.create_observable_gauge( - name="webui.users.total", - description="Total number of registered users", - unit="users", + name='webui.users.total', + description='Total number of registered users', + unit='users', callbacks=[observe_total_registered_users], ) meter.create_observable_gauge( - name="webui.users.active", - description="Number of currently active users", - unit="users", + name='webui.users.active', + description='Number of currently active users', + unit='users', callbacks=[observe_active_users], ) @@ -169,21 +165,21 @@ def setup_metrics(app: FastAPI, resource: Resource) -> None: return [metrics.Observation(value=Users.get_num_users_active_today())] meter.create_observable_gauge( - name="webui.users.active.today", - description="Number of users active since midnight today", - unit="users", + name='webui.users.active.today', + description='Number of users active since midnight today', + unit='users', callbacks=[observe_users_active_today], ) # FastAPI middleware - @app.middleware("http") + @app.middleware('http') async def _metrics_middleware(request: Request, call_next): start_time = time.perf_counter() status_code = None try: response = await call_next(request) - status_code = getattr(response, "status_code", 500) + status_code = getattr(response, 'status_code', 500) return response except Exception: status_code = 500 @@ -192,13 +188,13 @@ def setup_metrics(app: FastAPI, resource: Resource) -> None: elapsed_ms = (time.perf_counter() - start_time) * 1000.0 # Route template e.g. "/items/{item_id}" instead of real path. - route = request.scope.get("route") - route_path = getattr(route, "path", request.url.path) + route = request.scope.get('route') + route_path = getattr(route, 'path', request.url.path) attrs: Dict[str, str | int] = { - "http.method": request.method, - "http.route": route_path, - "http.status_code": status_code, + 'http.method': request.method, + 'http.route': route_path, + 'http.status_code': status_code, } request_counter.add(1, attrs) diff --git a/backend/open_webui/utils/telemetry/setup.py b/backend/open_webui/utils/telemetry/setup.py index 36294b4e56..744dced2d0 100644 --- a/backend/open_webui/utils/telemetry/setup.py +++ b/backend/open_webui/utils/telemetry/setup.py @@ -34,12 +34,12 @@ def setup(app: FastAPI, db_engine: Engine): # Add basic auth header only if both username and password are not empty headers = [] if OTEL_BASIC_AUTH_USERNAME and OTEL_BASIC_AUTH_PASSWORD: - auth_string = f"{OTEL_BASIC_AUTH_USERNAME}:{OTEL_BASIC_AUTH_PASSWORD}" + auth_string = f'{OTEL_BASIC_AUTH_USERNAME}:{OTEL_BASIC_AUTH_PASSWORD}' auth_header = b64encode(auth_string.encode()).decode() - headers = [("authorization", f"Basic {auth_header}")] + headers = [('authorization', f'Basic {auth_header}')] # otlp export - if OTEL_OTLP_SPAN_EXPORTER == "http": + if OTEL_OTLP_SPAN_EXPORTER == 'http': exporter = HttpOTLPSpanExporter( endpoint=OTEL_EXPORTER_OTLP_ENDPOINT, headers=headers, diff --git a/backend/open_webui/utils/tools.py b/backend/open_webui/utils/tools.py index 17fc2cca5e..8febdfa31d 100644 --- a/backend/open_webui/utils/tools.py +++ b/backend/open_webui/utils/tools.py @@ -89,9 +89,7 @@ import copy log = logging.getLogger(__name__) -def get_async_tool_function_and_apply_extra_params( - function: Callable, extra_params: dict -) -> Callable[..., Awaitable]: +def get_async_tool_function_and_apply_extra_params(function: Callable, extra_params: dict) -> Callable[..., Awaitable]: sig = inspect.signature(function) extra_params = {k: v for k, v in extra_params.items() if k in sig.parameters} partial_func = partial(function, **extra_params) @@ -106,9 +104,7 @@ def get_async_tool_function_and_apply_extra_params( # Keep remaining parameters parameters.append(parameter) - new_sig = inspect.Signature( - parameters=parameters, return_annotation=sig.return_annotation - ) + new_sig = inspect.Signature(parameters=parameters, return_annotation=sig.return_annotation) if inspect.iscoroutinefunction(function): # wrap the functools.partial as python-genai has trouble with it @@ -132,8 +128,8 @@ def get_async_tool_function_and_apply_extra_params( def get_updated_tool_function(function: Callable, extra_params: dict): # Get the original function and merge updated params - __function__ = getattr(function, "__function__", None) - __extra_params__ = getattr(function, "__extra_params__", None) + __function__ = getattr(function, '__function__', None) + __extra_params__ = getattr(function, '__extra_params__', None) if __function__ is not None and __extra_params__ is not None: return get_async_tool_function_and_apply_extra_params( @@ -144,9 +140,7 @@ def get_updated_tool_function(function: Callable, extra_params: dict): return function -async def get_tools( - request: Request, tool_ids: list[str], user: UserModel, extra_params: dict -) -> dict[str, dict]: +async def get_tools(request: Request, tool_ids: list[str], user: UserModel, extra_params: dict) -> dict[str, dict]: """Load tools for the given tool_ids, checking access control.""" if not tool_ids: return {} @@ -161,17 +155,17 @@ async def get_tools( if tool: # Check access control for local tools if ( - not (user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL) + not (user.role == 'admin' and BYPASS_ADMIN_ACCESS_CONTROL) and tool.user_id != user.id and not AccessGrants.has_access( user_id=user.id, - resource_type="tool", + resource_type='tool', resource_id=tool.id, - permission="read", + permission='read', user_group_ids=user_group_ids, ) ): - log.warning(f"Access denied to tool {tool_id} for user {user.id}") + log.warning(f'Access denied to tool {tool_id} for user {user.id}') continue module = request.app.state.TOOLS.get(tool_id, None) @@ -180,164 +174,146 @@ async def get_tools( request.app.state.TOOLS[tool_id] = module __user__ = { - **extra_params["__user__"], + **extra_params['__user__'], } # Set valves for the tool - if hasattr(module, "valves") and hasattr(module, "Valves"): + if hasattr(module, 'valves') and hasattr(module, 'Valves'): valves = Tools.get_tool_valves_by_id(tool_id) or {} module.valves = module.Valves(**valves) - if hasattr(module, "UserValves"): - __user__["valves"] = module.UserValves( # type: ignore + if hasattr(module, 'UserValves'): + __user__['valves'] = module.UserValves( # type: ignore **Tools.get_user_valves_by_id_and_user_id(tool_id, user.id) ) for spec in tool.specs: # TODO: Fix hack for OpenAI API # Some times breaks OpenAI but others don't. Leaving the comment - for val in spec.get("parameters", {}).get("properties", {}).values(): - if val.get("type") == "str": - val["type"] = "string" + for val in spec.get('parameters', {}).get('properties', {}).values(): + if val.get('type') == 'str': + val['type'] = 'string' # Remove internal reserved parameters (e.g. __id__, __user__) - spec["parameters"]["properties"] = { - key: val - for key, val in spec["parameters"]["properties"].items() - if not key.startswith("__") + spec['parameters']['properties'] = { + key: val for key, val in spec['parameters']['properties'].items() if not key.startswith('__') } # convert to function that takes only model params and inserts custom params - function_name = spec["name"] + function_name = spec['name'] tool_function = getattr(module, function_name) callable = get_async_tool_function_and_apply_extra_params( tool_function, { **extra_params, - "__id__": tool_id, - "__user__": __user__, + '__id__': tool_id, + '__user__': __user__, }, ) # TODO: Support Pydantic models as parameters - if callable.__doc__ and callable.__doc__.strip() != "": - s = re.split(":(param|return)", callable.__doc__, 1) - spec["description"] = s[0] + if callable.__doc__ and callable.__doc__.strip() != '': + s = re.split(':(param|return)', callable.__doc__, 1) + spec['description'] = s[0] else: - spec["description"] = function_name + spec['description'] = function_name tool_dict = { - "tool_id": tool_id, - "callable": callable, - "spec": spec, + 'tool_id': tool_id, + 'callable': callable, + 'spec': spec, # Misc info - "metadata": { - "file_handler": hasattr(module, "file_handler") - and module.file_handler, - "citation": hasattr(module, "citation") and module.citation, + 'metadata': { + 'file_handler': hasattr(module, 'file_handler') and module.file_handler, + 'citation': hasattr(module, 'citation') and module.citation, }, } # Handle function name collisions while function_name in tools_dict: - log.warning( - f"Tool {function_name} already exists in another tools!" - ) + log.warning(f'Tool {function_name} already exists in another tools!') # Prepend tool ID to function name - function_name = f"{tool_id}_{function_name}" + function_name = f'{tool_id}_{function_name}' tools_dict[function_name] = tool_dict else: - if tool_id.startswith("server:"): - splits = tool_id.split(":") + if tool_id.startswith('server:'): + splits = tool_id.split(':') if len(splits) == 2: - type = "openapi" + type = 'openapi' server_id = splits[1] elif len(splits) == 3: type = splits[1] server_id = splits[2] - server_id_splits = server_id.split("|") + server_id_splits = server_id.split('|') if len(server_id_splits) == 2: server_id = server_id_splits[0] - function_names = server_id_splits[1].split(",") - - if type == "openapi": + function_names = server_id_splits[1].split(',') + if type == 'openapi': tool_server_data = None for server in await get_tool_servers(request): - if server["id"] == server_id: + if server['id'] == server_id: tool_server_data = server break if tool_server_data is None: - log.warning(f"Tool server data not found for {server_id}") + log.warning(f'Tool server data not found for {server_id}') continue - tool_server_idx = tool_server_data.get("idx", 0) + tool_server_idx = tool_server_data.get('idx', 0) connections = request.app.state.config.TOOL_SERVER_CONNECTIONS if tool_server_idx >= len(connections): log.warning( - f"Tool server index {tool_server_idx} out of range " - f"(have {len(connections)} connections), skipping server {server_id}" + f'Tool server index {tool_server_idx} out of range ' + f'(have {len(connections)} connections), skipping server {server_id}' ) continue tool_server_connection = connections[tool_server_idx] # Check access control for tool server - if not has_connection_access( - user, tool_server_connection, user_group_ids - ): - log.warning( - f"Access denied to tool server {server_id} for user {user.id}" - ) + if not has_connection_access(user, tool_server_connection, user_group_ids): + log.warning(f'Access denied to tool server {server_id} for user {user.id}') continue - specs = tool_server_data.get("specs", []) - function_name_filter_list = tool_server_connection.get( - "config", {} - ).get("function_name_filter_list", "") + specs = tool_server_data.get('specs', []) + function_name_filter_list = tool_server_connection.get('config', {}).get( + 'function_name_filter_list', '' + ) if isinstance(function_name_filter_list, str): - function_name_filter_list = function_name_filter_list.split(",") + function_name_filter_list = function_name_filter_list.split(',') for spec in specs: - function_name = spec["name"] + function_name = spec['name'] if function_name_filter_list: - if not is_string_allowed( - function_name, function_name_filter_list - ): + if not is_string_allowed(function_name, function_name_filter_list): # Skip this function continue - auth_type = tool_server_connection.get("auth_type", "bearer") + auth_type = tool_server_connection.get('auth_type', 'bearer') cookies = {} headers = { - "Content-Type": "application/json", + 'Content-Type': 'application/json', } - if auth_type == "bearer": - headers["Authorization"] = ( - f"Bearer {tool_server_connection.get('key', '')}" - ) - elif auth_type == "none": + if auth_type == 'bearer': + headers['Authorization'] = f'Bearer {tool_server_connection.get("key", "")}' + elif auth_type == 'none': # No authentication pass - elif auth_type == "session": + elif auth_type == 'session': cookies = request.cookies - headers["Authorization"] = ( - f"Bearer {request.state.token.credentials}" - ) - elif auth_type == "system_oauth": + headers['Authorization'] = f'Bearer {request.state.token.credentials}' + elif auth_type == 'system_oauth': cookies = request.cookies - oauth_token = extra_params.get("__oauth_token__", None) + oauth_token = extra_params.get('__oauth_token__', None) if oauth_token: - headers["Authorization"] = ( - f"Bearer {oauth_token.get('access_token', '')}" - ) + headers['Authorization'] = f'Bearer {oauth_token.get("access_token", "")}' - connection_headers = tool_server_connection.get("headers", None) + connection_headers = tool_server_connection.get('headers', None) if connection_headers and isinstance(connection_headers, dict): for key, value in connection_headers.items(): headers[key] = value @@ -345,22 +321,16 @@ async def get_tools( # Add user info headers if enabled if ENABLE_FORWARD_USER_INFO_HEADERS and user: headers = include_user_info_headers(headers, user) - metadata = extra_params.get("__metadata__", {}) - if metadata and metadata.get("chat_id"): - headers[FORWARD_SESSION_INFO_HEADER_CHAT_ID] = ( - metadata.get("chat_id") - ) - if metadata and metadata.get("message_id"): - headers[FORWARD_SESSION_INFO_HEADER_MESSAGE_ID] = ( - metadata.get("message_id") - ) + metadata = extra_params.get('__metadata__', {}) + if metadata and metadata.get('chat_id'): + headers[FORWARD_SESSION_INFO_HEADER_CHAT_ID] = metadata.get('chat_id') + if metadata and metadata.get('message_id'): + headers[FORWARD_SESSION_INFO_HEADER_MESSAGE_ID] = metadata.get('message_id') - def make_tool_function( - function_name, tool_server_data, headers - ): + def make_tool_function(function_name, tool_server_data, headers): async def tool_function(**kwargs): return await execute_tool_server( - url=tool_server_data["url"], + url=tool_server_data['url'], headers=headers, cookies=cookies, name=function_name, @@ -370,9 +340,7 @@ async def get_tools( return tool_function - tool_function = make_tool_function( - function_name, tool_server_data, headers - ) + tool_function = make_tool_function(function_name, tool_server_data, headers) callable = get_async_tool_function_and_apply_extra_params( tool_function, @@ -380,20 +348,18 @@ async def get_tools( ) tool_dict = { - "tool_id": tool_id, - "callable": callable, - "spec": clean_openai_tool_schema(spec), + 'tool_id': tool_id, + 'callable': callable, + 'spec': clean_openai_tool_schema(spec), # Misc info - "type": "external", + 'type': 'external', } # Handle function name collisions while function_name in tools_dict: - log.warning( - f"Tool {function_name} already exists in another tools!" - ) + log.warning(f'Tool {function_name} already exists in another tools!') # Prepend server ID to function name - function_name = f"{server_id}_{function_name}" + function_name = f'{server_id}_{function_name}' tools_dict[function_name] = tool_dict @@ -417,37 +383,35 @@ def get_builtin_tools( # Helper to get model capabilities (defaults to True if not specified) def get_model_capability(name: str, default: bool = True) -> bool: - return (model.get("info", {}).get("meta", {}).get("capabilities") or {}).get( - name, default - ) + return (model.get('info', {}).get('meta', {}).get('capabilities') or {}).get(name, default) # Helper to check if a builtin tool category is enabled via meta.builtinTools # Defaults to True if not specified (backward compatible) def is_builtin_tool_enabled(category: str) -> bool: - builtin_tools = model.get("info", {}).get("meta", {}).get("builtinTools", {}) + builtin_tools = model.get('info', {}).get('meta', {}).get('builtinTools', {}) return builtin_tools.get(category, True) # Time utilities - available for date calculations - if is_builtin_tool_enabled("time"): + if is_builtin_tool_enabled('time'): builtin_functions.extend([get_current_timestamp, calculate_timestamp]) # Knowledge base tools - conditional injection based on model knowledge # If model has attached knowledge (any type), only provide query_knowledge_files # Otherwise, provide all KB browsing tools - model_knowledge = model.get("info", {}).get("meta", {}).get("knowledge", []) + model_knowledge = model.get('info', {}).get('meta', {}).get('knowledge', []) # Merge folder-attached knowledge so builtin tools can search it - folder_knowledge = extra_params.get("__metadata__", {}).get("folder_knowledge") + folder_knowledge = extra_params.get('__metadata__', {}).get('folder_knowledge') if folder_knowledge: model_knowledge = list(model_knowledge or []) + list(folder_knowledge) - if is_builtin_tool_enabled("knowledge"): + if is_builtin_tool_enabled('knowledge'): if model_knowledge: # Model has attached knowledge - only allow semantic search within it builtin_functions.append(query_knowledge_files) - knowledge_types = {item.get("type") for item in model_knowledge} - if "file" in knowledge_types or "collection" in knowledge_types: + knowledge_types = {item.get('type') for item in model_knowledge} + if 'file' in knowledge_types or 'collection' in knowledge_types: builtin_functions.append(view_file) - if "note" in knowledge_types: + if 'note' in knowledge_types: builtin_functions.append(view_note) else: # No model knowledge - allow full KB browsing @@ -463,11 +427,11 @@ def get_builtin_tools( ) # Chats tools - search and fetch user's chat history - if is_builtin_tool_enabled("chats"): + if is_builtin_tool_enabled('chats'): builtin_functions.extend([search_chats, view_chat]) # Add memory tools if builtin category enabled AND enabled for this chat - if is_builtin_tool_enabled("memory") and features.get("memory"): + if is_builtin_tool_enabled('memory') and features.get('memory'): builtin_functions.extend( [ search_memories, @@ -480,50 +444,44 @@ def get_builtin_tools( # Add web search tools if builtin category enabled AND enabled globally AND model has web_search capability if ( - is_builtin_tool_enabled("web_search") - and getattr(request.app.state.config, "ENABLE_WEB_SEARCH", False) - and get_model_capability("web_search") - and features.get("web_search") + is_builtin_tool_enabled('web_search') + and getattr(request.app.state.config, 'ENABLE_WEB_SEARCH', False) + and get_model_capability('web_search') + and features.get('web_search') ): builtin_functions.extend([search_web, fetch_url]) # Add image generation/edit tools if builtin category enabled AND enabled globally AND model has image_generation capability if ( - is_builtin_tool_enabled("image_generation") - and getattr(request.app.state.config, "ENABLE_IMAGE_GENERATION", False) - and get_model_capability("image_generation") - and features.get("image_generation") + is_builtin_tool_enabled('image_generation') + and getattr(request.app.state.config, 'ENABLE_IMAGE_GENERATION', False) + and get_model_capability('image_generation') + and features.get('image_generation') ): builtin_functions.append(generate_image) if ( - is_builtin_tool_enabled("image_generation") - and getattr(request.app.state.config, "ENABLE_IMAGE_EDIT", False) - and get_model_capability("image_generation") - and features.get("image_generation") + is_builtin_tool_enabled('image_generation') + and getattr(request.app.state.config, 'ENABLE_IMAGE_EDIT', False) + and get_model_capability('image_generation') + and features.get('image_generation') ): builtin_functions.append(edit_image) # Add code interpreter tool if builtin category enabled AND enabled globally AND model has code_interpreter capability if ( - is_builtin_tool_enabled("code_interpreter") - and getattr(request.app.state.config, "ENABLE_CODE_INTERPRETER", True) - and get_model_capability("code_interpreter") - and features.get("code_interpreter") + is_builtin_tool_enabled('code_interpreter') + and getattr(request.app.state.config, 'ENABLE_CODE_INTERPRETER', True) + and get_model_capability('code_interpreter') + and features.get('code_interpreter') ): builtin_functions.append(execute_code) # Notes tools - search, view, create, and update user's notes (if builtin category enabled AND notes enabled globally) - if is_builtin_tool_enabled("notes") and getattr( - request.app.state.config, "ENABLE_NOTES", False - ): - builtin_functions.extend( - [search_notes, view_note, write_note, replace_note_content] - ) + if is_builtin_tool_enabled('notes') and getattr(request.app.state.config, 'ENABLE_NOTES', False): + builtin_functions.extend([search_notes, view_note, write_note, replace_note_content]) # Channels tools - search channels and messages (if builtin category enabled AND channels enabled globally) - if is_builtin_tool_enabled("channels") and getattr( - request.app.state.config, "ENABLE_CHANNELS", False - ): + if is_builtin_tool_enabled('channels') and getattr(request.app.state.config, 'ENABLE_CHANNELS', False): builtin_functions.extend( [ search_channels, @@ -534,21 +492,21 @@ def get_builtin_tools( ) # Skills tools - view_skill allows model to load full skill instructions on demand - if extra_params.get("__skill_ids__"): + if extra_params.get('__skill_ids__'): builtin_functions.append(view_skill) for func in builtin_functions: callable = get_async_tool_function_and_apply_extra_params( func, { - "__request__": request, - "__user__": extra_params.get("__user__", {}), - "__event_emitter__": extra_params.get("__event_emitter__"), - "__event_call__": extra_params.get("__event_call__"), - "__metadata__": extra_params.get("__metadata__"), - "__chat_id__": extra_params.get("__chat_id__"), - "__message_id__": extra_params.get("__message_id__"), - "__model_knowledge__": model_knowledge, + '__request__': request, + '__user__': extra_params.get('__user__', {}), + '__event_emitter__': extra_params.get('__event_emitter__'), + '__event_call__': extra_params.get('__event_call__'), + '__metadata__': extra_params.get('__metadata__'), + '__chat_id__': extra_params.get('__chat_id__'), + '__message_id__': extra_params.get('__message_id__'), + '__model_knowledge__': model_knowledge, }, ) @@ -558,10 +516,10 @@ def get_builtin_tools( spec = clean_openai_tool_schema(spec) tools_dict[func.__name__] = { - "tool_id": f"builtin:{func.__name__}", - "callable": callable, - "spec": spec, - "type": "builtin", + 'tool_id': f'builtin:{func.__name__}', + 'callable': callable, + 'spec': spec, + 'type': 'builtin', } return tools_dict @@ -579,18 +537,18 @@ def parse_description(docstring: str | None) -> str: """ if not docstring: - return "" + return '' - lines = [line.strip() for line in docstring.strip().split("\n")] + lines = [line.strip() for line in docstring.strip().split('\n')] description_lines: list[str] = [] for line in lines: - if re.match(r":param", line) or re.match(r":return", line): + if re.match(r':param', line) or re.match(r':return', line): break description_lines.append(line) - return "\n".join(description_lines) + return '\n'.join(description_lines) def parse_docstring(docstring): @@ -607,7 +565,7 @@ def parse_docstring(docstring): return {} # Regex to match `:param name: description` format - param_pattern = re.compile(r":param (\w+):\s*(.+)") + param_pattern = re.compile(r':param (\w+):\s*(.+)') param_descriptions = {} for line in docstring.splitlines(): @@ -615,7 +573,7 @@ def parse_docstring(docstring): if not match: continue param_name, param_description = match.groups() - if param_name.startswith("__"): + if param_name.startswith('__'): continue param_descriptions[param_name] = param_description @@ -668,27 +626,27 @@ def clean_properties(schema: dict): if not isinstance(schema, dict): return - if "anyOf" in schema: - non_null_types = [t for t in schema["anyOf"] if t.get("type") != "null"] + if 'anyOf' in schema: + non_null_types = [t for t in schema['anyOf'] if t.get('type') != 'null'] if len(non_null_types) == 1: schema.update(non_null_types[0]) - del schema["anyOf"] + del schema['anyOf'] else: - schema["anyOf"] = non_null_types + schema['anyOf'] = non_null_types - if "default" in schema and schema["default"] is None: - del schema["default"] + if 'default' in schema and schema['default'] is None: + del schema['default'] # fix missing type - if "type" not in schema and "anyOf" not in schema and "properties" not in schema: - schema["type"] = "string" + if 'type' not in schema and 'anyOf' not in schema and 'properties' not in schema: + schema['type'] = 'string' - if "properties" in schema: - for prop_name, prop_schema in schema["properties"].items(): + if 'properties' in schema: + for prop_name, prop_schema in schema['properties'].items(): clean_properties(prop_schema) - if "items" in schema: - clean_properties(schema["items"]) + if 'items' in schema: + clean_properties(schema['items']) def clean_openai_tool_schema(spec: dict) -> dict: @@ -696,8 +654,8 @@ def clean_openai_tool_schema(spec: dict) -> dict: cleaned_spec = copy.deepcopy(spec) - if "parameters" in cleaned_spec: - clean_properties(cleaned_spec["parameters"]) + if 'parameters' in cleaned_spec: + clean_properties(cleaned_spec['parameters']) return cleaned_spec @@ -706,12 +664,8 @@ def get_functions_from_tool(tool: object) -> list[Callable]: return [ getattr(tool, func) for func in dir(tool) - if callable( - getattr(tool, func) - ) # checks if the attribute is callable (a method or function). - and not func.startswith( - "_" - ) # filters out internal methods (starting with _) and special (dunder) methods. + if callable(getattr(tool, func)) # checks if the attribute is callable (a method or function). + and not func.startswith('_') # filters out internal methods (starting with _) and special (dunder) methods. and not inspect.isclass( getattr(tool, func) ) # ensures that the callable is not a class itself, just a method or function. @@ -719,14 +673,10 @@ def get_functions_from_tool(tool: object) -> list[Callable]: def get_tool_specs(tool_module: object) -> list[dict]: - function_models = map( - convert_function_to_pydantic_model, get_functions_from_tool(tool_module) - ) + function_models = map(convert_function_to_pydantic_model, get_functions_from_tool(tool_module)) specs = [ - clean_openai_tool_schema( - convert_pydantic_model_to_openai_function_spec(function_model) - ) + clean_openai_tool_schema(convert_pydantic_model_to_openai_function_spec(function_model)) for function_model in function_models ] @@ -740,9 +690,9 @@ def resolve_schema(schema, components): if not schema: return {} - if "$ref" in schema: - ref_path = schema["$ref"] - ref_parts = ref_path.strip("#/").split("/") + if '$ref' in schema: + ref_path = schema['$ref'] + ref_parts = ref_path.strip('#/').split('/') resolved = components for part in ref_parts[1:]: # Skip the initial 'components' resolved = resolved.get(part, {}) @@ -751,14 +701,12 @@ def resolve_schema(schema, components): resolved_schema = copy.deepcopy(schema) # Recursively resolve inner schemas - if "properties" in resolved_schema: - for prop, prop_schema in resolved_schema["properties"].items(): - resolved_schema["properties"][prop] = resolve_schema( - prop_schema, components - ) + if 'properties' in resolved_schema: + for prop, prop_schema in resolved_schema['properties'].items(): + resolved_schema['properties'][prop] = resolve_schema(prop_schema, components) - if "items" in resolved_schema: - resolved_schema["items"] = resolve_schema(resolved_schema["items"], components) + if 'items' in resolved_schema: + resolved_schema['items'] = resolve_schema(resolved_schema['items'], components) return resolved_schema @@ -775,75 +723,60 @@ def convert_openapi_to_tool_payload(openapi_spec): """ tool_payload = [] - for path, methods in openapi_spec.get("paths", {}).items(): + for path, methods in openapi_spec.get('paths', {}).items(): for method, operation in methods.items(): - if operation.get("operationId"): + if operation.get('operationId'): tool = { - "name": operation.get("operationId"), - "description": operation.get( - "description", - operation.get("summary", "No description available."), + 'name': operation.get('operationId'), + 'description': operation.get( + 'description', + operation.get('summary', 'No description available.'), ), - "parameters": {"type": "object", "properties": {}, "required": []}, + 'parameters': {'type': 'object', 'properties': {}, 'required': []}, } - for param in operation.get("parameters", []): - param_name = param.get("name") + for param in operation.get('parameters', []): + param_name = param.get('name') if not param_name: continue - param_schema = param.get("schema", {}) - description = param_schema.get("description", "") + param_schema = param.get('schema', {}) + description = param_schema.get('description', '') if not description: - description = param.get("description") or "" - if param_schema.get("enum") and isinstance( - param_schema.get("enum"), list - ): - description += ( - f". Possible values: {', '.join(param_schema.get('enum'))}" - ) + description = param.get('description') or '' + if param_schema.get('enum') and isinstance(param_schema.get('enum'), list): + description += f'. Possible values: {", ".join(param_schema.get("enum"))}' param_property = { - "type": param_schema.get("type") or "string", - "description": description, + 'type': param_schema.get('type') or 'string', + 'description': description, } # Include items property for array types (required by OpenAI) - if param_schema.get("type") == "array" and "items" in param_schema: - param_property["items"] = param_schema["items"] + if param_schema.get('type') == 'array' and 'items' in param_schema: + param_property['items'] = param_schema['items'] # Filter out None values to prevent schema validation errors - param_property = { - k: v for k, v in param_property.items() if v is not None - } + param_property = {k: v for k, v in param_property.items() if v is not None} - tool["parameters"]["properties"][param_name] = param_property - if param.get("required"): - tool["parameters"]["required"].append(param_name) + tool['parameters']['properties'][param_name] = param_property + if param.get('required'): + tool['parameters']['required'].append(param_name) # Extract and resolve requestBody if available - request_body = operation.get("requestBody") + request_body = operation.get('requestBody') if request_body: - content = request_body.get("content", {}) - json_schema = content.get("application/json", {}).get("schema") + content = request_body.get('content', {}) + json_schema = content.get('application/json', {}).get('schema') if json_schema: - resolved_schema = resolve_schema( - json_schema, openapi_spec.get("components", {}) - ) + resolved_schema = resolve_schema(json_schema, openapi_spec.get('components', {})) - if resolved_schema.get("properties"): - tool["parameters"]["properties"].update( - resolved_schema["properties"] - ) - if "required" in resolved_schema: - tool["parameters"]["required"] = list( - set( - tool["parameters"]["required"] - + resolved_schema["required"] - ) + if resolved_schema.get('properties'): + tool['parameters']['properties'].update(resolved_schema['properties']) + if 'required' in resolved_schema: + tool['parameters']['required'] = list( + set(tool['parameters']['required'] + resolved_schema['required']) ) - elif resolved_schema.get("type") == "array": - tool["parameters"] = ( - resolved_schema # special case for array - ) + elif resolved_schema.get('type') == 'array': + tool['parameters'] = resolved_schema # special case for array tool_payload.append(tool) @@ -851,14 +784,10 @@ def convert_openapi_to_tool_payload(openapi_spec): async def set_tool_servers(request: Request): - request.app.state.TOOL_SERVERS = await get_tool_servers_data( - request.app.state.config.TOOL_SERVER_CONNECTIONS - ) + request.app.state.TOOL_SERVERS = await get_tool_servers_data(request.app.state.config.TOOL_SERVER_CONNECTIONS) if request.app.state.redis is not None: - await request.app.state.redis.set( - "tool_servers", json.dumps(request.app.state.TOOL_SERVERS) - ) + await request.app.state.redis.set('tool_servers', json.dumps(request.app.state.TOOL_SERVERS)) return request.app.state.TOOL_SERVERS @@ -867,10 +796,10 @@ async def get_tool_servers(request: Request): tool_servers = [] if request.app.state.redis is not None: try: - tool_servers = json.loads(await request.app.state.redis.get("tool_servers")) + tool_servers = json.loads(await request.app.state.redis.get('tool_servers')) request.app.state.TOOL_SERVERS = tool_servers except Exception as e: - log.error(f"Error fetching tool_servers from Redis: {e}") + log.error(f'Error fetching tool_servers from Redis: {e}') if not tool_servers: tool_servers = await set_tool_servers(request) @@ -885,19 +814,17 @@ async def get_terminal_cwd( ) -> Optional[str]: """Fetch the current working directory from a terminal server.""" try: - cwd_url = f"{base_url.rstrip('/')}/files/cwd" + cwd_url = f'{base_url.rstrip("/")}/files/cwd' async with aiohttp.ClientSession( timeout=aiohttp.ClientTimeout(total=5), trust_env=True, ) as session: - async with session.get( - cwd_url, headers=headers, cookies=cookies or {} - ) as resp: + async with session.get(cwd_url, headers=headers, cookies=cookies or {}) as resp: if resp.status == 200: data = await resp.json() - return data.get("cwd") + return data.get('cwd') except Exception as e: - log.debug(f"Failed to fetch terminal CWD: {e}") + log.debug(f'Failed to fetch terminal CWD: {e}') return None @@ -909,31 +836,31 @@ async def set_terminal_servers(request: Request): # Terminal connections store id/name at top level; translate to info dict server_configs = [] for connection in connections: - if not connection.get("url"): + if not connection.get('url'): continue - enabled = connection.get("enabled", True) + enabled = connection.get('enabled', True) - base_url = connection.get("url", "").rstrip("/") - policy_id = connection.get("policy_id", "") + base_url = connection.get('url', '').rstrip('/') + policy_id = connection.get('policy_id', '') # Orchestrator connections route through /p/{policy_id}/ — the # OpenAPI spec lives on the proxied terminal, not the orchestrator. - if connection.get("server_type") == "orchestrator" and policy_id: - base_url = f"{base_url}/p/{policy_id}" + if connection.get('server_type') == 'orchestrator' and policy_id: + base_url = f'{base_url}/p/{policy_id}' server_configs.append( { - "url": base_url, - "key": connection.get("key", ""), - "auth_type": connection.get("auth_type", "bearer"), - "path": connection.get("path", "/openapi.json"), - "spec_type": "url", + 'url': base_url, + 'key': connection.get('key', ''), + 'auth_type': connection.get('auth_type', 'bearer'), + 'path': connection.get('path', '/openapi.json'), + 'spec_type': 'url', # get_tool_servers_data reads config.enable to filter active servers - "config": {"enable": enabled}, - "info": { - "id": connection.get("id", ""), - "name": connection.get("name", ""), + 'config': {'enable': enabled}, + 'info': { + 'id': connection.get('id', ''), + 'name': connection.get('name', ''), }, } ) @@ -941,9 +868,7 @@ async def set_terminal_servers(request: Request): request.app.state.TERMINAL_SERVERS = await get_tool_servers_data(server_configs) if request.app.state.redis is not None: - await request.app.state.redis.set( - "terminal_servers", json.dumps(request.app.state.TERMINAL_SERVERS) - ) + await request.app.state.redis.set('terminal_servers', json.dumps(request.app.state.TERMINAL_SERVERS)) return request.app.state.TERMINAL_SERVERS @@ -953,12 +878,10 @@ async def get_terminal_servers(request: Request): terminal_servers = [] if request.app.state.redis is not None: try: - terminal_servers = json.loads( - await request.app.state.redis.get("terminal_servers") - ) + terminal_servers = json.loads(await request.app.state.redis.get('terminal_servers')) request.app.state.TERMINAL_SERVERS = terminal_servers except Exception as e: - log.error(f"Error fetching terminal_servers from Redis: {e}") + log.error(f'Error fetching terminal_servers from Redis: {e}') if not terminal_servers: terminal_servers = await set_terminal_servers(request) @@ -980,64 +903,61 @@ async def get_terminal_tools( - Builds callables that route through the terminal proxy """ connections = request.app.state.config.TERMINAL_SERVER_CONNECTIONS or [] - connection = next((c for c in connections if c.get("id") == terminal_id), None) + connection = next((c for c in connections if c.get('id') == terminal_id), None) if connection is None: - log.warning(f"Terminal server not found: {terminal_id}") + log.warning(f'Terminal server not found: {terminal_id}') return {} user_group_ids = {group.id for group in Groups.get_groups_by_member_id(user.id)} if not has_connection_access(user, connection, user_group_ids): - log.warning(f"Access denied to terminal {terminal_id} for user {user.id}") + log.warning(f'Access denied to terminal {terminal_id} for user {user.id}') return {} # Find the cached spec data for this terminal terminal_servers = await get_terminal_servers(request) - server_data = next( - (s for s in terminal_servers if s.get("id") == terminal_id), None - ) + server_data = next((s for s in terminal_servers if s.get('id') == terminal_id), None) if server_data is None: - log.warning(f"Terminal server spec not found for {terminal_id}") + log.warning(f'Terminal server spec not found for {terminal_id}') return {} - specs = server_data.get("specs", []) + specs = server_data.get('specs', []) if not specs: return {} # Build auth headers - auth_type = connection.get("auth_type", "bearer") + auth_type = connection.get('auth_type', 'bearer') cookies = {} - headers = {"Content-Type": "application/json", "X-User-Id": user.id} + headers = {'Content-Type': 'application/json', 'X-User-Id': user.id} - if auth_type == "bearer": - headers["Authorization"] = f"Bearer {connection.get('key', '')}" - elif auth_type == "session": + if auth_type == 'bearer': + headers['Authorization'] = f'Bearer {connection.get("key", "")}' + elif auth_type == 'session': cookies = request.cookies - headers["Authorization"] = f"Bearer {request.state.token.credentials}" - elif auth_type == "system_oauth": + headers['Authorization'] = f'Bearer {request.state.token.credentials}' + elif auth_type == 'system_oauth': cookies = request.cookies - oauth_token = extra_params.get("__oauth_token__", None) + oauth_token = extra_params.get('__oauth_token__', None) if oauth_token: - headers["Authorization"] = f"Bearer {oauth_token.get('access_token', '')}" + headers['Authorization'] = f'Bearer {oauth_token.get("access_token", "")}' # auth_type == "none": no Authorization header - terminal_cwd = await get_terminal_cwd(connection.get("url", ""), headers, cookies) + terminal_cwd = await get_terminal_cwd(connection.get('url', ''), headers, cookies) tools_dict = {} for spec in specs: - function_name = spec["name"] + function_name = spec['name'] # Inject CWD into run_command description tool_spec = clean_openai_tool_schema(spec) - if function_name == "run_command" and terminal_cwd: - tool_spec["description"] = ( - tool_spec.get("description", "") - + f"\n\nThe current working directory is: {terminal_cwd}" + if function_name == 'run_command' and terminal_cwd: + tool_spec['description'] = ( + tool_spec.get('description', '') + f'\n\nThe current working directory is: {terminal_cwd}' ) def make_tool_function(fn_name, srv_data, hdrs, cks): async def tool_function(**kwargs): return await execute_tool_server( - url=srv_data["url"], + url=srv_data['url'], headers=hdrs, cookies=cks, name=fn_name, @@ -1051,10 +971,10 @@ async def get_terminal_tools( callable = get_async_tool_function_and_apply_extra_params(tool_function, {}) tools_dict[function_name] = { - "tool_id": f"terminal:{terminal_id}", - "callable": callable, - "spec": tool_spec, - "type": "terminal", + 'tool_id': f'terminal:{terminal_id}', + 'callable': callable, + 'spec': tool_spec, + 'type': 'terminal', } return tools_dict @@ -1062,8 +982,8 @@ async def get_terminal_tools( async def get_tool_server_data(url: str, headers: Optional[dict]) -> Dict[str, Any]: _headers = { - "Accept": "application/json", - "Content-Type": "application/json", + 'Accept': 'application/json', + 'Content-Type': 'application/json', } if headers: @@ -1073,9 +993,7 @@ async def get_tool_server_data(url: str, headers: Optional[dict]) -> Dict[str, A try: timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_TOOL_SERVER_DATA) async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session: - async with session.get( - url, headers=_headers, ssl=AIOHTTP_CLIENT_SESSION_TOOL_SERVER_SSL - ) as response: + async with session.get(url, headers=_headers, ssl=AIOHTTP_CLIENT_SESSION_TOOL_SERVER_SSL) as response: if response.status != 200: error_body = await response.json() raise Exception(error_body) @@ -1083,7 +1001,7 @@ async def get_tool_server_data(url: str, headers: Optional[dict]) -> Dict[str, A text_content = None # Check if URL ends with .yaml or .yml to determine format - if url.lower().endswith((".yaml", ".yml")): + if url.lower().endswith(('.yaml', '.yml')): text_content = await response.text() res = yaml.safe_load(text_content) else: @@ -1098,14 +1016,14 @@ async def get_tool_server_data(url: str, headers: Optional[dict]) -> Dict[str, A raise e except Exception as err: - log.exception(f"Could not fetch tool server spec from {url}") - if isinstance(err, dict) and "detail" in err: - error = err["detail"] + log.exception(f'Could not fetch tool server spec from {url}') + if isinstance(err, dict) and 'detail' in err: + error = err['detail'] else: error = str(err) raise Exception(error) - log.debug(f"Fetched data: {res}") + log.debug(f'Fetched data: {res}') return res @@ -1115,46 +1033,43 @@ async def get_tool_servers_data(servers: List[Dict[str, Any]]) -> List[Dict[str, tasks = [] server_entries = [] for idx, server in enumerate(servers): - if ( - server.get("config", {}).get("enable") - and server.get("type", "openapi") == "openapi" - ): - info = server.get("info", {}) + if server.get('config', {}).get('enable') and server.get('type', 'openapi') == 'openapi': + info = server.get('info', {}) - auth_type = server.get("auth_type", "bearer") + auth_type = server.get('auth_type', 'bearer') token = None - if auth_type == "bearer": - token = server.get("key", "") - elif auth_type == "none": + if auth_type == 'bearer': + token = server.get('key', '') + elif auth_type == 'none': # No authentication pass - id = info.get("id") + id = info.get('id') if not id: id = str(idx) - server_url = server.get("url") - spec_type = server.get("spec_type", "url") + server_url = server.get('url') + spec_type = server.get('spec_type', 'url') # Create async tasks to fetch data task = None - if spec_type == "url": + if spec_type == 'url': # Path (to OpenAPI spec URL) can be either a full URL or a path to append to the base URL - openapi_path = server.get("path", "openapi.json") + openapi_path = server.get('path', 'openapi.json') spec_url = get_tool_server_url(server_url, openapi_path) # Fetch from URL task = get_tool_server_data( spec_url, - {"Authorization": f"Bearer {token}"} if token else None, + {'Authorization': f'Bearer {token}'} if token else None, ) - elif spec_type == "json" and server.get("spec", ""): + elif spec_type == 'json' and server.get('spec', ''): # Use provided JSON spec spec_json = None try: - spec_json = json.loads(server.get("spec", "")) + spec_json = json.loads(server.get('spec', '')) except Exception as e: - log.error(f"Error parsing JSON spec for tool server {id}: {e}") + log.error(f'Error parsing JSON spec for tool server {id}: {e}') if spec_json: task = asyncio.sleep( @@ -1173,38 +1088,38 @@ async def get_tool_servers_data(servers: List[Dict[str, Any]]) -> List[Dict[str, results = [] for (id, idx, server, url, info, _), response in zip(server_entries, responses): if isinstance(response, Exception): - log.error(f"Failed to connect to {url} OpenAPI tool server") + log.error(f'Failed to connect to {url} OpenAPI tool server') continue # Guard against invalid or non-OpenAPI specs (e.g., MCP-style configs) - if not isinstance(response, dict) or "paths" not in response: + if not isinstance(response, dict) or 'paths' not in response: log.warning(f"Invalid OpenAPI spec from {url}: missing 'paths'") continue response = { - "openapi": response, - "info": response.get("info", {}), - "specs": convert_openapi_to_tool_payload(response), + 'openapi': response, + 'info': response.get('info', {}), + 'specs': convert_openapi_to_tool_payload(response), } - openapi_data = response.get("openapi", {}) + openapi_data = response.get('openapi', {}) if info and isinstance(openapi_data, dict): - openapi_data["info"] = openapi_data.get("info", {}) + openapi_data['info'] = openapi_data.get('info', {}) - if "name" in info: - openapi_data["info"]["title"] = info.get("name", "Tool Server") + if 'name' in info: + openapi_data['info']['title'] = info.get('name', 'Tool Server') - if "description" in info: - openapi_data["info"]["description"] = info.get("description", "") + if 'description' in info: + openapi_data['info']['description'] = info.get('description', '') results.append( { - "id": str(id), - "idx": idx, - "url": (server.get("url") or "").rstrip("/"), - "openapi": openapi_data, - "info": response.get("info"), - "specs": response.get("specs"), + 'id': str(id), + 'idx': idx, + 'url': (server.get('url') or '').rstrip('/'), + 'openapi': openapi_data, + 'info': response.get('info'), + 'specs': response.get('specs'), } ) @@ -1221,31 +1136,31 @@ async def execute_tool_server( ) -> Tuple[Dict[str, Any], Optional[Dict[str, Any]]]: error = None try: - openapi = server_data.get("openapi", {}) - paths = openapi.get("paths", {}) + openapi = server_data.get('openapi', {}) + paths = openapi.get('paths', {}) matching_route = None for route_path, methods in paths.items(): for http_method, operation in methods.items(): - if isinstance(operation, dict) and operation.get("operationId") == name: + if isinstance(operation, dict) and operation.get('operationId') == name: matching_route = (route_path, methods) break if matching_route: break if not matching_route: - raise Exception(f"No matching route found for operationId: {name}") + raise Exception(f'No matching route found for operationId: {name}') route_path, methods = matching_route method_entry = None for http_method, operation in methods.items(): - if operation.get("operationId") == name: + if operation.get('operationId') == name: method_entry = (http_method.lower(), operation) break if not method_entry: - raise Exception(f"No matching method found for operationId: {name}") + raise Exception(f'No matching method found for operationId: {name}') http_method, operation = method_entry @@ -1253,27 +1168,27 @@ async def execute_tool_server( query_params = {} body_params = {} - for param in operation.get("parameters", []): - param_name = param.get("name") + for param in operation.get('parameters', []): + param_name = param.get('name') if not param_name: continue - param_in = param.get("in") + param_in = param.get('in') if param_name in params: - if param_in == "path": + if param_in == 'path': path_params[param_name] = params[param_name] - elif param_in == "query": + elif param_in == 'query': if params[param_name] is not None: query_params[param_name] = params[param_name] - final_url = f"{url.rstrip('/')}{route_path}" + final_url = f'{url.rstrip("/")}{route_path}' for key, value in path_params.items(): - final_url = final_url.replace(f"{{{key}}}", str(value)) + final_url = final_url.replace(f'{{{key}}}', str(value)) if query_params: - query_string = "&".join(f"{k}={v}" for k, v in query_params.items()) - final_url = f"{final_url}?{query_string}" + query_string = '&'.join(f'{k}={v}' for k, v in query_params.items()) + final_url = f'{final_url}?{query_string}' - if operation.get("requestBody", {}).get("content"): + if operation.get('requestBody', {}).get('content'): if params: body_params = params @@ -1282,7 +1197,7 @@ async def execute_tool_server( ) as session: request_method = getattr(session, http_method.lower()) - if http_method in ["post", "put", "patch", "delete"]: + if http_method in ['post', 'put', 'patch', 'delete']: async with request_method( final_url, json=body_params, @@ -1293,7 +1208,7 @@ async def execute_tool_server( ) as response: if response.status >= 400: text = await response.text() - raise Exception(f"HTTP error {response.status}: {text}") + raise Exception(f'HTTP error {response.status}: {text}') try: response_data = await response.json() @@ -1312,7 +1227,7 @@ async def execute_tool_server( ) as response: if response.status >= 400: text = await response.text() - raise Exception(f"HTTP error {response.status}: {text}") + raise Exception(f'HTTP error {response.status}: {text}') try: response_data = await response.json() @@ -1324,20 +1239,20 @@ async def execute_tool_server( except Exception as err: error = str(err) - log.exception(f"API Request Error: {error}") - return ({"error": error}, None) + log.exception(f'API Request Error: {error}') + return ({'error': error}, None) def get_tool_server_url(url: Optional[str], path: str) -> str: """ Build the full URL for a tool server, given a base url and a path. """ - if "://" in path: + if '://' in path: # If it contains "://", it's a full URL return path if url: - url = url.rstrip("/") - if not path.startswith("/"): + url = url.rstrip('/') + if not path.startswith('/'): # Ensure the path starts with a slash - path = f"/{path}" - return f"{url}{path}" + path = f'/{path}' + return f'{url}{path}' diff --git a/backend/open_webui/utils/validate.py b/backend/open_webui/utils/validate.py index 6e62dd5416..c2064de257 100644 --- a/backend/open_webui/utils/validate.py +++ b/backend/open_webui/utils/validate.py @@ -2,8 +2,8 @@ # Known static asset paths used as default profile images _ALLOWED_STATIC_PATHS = ( - "/user.png", - "/static/favicon.png", + '/user.png', + '/static/favicon.png', ) @@ -22,10 +22,10 @@ def validate_profile_image_url(url: str) -> str: return url _ALLOWED_DATA_PREFIXES = ( - "data:image/png", - "data:image/jpeg", - "data:image/gif", - "data:image/webp", + 'data:image/png', + 'data:image/jpeg', + 'data:image/gif', + 'data:image/webp', ) if any(url.startswith(prefix) for prefix in _ALLOWED_DATA_PREFIXES): return url @@ -33,6 +33,4 @@ def validate_profile_image_url(url: str) -> str: if url in _ALLOWED_STATIC_PATHS: return url - raise ValueError( - "Invalid profile image URL: only data URIs and default avatars are allowed." - ) + raise ValueError('Invalid profile image URL: only data URIs and default avatars are allowed.') diff --git a/backend/open_webui/utils/webhook.py b/backend/open_webui/utils/webhook.py index b3a3c6bcd1..800450dfd3 100644 --- a/backend/open_webui/utils/webhook.py +++ b/backend/open_webui/utils/webhook.py @@ -10,42 +10,36 @@ log = logging.getLogger(__name__) async def post_webhook(name: str, url: str, message: str, event_data: dict) -> bool: try: - log.debug(f"post_webhook: {url}, {message}, {event_data}") + log.debug(f'post_webhook: {url}, {message}, {event_data}') payload = {} # Slack and Google Chat Webhooks - if "https://hooks.slack.com" in url or "https://chat.googleapis.com" in url: - payload["text"] = message + if 'https://hooks.slack.com' in url or 'https://chat.googleapis.com' in url: + payload['text'] = message # Discord Webhooks - elif "https://discord.com/api/webhooks" in url: - payload["content"] = ( - message - if len(message) < 2000 - else f"{message[: 2000 - 20]}... (truncated)" - ) + elif 'https://discord.com/api/webhooks' in url: + payload['content'] = message if len(message) < 2000 else f'{message[: 2000 - 20]}... (truncated)' # Microsoft Teams Webhooks - elif "webhook.office.com" in url: - action = event_data.get("action", "undefined") - user_data = event_data.get("user", "{}") + elif 'webhook.office.com' in url: + action = event_data.get('action', 'undefined') + user_data = event_data.get('user', '{}') if isinstance(user_data, dict): user_dict = user_data else: user_dict = json.loads(user_data) - facts = [ - {"name": name, "value": value} for name, value in user_dict.items() - ] + facts = [{'name': name, 'value': value} for name, value in user_dict.items()] payload = { - "@type": "MessageCard", - "@context": "http://schema.org/extensions", - "themeColor": "0076D7", - "summary": message, - "sections": [ + '@type': 'MessageCard', + '@context': 'http://schema.org/extensions', + 'themeColor': '0076D7', + 'summary': message, + 'sections': [ { - "activityTitle": message, - "activitySubtitle": f"{name} ({VERSION}) - {action}", - "activityImage": WEBUI_FAVICON_URL, - "facts": facts, - "markdown": True, + 'activityTitle': message, + 'activitySubtitle': f'{name} ({VERSION}) - {action}', + 'activityImage': WEBUI_FAVICON_URL, + 'facts': facts, + 'markdown': True, } ], } @@ -53,14 +47,14 @@ async def post_webhook(name: str, url: str, message: str, event_data: dict) -> b else: payload = {**event_data} - log.debug(f"payload: {payload}") + log.debug(f'payload: {payload}') async with aiohttp.ClientSession( trust_env=True, timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT) ) as session: async with session.post(url, json=payload) as r: r_text = await r.text() r.raise_for_status() - log.debug(f"r.text: {r_text}") + log.debug(f'r.text: {r_text}') return True except Exception as e: diff --git a/contribution_stats.py b/contribution_stats.py index 3caa4738ec..2dd9eab187 100644 --- a/contribution_stats.py +++ b/contribution_stats.py @@ -2,15 +2,15 @@ import os import subprocess from collections import Counter -CONFIG_FILE_EXTENSIONS = (".json", ".yml", ".yaml", ".ini", ".conf", ".toml") +CONFIG_FILE_EXTENSIONS = ('.json', '.yml', '.yaml', '.ini', '.conf', '.toml') def is_text_file(filepath): # Check for binary file by scanning for null bytes. try: - with open(filepath, "rb") as f: + with open(filepath, 'rb') as f: chunk = f.read(4096) - if b"\0" in chunk: + if b'\0' in chunk: return False return True except Exception: @@ -20,7 +20,7 @@ def is_text_file(filepath): def should_skip_file(path): base = os.path.basename(path) # Skip dotfiles and dotdirs - if base.startswith("."): + if base.startswith('.'): return True # Skip config files by extension if base.lower().endswith(CONFIG_FILE_EXTENSIONS): @@ -30,12 +30,12 @@ def should_skip_file(path): def get_tracked_files(): try: - output = subprocess.check_output(["git", "ls-files"], text=True) - files = output.strip().split("\n") + output = subprocess.check_output(['git', 'ls-files'], text=True) + files = output.strip().split('\n') files = [f for f in files if f and os.path.isfile(f)] return files except subprocess.CalledProcessError: - print("Error: Are you in a git repository?") + print('Error: Are you in a git repository?') return [] @@ -50,14 +50,12 @@ def main(): if not is_text_file(file): continue try: - blame = subprocess.check_output( - ["git", "blame", "-e", file], text=True, errors="replace" - ) + blame = subprocess.check_output(['git', 'blame', '-e', file], text=True, errors='replace') for line in blame.splitlines(): # The email always inside <> - if "<" in line and ">" in line: + if '<' in line and '>' in line: try: - email = line.split("<")[1].split(">")[0].strip() + email = line.split('<')[1].split('>')[0].strip() except Exception: continue email_counter[email] += 1 @@ -67,8 +65,8 @@ def main(): for email, lines in email_counter.most_common(): percent = (lines / total_lines * 100) if total_lines else 0 - print(f"{email}: {lines}/{total_lines} {percent:.2f}%") + print(f'{email}: {lines}/{total_lines} {percent:.2f}%') -if __name__ == "__main__": +if __name__ == '__main__': main() diff --git a/hatch_build.py b/hatch_build.py index 28aad1b6cd..f18f323635 100644 --- a/hatch_build.py +++ b/hatch_build.py @@ -10,14 +10,12 @@ from hatchling.builders.hooks.plugin.interface import BuildHookInterface class CustomBuildHook(BuildHookInterface): def initialize(self, version, build_data): super().initialize(version, build_data) - stderr.write(">>> Building Open Webui frontend\n") - npm = shutil.which("npm") + stderr.write('>>> Building Open Webui frontend\n') + npm = shutil.which('npm') if npm is None: - raise RuntimeError( - "NodeJS `npm` is required for building Open Webui but it was not found" - ) - stderr.write("### npm install\n") - subprocess.run([npm, "install", "--force"], check=True) # noqa: S603 - stderr.write("\n### npm run build\n") - os.environ["APP_BUILD_HASH"] = version - subprocess.run([npm, "run", "build"], check=True) # noqa: S603 + raise RuntimeError('NodeJS `npm` is required for building Open Webui but it was not found') + stderr.write('### npm install\n') + subprocess.run([npm, 'install', '--force'], check=True) # noqa: S603 + stderr.write('\n### npm run build\n') + os.environ['APP_BUILD_HASH'] = version + subprocess.run([npm, 'run', 'build'], check=True) # noqa: S603 diff --git a/package.json b/package.json index 68c746457a..887781d325 100644 --- a/package.json +++ b/package.json @@ -15,7 +15,7 @@ "lint:types": "npm run check", "lint:backend": "pylint backend/", "format": "prettier --plugin-search-dir --write \"**/*.{js,ts,svelte,css,md,html,json}\"", - "format:backend": "black . --exclude \".venv/|/venv/\"", + "format:backend": "ruff format . --exclude .venv --exclude venv", "i18n:parse": "i18next --config i18next-parser.config.ts && prettier --write \"src/lib/i18n/**/*.{js,json}\"", "cy:open": "cypress open", "test:frontend": "vitest --passWithNoTests",