diff --git a/authentik/sources/oauth/api/source.py b/authentik/sources/oauth/api/source.py index 805de67c65..f918d26967 100644 --- a/authentik/sources/oauth/api/source.py +++ b/authentik/sources/oauth/api/source.py @@ -59,7 +59,11 @@ class OAuthSourceSerializer(SourceSerializer): def validate(self, attrs: dict) -> dict: session = get_http_session() - source_type = registry.find_type(attrs["provider_type"]) + provider_type_name = attrs.get( + "provider_type", + self.instance.provider_type if self.instance else None, + ) + source_type = registry.find_type(provider_type_name) well_known = attrs.get("oidc_well_known_url") or source_type.oidc_well_known_url inferred_oidc_jwks_url = None @@ -101,16 +105,15 @@ class OAuthSourceSerializer(SourceSerializer): config = jwks_config.json() attrs["oidc_jwks"] = config - provider_type = registry.find_type(attrs.get("provider_type", "")) for url in [ "authorization_url", "access_token_url", "profile_url", ]: - if getattr(provider_type, url, None) is None: + if getattr(source_type, url, None) is None: if url not in attrs: raise ValidationError( - f"{url} is required for provider {provider_type.verbose_name}" + f"{url} is required for provider {source_type.verbose_name}" ) return attrs diff --git a/authentik/sources/oauth/tests/test_api.py b/authentik/sources/oauth/tests/test_api.py new file mode 100644 index 0000000000..0452bbe8bb --- /dev/null +++ b/authentik/sources/oauth/tests/test_api.py @@ -0,0 +1,31 @@ +from django.urls import reverse +from rest_framework.test import APITestCase + +from authentik.core.tests.utils import create_test_admin_user +from authentik.lib.generators import generate_id +from authentik.sources.oauth.models import OAuthSource + + +class TestOAuthSourceAPI(APITestCase): + def setUp(self): + self.source = OAuthSource.objects.create( + name=generate_id(), + slug=generate_id(), + provider_type="openidconnect", + authorization_url="", + profile_url="", + consumer_key=generate_id(), + ) + self.user = create_test_admin_user() + + def test_patch_no_type(self): + self.client.force_login(self.user) + res = self.client.patch( + reverse("authentik_api:oauthsource-detail", kwargs={"slug": self.source.slug}), + { + "authorization_url": f"https://{generate_id()}", + "profile_url": f"https://{generate_id()}", + "access_token_url": f"https://{generate_id()}", + }, + ) + self.assertEqual(res.status_code, 200)