This commit is contained in:
Timothy Jaeryang Baek
2026-04-24 16:17:46 +09:00
parent 5cc55e2278
commit 678c44c7cd
12 changed files with 268 additions and 76 deletions

View File

@@ -3,6 +3,7 @@ import json
import logging
import ssl as _stdlib_ssl
from contextlib import asynccontextmanager, contextmanager
from dataclasses import dataclass
from typing import Any, Optional
from urllib.parse import parse_qs, urlencode, urlparse, urlunparse
@@ -37,90 +38,154 @@ from typing_extensions import Self
log = logging.getLogger(__name__)
def extract_ssl_mode_from_url(url: str) -> tuple[str, str | None]:
"""Strip SSL query-string parameters from a PostgreSQL URL.
@dataclass
class SSLParams:
"""SSL parameters extracted from a PostgreSQL ``DATABASE_URL``.
asyncpg and psycopg2 use different query-string keys for SSL
(``ssl`` vs ``sslmode``). This helper removes **both** from the
URL so that each driver can receive the correct parameter through
its own mechanism (query-string re-injection for psycopg2,
``connect_args`` for asyncpg).
Returns
-------
(url_without_ssl, ssl_mode)
*url_without_ssl* is the original URL with ``ssl`` / ``sslmode``
query parameters removed. *ssl_mode* is the extracted mode
string (e.g. ``'require'``), or ``None`` if neither parameter
was present.
Non-PostgreSQL URLs are returned unchanged with ``ssl_mode=None``.
Holds the connection-mode flag and optional certificate file paths
so that each driver (asyncpg, psycopg2/libpq) can receive them in
the format it expects.
"""
if not url or not any(url.startswith(prefix) for prefix in ('postgresql://', 'postgresql+', 'postgres://')):
return url, None
mode: str | None = None
rootcert: str | None = None
cert: str | None = None
key: str | None = None
crl: str | None = None
def __bool__(self) -> bool:
return self.mode is not None
@property
def has_any(self) -> bool:
"""True when *any* SSL-related field is set (mode or cert files)."""
return any((self.mode, self.rootcert, self.cert, self.key, self.crl))
# ── URL extraction / reattachment ────────────────────────────────────
def _pop_first(params: dict[str, list[str]], key: str) -> str | None:
"""Pop a single-valued query param, returning ``None`` if absent."""
values = params.pop(key, None)
return values[0] if values else None
def extract_ssl_params_from_url(url: str) -> tuple[str, SSLParams]:
"""Strip all SSL query-string parameters from a PostgreSQL URL.
asyncpg does not accept libpq-style certificate-file keys
(``sslrootcert``, ``sslcert``, ``sslkey``, ``sslcrl``), so every
SSL-related key is removed and returned as a structured
:class:`SSLParams` object.
Returns ``(url_without_ssl, ssl_params)``. Non-PostgreSQL URLs are
returned unchanged with an empty ``SSLParams``.
"""
if not url or not any(
url.startswith(p) for p in ('postgresql://', 'postgresql+', 'postgres://')
):
return url, SSLParams()
parsed = urlparse(url)
query_params = parse_qs(parsed.query, keep_blank_values=True)
qp = parse_qs(parsed.query, keep_blank_values=True)
# Prefer sslmode (libpq canonical) over the asyncpg-only ssl key.
ssl_mode: str | None = None
for key in ('sslmode', 'ssl'):
values = query_params.pop(key, None)
if values and ssl_mode is None:
ssl_mode = values[0]
# Prefer sslmode (libpq canonical) over the asyncpg-only ``ssl`` key.
# Both must be popped unconditionally so neither leaks into the cleaned URL.
sslmode_val = _pop_first(qp, 'sslmode')
ssl_val = _pop_first(qp, 'ssl')
ssl_mode = sslmode_val or ssl_val
if ssl_mode is None:
# Nothing to strip — return the URL untouched.
return url, None
params = SSLParams(
mode=ssl_mode,
rootcert=_pop_first(qp, 'sslrootcert'),
cert=_pop_first(qp, 'sslcert'),
key=_pop_first(qp, 'sslkey'),
crl=_pop_first(qp, 'sslcrl'),
)
# Rebuild the query string without the SSL keys.
remaining_query = urlencode(query_params, doseq=True)
url_without_ssl = urlunparse(parsed._replace(query=remaining_query))
return url_without_ssl, ssl_mode
if not params.has_any:
return url, params
cleaned_query = urlencode(qp, doseq=True)
return urlunparse(parsed._replace(query=cleaned_query)), params
def build_asyncpg_ssl_args(ssl_mode: str | None) -> dict:
"""Convert a libpq-style SSL mode value to asyncpg ``connect_args``.
def reattach_ssl_params_to_url(url_without_ssl: str, ssl_params: SSLParams) -> str:
"""Re-append SSL query-string parameters to a cleaned PostgreSQL URL.
Used for psycopg2/libpq consumers that expect ``sslmode`` and the
certificate-file keys in the connection string.
"""
if not ssl_params:
return url_without_ssl
mapping = (
('sslmode', ssl_params.mode),
('sslrootcert', ssl_params.rootcert),
('sslcert', ssl_params.cert),
('sslkey', ssl_params.key),
('sslcrl', ssl_params.crl),
)
parts = [f'{k}={v}' for k, v in mapping if v]
if not parts:
return url_without_ssl
sep = '&' if '?' in url_without_ssl else '?'
return f'{url_without_ssl}{sep}{"&".join(parts)}'
# ── asyncpg SSLContext builder ───────────────────────────────────────
def _make_ssl_context(ssl_params: SSLParams, *, verify: bool) -> _stdlib_ssl.SSLContext:
"""Create an :class:`ssl.SSLContext` from *ssl_params*.
When *verify* is ``False``, hostname checking and certificate
verification are disabled (matching libpq ``require`` semantics).
"""
ctx = _stdlib_ssl.create_default_context(cafile=ssl_params.rootcert)
if not verify:
ctx.check_hostname = False
ctx.verify_mode = _stdlib_ssl.CERT_NONE
if ssl_params.cert and ssl_params.key:
ctx.load_cert_chain(certfile=ssl_params.cert, keyfile=ssl_params.key)
if verify and ssl_params.crl:
ctx.load_verify_locations(cafile=ssl_params.crl)
ctx.verify_flags |= _stdlib_ssl.VERIFY_CRL_CHECK_LEAF
return ctx
def build_asyncpg_ssl_args(ssl_params: SSLParams) -> dict:
"""Convert :class:`SSLParams` to asyncpg-compatible ``connect_args``.
Returns a dict suitable for unpacking into
``create_async_engine(..., connect_args=...)``.
``create_async_engine(...)``.
"""
if ssl_mode is None:
if not ssl_params:
return {}
mode = ssl_mode.lower()
mode = (ssl_params.mode or 'require').lower()
if mode == 'disable':
return {'connect_args': {'ssl': False}}
if mode in ('allow', 'prefer'):
# asyncpg has no direct equivalent — omit to let it try without.
return {}
if mode == 'require':
# SSL required but no certificate verification (matches libpq).
ctx = _stdlib_ssl.create_default_context()
ctx.check_hostname = False
ctx.verify_mode = _stdlib_ssl.CERT_NONE
return {'connect_args': {'ssl': ctx}}
return {'connect_args': {'ssl': _make_ssl_context(ssl_params, verify=False)}}
if mode in ('verify-ca', 'verify-full'):
# Full verification — use the system trust store.
ctx = _stdlib_ssl.create_default_context()
ctx = _make_ssl_context(ssl_params, verify=True)
if mode == 'verify-ca':
ctx.check_hostname = False
return {'connect_args': {'ssl': ctx}}
# Unknown value — pass through as-is and let asyncpg decide.
return {'connect_args': {'ssl': ssl_mode}}
return {'connect_args': {'ssl': ssl_params.mode}}
def reattach_ssl_mode_to_url(url_without_ssl: str, ssl_mode: str | None) -> str:
"""Re-append ``sslmode=<value>`` to a cleaned PostgreSQL URL.
Used for psycopg2 / libpq consumers that expect the canonical
``sslmode`` query-string key.
"""
if ssl_mode is None:
return url_without_ssl
separator = '&' if '?' in url_without_ssl else '?'
return f'{url_without_ssl}{separator}sslmode={ssl_mode}'
# Backwards-compatible aliases for external callers.
extract_ssl_mode_from_url = extract_ssl_params_from_url
reattach_ssl_mode_to_url = reattach_ssl_params_to_url
class JSONField(types.TypeDecorator):
@@ -150,9 +215,10 @@ class JSONField(types.TypeDecorator):
def handle_peewee_migration(DATABASE_URL):
db = None
try:
# Normalize SSL params so psycopg2 always sees `sslmode=` (never `ssl=`).
url_without_ssl, ssl_mode = extract_ssl_mode_from_url(DATABASE_URL)
normalized_url = reattach_ssl_mode_to_url(url_without_ssl, ssl_mode)
# Normalize SSL params so psycopg2 always sees `sslmode=` (never `ssl=`)
# and cert-file params are preserved in the connection string.
url_without_ssl, ssl_params = extract_ssl_params_from_url(DATABASE_URL)
normalized_url = reattach_ssl_params_to_url(url_without_ssl, ssl_params)
# Replace the postgresql:// with postgres:// to handle the peewee migration
db = register_connection(normalized_url.replace('postgresql://', 'postgres://'))
@@ -181,11 +247,11 @@ if ENABLE_DB_MIGRATIONS:
# Normalize SSL params from the URL once; each engine branch re-injects
# the driver-appropriate form.
DATABASE_URL_WITHOUT_SSL, DATABASE_SSL_MODE = extract_ssl_mode_from_url(DATABASE_URL)
DATABASE_URL_WITHOUT_SSL, DATABASE_SSL_PARAMS = extract_ssl_params_from_url(DATABASE_URL)
# For psycopg2 (sync engine), re-append sslmode=<value>.
# For psycopg2 (sync engine), re-append sslmode + cert-file params.
SQLALCHEMY_DATABASE_URL = (
reattach_ssl_mode_to_url(DATABASE_URL_WITHOUT_SSL, DATABASE_SSL_MODE) if DATABASE_SSL_MODE else DATABASE_URL
reattach_ssl_params_to_url(DATABASE_URL_WITHOUT_SSL, DATABASE_SSL_PARAMS) if DATABASE_SSL_PARAMS else DATABASE_URL
)
@@ -331,7 +397,7 @@ get_db = contextmanager(get_session)
# Use the SSL-stripped URL for asyncpg — SSL is injected via connect_args.
ASYNC_SQLALCHEMY_DATABASE_URL = _make_async_url(
DATABASE_URL_WITHOUT_SSL if DATABASE_SSL_MODE else SQLALCHEMY_DATABASE_URL
DATABASE_URL_WITHOUT_SSL if DATABASE_SSL_PARAMS else SQLALCHEMY_DATABASE_URL
)
if 'sqlite' in ASYNC_SQLALCHEMY_DATABASE_URL:
@@ -352,7 +418,7 @@ if 'sqlite' in ASYNC_SQLALCHEMY_DATABASE_URL:
else:
# Inject asyncpg-compatible SSL connect_args when the user specified
# sslmode/ssl in DATABASE_URL.
asyncpg_ssl_args = build_asyncpg_ssl_args(DATABASE_SSL_MODE)
asyncpg_ssl_args = build_asyncpg_ssl_args(DATABASE_SSL_PARAMS)
if isinstance(DATABASE_POOL_SIZE, int):
if DATABASE_POOL_SIZE > 0:

