diff --git a/backend/open_webui/utils/oauth.py b/backend/open_webui/utils/oauth.py index 47302e7535..4a7d79d87c 100644 --- a/backend/open_webui/utils/oauth.py +++ b/backend/open_webui/utils/oauth.py @@ -304,57 +304,67 @@ async def get_authorization_server_discovery_urls(server_url: str) -> list[str]: ssl=AIOHTTP_CLIENT_SESSION_SSL, ) as response: if response.status == 401: + resource_metadata_urls = [] match = re.search( r'resource_metadata=(?:"([^"]+)"|([^\s,]+))', response.headers.get('WWW-Authenticate', ''), ) if match: - resource_metadata_url = match.group(1) or match.group(2) - log.debug(f'Found resource_metadata URL: {resource_metadata_url}') + resource_metadata_urls = [match.group(1) or match.group(2)] + log.debug(f'Found resource_metadata URL: {resource_metadata_urls[0]}') + else: + # Fall back to well-known resource metadata URIs (RFC 9728 ยง4.2) + parsed, base_url = get_parsed_and_base_url(server_url) + if parsed.path and parsed.path != '/': + path = parsed.path.rstrip('/') + resource_metadata_urls.append( + urllib.parse.urljoin(base_url, f'/.well-known/oauth-protected-resource{path}') + ) + resource_metadata_urls.append( + urllib.parse.urljoin(base_url, '/.well-known/oauth-protected-resource') + ) + log.debug(f'No resource_metadata in header, trying well-known URIs: {resource_metadata_urls}') - # Step 2: Fetch Protected Resource metadata - async with session.get( - resource_metadata_url, ssl=AIOHTTP_CLIENT_SESSION_SSL - ) as resource_response: - if resource_response.status == 200: - resource_metadata = await resource_response.json() + # Fetch Protected Resource metadata from candidate URLs + for resource_metadata_url in resource_metadata_urls: + try: + async with session.get( + resource_metadata_url, ssl=AIOHTTP_CLIENT_SESSION_SSL + ) as resource_response: + if resource_response.status == 200: + resource_metadata = await resource_response.json() - # Step 3: Extract authorization_servers - servers = resource_metadata.get('authorization_servers', []) - if servers: - authorization_servers = servers - log.debug(f'Discovered authorization servers: {servers}') + servers = resource_metadata.get('authorization_servers', []) + if servers: + authorization_servers = servers + log.debug(f'Discovered authorization servers: {servers}') + break + except Exception as e: + log.debug(f'Failed to fetch resource metadata from {resource_metadata_url}: {e}') + continue except Exception as e: log.debug(f'MCP Protected Resource discovery failed: {e}') discovery_urls = [] for auth_server in authorization_servers: auth_server = auth_server.rstrip('/') - discovery_urls.extend( - [ - f'{auth_server}/.well-known/oauth-authorization-server', - f'{auth_server}/.well-known/openid-configuration', - ] - ) + discovery_urls.extend(_build_well_known_urls(auth_server)) return discovery_urls -async def get_discovery_urls(server_url) -> list[str]: - urls = await get_authorization_server_discovery_urls(server_url) +def _build_well_known_urls(server_url: str) -> list[str]: + """Build RFC 8414 / OIDC Discovery well-known URLs for a server URL.""" parsed, base_url = get_parsed_and_base_url(server_url) + urls = [] if parsed.path and parsed.path != '/': - # Generate discovery URLs based on https://modelcontextprotocol.io/specification/draft/basic/authorization#authorization-server-metadata-discovery - tenant = parsed.path.rstrip('/') + path = parsed.path.rstrip('/') urls.extend( [ - urllib.parse.urljoin( - base_url, - f'/.well-known/oauth-authorization-server{tenant}', - ), - urllib.parse.urljoin(base_url, f'/.well-known/openid-configuration{tenant}'), - urllib.parse.urljoin(base_url, f'{tenant}/.well-known/openid-configuration'), + urllib.parse.urljoin(base_url, f'/.well-known/oauth-authorization-server{path}'), + urllib.parse.urljoin(base_url, f'/.well-known/openid-configuration{path}'), + urllib.parse.urljoin(base_url, f'{path}/.well-known/openid-configuration'), ] ) @@ -368,6 +378,12 @@ async def get_discovery_urls(server_url) -> list[str]: return urls +async def get_discovery_urls(server_url) -> list[str]: + urls = await get_authorization_server_discovery_urls(server_url) + urls.extend(_build_well_known_urls(server_url)) + return urls + + # TODO: Some OAuth providers require Initial Access Tokens (IATs) for dynamic client registration. # This is not currently supported. async def get_oauth_client_info_with_dynamic_client_registration(