providers/oauth2: clip device authorization scope against the provider's ScopeMapping set (cherry-pick #21701 to version-2025.12) (#21798)

Cherry-pick #21701 to version-2025.12 (with conflicts)

This cherry-pick has conflicts that need manual resolution.

Original PR: #21701
Original commit: cce646b132

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

# Conflicts:
#	authentik/providers/oauth2/tests/test_device_backchannel.py
#	authentik/providers/oauth2/views/device_backchannel.py

Co-authored-by: Sai Asish Y <say.apm35@gmail.com>
This commit is contained in:
authentik-automation[bot]
2026-04-23 18:26:55 +02:00
committed by GitHub
parent ad02dc6b92
commit 71d2a4a5dd
2 changed files with 73 additions and 4 deletions

View File

@@ -5,10 +5,11 @@ from json import loads
from django.urls import reverse
from authentik.blueprints.tests import apply_blueprint
from authentik.core.models import Application
from authentik.core.tests.utils import create_test_flow
from authentik.lib.generators import generate_id
from authentik.providers.oauth2.models import OAuth2Provider
from authentik.providers.oauth2.models import DeviceToken, OAuth2Provider, ScopeMapping
from authentik.providers.oauth2.tests.utils import OAuthTestCase
@@ -96,3 +97,57 @@ class TesOAuth2DeviceBackchannel(OAuthTestCase):
self.assertEqual(res.status_code, 200)
body = loads(res.content.decode())
self.assertEqual(body["expires_in"], 60)
@apply_blueprint("system/providers-oauth2.yaml")
def test_backchannel_scopes(self):
"""Test backchannel"""
self.provider.property_mappings.set(
ScopeMapping.objects.filter(
managed__in=[
"goauthentik.io/providers/oauth2/scope-openid",
"goauthentik.io/providers/oauth2/scope-email",
"goauthentik.io/providers/oauth2/scope-profile",
]
)
)
creds = b64encode(f"{self.provider.client_id}:".encode()).decode()
res = self.client.post(
reverse("authentik_providers_oauth2:device"),
HTTP_AUTHORIZATION=f"Basic {creds}",
data={"scope": "openid email"},
)
self.assertEqual(res.status_code, 200)
body = loads(res.content.decode())
self.assertEqual(body["expires_in"], 60)
token = DeviceToken.objects.filter(device_code=body["device_code"]).first()
self.assertIsNotNone(token)
self.assertEqual(len(token.scope), 2)
self.assertIn("openid", token.scope)
self.assertIn("email", token.scope)
@apply_blueprint("system/providers-oauth2.yaml")
def test_backchannel_scopes_extra(self):
"""Test backchannel"""
self.provider.property_mappings.set(
ScopeMapping.objects.filter(
managed__in=[
"goauthentik.io/providers/oauth2/scope-openid",
"goauthentik.io/providers/oauth2/scope-email",
"goauthentik.io/providers/oauth2/scope-profile",
]
)
)
creds = b64encode(f"{self.provider.client_id}:".encode()).decode()
res = self.client.post(
reverse("authentik_providers_oauth2:device"),
HTTP_AUTHORIZATION=f"Basic {creds}",
data={"scope": "openid email foo"},
)
self.assertEqual(res.status_code, 200)
body = loads(res.content.decode())
self.assertEqual(body["expires_in"], 60)
token = DeviceToken.objects.filter(device_code=body["device_code"]).first()
self.assertIsNotNone(token)
self.assertEqual(len(token.scope), 2)
self.assertIn("openid", token.scope)
self.assertIn("email", token.scope)

View File

@@ -15,7 +15,7 @@ from authentik.core.models import Application
from authentik.lib.config import CONFIG
from authentik.lib.utils.time import timedelta_from_string
from authentik.providers.oauth2.errors import DeviceCodeError
from authentik.providers.oauth2.models import DeviceToken, OAuth2Provider
from authentik.providers.oauth2.models import DeviceToken, OAuth2Provider, ScopeMapping
from authentik.providers.oauth2.utils import TokenResponse, extract_client_auth
from authentik.providers.oauth2.views.device_init import QS_KEY_CODE
@@ -28,7 +28,7 @@ class DeviceView(View):
client_id: str
provider: OAuth2Provider
scopes: list[str] = []
scopes: set[str] = []
def parse_request(self):
"""Parse incoming request"""
@@ -44,7 +44,21 @@ class DeviceView(View):
raise DeviceCodeError("invalid_client") from None
self.provider = provider
self.client_id = client_id
self.scopes = self.request.POST.get("scope", "").split(" ")
scopes_to_check = set(self.request.POST.get("scope", "").split())
default_scope_names = set(
ScopeMapping.objects.filter(provider__in=[self.provider]).values_list(
"scope_name", flat=True
)
)
self.scopes = scopes_to_check
if not scopes_to_check.issubset(default_scope_names):
LOGGER.info(
"Application requested scopes not configured, setting to overlap",
scope_allowed=default_scope_names,
scope_given=self.scopes,
)
self.scopes = self.scopes.intersection(default_scope_names)
def dispatch(self, request: HttpRequest, *args, **kwargs) -> HttpResponse:
throttle = AnonRateThrottle()