View File

@@ -5,7 +5,7 @@ from alembic import context
from open_webui.models.auths import Auth
from open_webui.models.calendar import Calendar, CalendarEvent, CalendarEventAttendee # noqa: F401
from open_webui.env import DATABASE_URL, DATABASE_PASSWORD, LOG_FORMAT
from open_webui.internal.db import extract_ssl_mode_from_url, reattach_ssl_mode_to_url
from open_webui.internal.db import extract_ssl_params_from_url, reattach_ssl_params_to_url
from sqlalchemy import engine_from_config, pool, create_engine
# this is the Alembic Config object, which provides
@@ -38,8 +38,8 @@ target_metadata = Auth.metadata
DB_URL = DATABASE_URL
# Normalize SSL query params for psycopg2 (Alembic uses psycopg2, not asyncpg).
url_without_ssl, ssl_mode = extract_ssl_mode_from_url(DB_URL)
DB_URL = reattach_ssl_mode_to_url(url_without_ssl, ssl_mode) if ssl_mode else DB_URL
url_without_ssl, ssl_params = extract_ssl_params_from_url(DB_URL)
DB_URL = reattach_ssl_params_to_url(url_without_ssl, ssl_params) if ssl_params else DB_URL
if DB_URL:
config.set_main_option('sqlalchemy.url', DB_URL.replace('%', '%%'))

