sources/oauth: Allow patching without provider type (#21211)

* sources/oauth: Allow patching without provider type

* fix, add test

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

---------

Signed-off-by: Jens Langhammer <jens@goauthentik.io>
Co-authored-by: Jens Langhammer <jens@goauthentik.io>
This commit is contained in:
Marcus Yanello
2026-03-28 08:31:29 -05:00
committed by GitHub
parent b2061ab3b2
commit 9a974f14c8
2 changed files with 38 additions and 4 deletions

View File

@@ -59,7 +59,11 @@ class OAuthSourceSerializer(SourceSerializer):
def validate(self, attrs: dict) -> dict:
session = get_http_session()
source_type = registry.find_type(attrs["provider_type"])
provider_type_name = attrs.get(
"provider_type",
self.instance.provider_type if self.instance else None,
)
source_type = registry.find_type(provider_type_name)
well_known = attrs.get("oidc_well_known_url") or source_type.oidc_well_known_url
inferred_oidc_jwks_url = None
@@ -101,16 +105,15 @@ class OAuthSourceSerializer(SourceSerializer):
config = jwks_config.json()
attrs["oidc_jwks"] = config
provider_type = registry.find_type(attrs.get("provider_type", ""))
for url in [
"authorization_url",
"access_token_url",
"profile_url",
]:
if getattr(provider_type, url, None) is None:
if getattr(source_type, url, None) is None:
if url not in attrs:
raise ValidationError(
f"{url} is required for provider {provider_type.verbose_name}"
f"{url} is required for provider {source_type.verbose_name}"
)
return attrs

View File

@@ -0,0 +1,31 @@
from django.urls import reverse
from rest_framework.test import APITestCase
from authentik.core.tests.utils import create_test_admin_user
from authentik.lib.generators import generate_id
from authentik.sources.oauth.models import OAuthSource
class TestOAuthSourceAPI(APITestCase):
def setUp(self):
self.source = OAuthSource.objects.create(
name=generate_id(),
slug=generate_id(),
provider_type="openidconnect",
authorization_url="",
profile_url="",
consumer_key=generate_id(),
)
self.user = create_test_admin_user()
def test_patch_no_type(self):
self.client.force_login(self.user)
res = self.client.patch(
reverse("authentik_api:oauthsource-detail", kwargs={"slug": self.source.slug}),
{
"authorization_url": f"https://{generate_id()}",
"profile_url": f"https://{generate_id()}",
"access_token_url": f"https://{generate_id()}",
},
)
self.assertEqual(res.status_code, 200)