View File

@@ -320,6 +320,21 @@ class OAuthSessionTable:
log.error(f'Error deleting OAuth sessions by user ID: {e}')
return False
async def delete_sessions_by_user_id_and_provider(
self, user_id: str, provider: str, db: Optional[AsyncSession] = None
) -> bool:
"""Delete all OAuth sessions for a specific user and provider"""
try:
async with get_async_db_context(db) as db:
result = await db.execute(
delete(OAuthSession).filter_by(user_id=user_id, provider=provider)
)
await db.commit()
return result.rowcount > 0
except Exception as e:
log.error(f'Error deleting OAuth sessions for user {user_id} and provider {provider}: {e}')
return False
async def delete_sessions_by_provider(self, provider: str, db: Optional[AsyncSession] = None) -> bool:
"""Delete all OAuth sessions for a provider"""
try:

View File

@@ -172,10 +172,17 @@ async def get_session_user(
user=Depends(get_current_user),
db: AsyncSession = Depends(get_async_session),
):
token = None
auth_header = request.headers.get('Authorization')
auth_token = get_http_authorization_cred(auth_header)
token = auth_token.credentials
data = decode_token(token)
if auth_header:
auth_token = get_http_authorization_cred(auth_header)
if auth_token is not None:
token = auth_token.credentials
if token is None:
token = request.cookies.get('token')
if token is None and getattr(request.state, 'token', None):
token = request.state.token.credentials
data = decode_token(token) if token else None
expires_at = None
@@ -773,8 +780,9 @@ async def signout(request: Request, response: Response, db: AsyncSession = Depen
auth_header = request.headers.get('Authorization')
if auth_header:
auth_cred = get_http_authorization_cred(auth_header)
token = auth_cred.credentials
else:
if auth_cred is not None:
token = auth_cred.credentials
if token is None:
token = request.cookies.get('token')
if token:
@@ -853,6 +861,33 @@ async def signout(request: Request, response: Response, db: AsyncSession = Depen
return JSONResponse(status_code=200, content={'status': True}, headers=response.headers)
############################
# OAuth Session Management
############################
@router.delete('/oauth/sessions/{provider:path}', response_model=bool)
async def delete_oauth_session_by_provider(
provider: str,
user=Depends(get_verified_user),
db: AsyncSession = Depends(get_async_session),
):
"""
Disconnect the current user's OAuth session for a specific provider.
The provider string matches the 'provider' field in the oauth_session table
(e.g. 'mcp:server-id' for MCP connections).
"""
result = await OAuthSessions.delete_sessions_by_user_id_and_provider(
user.id, provider, db=db
)
if not result:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail='No OAuth session found for this provider',
)
return True
############################
# AddUser
############################

View File

@@ -917,3 +917,4 @@ async def update_tools_user_valves_by_id(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.NOT_FOUND,
)

View File

@@ -550,6 +550,8 @@ async def update_user_by_id(
detail=ERROR_MESSAGES.ACTION_PROHIBITED,
)
except HTTPException:
raise
except Exception as e:
log.error(f'Error checking primary admin status: {e}')
raise HTTPException(
@@ -631,6 +633,8 @@ async def delete_user_by_id(user_id: str, user=Depends(get_admin_user), db: Asyn
status_code=status.HTTP_403_FORBIDDEN,
detail=ERROR_MESSAGES.ACTION_PROHIBITED,
)
except HTTPException:
raise
except Exception as e:
log.error(f'Error checking primary admin status: {e}')
raise HTTPException(

View File

@@ -712,3 +712,33 @@ export const deleteAPIKey = async (token: string) => {
}
return res;
};
export const deleteOAuthSession = async (token: string, provider: string) => {
let error = null;
const res = await fetch(
`${WEBUI_API_BASE_URL}/auths/oauth/sessions/${encodeURIComponent(provider)}`,
{
method: 'DELETE',
headers: {
'Content-Type': 'application/json',
Authorization: `Bearer ${token}`
}
}
)
.then(async (res) => {
if (!res.ok) throw await res.json();
return res.json();
})
.catch((err) => {
console.error(err);
error = err.detail;
return null;
});
if (error) {
throw error;
}
return res;
};

View File

@@ -647,3 +647,4 @@ export const setBanners = async (token: string, banners: Banner[]) => {
return res;
};

View File

@@ -483,3 +483,4 @@ export const updateUserValvesById = async (token: string, id: string, valves: ob
return res;
};

View File

@@ -550,3 +550,4 @@ export const getUserGroupsById = async (token: string, userId: string) => {
return res;
};

View File

@@ -13,8 +13,11 @@
} from '$lib/stores';
import { getOAuthClientAuthorizationUrl } from '$lib/apis/configs';
import { deleteOAuthSession } from '$lib/apis/auths';
import { getTools } from '$lib/apis/tools';
import { toast } from 'svelte-sonner';
import Knobs from '$lib/components/icons/Knobs.svelte';
import Dropdown from '$lib/components/common/Dropdown.svelte';
import Tooltip from '$lib/components/common/Tooltip.svelte';
@@ -27,6 +30,7 @@
import Terminal from '$lib/components/icons/Terminal.svelte';
import ChevronRight from '$lib/components/icons/ChevronRight.svelte';
import ChevronLeft from '$lib/components/icons/ChevronLeft.svelte';
import LinkSlash from '$lib/components/icons/LinkSlash.svelte';
const i18n = getContext('i18n');
@@ -375,6 +379,40 @@
</div>
</div>
{#if (tools[toolId]?.authenticated ?? true) && toolId.startsWith('server:mcp:')}
<div class="shrink-0">
<Tooltip content={$i18n.t('Disconnect OAuth')}>
<button
class="self-center w-fit text-sm text-gray-600 dark:text-gray-400 hover:text-gray-700 dark:hover:text-gray-300 transition rounded-full"
type="button"
on:click={async (e) => {
e.stopPropagation();
e.preventDefault();
const parts = toolId.split(':');
const serverId = parts.at(-1) ?? toolId;
const provider = `mcp:${serverId}`;
try {
await deleteOAuthSession(localStorage.token, provider);
toast.success($i18n.t('OAuth session disconnected'));
// Refresh tools to update authenticated state
_tools.set(await getTools(localStorage.token));
// Remove from selected if it was selected
selectedToolIds = selectedToolIds.filter((id) => id !== toolId);
} catch (err) {
toast.error(err ?? $i18n.t('Failed to disconnect'));
}
}}
>
<LinkSlash className="size-3.5" />
</button>
</Tooltip>
</div>
{/if}
{#if tools[toolId]?.has_user_valves && ($user?.role === 'admin' || ($user?.permissions?.chat?.valves ?? true))}
<div class=" shrink-0">
<Tooltip content={$i18n.t('Valves')}>

View File

@@ -4987,10 +4987,10 @@
},
"pathspec": {
"name": "pathspec",
"version": "1.0.4",
"file_name": "pathspec-1.0.4-py3-none-any.whl",
"version": "1.1.0",
"file_name": "pathspec-1.1.0-py3-none-any.whl",
"install_dir": "site",
"sha256": "fb6ae2fd4e7c921a165808a552060e722767cfa526f99ca5156ed2ce45a5c723",
"sha256": "574b128f7456bd899045ccd142dd446af7e6cfd0072d63ad73fbc55fbb4aaa42",
"package_type": "package",
"imports": [
"pathspec"