Compare commits

..

15 Commits

Author SHA1 Message Date
Tana M Berry
b5e38dcacc tweak 2025-07-07 11:57:27 -05:00
Tana M Berry
aca8c883db added punctuation 2025-07-07 11:43:16 -05:00
Tana M Berry
47a54fedd0 tweak 2025-07-04 18:48:19 -05:00
Tana M Berry
bc00e7284b tweak to bump 2025-07-04 16:35:47 -05:00
Tana M Berry
207d3557e6 fixed image 2025-07-03 21:48:34 -05:00
Tana M Berry
1beea91bbf fixed link 2025-07-03 18:42:12 -05:00
Tana M Berry
02bee093b7 all edits 2025-07-03 18:36:47 -05:00
Tana M Berry
17c957b94d tweak to bump build 2025-07-02 18:08:42 -05:00
Tana M Berry
fb2450169c dom and dewi edits 2025-07-02 18:03:17 -05:00
Tana M Berry
d8eb2bd016 tweaks 2025-07-02 18:03:17 -05:00
authentik-automation[bot]
e8f56df048 Optimised images with calibre/image-actions 2025-07-02 18:03:17 -05:00
Tana M Berry
5a6c13e991 fix image link 2025-07-02 18:03:17 -05:00
Tana M Berry
b28af354a2 major surgery 2025-07-02 18:03:17 -05:00
Tana M Berry
be9572b12b more content 2025-07-02 18:03:17 -05:00
Tana M Berry
b6d1c055cb tweak 2025-07-02 18:03:17 -05:00
1158 changed files with 19728 additions and 31096 deletions

View File

@@ -31,4 +31,4 @@ If changes to the frontend have been made
If applicable
- [ ] The documentation has been updated
- [ ] The documentation has been formatted (`make docs`)
- [ ] The documentation has been formatted (`make website`)

View File

@@ -27,8 +27,8 @@ jobs:
- name: Publish package
working-directory: gen-ts-api/
run: |
npm i
npm publish --tag generated
npm ci
npm publish
env:
NODE_AUTH_TOKEN: ${{ secrets.NPM_PUBLISH_TOKEN }}
- name: Upgrade /web

View File

@@ -62,7 +62,7 @@ RUN --mount=type=cache,sharing=locked,target=/go/pkg/mod \
go build -o /go/authentik ./cmd/server
# Stage 3: MaxMind GeoIP
FROM --platform=${BUILDPLATFORM} ghcr.io/maxmind/geoipupdate:v7.1.1 AS geoip
FROM --platform=${BUILDPLATFORM} ghcr.io/maxmind/geoipupdate:v7.1.0 AS geoip
ENV GEOIPUPDATE_EDITION_IDS="GeoLite2-City GeoLite2-ASN"
ENV GEOIPUPDATE_VERBOSE="1"
@@ -75,7 +75,7 @@ RUN --mount=type=secret,id=GEOIPUPDATE_ACCOUNT_ID \
/bin/sh -c "GEOIPUPDATE_LICENSE_KEY_FILE=/run/secrets/GEOIPUPDATE_LICENSE_KEY /usr/bin/entry.sh || echo 'Failed to get GeoIP database, disabling'; exit 0"
# Stage 4: Download uv
FROM ghcr.io/astral-sh/uv:0.8.0 AS uv
FROM ghcr.io/astral-sh/uv:0.7.18 AS uv
# Stage 5: Base python image
FROM ghcr.io/goauthentik/fips-python:3.13.5-slim-bookworm-fips AS python-base

View File

@@ -1,4 +1,4 @@
.PHONY: gen dev-reset all clean test web docs
.PHONY: gen dev-reset all clean test web website
SHELL := /usr/bin/env bash
.SHELLFLAGS += ${SHELLFLAGS} -e -o pipefail
@@ -73,7 +73,7 @@ core-i18n-extract:
--ignore website \
-l en
install: node-install docs-install core-install ## Install all requires dependencies for `node`, `docs` and `core`
install: web-install website-install core-install ## Install all requires dependencies for `web`, `website` and `core`
dev-drop-db:
dropdb -U ${pg_user} -h ${pg_host} ${pg_name}
@@ -183,23 +183,18 @@ gen-dev-config: ## Generate a local development config file
gen: gen-build gen-client-ts
#########################
## Node.js
#########################
node-install: ## Install the necessary libraries to build Node.js packages
npm ci
npm ci --prefix web
#########################
## Web
#########################
web-build: node-install ## Build the Authentik UI
web-build: web-install ## Build the Authentik UI
cd web && npm run build
web: web-lint-fix web-lint web-check-compile ## Automatically fix formatting issues in the Authentik UI source code, lint the code, and compile it
web-install: ## Install the necessary libraries to build the Authentik UI
cd web && npm ci
web-test: ## Run tests for the Authentik UI
cd web && npm run test
@@ -226,28 +221,22 @@ web-i18n-extract:
cd web && npm run extract-locales
#########################
## Docs
## Website
#########################
docs: docs-lint-fix docs-build ## Automatically fix formatting issues in the Authentik docs source code, lint the code, and compile it
website: website-lint-fix website-build ## Automatically fix formatting issues in the Authentik website/docs source code, lint the code, and compile it
docs-install:
npm ci --prefix website
website-install:
cd website && npm ci
docs-lint-fix: lint-codespell
npm run prettier --prefix website
website-lint-fix: lint-codespell
cd website && npm run prettier
docs-build:
npm run build --prefix website
website-build:
cd website && npm run build
docs-watch: ## Build and watch the topics documentation
npm run start --prefix website
docs-integrations-build:
npm run build --prefix website -w integrations
docs-integrations-watch: ## Build and watch the Integrations documentation
npm run start --prefix website -w integrations
website-watch: ## Build and watch the documentation website, updating automatically
cd website && npm run watch
#########################
## Docker

View File

@@ -52,27 +52,6 @@ class TestBrands(APITestCase):
},
)
def test_brand_subdomain_same_suffix(self):
"""Test Current brand API"""
Brand.objects.all().delete()
Brand.objects.create(domain="bar.baz", branding_title="custom")
Brand.objects.create(domain="foo.bar.baz", branding_title="custom")
self.assertJSONEqual(
self.client.get(
reverse("authentik_api:brand-current"), HTTP_HOST="foo.bar.baz"
).content.decode(),
{
"branding_logo": "/static/dist/assets/icons/icon_left_brand.svg",
"branding_favicon": "/static/dist/assets/icons/icon.png",
"branding_title": "custom",
"branding_custom_css": "",
"matched_domain": "foo.bar.baz",
"ui_footer_links": [],
"ui_theme": Themes.AUTOMATIC,
"default_locale": "",
},
)
def test_fallback(self):
"""Test fallback brand"""
Brand.objects.all().delete()

View File

@@ -4,7 +4,6 @@ from typing import Any
from django.db.models import F, Q
from django.db.models import Value as V
from django.db.models.functions import Length
from django.http.request import HttpRequest
from django.utils.html import _json_script_escapes
from django.utils.safestring import mark_safe
@@ -21,9 +20,9 @@ DEFAULT_BRAND = Brand(domain="fallback")
def get_brand_for_request(request: HttpRequest) -> Brand:
"""Get brand object for current request"""
db_brands = (
Brand.objects.annotate(host_domain=V(request.get_host()), match_length=Length("domain"))
Brand.objects.annotate(host_domain=V(request.get_host()))
.filter(Q(host_domain__iendswith=F("domain")) | _q_default)
.order_by("-match_length", "default")
.order_by("default")
)
brands = list(db_brands.all())
if len(brands) < 1:

View File

@@ -149,10 +149,10 @@ class ApplicationViewSet(UsedByMixin, ModelViewSet):
return applications
def _filter_applications_with_launch_url(
self, paginated_apps: Iterator[Application]
self, pagined_apps: Iterator[Application]
) -> list[Application]:
applications = []
for app in paginated_apps:
for app in pagined_apps:
if app.get_launch_url():
applications.append(app)
return applications

View File

@@ -11,6 +11,7 @@ from authentik.core.expression.exceptions import SkipObjectException
from authentik.core.models import User
from authentik.events.models import Event, EventAction
from authentik.lib.expression.evaluator import BaseEvaluator
from authentik.lib.utils.errors import exception_to_string
from authentik.policies.types import PolicyRequest
PROPERTY_MAPPING_TIME = Histogram(
@@ -68,11 +69,12 @@ class PropertyMappingEvaluator(BaseEvaluator):
# For dry-run requests we don't save exceptions
if self.dry_run:
return
error_string = exception_to_string(exc)
event = Event.new(
EventAction.PROPERTY_MAPPING_EXCEPTION,
expression=expression_source,
message="Failed to execute property mapping",
).with_exception(exc)
message=error_string,
)
if "request" in self._context:
req: PolicyRequest = self._context["request"]
if req.http_request:

View File

@@ -1,24 +0,0 @@
# Generated by Django 5.1.11 on 2025-07-03 13:08
from django.db import migrations
class Migration(migrations.Migration):
dependencies = [
("authentik_core", "0048_delete_oldauthenticatedsession_content_type"),
]
operations = [
migrations.AlterModelOptions(
name="token",
options={
"permissions": [
("view_token_key", "View token's key"),
("set_token_key", "Set a token's key"),
],
"verbose_name": "Token",
"verbose_name_plural": "Tokens",
},
),
]

View File

@@ -953,10 +953,7 @@ class Token(SerializerModel, ManagedModel, ExpiringModel):
models.Index(fields=["identifier"]),
models.Index(fields=["key"]),
]
permissions = [
("view_token_key", _("View token's key")),
("set_token_key", _("Set a token's key")),
]
permissions = [("view_token_key", _("View token's key"))]
def __str__(self):
description = f"{self.identifier}"

View File

@@ -12,7 +12,6 @@ from rest_framework.fields import CharField, IntegerField
from rest_framework.permissions import IsAuthenticated
from rest_framework.request import Request
from rest_framework.response import Response
from rest_framework.validators import UniqueValidator
from rest_framework.viewsets import ModelViewSet
from authentik.core.api.used_by import UsedByMixin
@@ -54,7 +53,6 @@ class LicenseSerializer(ModelSerializer):
"external_users",
]
extra_kwargs = {
"key": {"validators": [UniqueValidator(queryset=License.objects.all())]},
"name": {"read_only": True},
"expiry": {"read_only": True},
"internal_users": {"read_only": True},

View File

@@ -65,17 +65,13 @@ class EnterpriseAuditMiddleware(AuditMiddleware):
data[field.name] = deepcopy(field_value)
return cleanse_dict(data)
def diff(self, before: dict, after: dict, update_fields: list[str] | None = None) -> dict:
def diff(self, before: dict, after: dict) -> dict:
"""Generate diff between dicts"""
diff = {}
for key, value in before.items():
if update_fields and key not in update_fields:
continue
if after.get(key) != value:
diff[key] = {"previous_value": value, "new_value": after.get(key)}
for key, value in after.items():
if update_fields and key not in update_fields:
continue
if key not in before and key not in diff and before.get(key) != value:
diff[key] = {"previous_value": before.get(key), "new_value": value}
return sanitize_item(diff)
@@ -99,7 +95,6 @@ class EnterpriseAuditMiddleware(AuditMiddleware):
instance: Model,
created: bool,
thread_kwargs: dict | None = None,
update_fields: list[str] | None = None,
**_,
):
if not self.enabled:
@@ -113,7 +108,7 @@ class EnterpriseAuditMiddleware(AuditMiddleware):
prev_state = {}
# Get current state
new_state = self.serialize_simple(instance)
diff = self.diff(prev_state, new_state, update_fields)
diff = self.diff(prev_state, new_state)
thread_kwargs["diff"] = diff
return super().post_save_handler(request, sender, instance, created, thread_kwargs, **_)

View File

@@ -7,7 +7,6 @@ from rest_framework.test import APITestCase
from authentik.core.models import Group, User
from authentik.core.tests.utils import create_test_admin_user
from authentik.enterprise.audit.middleware import EnterpriseAuditMiddleware
from authentik.events.models import Event, EventAction
from authentik.events.utils import sanitize_item
from authentik.lib.generators import generate_id
@@ -209,23 +208,3 @@ class TestEnterpriseAudit(APITestCase):
diff,
{"users": {"remove": [user.pk]}},
)
@patch(
"authentik.enterprise.audit.middleware.EnterpriseAuditMiddleware.enabled",
PropertyMock(return_value=True),
)
def test_diff_update_fields(self):
"""Test update audit log"""
self.client.force_login(self.user)
diff = EnterpriseAuditMiddleware(None).diff(
{
"foo": "bar",
"is_active": False,
},
{
"foo": "baz",
"is_active": True,
},
update_fields=["is_active"],
)
self.assertEqual(diff, {"is_active": {"new_value": True, "previous_value": False}})

View File

@@ -6,7 +6,7 @@ from djangoql.ast import Name
from djangoql.exceptions import DjangoQLError
from djangoql.queryset import apply_search
from djangoql.schema import DjangoQLSchema
from rest_framework.filters import SearchFilter
from rest_framework.filters import BaseFilterBackend, SearchFilter
from rest_framework.request import Request
from structlog.stdlib import get_logger
@@ -39,7 +39,7 @@ class BaseSchema(DjangoQLSchema):
return super().resolve_name(name)
class QLSearch(SearchFilter):
class QLSearch(BaseFilterBackend):
"""rest_framework search filter which uses DjangoQL"""
def __init__(self):

View File

@@ -16,7 +16,7 @@ from authentik.stages.authenticator.models import Device
class AuthenticatorEndpointGDTCStage(ConfigurableStage, FriendlyNamedStage, Stage):
"""Setup Google Chrome Device Trust connection"""
"""Setup Google Chrome Device-trust connection"""
credentials = models.JSONField()

View File

@@ -17,7 +17,6 @@ from authentik.enterprise.stages.authenticator_endpoint_gdtc.models import (
from authentik.flows.planner import PLAN_CONTEXT_PENDING_USER, FlowPlan
from authentik.flows.views.executor import SESSION_KEY_PLAN
from authentik.stages.password.stage import PLAN_CONTEXT_METHOD, PLAN_CONTEXT_METHOD_ARGS
from authentik.stages.user_login.stage import PLAN_CONTEXT_METHOD_ARGS_KNOWN_DEVICE
# Header we get from chrome that initiates verified access
HEADER_DEVICE_TRUST = "X-Device-Trust"
@@ -28,8 +27,6 @@ HEADER_ACCESS_CHALLENGE_RESPONSE = "X-Verified-Access-Challenge-Response"
# Header value for x-device-trust that initiates the flow
DEVICE_TRUST_VERIFIED_ACCESS = "VerifiedAccess"
PLAN_CONTEXT_METHOD_ARGS_ENDPOINTS = "endpoints"
@method_decorator(xframe_options_sameorigin, name="dispatch")
class GoogleChromeDeviceTrustConnector(View):
@@ -84,14 +81,7 @@ class GoogleChromeDeviceTrustConnector(View):
)
flow_plan.context.setdefault(PLAN_CONTEXT_METHOD, "trusted_endpoint")
flow_plan.context.setdefault(PLAN_CONTEXT_METHOD_ARGS, {})
flow_plan.context[PLAN_CONTEXT_METHOD_ARGS].setdefault(
PLAN_CONTEXT_METHOD_ARGS_ENDPOINTS, []
)
flow_plan.context[PLAN_CONTEXT_METHOD_ARGS][PLAN_CONTEXT_METHOD_ARGS_ENDPOINTS].append(
response
)
flow_plan.context[PLAN_CONTEXT_METHOD_ARGS].setdefault(
PLAN_CONTEXT_METHOD_ARGS_KNOWN_DEVICE, True
)
flow_plan.context[PLAN_CONTEXT_METHOD_ARGS].setdefault("endpoints", [])
flow_plan.context[PLAN_CONTEXT_METHOD_ARGS]["endpoints"].append(response)
request.session[SESSION_KEY_PLAN] = flow_plan
return TemplateResponse(request, "stages/authenticator_endpoint/google_chrome_dtc.html")

View File

@@ -20,7 +20,7 @@ from authentik.core.models import Group, User
from authentik.events.models import Event, EventAction, Notification
from authentik.events.utils import model_to_dict
from authentik.lib.sentry import should_ignore_exception
from authentik.lib.utils.errors import exception_to_dict
from authentik.lib.utils.errors import exception_to_string
from authentik.stages.authenticator_static.models import StaticToken
IGNORED_MODELS = tuple(
@@ -170,16 +170,14 @@ class AuditMiddleware:
thread = EventNewThread(
EventAction.SUSPICIOUS_REQUEST,
request,
message=str(exception),
exception=exception_to_dict(exception),
message=exception_to_string(exception),
)
thread.run()
elif not should_ignore_exception(exception):
thread = EventNewThread(
EventAction.SYSTEM_EXCEPTION,
request,
message=str(exception),
exception=exception_to_dict(exception),
message=exception_to_string(exception),
)
thread.run()

View File

@@ -38,7 +38,6 @@ from authentik.events.utils import (
)
from authentik.lib.models import DomainlessURLValidator, SerializerModel
from authentik.lib.sentry import SentryIgnoredException
from authentik.lib.utils.errors import exception_to_dict
from authentik.lib.utils.http import get_http_session
from authentik.lib.utils.time import timedelta_from_string
from authentik.policies.models import PolicyBindingModel
@@ -164,12 +163,6 @@ class Event(SerializerModel, ExpiringModel):
event = Event(action=action, app=app, context=cleaned_kwargs)
return event
def with_exception(self, exc: Exception) -> "Event":
"""Add data from 'exc' to the event in a database-saveable format"""
self.context.setdefault("message", str(exc))
self.context["exception"] = exception_to_dict(exc)
return self
def set_user(self, user: User) -> "Event":
"""Set `.user` based on user, ensuring the correct attributes are copied.
This should only be used when self.from_http is *not* used."""

View File

@@ -127,8 +127,8 @@ class SystemTask(TenantTask):
)
Event.new(
EventAction.SYSTEM_TASK_EXCEPTION,
message=f"Task {self.__name__} encountered an error",
).with_exception(exc).save()
message=f"Task {self.__name__} encountered an error: {exception_to_string(exc)}",
).save()
def run(self, *args, **kwargs):
raise NotImplementedError

View File

@@ -62,7 +62,6 @@ def event_trigger_handler(event_uuid: str, trigger_name: str):
policy_engine.mode = PolicyEngineMode.MODE_ANY
policy_engine.empty_result = False
policy_engine.use_cache = False
policy_engine.request.obj = event
policy_engine.request.context["event"] = event
policy_engine.build()
result = policy_engine.result

View File

@@ -56,6 +56,7 @@ from authentik.flows.planner import (
)
from authentik.flows.stage import AccessDeniedStage, StageView
from authentik.lib.sentry import SentryIgnoredException, should_ignore_exception
from authentik.lib.utils.errors import exception_to_string
from authentik.lib.utils.reflection import all_subclasses, class_to_path
from authentik.lib.utils.urls import is_url_absolute, redirect_with_qs
from authentik.policies.engine import PolicyEngine
@@ -238,8 +239,8 @@ class FlowExecutorView(APIView):
capture_exception(exc)
Event.new(
action=EventAction.SYSTEM_EXCEPTION,
message="System exception during flow execution.",
).with_exception(exc).from_http(self.request)
message=exception_to_string(exc),
).from_http(self.request)
challenge = FlowErrorChallenge(self.request, exc)
challenge.is_valid(raise_exception=True)
return to_stage_response(self.request, HttpChallengeResponse(challenge))

View File

@@ -14,6 +14,7 @@ from authentik.events.models import Event, EventAction
from authentik.lib.expression.exceptions import ControlFlowException
from authentik.lib.sync.mapper import PropertyMappingManager
from authentik.lib.sync.outgoing.exceptions import NotFoundSyncException, StopSync
from authentik.lib.utils.errors import exception_to_string
if TYPE_CHECKING:
from django.db.models import Model
@@ -105,9 +106,9 @@ class BaseOutgoingSyncClient[
# Value error can be raised when assigning invalid data to an attribute
Event.new(
EventAction.CONFIGURATION_ERROR,
message="Failed to evaluate property-mapping",
message=f"Failed to evaluate property-mapping {exception_to_string(exc)}",
mapping=exc.mapping,
).with_exception(exc).save()
).save()
raise StopSync(exc, obj, exc.mapping) from exc
if not raw_final_object:
raise StopSync(ValueError("No mappings configured"), obj)

View File

@@ -2,8 +2,6 @@
from traceback import extract_tb
from structlog.tracebacks import ExceptionDictTransformer
from authentik.lib.utils.reflection import class_to_path
TRACEBACK_HEADER = "Traceback (most recent call last):"
@@ -19,8 +17,3 @@ def exception_to_string(exc: Exception) -> str:
f"{class_to_path(exc.__class__)}: {str(exc)}",
]
)
def exception_to_dict(exc: Exception) -> dict:
"""Format exception as a dictionary"""
return ExceptionDictTransformer()((type(exc), exc, exc.__traceback__))

View File

@@ -35,6 +35,7 @@ from authentik.events.models import Event, EventAction
from authentik.lib.config import CONFIG
from authentik.lib.models import InheritanceForeignKey, SerializerModel
from authentik.lib.sentry import SentryIgnoredException
from authentik.lib.utils.errors import exception_to_string
from authentik.outposts.controllers.k8s.utils import get_namespace
OUR_VERSION = parse(__version__)
@@ -325,8 +326,9 @@ class Outpost(SerializerModel, ManagedModel):
"While setting the permissions for the service-account, a "
"permission was not found: Check "
"https://goauthentik.io/docs/troubleshooting/missing_permission"
),
).with_exception(exc).set_user(user).save()
)
+ exception_to_string(exc),
).set_user(user).save()
else:
app_label, perm = model_or_perm.split(".")
permission = Permission.objects.filter(

View File

@@ -1,11 +1,11 @@
"""authentik policy engine"""
from collections.abc import Iterable
from collections.abc import Iterator
from multiprocessing import Pipe, current_process
from multiprocessing.connection import Connection
from time import perf_counter
from django.core.cache import cache
from django.db.models import Count, Q, QuerySet
from django.http import HttpRequest
from sentry_sdk import start_span
from sentry_sdk.tracing import Span
@@ -67,11 +67,14 @@ class PolicyEngine:
self.__processes: list[PolicyProcessInfo] = []
self.use_cache = True
self.__expected_result_count = 0
self.__static_result: PolicyResult | None = None
def bindings(self) -> QuerySet[PolicyBinding] | Iterable[PolicyBinding]:
def iterate_bindings(self) -> Iterator[PolicyBinding]:
"""Make sure all Policies are their respective classes"""
return PolicyBinding.objects.filter(target=self.__pbm, enabled=True).order_by("order")
return (
PolicyBinding.objects.filter(target=self.__pbm, enabled=True)
.order_by("order")
.iterator()
)
def _check_policy_type(self, binding: PolicyBinding):
"""Check policy type, make sure it's not the root class as that has no logic implemented"""
@@ -81,66 +84,30 @@ class PolicyEngine:
def _check_cache(self, binding: PolicyBinding):
if not self.use_cache:
return False
# It's a bit silly to time this, but
with HIST_POLICIES_EXECUTION_TIME.labels(
binding_order=binding.order,
binding_target_type=binding.target_type,
binding_target_name=binding.target_name,
object_pk=str(self.request.obj.pk),
object_type=class_to_path(self.request.obj.__class__),
mode="cache_retrieve",
).time():
key = cache_key(binding, self.request)
cached_policy = cache.get(key, None)
if not cached_policy:
return False
before = perf_counter()
key = cache_key(binding, self.request)
cached_policy = cache.get(key, None)
duration = max(perf_counter() - before, 0)
if not cached_policy:
return False
self.logger.debug(
"P_ENG: Taking result from cache",
binding=binding,
cache_key=key,
request=self.request,
)
HIST_POLICIES_EXECUTION_TIME.labels(
binding_order=binding.order,
binding_target_type=binding.target_type,
binding_target_name=binding.target_name,
object_pk=str(self.request.obj.pk),
object_type=class_to_path(self.request.obj.__class__),
mode="cache_retrieve",
).observe(duration)
# It's a bit silly to time this, but
self.__cached_policies.append(cached_policy)
return True
def compute_static_bindings(self, bindings: QuerySet[PolicyBinding]):
"""Check static bindings if possible"""
aggrs = {
"total": Count(
"pk", filter=Q(Q(group__isnull=False) | Q(user__isnull=False), policy=None)
),
}
if self.request.user.pk:
all_groups = self.request.user.all_groups()
aggrs["passing"] = Count(
"pk",
filter=Q(
Q(
Q(user=self.request.user) | Q(group__in=all_groups),
negate=False,
)
| Q(
Q(~Q(user=self.request.user), user__isnull=False)
| Q(~Q(group__in=all_groups), group__isnull=False),
negate=True,
),
enabled=True,
),
)
matched_bindings = bindings.aggregate(**aggrs)
passing = False
if matched_bindings["total"] == 0 and matched_bindings.get("passing", 0) == 0:
# If we didn't find any static bindings, do nothing
return
self.logger.debug("P_ENG: Found static bindings", **matched_bindings)
if matched_bindings.get("passing", 0) > 0:
# Any passing static binding -> passing
passing = True
elif matched_bindings["total"] > 0 and matched_bindings.get("passing", 0) < 1:
# No matching static bindings but at least one is configured -> not passing
passing = False
self.__static_result = PolicyResult(passing)
def build(self) -> "PolicyEngine":
"""Build wrapper which monitors performance"""
with (
@@ -156,12 +123,7 @@ class PolicyEngine:
span: Span
span.set_data("pbm", self.__pbm)
span.set_data("request", self.request)
bindings = self.bindings()
policy_bindings = bindings
if isinstance(bindings, QuerySet):
self.compute_static_bindings(bindings)
policy_bindings = [x for x in bindings if x.policy]
for binding in policy_bindings:
for binding in self.iterate_bindings():
self.__expected_result_count += 1
self._check_policy_type(binding)
@@ -191,13 +153,10 @@ class PolicyEngine:
@property
def result(self) -> PolicyResult:
"""Get policy-checking result"""
self.__processes.sort(key=lambda x: x.binding.order)
process_results: list[PolicyResult] = [x.result for x in self.__processes if x.result]
all_results = list(process_results + self.__cached_policies)
if len(all_results) < self.__expected_result_count: # pragma: no cover
raise AssertionError("Got less results than polices")
if self.__static_result:
all_results.append(self.__static_result)
# No results, no policies attached -> passing
if len(all_results) == 0:
return PolicyResult(self.empty_result)

View File

@@ -6,7 +6,6 @@ from typing import TYPE_CHECKING, Optional
from django.http import HttpRequest
from structlog.stdlib import get_logger
from authentik.events.models import Event
from authentik.flows.planner import PLAN_CONTEXT_SSO
from authentik.lib.expression.evaluator import BaseEvaluator
from authentik.policies.exceptions import PolicyException
@@ -46,10 +45,6 @@ class PolicyEvaluator(BaseEvaluator):
self.set_http_request(request.http_request)
self._context["request"] = request
self._context["context"] = request.context
if request.obj and isinstance(request.obj, Event):
self._context["ak_client_ip"] = ip_address(
request.obj.client_ip or ClientIPMiddleware.default_ip
)
def set_http_request(self, request: HttpRequest):
"""Update context based on http request"""

View File

@@ -10,7 +10,7 @@ from structlog.stdlib import get_logger
from authentik.events.models import Event, EventAction
from authentik.lib.config import CONFIG
from authentik.lib.utils.errors import exception_to_dict
from authentik.lib.utils.errors import exception_to_string
from authentik.lib.utils.reflection import class_to_path
from authentik.policies.apps import HIST_POLICIES_EXECUTION_TIME
from authentik.policies.exceptions import PolicyException
@@ -95,13 +95,10 @@ class PolicyProcess(PROCESS_CLASS):
except PolicyException as exc:
# Either use passed original exception or whatever we have
src_exc = exc.src_exc if exc.src_exc else exc
error_string = exception_to_string(src_exc)
# Create policy exception event, only when we're not debugging
if not self.request.debug:
self.create_event(
EventAction.POLICY_EXCEPTION,
message="Policy failed to execute",
exception=exception_to_dict(src_exc),
)
self.create_event(EventAction.POLICY_EXCEPTION, message=error_string)
LOGGER.debug("P_ENG(proc): error, using failure result", exc=src_exc)
policy_result = PolicyResult(self.binding.failure_result, str(src_exc))
policy_result.source_binding = self.binding
@@ -146,5 +143,5 @@ class PolicyProcess(PROCESS_CLASS):
try:
self.connection.send(self.profiling_wrapper())
except Exception as exc:
LOGGER.warning("Policy failed to run", exc=exc)
LOGGER.warning("Policy failed to run", exc=exception_to_string(exc))
self.connection.send(PolicyResult(False, str(exc)))

View File

@@ -1,12 +1,9 @@
"""policy engine tests"""
from django.core.cache import cache
from django.db import connections
from django.test import TestCase
from django.test.utils import CaptureQueriesContext
from authentik.core.models import Group
from authentik.core.tests.utils import create_test_user
from authentik.core.tests.utils import create_test_admin_user
from authentik.lib.generators import generate_id
from authentik.policies.dummy.models import DummyPolicy
from authentik.policies.engine import PolicyEngine
@@ -22,7 +19,7 @@ class TestPolicyEngine(TestCase):
def setUp(self):
clear_policy_cache()
self.user = create_test_user()
self.user = create_test_admin_user()
self.policy_false = DummyPolicy.objects.create(
name=generate_id(), result=False, wait_min=0, wait_max=1
)
@@ -130,58 +127,3 @@ class TestPolicyEngine(TestCase):
self.assertEqual(len(cache.keys(f"{CACHE_PREFIX}{binding.policy_binding_uuid.hex}*")), 1)
self.assertEqual(engine.build().passing, False)
self.assertEqual(len(cache.keys(f"{CACHE_PREFIX}{binding.policy_binding_uuid.hex}*")), 1)
def test_engine_static_bindings(self):
"""Test static bindings"""
group_a = Group.objects.create(name=generate_id())
group_b = Group.objects.create(name=generate_id())
group_b.users.add(self.user)
user = create_test_user()
for case in [
{
"message": "Group, not member",
"binding_args": {"group": group_a},
"passing": False,
},
{
"message": "Group, member",
"binding_args": {"group": group_b},
"passing": True,
},
{
"message": "User, other",
"binding_args": {"user": user},
"passing": False,
},
{
"message": "User, same",
"binding_args": {"user": self.user},
"passing": True,
},
]:
with self.subTest():
pbm = PolicyBindingModel.objects.create()
for x in range(1000):
PolicyBinding.objects.create(target=pbm, order=x, **case["binding_args"])
engine = PolicyEngine(pbm, self.user)
engine.use_cache = False
with CaptureQueriesContext(connections["default"]) as ctx:
engine.build()
self.assertLess(ctx.final_queries, 1000)
self.assertEqual(engine.result.passing, case["passing"])
def test_engine_group_complex(self):
"""Test more complex group setups"""
group_a = Group.objects.create(name=generate_id())
group_b = Group.objects.create(name=generate_id(), parent=group_a)
user = create_test_user()
group_b.users.add(user)
pbm = PolicyBindingModel.objects.create()
PolicyBinding.objects.create(target=pbm, order=0, group=group_a)
engine = PolicyEngine(pbm, user)
engine.use_cache = False
with CaptureQueriesContext(connections["default"]) as ctx:
engine.build()
self.assertLess(ctx.final_queries, 1000)
self.assertTrue(engine.result.passing)

View File

@@ -29,12 +29,13 @@ class TestPolicyProcess(TestCase):
def setUp(self):
clear_policy_cache()
self.factory = RequestFactory()
self.user = User.objects.create_user(username=generate_id())
self.user = User.objects.create_user(username="policyuser")
def test_group_passing(self):
"""Test binding to group"""
group = Group.objects.create(name=generate_id())
group = Group.objects.create(name="test-group")
group.users.add(self.user)
group.save()
binding = PolicyBinding(group=group)
request = PolicyRequest(self.user)
@@ -43,7 +44,8 @@ class TestPolicyProcess(TestCase):
def test_group_negative(self):
"""Test binding to group"""
group = Group.objects.create(name=generate_id())
group = Group.objects.create(name="test-group")
group.save()
binding = PolicyBinding(group=group)
request = PolicyRequest(self.user)
@@ -113,10 +115,8 @@ class TestPolicyProcess(TestCase):
def test_exception(self):
"""Test policy execution"""
policy = Policy.objects.create(name=generate_id())
binding = PolicyBinding(
policy=policy, target=Application.objects.create(name=generate_id())
)
policy = Policy.objects.create(name="test-execution")
binding = PolicyBinding(policy=policy, target=Application.objects.create(name="test"))
request = PolicyRequest(self.user)
response = PolicyProcess(binding, request, None).execute()
@@ -125,15 +125,13 @@ class TestPolicyProcess(TestCase):
def test_execution_logging(self):
"""Test policy execution creates event"""
policy = DummyPolicy.objects.create(
name=generate_id(),
name="test-execution-logging",
result=False,
wait_min=0,
wait_max=1,
execution_logging=True,
)
binding = PolicyBinding(
policy=policy, target=Application.objects.create(name=generate_id())
)
binding = PolicyBinding(policy=policy, target=Application.objects.create(name="test"))
http_request = self.factory.get(reverse("authentik_api:user-impersonate-end"))
http_request.user = self.user
@@ -188,15 +186,13 @@ class TestPolicyProcess(TestCase):
def test_execution_logging_anonymous(self):
"""Test policy execution creates event with anonymous user"""
policy = DummyPolicy.objects.create(
name=generate_id(),
name="test-execution-logging-anon",
result=False,
wait_min=0,
wait_max=1,
execution_logging=True,
)
binding = PolicyBinding(
policy=policy, target=Application.objects.create(name=generate_id())
)
binding = PolicyBinding(policy=policy, target=Application.objects.create(name="test"))
user = AnonymousUser()
@@ -223,9 +219,9 @@ class TestPolicyProcess(TestCase):
def test_raises(self):
"""Test policy that raises error"""
policy_raises = ExpressionPolicy.objects.create(name=generate_id(), expression="{{ 0/0 }}")
policy_raises = ExpressionPolicy.objects.create(name="raises", expression="{{ 0/0 }}")
binding = PolicyBinding(
policy=policy_raises, target=Application.objects.create(name=generate_id())
policy=policy_raises, target=Application.objects.create(name="test")
)
request = PolicyRequest(self.user)
@@ -241,4 +237,4 @@ class TestPolicyProcess(TestCase):
self.assertEqual(len(events), 1)
event = events.first()
self.assertEqual(event.user["username"], self.user.username)
self.assertIn("Policy failed to execute", event.context["message"])
self.assertIn("division by zero", event.context["message"])

View File

@@ -15,14 +15,12 @@ class OAuth2Error(SentryIgnoredException):
error: str
description: str
cause: str | None = None
def create_dict(self, request: HttpRequest):
def create_dict(self):
"""Return error as dict for JSON Rendering"""
return {
"error": self.error,
"error_description": self.description,
"request_id": request.request_id,
}
def __repr__(self) -> str:
@@ -33,15 +31,9 @@ class OAuth2Error(SentryIgnoredException):
return Event.new(
EventAction.CONFIGURATION_ERROR,
message=message or self.description,
cause=self.cause,
error=self.error,
**kwargs,
)
def with_cause(self, cause: str):
self.cause = cause
return self
class RedirectUriError(OAuth2Error):
"""The request fails due to a missing, invalid, or mismatching
@@ -251,14 +243,13 @@ class TokenRevocationError(OAuth2Error):
self.description = self.errors[error]
class DeviceCodeError(TokenError):
class DeviceCodeError(OAuth2Error):
"""
Device-code flow errors
See https://datatracker.ietf.org/doc/html/rfc8628#section-3.2
Can also use codes form TokenError
"""
errors = TokenError.errors | {
errors = {
"authorization_pending": (
"The authorization request is still pending as the end user hasn't "
"yet completed the user-interaction steps"
@@ -270,15 +261,10 @@ class DeviceCodeError(TokenError):
"authorization request but SHOULD wait for user interaction before "
"restarting to avoid unnecessary polling."
),
"slow_down": (
'A variant of "authorization_pending", the authorization request is'
"still pending and polling should continue, but the interval MUST"
"be increased by 5 seconds for this and all subsequent requests."
),
}
def __init__(self, error: str):
super().__init__(error)
super().__init__()
self.error = error
self.description = self.errors[error]

View File

@@ -12,7 +12,7 @@ from authentik.core.tests.utils import create_test_admin_user, create_test_flow
from authentik.events.models import Event, EventAction
from authentik.lib.generators import generate_id
from authentik.lib.utils.time import timedelta_from_string
from authentik.providers.oauth2.constants import SCOPE_OFFLINE_ACCESS, SCOPE_OPENID, TOKEN_TYPE
from authentik.providers.oauth2.constants import TOKEN_TYPE
from authentik.providers.oauth2.errors import AuthorizeError, ClientIdError, RedirectUriError
from authentik.providers.oauth2.models import (
AccessToken,
@@ -43,7 +43,7 @@ class TestAuthorize(OAuthTestCase):
authorization_flow=create_test_flow(),
redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://local.invalid/Foo")],
)
with self.assertRaises(AuthorizeError) as cm:
with self.assertRaises(AuthorizeError):
request = self.factory.get(
"/",
data={
@@ -53,7 +53,6 @@ class TestAuthorize(OAuthTestCase):
},
)
OAuthAuthorizationParams.from_request(request)
self.assertEqual(cm.exception.error, "unsupported_response_type")
def test_invalid_client_id(self):
"""Test invalid client ID"""
@@ -69,7 +68,7 @@ class TestAuthorize(OAuthTestCase):
authorization_flow=create_test_flow(),
redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://local.invalid/Foo")],
)
with self.assertRaises(AuthorizeError) as cm:
with self.assertRaises(AuthorizeError):
request = self.factory.get(
"/",
data={
@@ -80,30 +79,19 @@ class TestAuthorize(OAuthTestCase):
},
)
OAuthAuthorizationParams.from_request(request)
self.assertEqual(cm.exception.error, "request_not_supported")
def test_invalid_redirect_uri_missing(self):
"""test missing redirect URI"""
OAuth2Provider.objects.create(
name=generate_id(),
client_id="test",
authorization_flow=create_test_flow(),
redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://local.invalid")],
)
with self.assertRaises(RedirectUriError) as cm:
request = self.factory.get("/", data={"response_type": "code", "client_id": "test"})
OAuthAuthorizationParams.from_request(request)
self.assertEqual(cm.exception.cause, "redirect_uri_missing")
def test_invalid_redirect_uri(self):
"""test invalid redirect URI"""
"""test missing/invalid redirect URI"""
OAuth2Provider.objects.create(
name=generate_id(),
client_id="test",
authorization_flow=create_test_flow(),
redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://local.invalid")],
)
with self.assertRaises(RedirectUriError) as cm:
with self.assertRaises(RedirectUriError):
request = self.factory.get("/", data={"response_type": "code", "client_id": "test"})
OAuthAuthorizationParams.from_request(request)
with self.assertRaises(RedirectUriError):
request = self.factory.get(
"/",
data={
@@ -113,7 +101,6 @@ class TestAuthorize(OAuthTestCase):
},
)
OAuthAuthorizationParams.from_request(request)
self.assertEqual(cm.exception.cause, "redirect_uri_no_match")
def test_blocked_redirect_uri(self):
"""test missing/invalid redirect URI"""
@@ -121,9 +108,9 @@ class TestAuthorize(OAuthTestCase):
name=generate_id(),
client_id="test",
authorization_flow=create_test_flow(),
redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "data:localhost")],
redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "data:local.invalid")],
)
with self.assertRaises(RedirectUriError) as cm:
with self.assertRaises(RedirectUriError):
request = self.factory.get(
"/",
data={
@@ -133,7 +120,6 @@ class TestAuthorize(OAuthTestCase):
},
)
OAuthAuthorizationParams.from_request(request)
self.assertEqual(cm.exception.cause, "redirect_uri_forbidden_scheme")
def test_invalid_redirect_uri_empty(self):
"""test missing/invalid redirect URI"""
@@ -143,6 +129,9 @@ class TestAuthorize(OAuthTestCase):
authorization_flow=create_test_flow(),
redirect_uris=[],
)
with self.assertRaises(RedirectUriError):
request = self.factory.get("/", data={"response_type": "code", "client_id": "test"})
OAuthAuthorizationParams.from_request(request)
request = self.factory.get(
"/",
data={
@@ -161,9 +150,12 @@ class TestAuthorize(OAuthTestCase):
name=generate_id(),
client_id="test",
authorization_flow=create_test_flow(),
redirect_uris=[RedirectURI(RedirectURIMatchingMode.REGEX, "http://local.invalid?")],
redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://local.invalid?")],
)
with self.assertRaises(RedirectUriError) as cm:
with self.assertRaises(RedirectUriError):
request = self.factory.get("/", data={"response_type": "code", "client_id": "test"})
OAuthAuthorizationParams.from_request(request)
with self.assertRaises(RedirectUriError):
request = self.factory.get(
"/",
data={
@@ -173,7 +165,6 @@ class TestAuthorize(OAuthTestCase):
},
)
OAuthAuthorizationParams.from_request(request)
self.assertEqual(cm.exception.cause, "redirect_uri_no_match")
def test_redirect_uri_invalid_regex(self):
"""test missing/invalid redirect URI (invalid regex)"""
@@ -181,9 +172,12 @@ class TestAuthorize(OAuthTestCase):
name=generate_id(),
client_id="test",
authorization_flow=create_test_flow(),
redirect_uris=[RedirectURI(RedirectURIMatchingMode.REGEX, "+")],
redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "+")],
)
with self.assertRaises(RedirectUriError) as cm:
with self.assertRaises(RedirectUriError):
request = self.factory.get("/", data={"response_type": "code", "client_id": "test"})
OAuthAuthorizationParams.from_request(request)
with self.assertRaises(RedirectUriError):
request = self.factory.get(
"/",
data={
@@ -193,22 +187,23 @@ class TestAuthorize(OAuthTestCase):
},
)
OAuthAuthorizationParams.from_request(request)
self.assertEqual(cm.exception.cause, "redirect_uri_no_match")
def test_redirect_uri_regex(self):
"""test valid redirect URI (regex)"""
def test_empty_redirect_uri(self):
"""test empty redirect URI (configure in provider)"""
OAuth2Provider.objects.create(
name=generate_id(),
client_id="test",
authorization_flow=create_test_flow(),
redirect_uris=[RedirectURI(RedirectURIMatchingMode.REGEX, ".+")],
)
with self.assertRaises(RedirectUriError):
request = self.factory.get("/", data={"response_type": "code", "client_id": "test"})
OAuthAuthorizationParams.from_request(request)
request = self.factory.get(
"/",
data={
"response_type": "code",
"client_id": "test",
"redirect_uri": "http://foo.bar.baz",
"redirect_uri": "http://localhost",
},
)
OAuthAuthorizationParams.from_request(request)
@@ -263,7 +258,7 @@ class TestAuthorize(OAuthTestCase):
GrantTypes.IMPLICIT,
)
# Implicit without openid scope
with self.assertRaises(AuthorizeError) as cm:
with self.assertRaises(AuthorizeError):
request = self.factory.get(
"/",
data={
@@ -290,7 +285,7 @@ class TestAuthorize(OAuthTestCase):
self.assertEqual(
OAuthAuthorizationParams.from_request(request).grant_type, GrantTypes.HYBRID
)
with self.assertRaises(AuthorizeError) as cm:
with self.assertRaises(AuthorizeError):
request = self.factory.get(
"/",
data={
@@ -300,7 +295,6 @@ class TestAuthorize(OAuthTestCase):
},
)
OAuthAuthorizationParams.from_request(request)
self.assertEqual(cm.exception.error, "unsupported_response_type")
def test_full_code(self):
"""Test full authorization"""
@@ -619,54 +613,3 @@ class TestAuthorize(OAuthTestCase):
},
},
)
def test_openid_missing_invalid(self):
"""test request requiring an OpenID scope to be set"""
OAuth2Provider.objects.create(
name=generate_id(),
client_id="test",
authorization_flow=create_test_flow(),
redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://localhost")],
)
request = self.factory.get(
"/",
data={
"response_type": "id_token",
"client_id": "test",
"redirect_uri": "http://localhost",
"scope": "",
},
)
with self.assertRaises(AuthorizeError) as cm:
OAuthAuthorizationParams.from_request(request)
self.assertEqual(cm.exception.cause, "scope_openid_missing")
@apply_blueprint("system/providers-oauth2.yaml")
def test_offline_access_invalid(self):
"""test request for offline_access with invalid response type"""
provider = OAuth2Provider.objects.create(
name=generate_id(),
client_id="test",
authorization_flow=create_test_flow(),
redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://localhost")],
)
provider.property_mappings.set(
ScopeMapping.objects.filter(
managed__in=[
"goauthentik.io/providers/oauth2/scope-openid",
"goauthentik.io/providers/oauth2/scope-offline_access",
]
)
)
request = self.factory.get(
"/",
data={
"response_type": "id_token",
"client_id": "test",
"redirect_uri": "http://localhost",
"scope": f"{SCOPE_OPENID} {SCOPE_OFFLINE_ACCESS}",
"nonce": generate_id(),
},
)
parsed = OAuthAuthorizationParams.from_request(request)
self.assertNotIn(SCOPE_OFFLINE_ACCESS, parsed.scope)

View File

@@ -68,11 +68,7 @@ class TestTokenClientCredentialsStandard(OAuthTestCase):
self.assertEqual(response.status_code, 400)
self.assertJSONEqual(
response.content.decode(),
{
"error": "invalid_grant",
"error_description": TokenError.errors["invalid_grant"],
"request_id": response.headers["X-authentik-id"],
},
{"error": "invalid_grant", "error_description": TokenError.errors["invalid_grant"]},
)
def test_no_provider(self):
@@ -91,11 +87,7 @@ class TestTokenClientCredentialsStandard(OAuthTestCase):
self.assertEqual(response.status_code, 400)
self.assertJSONEqual(
response.content.decode(),
{
"error": "invalid_grant",
"error_description": TokenError.errors["invalid_grant"],
"request_id": response.headers["X-authentik-id"],
},
{"error": "invalid_grant", "error_description": TokenError.errors["invalid_grant"]},
)
def test_permission_denied(self):
@@ -118,11 +110,7 @@ class TestTokenClientCredentialsStandard(OAuthTestCase):
self.assertEqual(response.status_code, 400)
self.assertJSONEqual(
response.content.decode(),
{
"error": "invalid_grant",
"error_description": TokenError.errors["invalid_grant"],
"request_id": response.headers["X-authentik-id"],
},
{"error": "invalid_grant", "error_description": TokenError.errors["invalid_grant"]},
)
def test_incorrect_scopes(self):

View File

@@ -68,11 +68,7 @@ class TestTokenClientCredentialsStandardCompat(OAuthTestCase):
self.assertEqual(response.status_code, 400)
self.assertJSONEqual(
response.content.decode(),
{
"error": "invalid_grant",
"error_description": TokenError.errors["invalid_grant"],
"request_id": response.headers["X-authentik-id"],
},
{"error": "invalid_grant", "error_description": TokenError.errors["invalid_grant"]},
)
def test_wrong_token(self):
@@ -89,11 +85,7 @@ class TestTokenClientCredentialsStandardCompat(OAuthTestCase):
self.assertEqual(response.status_code, 400)
self.assertJSONEqual(
response.content.decode(),
{
"error": "invalid_grant",
"error_description": TokenError.errors["invalid_grant"],
"request_id": response.headers["X-authentik-id"],
},
{"error": "invalid_grant", "error_description": TokenError.errors["invalid_grant"]},
)
def test_no_provider(self):
@@ -112,11 +104,7 @@ class TestTokenClientCredentialsStandardCompat(OAuthTestCase):
self.assertEqual(response.status_code, 400)
self.assertJSONEqual(
response.content.decode(),
{
"error": "invalid_grant",
"error_description": TokenError.errors["invalid_grant"],
"request_id": response.headers["X-authentik-id"],
},
{"error": "invalid_grant", "error_description": TokenError.errors["invalid_grant"]},
)
def test_permission_denied(self):
@@ -139,11 +127,7 @@ class TestTokenClientCredentialsStandardCompat(OAuthTestCase):
self.assertEqual(response.status_code, 400)
self.assertJSONEqual(
response.content.decode(),
{
"error": "invalid_grant",
"error_description": TokenError.errors["invalid_grant"],
"request_id": response.headers["X-authentik-id"],
},
{"error": "invalid_grant", "error_description": TokenError.errors["invalid_grant"]},
)
def test_successful(self):

View File

@@ -68,11 +68,7 @@ class TestTokenClientCredentialsUserNamePassword(OAuthTestCase):
self.assertEqual(response.status_code, 400)
self.assertJSONEqual(
response.content.decode(),
{
"error": "invalid_grant",
"error_description": TokenError.errors["invalid_grant"],
"request_id": response.headers["X-authentik-id"],
},
{"error": "invalid_grant", "error_description": TokenError.errors["invalid_grant"]},
)
def test_wrong_token(self):
@@ -90,11 +86,7 @@ class TestTokenClientCredentialsUserNamePassword(OAuthTestCase):
self.assertEqual(response.status_code, 400)
self.assertJSONEqual(
response.content.decode(),
{
"error": "invalid_grant",
"error_description": TokenError.errors["invalid_grant"],
"request_id": response.headers["X-authentik-id"],
},
{"error": "invalid_grant", "error_description": TokenError.errors["invalid_grant"]},
)
def test_no_provider(self):
@@ -114,11 +106,7 @@ class TestTokenClientCredentialsUserNamePassword(OAuthTestCase):
self.assertEqual(response.status_code, 400)
self.assertJSONEqual(
response.content.decode(),
{
"error": "invalid_grant",
"error_description": TokenError.errors["invalid_grant"],
"request_id": response.headers["X-authentik-id"],
},
{"error": "invalid_grant", "error_description": TokenError.errors["invalid_grant"]},
)
def test_permission_denied(self):
@@ -142,11 +130,7 @@ class TestTokenClientCredentialsUserNamePassword(OAuthTestCase):
self.assertEqual(response.status_code, 400)
self.assertJSONEqual(
response.content.decode(),
{
"error": "invalid_grant",
"error_description": TokenError.errors["invalid_grant"],
"request_id": response.headers["X-authentik-id"],
},
{"error": "invalid_grant", "error_description": TokenError.errors["invalid_grant"]},
)
def test_successful(self):

View File

@@ -80,7 +80,6 @@ class TestTokenPKCE(OAuthTestCase):
"revoked, does not match the redirection URI used in the authorization "
"request, or was issued to another client"
),
"request_id": response.headers["X-authentik-id"],
},
)
self.assertEqual(response.status_code, 400)
@@ -137,7 +136,6 @@ class TestTokenPKCE(OAuthTestCase):
"revoked, does not match the redirection URI used in the authorization "
"request, or was issued to another client"
),
"request_id": response.headers["X-authentik-id"],
},
)
self.assertEqual(response.status_code, 400)

View File

@@ -190,7 +190,7 @@ class OAuthAuthorizationParams:
allowed_redirect_urls = self.provider.redirect_uris
if not self.redirect_uri:
LOGGER.warning("Missing redirect uri.")
raise RedirectUriError("", allowed_redirect_urls).with_cause("redirect_uri_missing")
raise RedirectUriError("", allowed_redirect_urls)
if len(allowed_redirect_urls) < 1:
LOGGER.info("Setting redirect for blank redirect_uris", redirect=self.redirect_uri)
@@ -219,14 +219,10 @@ class OAuthAuthorizationParams:
provider=self.provider,
)
if not match_found:
raise RedirectUriError(self.redirect_uri, allowed_redirect_urls).with_cause(
"redirect_uri_no_match"
)
raise RedirectUriError(self.redirect_uri, allowed_redirect_urls)
# Check against forbidden schemes
if urlparse(self.redirect_uri).scheme in FORBIDDEN_URI_SCHEMES:
raise RedirectUriError(self.redirect_uri, allowed_redirect_urls).with_cause(
"redirect_uri_forbidden_scheme"
)
raise RedirectUriError(self.redirect_uri, allowed_redirect_urls)
def check_scope(self, github_compat=False):
"""Ensure openid scope is set in Hybrid flows, or when requesting an id_token"""
@@ -255,9 +251,7 @@ class OAuthAuthorizationParams:
or self.response_type in [ResponseTypes.ID_TOKEN, ResponseTypes.ID_TOKEN_TOKEN]
):
LOGGER.warning("Missing 'openid' scope.")
raise AuthorizeError(
self.redirect_uri, "invalid_scope", self.grant_type, self.state
).with_cause("scope_openid_missing")
raise AuthorizeError(self.redirect_uri, "invalid_scope", self.grant_type, self.state)
if SCOPE_OFFLINE_ACCESS in self.scope:
# https://openid.net/specs/openid-connect-core-1_0.html#OfflineAccess
# Don't explicitly request consent with offline_access, as the spec allows for
@@ -292,9 +286,7 @@ class OAuthAuthorizationParams:
return
if not self.nonce:
LOGGER.warning("Missing nonce for OpenID Request")
raise AuthorizeError(
self.redirect_uri, "invalid_request", self.grant_type, self.state
).with_cause("nonce_missing")
raise AuthorizeError(self.redirect_uri, "invalid_request", self.grant_type, self.state)
def check_code_challenge(self):
"""PKCE validation of the transformation method."""
@@ -353,10 +345,10 @@ class AuthorizationFlowInitView(PolicyAccessView):
self.request, github_compat=self.github_compat
)
except AuthorizeError as error:
LOGGER.warning(error.description, redirect_uri=error.redirect_uri, cause=error.cause)
LOGGER.warning(error.description, redirect_uri=error.redirect_uri)
raise RequestValidationError(error.get_response(self.request)) from None
except OAuth2Error as error:
LOGGER.warning(error.description, cause=error.cause)
LOGGER.warning(error.description)
raise RequestValidationError(
bad_request_message(self.request, error.description, title=error.error)
) from None

View File

@@ -2,7 +2,7 @@
from urllib.parse import urlencode
from django.http import HttpRequest, HttpResponse
from django.http import HttpRequest, HttpResponse, HttpResponseBadRequest, JsonResponse
from django.urls import reverse
from django.utils.decorators import method_decorator
from django.utils.timezone import now
@@ -14,9 +14,7 @@ from structlog.stdlib import get_logger
from authentik.core.models import Application
from authentik.lib.config import CONFIG
from authentik.lib.utils.time import timedelta_from_string
from authentik.providers.oauth2.errors import DeviceCodeError
from authentik.providers.oauth2.models import DeviceToken, OAuth2Provider
from authentik.providers.oauth2.utils import TokenResponse
from authentik.providers.oauth2.views.device_init import QS_KEY_CODE
LOGGER = get_logger()
@@ -30,36 +28,38 @@ class DeviceView(View):
provider: OAuth2Provider
scopes: list[str] = []
def parse_request(self):
def parse_request(self) -> HttpResponse | None:
"""Parse incoming request"""
client_id = self.request.POST.get("client_id", None)
if not client_id:
raise DeviceCodeError("invalid_client")
provider = OAuth2Provider.objects.filter(client_id=client_id).first()
return HttpResponseBadRequest()
provider = OAuth2Provider.objects.filter(
client_id=client_id,
).first()
if not provider:
raise DeviceCodeError("invalid_client")
return HttpResponseBadRequest()
try:
_ = provider.application
except Application.DoesNotExist:
raise DeviceCodeError("invalid_client") from None
return HttpResponseBadRequest()
self.provider = provider
self.client_id = client_id
self.scopes = self.request.POST.get("scope", "").split(" ")
return None
def dispatch(self, request: HttpRequest, *args, **kwargs) -> HttpResponse:
throttle = AnonRateThrottle()
throttle.rate = CONFIG.get("throttle.providers.oauth2.device", "20/hour")
throttle.num_requests, throttle.duration = throttle.parse_rate(throttle.rate)
if not throttle.allow_request(request, self):
return TokenResponse(DeviceCodeError("slow_down").create_dict(request), status=429)
return HttpResponse(status=429)
return super().dispatch(request, *args, **kwargs)
def post(self, request: HttpRequest) -> HttpResponse:
"""Generate device token"""
try:
self.parse_request()
except DeviceCodeError as exc:
return TokenResponse(exc.create_dict(request), status=400)
resp = self.parse_request()
if resp:
return resp
until = timedelta_from_string(self.provider.access_code_validity)
token: DeviceToken = DeviceToken.objects.create(
expires=now() + until, provider=self.provider, _scope=" ".join(self.scopes)
@@ -67,7 +67,7 @@ class DeviceView(View):
device_url = self.request.build_absolute_uri(
reverse("authentik_providers_oauth2_root:device-login")
)
return TokenResponse(
return JsonResponse(
{
"device_code": token.device_code,
"verification_uri": device_url,

View File

@@ -598,9 +598,9 @@ class TokenView(View):
return TokenResponse(self.create_device_code_response())
raise TokenError("unsupported_grant_type")
except (TokenError, DeviceCodeError) as error:
return TokenResponse(error.create_dict(request), status=400)
return TokenResponse(error.create_dict(), status=400)
except UserAuthError as error:
return TokenResponse(error.create_dict(request), status=403)
return TokenResponse(error.create_dict(), status=403)
def create_code_response(self) -> dict[str, Any]:
"""See https://datatracker.ietf.org/doc/html/rfc6749#section-4.1"""

View File

@@ -65,7 +65,7 @@ class TokenRevokeView(View):
return TokenResponse(data={}, status=200)
except TokenRevocationError as exc:
return TokenResponse(exc.create_dict(request), status=401)
return TokenResponse(exc.create_dict(), status=401)
except Http404:
# Token not found should return a HTTP 200
# https://datatracker.ietf.org/doc/html/rfc7009#section-2.2

View File

@@ -102,7 +102,6 @@ class IngressReconciler(KubernetesObjectReconciler[V1Ingress]):
# Buffer sizes for large headers with JWTs
"nginx.ingress.kubernetes.io/proxy-buffers-number": "4",
"nginx.ingress.kubernetes.io/proxy-buffer-size": "16k",
"nginx.ingress.kubernetes.io/proxy-busy-buffers-size": "32k",
# Enable TLS in traefik
"traefik.ingress.kubernetes.io/router.tls": "true",
}

View File

@@ -23,6 +23,7 @@ from authentik.core.models import Application
from authentik.events.models import Event, EventAction
from authentik.lib.expression.exceptions import ControlFlowException
from authentik.lib.sync.mapper import PropertyMappingManager
from authentik.lib.utils.errors import exception_to_string
from authentik.policies.api.exec import PolicyTestResultSerializer
from authentik.policies.engine import PolicyEngine
from authentik.policies.types import PolicyResult
@@ -141,9 +142,9 @@ class RadiusOutpostConfigViewSet(ListModelMixin, GenericViewSet):
# Value error can be raised when assigning invalid data to an attribute
Event.new(
EventAction.CONFIGURATION_ERROR,
message="Failed to evaluate property-mapping",
message=f"Failed to evaluate property-mapping {exception_to_string(exc)}",
mapping=exc.mapping,
).with_exception(exc).save()
).save()
return None
return b64encode(packet.RequestPacket()).decode()

View File

@@ -2,7 +2,7 @@
from enum import Enum
from pydantic import AnyUrl, BaseModel, ConfigDict, Field
from pydantic import Field
from pydanticscim.group import Group as BaseGroup
from pydanticscim.responses import PatchOperation as BasePatchOperation
from pydanticscim.responses import PatchRequest as BasePatchRequest
@@ -12,95 +12,19 @@ from pydanticscim.service_provider import ChangePassword, Filter, Patch, Sort
from pydanticscim.service_provider import (
ServiceProviderConfiguration as BaseServiceProviderConfiguration,
)
from pydanticscim.user import AddressKind
from pydanticscim.user import User as BaseUser
SCIM_USER_SCHEMA = "urn:ietf:params:scim:schemas:core:2.0:User"
SCIM_GROUP_SCHEMA = "urn:ietf:params:scim:schemas:core:2.0:Group"
class Address(BaseModel):
formatted: str | None = Field(
None,
description="The full mailing address, formatted for display "
"or use with a mailing label. This attribute MAY contain newlines.",
)
streetAddress: str | None = Field(
None,
description="The full street address component, which may "
"include house number, street name, P.O. box, and multi-line "
"extended street address information. This attribute MAY contain newlines.",
)
locality: str | None = Field(None, description="The city or locality component.")
region: str | None = Field(None, description="The state or region component.")
postalCode: str | None = Field(None, description="The zip code or postal code component.")
country: str | None = Field(None, description="The country name component.")
type: AddressKind | None = Field(
None,
description="A label indicating the attribute's function, e.g., 'work' or 'home'.",
)
primary: bool | None = None
class Manager(BaseModel):
value: str | None = Field(
None,
description="The id of the SCIM resource representingthe User's manager. REQUIRED.",
)
ref: AnyUrl | None = Field(
None,
alias="$ref",
description="The URI of the SCIM resource representing the User's manager. REQUIRED.",
)
displayName: str | None = Field(
None,
description="The displayName of the User's manager. OPTIONAL and READ-ONLY.",
)
class EnterpriseUser(BaseModel):
employeeNumber: str | None = Field(
None,
description="Numeric or alphanumeric identifier assigned to a person, "
"typically based on order of hire or association with anorganization.",
)
costCenter: str | None = Field(None, description="Identifies the name of a cost center.")
organization: str | None = Field(None, description="Identifies the name of an organization.")
division: str | None = Field(None, description="Identifies the name of a division.")
department: str | None = Field(
None,
description="Numeric or alphanumeric identifier assigned to a person,"
" typically based on order of hire or association with anorganization.",
)
manager: Manager | None = Field(
None,
description="The User's manager. A complex type that optionally allows "
"service providers to represent organizational hierarchy by referencing"
" the 'id' attribute of another User.",
)
class User(BaseUser):
"""Modified User schema with added externalId field"""
model_config = ConfigDict(serialize_by_alias=True)
id: str | int | None = None
schemas: list[str] = [SCIM_USER_SCHEMA]
externalId: str | None = None
meta: dict | None = None
addresses: list[Address] | None = Field(
None,
description=(
"A physical mailing address for this User. Canonical type "
"values of 'work', 'home', and 'other'."
),
)
enterprise_user: EnterpriseUser | None = Field(
default=None,
alias="urn:ietf:params:scim:schemas:extension:enterprise:2.0:User",
serialization_alias="urn:ietf:params:scim:schemas:extension:enterprise:2.0:User",
)
class Group(BaseGroup):
@@ -168,7 +92,7 @@ class PatchOperation(BasePatchOperation):
"""PatchOperation with optional path"""
op: PatchOp
path: str | None = None
path: str | None
class SCIMError(BaseSCIMError):

View File

@@ -28,6 +28,7 @@ from tenant_schemas_celery.app import CeleryApp as TenantAwareCeleryApp
from authentik import get_full_version
from authentik.lib.sentry import should_ignore_exception
from authentik.lib.utils.errors import exception_to_string
# set the default Django settings module for the 'celery' program.
os.environ.setdefault("DJANGO_SETTINGS_MODULE", "authentik.root.settings")
@@ -82,8 +83,8 @@ def task_error_hook(task_id: str, exception: Exception, traceback, *args, **kwar
CTX_TASK_ID.set(...)
if not should_ignore_exception(exception):
Event.new(
EventAction.SYSTEM_EXCEPTION, message="Failed to execute task", task_id=task_id
).with_exception(exception).save()
EventAction.SYSTEM_EXCEPTION, message=exception_to_string(exception), task_id=task_id
).save()
def _get_startup_tasks_default_tenant() -> list[Callable]:

View File

@@ -49,8 +49,6 @@ class ReadyView(View):
def dispatch(self, request: HttpRequest) -> HttpResponse:
try:
for db_conn in connections.all():
# Force connection reload
db_conn.connect()
_ = db_conn.cursor()
except OperationalError: # pragma: no cover
return HttpResponse(status=503)

View File

@@ -156,17 +156,16 @@ SPECTACULAR_SETTINGS = {
},
"ENUM_NAME_OVERRIDES": {
"CountryCodeEnum": "django_countries.countries",
"DeviceClassesEnum": "authentik.stages.authenticator_validate.models.DeviceClasses",
"EventActions": "authentik.events.models.EventAction",
"FlowDesignationEnum": "authentik.flows.models.FlowDesignation",
"FlowLayoutEnum": "authentik.flows.models.FlowLayout",
"LDAPAPIAccessMode": "authentik.providers.ldap.models.APIAccessMode",
"OutgoingSyncDeleteAction": "authentik.lib.sync.outgoing.models.OutgoingSyncDeleteAction",
"PolicyEngineMode": "authentik.policies.models.PolicyEngineMode",
"PromptTypeEnum": "authentik.stages.prompt.models.FieldTypes",
"ProxyMode": "authentik.providers.proxy.models.ProxyMode",
"UserTypeEnum": "authentik.core.models.UserTypes",
"PromptTypeEnum": "authentik.stages.prompt.models.FieldTypes",
"LDAPAPIAccessMode": "authentik.providers.ldap.models.APIAccessMode",
"UserVerificationEnum": "authentik.stages.authenticator_webauthn.models.UserVerification",
"UserTypeEnum": "authentik.core.models.UserTypes",
"OutgoingSyncDeleteAction": "authentik.lib.sync.outgoing.models.OutgoingSyncDeleteAction",
},
"ENUM_ADD_EXPLICIT_BLANK_NULL_CHOICE": False,
"ENUM_GENERATE_CHOICE_DESCRIPTION": False,

View File

@@ -4,11 +4,11 @@ from pathlib import Path
from secrets import token_urlsafe
from tempfile import gettempdir
from django.test import TransactionTestCase
from django.test import TestCase
from django.urls import reverse
class TestRoot(TransactionTestCase):
class TestRoot(TestCase):
"""Test root application"""
def setUp(self):

View File

@@ -8,6 +8,7 @@ from authentik.events.models import TaskStatus
from authentik.events.system_tasks import SystemTask
from authentik.lib.config import CONFIG
from authentik.lib.sync.outgoing.exceptions import StopSync
from authentik.lib.utils.errors import exception_to_string
from authentik.root.celery import CELERY_APP
from authentik.sources.kerberos.models import KerberosSource
from authentik.sources.kerberos.sync import KerberosSync
@@ -63,5 +64,5 @@ def kerberos_sync_single(self, source_pk: str):
syncer.sync()
self.set_status(TaskStatus.SUCCESSFUL, *syncer.messages)
except StopSync as exc:
LOGGER.warning("Error syncing kerberos", exc=exc, source=source)
LOGGER.warning(exception_to_string(exc))
self.set_error(exc)

View File

@@ -12,6 +12,7 @@ from authentik.events.models import TaskStatus
from authentik.events.system_tasks import SystemTask
from authentik.lib.config import CONFIG
from authentik.lib.sync.outgoing.exceptions import StopSync
from authentik.lib.utils.errors import exception_to_string
from authentik.lib.utils.reflection import class_to_path, path_to_class
from authentik.root.celery import CELERY_APP
from authentik.sources.ldap.models import LDAPSource
@@ -148,5 +149,5 @@ def ldap_sync(self: SystemTask, source_pk: str, sync_class: str, page_cache_key:
cache.delete(page_cache_key)
except (LDAPException, StopSync) as exc:
# No explicit event is created here as .set_status with an error will do that
LOGGER.warning("Failed to sync LDAP", exc=exc, source=source)
LOGGER.warning(exception_to_string(exc))
self.set_error(exc)

View File

@@ -10,7 +10,6 @@ AUTHENTIK_SOURCES_OAUTH_TYPES = [
"authentik.sources.oauth.types.apple",
"authentik.sources.oauth.types.azure_ad",
"authentik.sources.oauth.types.discord",
"authentik.sources.oauth.types.entra_id",
"authentik.sources.oauth.types.facebook",
"authentik.sources.oauth.types.github",
"authentik.sources.oauth.types.gitlab",

View File

@@ -232,7 +232,7 @@ class GoogleOAuthSource(CreatableType, OAuthSource):
class AzureADOAuthSource(CreatableType, OAuthSource):
"""(Deprecated) Social Login using Azure AD."""
"""Social Login using Azure AD."""
class Meta:
abstract = True
@@ -240,17 +240,6 @@ class AzureADOAuthSource(CreatableType, OAuthSource):
verbose_name_plural = _("Azure AD OAuth Sources")
# TODO: When removing this, add a migration for OAuthSource that sets
# provider_type to `entraid` if it is currently `azuread`
class EntraIDOAuthSource(CreatableType, OAuthSource):
"""Social Login using Entra ID."""
class Meta:
abstract = True
verbose_name = _("Entra ID OAuth Source")
verbose_name_plural = _("Entra ID OAuth Sources")
class OpenIDConnectOAuthSource(CreatableType, OAuthSource):
"""Login using a Generic OpenID-Connect compliant provider."""

View File

@@ -1,12 +1,12 @@
"""Entra ID Type tests"""
"""azure ad Type tests"""
from django.test import TestCase
from authentik.sources.oauth.models import OAuthSource
from authentik.sources.oauth.types.entra_id import EntraIDOAuthCallback, EntraIDType
from authentik.sources.oauth.types.azure_ad import AzureADOAuthCallback, AzureADType
# https://docs.microsoft.com/en-us/graph/api/user-get?view=graph-rest-1.0&tabs=http#response-2
EID_USER = {
AAD_USER = {
"@odata.context": "https://graph.microsoft.com/v1.0/$metadata#users/$entity",
"@odata.id": (
"https://graph.microsoft.com/v2/7ce9b89e-646a-41d2-9fa6-8371c6a8423d/"
@@ -41,11 +41,11 @@ class TestTypeAzureAD(TestCase):
def test_enroll_context(self):
"""Test azure_ad Enrollment context"""
ak_context = EntraIDType().get_base_user_properties(source=self.source, info=EID_USER)
self.assertEqual(ak_context["username"], EID_USER["userPrincipalName"])
self.assertEqual(ak_context["email"], EID_USER["mail"])
self.assertEqual(ak_context["name"], EID_USER["displayName"])
ak_context = AzureADType().get_base_user_properties(source=self.source, info=AAD_USER)
self.assertEqual(ak_context["username"], AAD_USER["userPrincipalName"])
self.assertEqual(ak_context["email"], AAD_USER["mail"])
self.assertEqual(ak_context["name"], AAD_USER["displayName"])
def test_user_id(self):
"""Test Entra ID user ID"""
self.assertEqual(EntraIDOAuthCallback().get_user_id(EID_USER), EID_USER["id"])
"""Test azure AD user ID"""
self.assertEqual(AzureADOAuthCallback().get_user_id(AAD_USER), AAD_USER["id"])

View File

@@ -1,17 +1,105 @@
"""AzureAD OAuth2 Views"""
from authentik.sources.oauth.types.entra_id import EntraIDType
from authentik.sources.oauth.types.registry import registry
from typing import Any
# TODO: When removing this, add a migration for OAuthSource that sets
# provider_type to `entraid` if it is currently `azuread`
from requests import RequestException
from structlog.stdlib import get_logger
from authentik.sources.oauth.clients.oauth2 import UserprofileHeaderAuthClient
from authentik.sources.oauth.models import AuthorizationCodeAuthMethod
from authentik.sources.oauth.types.oidc import OpenIDConnectOAuth2Callback
from authentik.sources.oauth.types.registry import SourceType, registry
from authentik.sources.oauth.views.redirect import OAuthRedirect
LOGGER = get_logger()
class AzureADOAuthRedirect(OAuthRedirect):
"""Azure AD OAuth2 Redirect"""
def get_additional_parameters(self, source): # pragma: no cover
return {
"scope": ["openid", "https://graph.microsoft.com/User.Read"],
}
class AzureADClient(UserprofileHeaderAuthClient):
"""Fetch AzureAD group information"""
def get_profile_info(self, token):
profile_data = super().get_profile_info(token)
if "https://graph.microsoft.com/GroupMember.Read.All" not in self.source.additional_scopes:
return profile_data
group_response = self.session.request(
"get",
"https://graph.microsoft.com/v1.0/me/memberOf",
headers={"Authorization": f"{token['token_type']} {token['access_token']}"},
)
try:
group_response.raise_for_status()
except RequestException as exc:
LOGGER.warning(
"Unable to fetch user profile",
exc=exc,
response=exc.response.text if exc.response else str(exc),
)
return None
profile_data["raw_groups"] = group_response.json()
return profile_data
class AzureADOAuthCallback(OpenIDConnectOAuth2Callback):
"""AzureAD OAuth2 Callback"""
client_class = AzureADClient
def get_user_id(self, info: dict[str, str]) -> str:
# Default try to get `id` for the Graph API endpoint
# fallback to OpenID logic in case the profile URL was changed
return info.get("id", super().get_user_id(info))
@registry.register()
class AzureADType(EntraIDType):
class AzureADType(SourceType):
"""Azure AD Type definition"""
callback_view = AzureADOAuthCallback
redirect_view = AzureADOAuthRedirect
verbose_name = "Azure AD"
name = "azuread"
urls_customizable = True
authorization_url = "https://login.microsoftonline.com/common/oauth2/v2.0/authorize"
access_token_url = "https://login.microsoftonline.com/common/oauth2/v2.0/token" # nosec
profile_url = "https://graph.microsoft.com/v1.0/me"
oidc_well_known_url = (
"https://login.microsoftonline.com/common/.well-known/openid-configuration"
)
oidc_jwks_url = "https://login.microsoftonline.com/common/discovery/keys"
authorization_code_auth_method = AuthorizationCodeAuthMethod.POST_BODY
def get_base_user_properties(self, info: dict[str, Any], **kwargs) -> dict[str, Any]:
mail = info.get("mail", None) or info.get("otherMails", [None])[0]
# Format group info
groups = []
group_id_dict = {}
for group in info.get("raw_groups", {}).get("value", []):
if group["@odata.type"] != "#microsoft.graph.group":
continue
groups.append(group["id"])
group_id_dict[group["id"]] = group
info["raw_groups"] = group_id_dict
return {
"username": info.get("userPrincipalName"),
"email": mail,
"name": info.get("displayName"),
"groups": groups,
}
def get_base_group_properties(self, source, group_id, **kwargs):
raw_group = kwargs["info"]["raw_groups"][group_id]
return {
"name": raw_group["displayName"],
}

View File

@@ -1,102 +0,0 @@
"""EntraID OAuth2 Views"""
from typing import Any
from requests import RequestException
from structlog.stdlib import get_logger
from authentik.sources.oauth.clients.oauth2 import UserprofileHeaderAuthClient
from authentik.sources.oauth.models import AuthorizationCodeAuthMethod
from authentik.sources.oauth.types.oidc import OpenIDConnectOAuth2Callback
from authentik.sources.oauth.types.registry import SourceType, registry
from authentik.sources.oauth.views.redirect import OAuthRedirect
LOGGER = get_logger()
class EntraIDOAuthRedirect(OAuthRedirect):
"""Entra ID OAuth2 Redirect"""
def get_additional_parameters(self, source): # pragma: no cover
return {
"scope": ["openid", "https://graph.microsoft.com/User.Read"],
}
class EntraIDClient(UserprofileHeaderAuthClient):
"""Fetch EntraID group information"""
def get_profile_info(self, token):
profile_data = super().get_profile_info(token)
if "https://graph.microsoft.com/GroupMember.Read.All" not in self.source.additional_scopes:
return profile_data
group_response = self.session.request(
"get",
"https://graph.microsoft.com/v1.0/me/memberOf",
headers={"Authorization": f"{token['token_type']} {token['access_token']}"},
)
try:
group_response.raise_for_status()
except RequestException as exc:
LOGGER.warning(
"Unable to fetch user profile",
exc=exc,
response=exc.response.text if exc.response else str(exc),
)
return None
profile_data["raw_groups"] = group_response.json()
return profile_data
class EntraIDOAuthCallback(OpenIDConnectOAuth2Callback):
"""EntraID OAuth2 Callback"""
client_class = EntraIDClient
def get_user_id(self, info: dict[str, str]) -> str:
# Default try to get `id` for the Graph API endpoint
# fallback to OpenID logic in case the profile URL was changed
return info.get("id", super().get_user_id(info))
@registry.register()
class EntraIDType(SourceType):
"""Entra ID Type definition"""
callback_view = EntraIDOAuthCallback
redirect_view = EntraIDOAuthRedirect
verbose_name = "Entra ID"
name = "entraid"
urls_customizable = True
authorization_url = "https://login.microsoftonline.com/common/oauth2/v2.0/authorize"
access_token_url = "https://login.microsoftonline.com/common/oauth2/v2.0/token" # nosec
profile_url = "https://graph.microsoft.com/v1.0/me"
oidc_jwks_url = "https://login.microsoftonline.com/common/discovery/keys"
authorization_code_auth_method = AuthorizationCodeAuthMethod.POST_BODY
def get_base_user_properties(self, info: dict[str, Any], **kwargs) -> dict[str, Any]:
mail = info.get("mail", None) or info.get("otherMails", [None])[0]
# Format group info
groups = []
group_id_dict = {}
for group in info.get("raw_groups", {}).get("value", []):
if group["@odata.type"] != "#microsoft.graph.group":
continue
groups.append(group["id"])
group_id_dict[group["id"]] = group
info["raw_groups"] = group_id_dict
return {
"username": info.get("userPrincipalName"),
"email": mail,
"name": info.get("displayName"),
"groups": groups,
}
def get_base_group_properties(self, source, group_id, **kwargs):
raw_group = kwargs["info"]["raw_groups"][group_id]
return {
"name": raw_group["displayName"],
}

View File

@@ -18,7 +18,6 @@ class SCIMSourceGroupSerializer(SourceSerializer):
model = SCIMSourceGroup
fields = [
"id",
"external_id",
"group",
"group_obj",
"source",
@@ -32,5 +31,5 @@ class SCIMSourceGroupViewSet(UsedByMixin, ModelViewSet):
queryset = SCIMSourceGroup.objects.all().select_related("group")
serializer_class = SCIMSourceGroupSerializer
filterset_fields = ["source__slug", "group__name", "group__group_uuid"]
search_fields = ["source__slug", "group__name", "attributes", "external_id"]
search_fields = ["source__slug", "group__name", "attributes"]
ordering = ["group__name"]

View File

@@ -18,7 +18,6 @@ class SCIMSourceUserSerializer(SourceSerializer):
model = SCIMSourceUser
fields = [
"id",
"external_id",
"user",
"user_obj",
"source",
@@ -32,5 +31,5 @@ class SCIMSourceUserViewSet(UsedByMixin, ModelViewSet):
queryset = SCIMSourceUser.objects.all().select_related("user")
serializer_class = SCIMSourceUserSerializer
filterset_fields = ["source__slug", "user__username", "user__id"]
search_fields = ["source__slug", "user__username", "attributes", "user__uuid", "external_id"]
search_fields = ["source__slug", "user__username", "attributes"]
ordering = ["user__username"]

View File

@@ -1,4 +0,0 @@
SCIM_URN_SCHEMA = "urn:ietf:params:scim:schemas:core:2.0:Schema"
SCIM_URN_GROUP = "urn:ietf:params:scim:schemas:core:2.0:Group"
SCIM_URN_USER = "urn:ietf:params:scim:schemas:core:2.0:User"
SCIM_URN_USER_ENTERPRISE = "urn:ietf:params:scim:schemas:extension:enterprise:2.0:User"

View File

@@ -0,0 +1,8 @@
"""SCIM Errors"""
from authentik.lib.sentry import SentryIgnoredException
class PatchError(SentryIgnoredException):
"""Error raised within an atomic block when an error happened
so nothing is saved"""

View File

@@ -1,98 +0,0 @@
# Generated by Django 5.1.11 on 2025-07-13 01:07
import uuid
from django.db import migrations, models
from django.apps.registry import Apps
from django.db.backends.base.schema import BaseDatabaseSchemaEditor
def migrate_ext_id(apps: Apps, schema_editor: BaseDatabaseSchemaEditor):
SCIMSourceUser = apps.get_model("authentik_sources_scim", "SCIMSourceUser")
SCIMSourceGroup = apps.get_model("authentik_sources_scim", "SCIMSourceGroup")
db_alias = schema_editor.connection.alias
for user in SCIMSourceUser.objects.using(db_alias).all():
user.external_id = user.id
user.save(update_fields=["external_id"])
for group in SCIMSourceGroup.objects.using(db_alias).all():
group.external_id = group.id
group.save(update_fields=["external_id"])
class Migration(migrations.Migration):
dependencies = [
("authentik_sources_scim", "0002_scimsourcepropertymapping"),
]
operations = [
migrations.AlterUniqueTogether(
name="scimsourcegroup",
unique_together=set(),
),
migrations.AlterUniqueTogether(
name="scimsourceuser",
unique_together=set(),
),
migrations.AddField(
model_name="scimsourcegroup",
name="external_id",
field=models.TextField(default=None, null=True),
preserve_default=False,
),
migrations.AddField(
model_name="scimsourceuser",
name="external_id",
field=models.TextField(default=None, null=True),
preserve_default=False,
),
migrations.AlterUniqueTogether(
name="scimsourcegroup",
unique_together={("external_id", "source")},
),
migrations.AlterUniqueTogether(
name="scimsourceuser",
unique_together={("external_id", "source")},
),
migrations.RunPython(migrate_ext_id, migrations.RunPython.noop),
migrations.AlterField(
model_name="scimsourcegroup",
name="external_id",
field=models.TextField(),
preserve_default=False,
),
migrations.AlterField(
model_name="scimsourceuser",
name="external_id",
field=models.TextField(),
preserve_default=False,
),
migrations.AddIndex(
model_name="scimsourcegroup",
index=models.Index(fields=["external_id"], name="authentik_s_externa_05e346_idx"),
),
migrations.AddIndex(
model_name="scimsourceuser",
index=models.Index(fields=["external_id"], name="authentik_s_externa_4bd760_idx"),
),
migrations.AlterField(
model_name="scimsourcegroup",
name="id",
field=models.TextField(default=uuid.uuid4, primary_key=True, serialize=False),
),
migrations.AlterField(
model_name="scimsourceuser",
name="id",
field=models.TextField(default=uuid.uuid4, primary_key=True, serialize=False),
),
migrations.AddField(
model_name="scimsourcegroup",
name="last_update",
field=models.DateTimeField(auto_now=True),
),
migrations.AddField(
model_name="scimsourceuser",
name="last_update",
field=models.DateTimeField(auto_now=True),
),
]

View File

@@ -1,7 +1,6 @@
"""SCIM Source"""
from typing import Any
from uuid import uuid4
from django.db import models
from django.templatetags.static import static
@@ -104,12 +103,10 @@ class SCIMSourcePropertyMapping(PropertyMapping):
class SCIMSourceUser(SerializerModel):
"""Mapping of a user and source to a SCIM user ID"""
id = models.TextField(primary_key=True, default=uuid4)
external_id = models.TextField()
id = models.TextField(primary_key=True)
user = models.ForeignKey(User, on_delete=models.CASCADE)
source = models.ForeignKey(SCIMSource, on_delete=models.CASCADE)
attributes = models.JSONField(default=dict)
last_update = models.DateTimeField(auto_now=True)
@property
def serializer(self) -> BaseSerializer:
@@ -118,10 +115,7 @@ class SCIMSourceUser(SerializerModel):
return SCIMSourceUserSerializer
class Meta:
unique_together = (("external_id", "source"),)
indexes = [
models.Index(fields=["external_id"]),
]
unique_together = (("id", "user", "source"),)
def __str__(self) -> str:
return f"SCIM User {self.user_id} to {self.source_id}"
@@ -130,12 +124,10 @@ class SCIMSourceUser(SerializerModel):
class SCIMSourceGroup(SerializerModel):
"""Mapping of a group and source to a SCIM user ID"""
id = models.TextField(primary_key=True, default=uuid4)
external_id = models.TextField()
id = models.TextField(primary_key=True)
group = models.ForeignKey(Group, on_delete=models.CASCADE)
source = models.ForeignKey(SCIMSource, on_delete=models.CASCADE)
attributes = models.JSONField(default=dict)
last_update = models.DateTimeField(auto_now=True)
@property
def serializer(self) -> BaseSerializer:
@@ -144,10 +136,7 @@ class SCIMSourceGroup(SerializerModel):
return SCIMSourceGroupSerializer
class Meta:
unique_together = (("external_id", "source"),)
indexes = [
models.Index(fields=["external_id"]),
]
unique_together = (("id", "group", "source"),)
def __str__(self) -> str:
return f"SCIM Group {self.group_id} to {self.source_id}"

View File

@@ -1,180 +0,0 @@
from dataclasses import dataclass
from enum import Enum
from authentik.sources.scim.constants import (
SCIM_URN_GROUP,
SCIM_URN_SCHEMA,
SCIM_URN_USER,
SCIM_URN_USER_ENTERPRISE,
)
# Token types for SCIM path parsing
class TokenType(Enum):
ATTRIBUTE = "ATTRIBUTE"
DOT = "DOT"
LBRACKET = "LBRACKET"
RBRACKET = "RBRACKET"
LPAREN = "LPAREN"
RPAREN = "RPAREN"
STRING = "STRING"
NUMBER = "NUMBER"
BOOLEAN = "BOOLEAN"
NULL = "NULL"
OPERATOR = "OPERATOR"
AND = "AND"
OR = "OR"
NOT = "NOT"
EOF = "EOF"
@dataclass
class Token:
type: TokenType
value: str
position: int = 0
class SCIMPathLexer:
"""Lexer for SCIM paths and filter expressions"""
OPERATORS = ["eq", "ne", "co", "sw", "ew", "gt", "lt", "ge", "le", "pr"]
def __init__(self, text: str):
self.schema_urns = [
SCIM_URN_SCHEMA,
SCIM_URN_GROUP,
SCIM_URN_USER,
SCIM_URN_USER_ENTERPRISE,
]
self.text = text
self.pos = 0
self.current_char = self.text[self.pos] if self.pos < len(self.text) else None
def advance(self):
"""Move to next character"""
self.pos += 1
self.current_char = self.text[self.pos] if self.pos < len(self.text) else None
def skip_whitespace(self):
"""Skip whitespace characters"""
while self.current_char and self.current_char.isspace():
self.advance()
def read_string(self, quote_char):
"""Read a quoted string"""
value = ""
self.advance() # Skip opening quote
while self.current_char and self.current_char != quote_char:
if self.current_char == "\\":
self.advance()
if self.current_char:
value += self.current_char
self.advance()
else:
value += self.current_char
self.advance()
if self.current_char == quote_char:
self.advance() # Skip closing quote
return value
def read_number(self):
"""Read a number (integer or float)"""
value = ""
while self.current_char and (self.current_char.isdigit() or self.current_char == "."):
value += self.current_char
self.advance()
return value
def read_identifier(self):
"""Read an identifier (attribute name or operator) - supports URN format"""
value = ""
while self.current_char and (self.current_char.isalnum() or self.current_char in "_-:"):
value += self.current_char
self.advance()
# If the identifier value so far is a schema URN, take that as the identifier and
# treat the next part as a sub_attribute
if value in self.schema_urns:
self.current_char = "."
return value
# Handle dots within URN identifiers (like "2.0")
# A dot is part of the identifier if it's followed by a digit
if (
self.current_char == "."
and self.pos + 1 < len(self.text)
and self.text[self.pos + 1].isdigit()
):
value += self.current_char
self.advance()
# Continue reading digits after the dot
while self.current_char and self.current_char.isdigit():
value += self.current_char
self.advance()
return value
def get_next_token(self) -> Token: # noqa PLR0911
"""Get the next token from the input"""
while self.current_char:
if self.current_char.isspace():
self.skip_whitespace()
continue
if self.current_char == ".":
self.advance()
return Token(TokenType.DOT, ".")
if self.current_char == "[":
self.advance()
return Token(TokenType.LBRACKET, "[")
if self.current_char == "]":
self.advance()
return Token(TokenType.RBRACKET, "]")
if self.current_char == "(":
self.advance()
return Token(TokenType.LPAREN, "(")
if self.current_char == ")":
self.advance()
return Token(TokenType.RPAREN, ")")
if self.current_char in "\"'":
quote_char = self.current_char
value = self.read_string(quote_char)
return Token(TokenType.STRING, value)
if self.current_char.isdigit():
value = self.read_number()
return Token(TokenType.NUMBER, value)
if self.current_char.isalpha() or self.current_char == "_":
value = self.read_identifier()
# Check for special keywords
if value.lower() == "true":
return Token(TokenType.BOOLEAN, True)
elif value.lower() == "false":
return Token(TokenType.BOOLEAN, False)
elif value.lower() == "null":
return Token(TokenType.NULL, None)
elif value.lower() == "and":
return Token(TokenType.AND, "and")
elif value.lower() == "or":
return Token(TokenType.OR, "or")
elif value.lower() == "not":
return Token(TokenType.NOT, "not")
elif value.lower() in self.OPERATORS:
return Token(TokenType.OPERATOR, value.lower())
else:
return Token(TokenType.ATTRIBUTE, value)
# Skip unknown characters
self.advance()
return Token(TokenType.EOF, "")

View File

@@ -1,131 +0,0 @@
from typing import Any
from authentik.sources.scim.patch.lexer import SCIMPathLexer, TokenType
class SCIMPathParser:
"""Parser for SCIM paths including filter expressions"""
def __init__(self):
self.lexer = None
self.current_token = None
def parse_path(self, path: str | None) -> list[dict[str, Any]]:
"""Parse a SCIM path into components"""
self.lexer = SCIMPathLexer(path)
self.current_token = self.lexer.get_next_token()
components = []
while self.current_token.type != TokenType.EOF:
component = self._parse_path_component()
if component:
components.append(component)
return components
def _parse_path_component(self) -> dict[str, Any] | None:
"""Parse a single path component"""
if self.current_token.type != TokenType.ATTRIBUTE:
return None
attribute = self.current_token.value
self._consume(TokenType.ATTRIBUTE)
filter_expr = None
sub_attribute = None
# Check for filter expression
if self.current_token.type == TokenType.LBRACKET:
self._consume(TokenType.LBRACKET)
filter_expr = self._parse_filter_expression()
self._consume(TokenType.RBRACKET)
# Check for sub-attribute
if self.current_token.type == TokenType.DOT:
self._consume(TokenType.DOT)
if self.current_token.type == TokenType.ATTRIBUTE:
sub_attribute = self.current_token.value
self._consume(TokenType.ATTRIBUTE)
return {"attribute": attribute, "filter": filter_expr, "sub_attribute": sub_attribute}
def _parse_filter_expression(self) -> dict[str, Any] | None:
"""Parse a filter expression like 'primary eq true' or
'type eq "work" and primary eq true'"""
return self._parse_or_expression()
def _parse_or_expression(self) -> dict[str, Any] | None:
"""Parse OR expressions"""
left = self._parse_and_expression()
while self.current_token.type == TokenType.OR:
self._consume(TokenType.OR)
right = self._parse_and_expression()
left = {"type": "logical", "operator": "or", "left": left, "right": right}
return left
def _parse_and_expression(self) -> dict[str, Any] | None:
"""Parse AND expressions"""
left = self._parse_primary_expression()
while self.current_token.type == TokenType.AND:
self._consume(TokenType.AND)
right = self._parse_primary_expression()
left = {"type": "logical", "operator": "and", "left": left, "right": right}
return left
def _parse_primary_expression(self) -> dict[str, Any] | None:
"""Parse primary expressions (attribute operator value)"""
if self.current_token.type == TokenType.LPAREN:
self._consume(TokenType.LPAREN)
expr = self._parse_or_expression()
self._consume(TokenType.RPAREN)
return expr
if self.current_token.type == TokenType.NOT:
self._consume(TokenType.NOT)
expr = self._parse_primary_expression()
return {"type": "logical", "operator": "not", "operand": expr}
if self.current_token.type != TokenType.ATTRIBUTE:
return None
attribute = self.current_token.value
self._consume(TokenType.ATTRIBUTE)
if self.current_token.type != TokenType.OPERATOR:
return None
operator = self.current_token.value
self._consume(TokenType.OPERATOR)
# Parse value
value = None
if self.current_token.type == TokenType.STRING:
value = self.current_token.value
self._consume(TokenType.STRING)
elif self.current_token.type == TokenType.NUMBER:
value = (
float(self.current_token.value)
if "." in self.current_token.value
else int(self.current_token.value)
)
self._consume(TokenType.NUMBER)
elif self.current_token.type == TokenType.BOOLEAN:
value = self.current_token.value
self._consume(TokenType.BOOLEAN)
elif self.current_token.type == TokenType.NULL:
value = None
self._consume(TokenType.NULL)
return {"type": "comparison", "attribute": attribute, "operator": operator, "value": value}
def _consume(self, expected_type: TokenType):
"""Consume a token of the expected type"""
if self.current_token.type == expected_type:
self.current_token = self.lexer.get_next_token()
else:
raise ValueError(f"Expected {expected_type}, got {self.current_token.type}")

View File

@@ -1,246 +0,0 @@
from typing import Any
from authentik.providers.scim.clients.schema import PatchOp, PatchOperation
from authentik.sources.scim.constants import SCIM_URN_USER_ENTERPRISE
from authentik.sources.scim.patch.parser import SCIMPathParser
class SCIMPatchProcessor:
"""Processes SCIM patch operations on Python dictionaries"""
def __init__(self):
self.parser = SCIMPathParser()
def apply_patches(self, data: dict[str, Any], patches: list[PatchOperation]) -> dict[str, Any]:
"""Apply a list of patch operations to the data"""
result = data.copy()
for _patch in patches:
patch = PatchOperation.model_validate(_patch)
if patch.path is None:
# Handle operations with no path - value contains attribute paths as keys
self._apply_bulk_operation(result, patch.op, patch.value)
elif patch.op == PatchOp.add:
self._apply_add(result, patch.path, patch.value)
elif patch.op == PatchOp.remove:
self._apply_remove(result, patch.path)
elif patch.op == PatchOp.replace:
self._apply_replace(result, patch.path, patch.value)
return result
def _apply_bulk_operation(
self, data: dict[str, Any], operation: PatchOp, value: dict[str, Any]
):
"""Apply bulk operations when path is None"""
if not isinstance(value, dict):
return
for path, val in value.items():
if operation == PatchOp.add:
self._apply_add(data, path, val)
elif operation == PatchOp.remove:
self._apply_remove(data, path)
elif operation == PatchOp.replace:
self._apply_replace(data, path, val)
def _apply_add(self, data: dict[str, Any], path: str, value: Any):
"""Apply ADD operation"""
components = self.parser.parse_path(path)
if len(components) == 1 and not components[0]["filter"]:
# Simple path
attr = components[0]["attribute"]
if components[0]["sub_attribute"]:
if attr not in data:
data[attr] = {}
# Somewhat hacky workaround for the manager attribute of the enterprise schema
# ideally we'd do this based on the schema
if attr == SCIM_URN_USER_ENTERPRISE and components[0]["sub_attribute"] == "manager":
data[attr][components[0]["sub_attribute"]] = {"value": value}
else:
data[attr][components[0]["sub_attribute"]] = value
elif attr in data:
data[attr].append(value)
else:
data[attr] = value
else:
# Complex path with filters
self._navigate_and_modify(data, components, value, "add")
def _apply_remove(self, data: dict[str, Any], path: str):
"""Apply REMOVE operation"""
components = self.parser.parse_path(path)
if len(components) == 1 and not components[0]["filter"]:
# Simple path
attr = components[0]["attribute"]
if components[0]["sub_attribute"]:
if attr in data and isinstance(data[attr], dict):
data[attr].pop(components[0]["sub_attribute"], None)
else:
data.pop(attr, None)
else:
# Complex path with filters
self._navigate_and_modify(data, components, None, "remove")
def _apply_replace(self, data: dict[str, Any], path: str, value: Any):
"""Apply REPLACE operation"""
components = self.parser.parse_path(path)
if len(components) == 1 and not components[0]["filter"]:
# Simple path
attr = components[0]["attribute"]
if components[0]["sub_attribute"]:
if attr not in data:
data[attr] = {}
# Somewhat hacky workaround for the manager attribute of the enterprise schema
# ideally we'd do this based on the schema
if attr == SCIM_URN_USER_ENTERPRISE and components[0]["sub_attribute"] == "manager":
data[attr][components[0]["sub_attribute"]] = {"value": value}
else:
data[attr][components[0]["sub_attribute"]] = value
else:
data[attr] = value
else:
# Complex path with filters
self._navigate_and_modify(data, components, value, "replace")
def _navigate_and_modify( # noqa PLR0912
self, data: dict[str, Any], components: list[dict[str, Any]], value: Any, operation: str
):
"""Navigate through complex paths and apply modifications"""
current = data
for i, component in enumerate(components):
attr = component["attribute"]
filter_expr = component["filter"]
sub_attr = component["sub_attribute"]
if filter_expr:
# Handle array with filter
if attr not in current:
if operation == "add":
current[attr] = []
else:
return
if not isinstance(current[attr], list):
return
# Find matching items
matching_items = []
for item in current[attr]:
if self._matches_filter(item, filter_expr):
matching_items.append(item)
if not matching_items and operation == "add":
# Create new item if none match (only for simple comparison filters)
if filter_expr.get("type", "comparison") == "comparison":
new_item = {filter_expr["attribute"]: filter_expr["value"]}
current[attr].append(new_item)
matching_items = [new_item]
# Apply operation to matching items
for item in matching_items:
if sub_attr:
if operation in {"add", "replace"}:
item[sub_attr] = value
elif operation == "remove":
item.pop(sub_attr, None)
elif operation in {"add", "replace"}:
if isinstance(value, dict):
item.update(value)
else:
# If value is not a dict, we can't merge it
pass
elif operation == "remove":
# Remove the entire item
if item in current[attr]:
current[attr].remove(item)
# Handle simple attribute
elif i == len(components) - 1:
# Last component
if sub_attr:
if attr not in current:
current[attr] = {}
if operation in {"add", "replace"}:
current[attr][sub_attr] = value
elif operation == "remove":
current[attr].pop(sub_attr, None)
elif operation in {"add", "replace"}:
current[attr] = value
elif operation == "remove":
current.pop(attr, None)
else:
# Navigate deeper
if attr not in current:
current[attr] = {}
current = current[attr]
def _matches_filter(self, item: dict[str, Any], filter_expr: dict[str, Any]) -> bool:
"""Check if an item matches the filter expression"""
if not filter_expr:
return True
filter_type = filter_expr.get("type", "comparison")
if filter_type == "comparison":
return self._matches_comparison(item, filter_expr)
elif filter_type == "logical":
return self._matches_logical(item, filter_expr)
return False
def _matches_comparison( # noqa PLR0912
self, item: dict[str, Any], filter_expr: dict[str, Any]
) -> bool:
"""Check if an item matches a comparison filter"""
attr = filter_expr["attribute"]
operator = filter_expr["operator"]
expected_value = filter_expr["value"]
if attr not in item:
return False
actual_value = item[attr]
if operator == "eq":
return actual_value == expected_value
elif operator == "ne":
return actual_value != expected_value
elif operator == "co":
return str(expected_value) in str(actual_value)
elif operator == "sw":
return str(actual_value).startswith(str(expected_value))
elif operator == "ew":
return str(actual_value).endswith(str(expected_value))
elif operator == "gt":
return actual_value > expected_value
elif operator == "lt":
return actual_value < expected_value
elif operator == "ge":
return actual_value >= expected_value
elif operator == "le":
return actual_value <= expected_value
elif operator == "pr":
return actual_value is not None
return False
def _matches_logical(self, item: dict[str, Any], filter_expr: dict[str, Any]) -> bool:
"""Check if an item matches a logical filter expression"""
operator = filter_expr["operator"]
if operator == "and":
left_result = self._matches_filter(item, filter_expr["left"])
right_result = self._matches_filter(item, filter_expr["right"])
return left_result and right_result
elif operator == "or":
left_result = self._matches_filter(item, filter_expr["left"])
right_result = self._matches_filter(item, filter_expr["right"])
return left_result or right_result
elif operator == "not":
operand_result = self._matches_filter(item, filter_expr["operand"])
return not operand_result
return False

View File

@@ -1101,6 +1101,17 @@
"returned": "default",
"uniqueness": "none"
},
{
"name": "password",
"type": "string",
"multiValued": false,
"description": "The User's cleartext password. This attribute is intended to be used as a means to specify an initial\npassword when creating a new User or to reset an existing User's password.",
"required": false,
"caseExact": false,
"mutability": "writeOnly",
"returned": "never",
"uniqueness": "none"
},
{
"name": "emails",
"type": "complex",

View File

@@ -75,9 +75,7 @@ class TestSCIMGroups(APITestCase):
HTTP_AUTHORIZATION=f"Bearer {self.source.token.key}",
)
self.assertEqual(response.status_code, 201)
self.assertTrue(
SCIMSourceGroup.objects.filter(source=self.source, external_id=ext_id).exists()
)
self.assertTrue(SCIMSourceGroup.objects.filter(source=self.source, id=ext_id).exists())
self.assertTrue(
Event.objects.filter(
action=EventAction.MODEL_CREATED, user__username=self.source.token.user.username
@@ -88,7 +86,6 @@ class TestSCIMGroups(APITestCase):
"""Test group create"""
user = create_test_user()
ext_id = generate_id()
name = generate_id()
response = self.client.post(
reverse(
"authentik_sources_scim:v2-groups",
@@ -98,7 +95,7 @@ class TestSCIMGroups(APITestCase):
),
data=dumps(
{
"displayName": name,
"displayName": generate_id(),
"externalId": ext_id,
"members": [{"value": str(user.uuid)}],
}
@@ -107,22 +104,12 @@ class TestSCIMGroups(APITestCase):
HTTP_AUTHORIZATION=f"Bearer {self.source.token.key}",
)
self.assertEqual(response.status_code, 201)
connection = SCIMSourceGroup.objects.filter(source=self.source, external_id=ext_id).first()
self.assertIsNotNone(connection)
self.assertTrue(SCIMSourceGroup.objects.filter(source=self.source, id=ext_id).exists())
self.assertTrue(
Event.objects.filter(
action=EventAction.MODEL_CREATED, user__username=self.source.token.user.username
).exists()
)
connection.refresh_from_db()
self.assertEqual(
connection.attributes,
{
"displayName": name,
"externalId": ext_id,
"members": [{"value": str(user.uuid)}],
},
)
def test_group_create_members_empty(self):
"""Test group create"""
@@ -139,9 +126,7 @@ class TestSCIMGroups(APITestCase):
HTTP_AUTHORIZATION=f"Bearer {self.source.token.key}",
)
self.assertEqual(response.status_code, 201)
self.assertTrue(
SCIMSourceGroup.objects.filter(source=self.source, external_id=ext_id).exists()
)
self.assertTrue(SCIMSourceGroup.objects.filter(source=self.source, id=ext_id).exists())
self.assertTrue(
Event.objects.filter(
action=EventAction.MODEL_CREATED, user__username=self.source.token.user.username
@@ -151,9 +136,7 @@ class TestSCIMGroups(APITestCase):
def test_group_create_duplicate(self):
"""Test group create (duplicate)"""
group = Group.objects.create(name=generate_id())
existing = SCIMSourceGroup.objects.create(
source=self.source, group=group, external_id=uuid4()
)
existing = SCIMSourceGroup.objects.create(source=self.source, group=group, id=uuid4())
ext_id = generate_id()
response = self.client.post(
reverse(
@@ -182,9 +165,7 @@ class TestSCIMGroups(APITestCase):
def test_group_update(self):
"""Test group update"""
group = Group.objects.create(name=generate_id())
existing = SCIMSourceGroup.objects.create(
source=self.source, group=group, external_id=uuid4()
)
existing = SCIMSourceGroup.objects.create(source=self.source, group=group, id=uuid4())
ext_id = generate_id()
response = self.client.put(
reverse(
@@ -224,49 +205,12 @@ class TestSCIMGroups(APITestCase):
},
)
def test_group_patch_modify(self):
"""Test group patch"""
group = Group.objects.create(name=generate_id())
connection = SCIMSourceGroup.objects.create(
source=self.source,
group=group,
external_id=uuid4(),
attributes={"displayName": group.name, "members": []},
)
response = self.client.patch(
reverse(
"authentik_sources_scim:v2-groups",
kwargs={"source_slug": self.source.slug, "group_id": group.pk},
),
data=dumps(
{
"Operations": [
{
"op": "Add",
"value": {"externalId": "d85051cb-0557-4aa1-98ca-51eabcee4d40"},
}
]
}
),
content_type=SCIM_CONTENT_TYPE,
HTTP_AUTHORIZATION=f"Bearer {self.source.token.key}",
)
self.assertEqual(response.status_code, 200, response.content)
connection = SCIMSourceGroup.objects.filter(id="d85051cb-0557-4aa1-98ca-51eabcee4d40")
self.assertIsNotNone(connection)
def test_group_patch_member_add(self):
def test_group_patch_add(self):
"""Test group patch"""
user = create_test_user()
other_user = create_test_user()
group = Group.objects.create(name=generate_id())
group.users.add(other_user)
connection = SCIMSourceGroup.objects.create(
source=self.source,
group=group,
external_id=uuid4(),
attributes={"displayName": group.name, "members": [{"value": str(other_user.uuid)}]},
)
SCIMSourceGroup.objects.create(source=self.source, group=group, id=uuid4())
response = self.client.patch(
reverse(
"authentik_sources_scim:v2-groups",
@@ -278,7 +222,7 @@ class TestSCIMGroups(APITestCase):
{
"op": "Add",
"path": "members",
"value": [{"value": str(user.uuid)}],
"value": {"value": str(user.uuid)},
}
]
}
@@ -286,33 +230,16 @@ class TestSCIMGroups(APITestCase):
content_type=SCIM_CONTENT_TYPE,
HTTP_AUTHORIZATION=f"Bearer {self.source.token.key}",
)
self.assertEqual(response.status_code, 200, response.content)
self.assertEqual(response.status_code, second=200)
self.assertTrue(group.users.filter(pk=user.pk).exists())
self.assertTrue(group.users.filter(pk=other_user.pk).exists())
connection.refresh_from_db()
self.assertEqual(
connection.attributes,
{
"displayName": group.name,
"members": sorted(
[{"value": str(other_user.uuid)}, {"value": str(user.uuid)}],
key=lambda u: u["value"],
),
},
)
def test_group_patch_member_remove(self):
def test_group_patch_remove(self):
"""Test group patch"""
user = create_test_user()
group = Group.objects.create(name=generate_id())
group.users.add(user)
connection = SCIMSourceGroup.objects.create(
source=self.source,
group=group,
external_id=uuid4(),
attributes={"displayName": group.name, "members": []},
)
SCIMSourceGroup.objects.create(source=self.source, group=group, id=uuid4())
response = self.client.patch(
reverse(
"authentik_sources_scim:v2-groups",
@@ -324,7 +251,7 @@ class TestSCIMGroups(APITestCase):
{
"op": "remove",
"path": "members",
"value": [{"value": str(user.uuid)}],
"value": {"value": str(user.uuid)},
}
]
}
@@ -332,21 +259,13 @@ class TestSCIMGroups(APITestCase):
content_type=SCIM_CONTENT_TYPE,
HTTP_AUTHORIZATION=f"Bearer {self.source.token.key}",
)
self.assertEqual(response.status_code, 200, response.content)
self.assertEqual(response.status_code, second=200)
self.assertFalse(group.users.filter(pk=user.pk).exists())
connection.refresh_from_db()
self.assertEqual(
connection.attributes,
{
"displayName": group.name,
"members": [],
},
)
def test_group_delete(self):
"""Test group delete"""
group = Group.objects.create(name=generate_id())
SCIMSourceGroup.objects.create(source=self.source, group=group, external_id=uuid4())
SCIMSourceGroup.objects.create(source=self.source, group=group, id=uuid4())
response = self.client.delete(
reverse(
"authentik_sources_scim:v2-groups",

View File

@@ -1,510 +0,0 @@
from unittest import TestCase
from authentik.sources.scim.constants import (
SCIM_URN_GROUP,
SCIM_URN_SCHEMA,
SCIM_URN_USER,
SCIM_URN_USER_ENTERPRISE,
)
from authentik.sources.scim.patch.lexer import SCIMPathLexer, Token, TokenType
class TestTokenType(TestCase):
"""Test TokenType enum"""
def test_token_type_values(self):
"""Test that all token types have correct values"""
self.assertEqual(TokenType.ATTRIBUTE.value, "ATTRIBUTE")
self.assertEqual(TokenType.DOT.value, "DOT")
self.assertEqual(TokenType.LBRACKET.value, "LBRACKET")
self.assertEqual(TokenType.RBRACKET.value, "RBRACKET")
self.assertEqual(TokenType.LPAREN.value, "LPAREN")
self.assertEqual(TokenType.RPAREN.value, "RPAREN")
self.assertEqual(TokenType.STRING.value, "STRING")
self.assertEqual(TokenType.NUMBER.value, "NUMBER")
self.assertEqual(TokenType.BOOLEAN.value, "BOOLEAN")
self.assertEqual(TokenType.NULL.value, "NULL")
self.assertEqual(TokenType.OPERATOR.value, "OPERATOR")
self.assertEqual(TokenType.AND.value, "AND")
self.assertEqual(TokenType.OR.value, "OR")
self.assertEqual(TokenType.NOT.value, "NOT")
self.assertEqual(TokenType.EOF.value, "EOF")
class TestToken(TestCase):
"""Test Token dataclass"""
def test_token_creation(self):
"""Test token creation with all parameters"""
token = Token(TokenType.ATTRIBUTE, "userName", 5)
self.assertEqual(token.type, TokenType.ATTRIBUTE)
self.assertEqual(token.value, "userName")
self.assertEqual(token.position, 5)
def test_token_creation_default_position(self):
"""Test token creation with default position"""
token = Token(TokenType.DOT, ".")
self.assertEqual(token.type, TokenType.DOT)
self.assertEqual(token.value, ".")
self.assertEqual(token.position, 0)
class TestSCIMPathLexer(TestCase):
"""Test SCIMPathLexer class"""
def setUp(self):
"""Set up test fixtures"""
self.simple_lexer = SCIMPathLexer("userName")
def test_init(self):
"""Test lexer initialization"""
lexer = SCIMPathLexer("test")
self.assertEqual(lexer.text, "test")
self.assertEqual(lexer.pos, 0)
self.assertEqual(lexer.current_char, "t")
self.assertIn(SCIM_URN_SCHEMA, lexer.schema_urns)
self.assertIn(SCIM_URN_GROUP, lexer.schema_urns)
self.assertIn(SCIM_URN_USER, lexer.schema_urns)
self.assertIn(SCIM_URN_USER_ENTERPRISE, lexer.schema_urns)
self.assertEqual(
lexer.OPERATORS, ["eq", "ne", "co", "sw", "ew", "gt", "lt", "ge", "le", "pr"]
)
def test_init_empty_string(self):
"""Test lexer initialization with empty string"""
lexer = SCIMPathLexer("")
self.assertEqual(lexer.text, "")
self.assertEqual(lexer.pos, 0)
self.assertIsNone(lexer.current_char)
def test_advance(self):
"""Test advance method"""
lexer = SCIMPathLexer("abc")
self.assertEqual(lexer.current_char, "a")
lexer.advance()
self.assertEqual(lexer.pos, 1)
self.assertEqual(lexer.current_char, "b")
lexer.advance()
self.assertEqual(lexer.pos, 2)
self.assertEqual(lexer.current_char, "c")
lexer.advance()
self.assertEqual(lexer.pos, 3)
self.assertIsNone(lexer.current_char)
def test_skip_whitespace(self):
"""Test skip_whitespace method"""
lexer = SCIMPathLexer(" \t\n abc")
lexer.skip_whitespace()
self.assertEqual(lexer.current_char, "a")
def test_skip_whitespace_only_whitespace(self):
"""Test skip_whitespace with only whitespace"""
lexer = SCIMPathLexer(" \t\n ")
lexer.skip_whitespace()
self.assertIsNone(lexer.current_char)
def test_skip_whitespace_no_whitespace(self):
"""Test skip_whitespace with no leading whitespace"""
lexer = SCIMPathLexer("abc")
original_pos = lexer.pos
lexer.skip_whitespace()
self.assertEqual(lexer.pos, original_pos)
self.assertEqual(lexer.current_char, "a")
def test_read_string_double_quotes(self):
"""Test reading double-quoted string"""
lexer = SCIMPathLexer('"hello world"')
result = lexer.read_string('"')
self.assertEqual(result, "hello world")
self.assertIsNone(lexer.current_char) # Should be at end
def test_read_string_single_quotes(self):
"""Test reading single-quoted string"""
lexer = SCIMPathLexer("'hello world'")
result = lexer.read_string("'")
self.assertEqual(result, "hello world")
self.assertIsNone(lexer.current_char)
def test_read_string_with_escapes(self):
"""Test reading string with escape characters"""
lexer = SCIMPathLexer('"hello \\"world\\""')
result = lexer.read_string('"')
self.assertEqual(result, 'hello "world"')
def test_read_string_with_backslash_at_end(self):
"""Test reading string with backslash at end"""
lexer = SCIMPathLexer('"hello\\"')
result = lexer.read_string('"')
self.assertEqual(result, 'hello"')
def test_read_string_unclosed(self):
"""Test reading unclosed string"""
lexer = SCIMPathLexer('"hello world')
result = lexer.read_string('"')
self.assertEqual(result, "hello world")
self.assertIsNone(lexer.current_char)
def test_read_string_empty(self):
"""Test reading empty string"""
lexer = SCIMPathLexer('""')
result = lexer.read_string('"')
self.assertEqual(result, "")
def test_read_number_integer(self):
"""Test reading integer number"""
lexer = SCIMPathLexer("123")
result = lexer.read_number()
self.assertEqual(result, "123")
self.assertIsNone(lexer.current_char)
def test_read_number_float(self):
"""Test reading float number"""
lexer = SCIMPathLexer("123.456")
result = lexer.read_number()
self.assertEqual(result, "123.456")
self.assertIsNone(lexer.current_char)
def test_read_number_with_multiple_dots(self):
"""Test reading number with multiple dots (invalid but handled)"""
lexer = SCIMPathLexer("123.456.789")
result = lexer.read_number()
self.assertEqual(result, "123.456.789")
self.assertIsNone(lexer.current_char)
def test_read_number_starting_with_dot(self):
"""Test reading number starting with dot"""
lexer = SCIMPathLexer(".123")
result = lexer.read_number()
self.assertEqual(result, ".123")
def test_read_identifier_simple(self):
"""Test reading simple identifier"""
lexer = SCIMPathLexer("userName")
result = lexer.read_identifier()
self.assertEqual(result, "userName")
self.assertIsNone(lexer.current_char)
def test_read_identifier_with_underscore(self):
"""Test reading identifier with underscore"""
lexer = SCIMPathLexer("user_name")
result = lexer.read_identifier()
self.assertEqual(result, "user_name")
def test_read_identifier_with_hyphen(self):
"""Test reading identifier with hyphen"""
lexer = SCIMPathLexer("user-name")
result = lexer.read_identifier()
self.assertEqual(result, "user-name")
def test_read_identifier_with_colon(self):
"""Test reading identifier with colon (URN format)"""
lexer = SCIMPathLexer("urn:ietf:params:scim:schemas:core:2.0:User")
result = lexer.read_identifier()
self.assertEqual(result, "urn:ietf:params:scim:schemas:core:2.0:User")
def test_read_identifier_schema_urn(self):
"""Test reading schema URN identifier"""
lexer = SCIMPathLexer(f"{SCIM_URN_USER}.userName")
result = lexer.read_identifier()
self.assertEqual(result, SCIM_URN_USER)
self.assertEqual(lexer.current_char, ".") # Should stop at dot and set current_char to dot
def test_read_identifier_with_version_number(self):
"""Test reading identifier with version number (dots followed by digits)"""
lexer = SCIMPathLexer("urn:ietf:params:scim:schemas:core:2.0:User")
result = lexer.read_identifier()
self.assertEqual(result, "urn:ietf:params:scim:schemas:core:2.0:User")
def test_read_identifier_partial_urn_match(self):
"""Test reading identifier that partially matches URN"""
lexer = SCIMPathLexer("urn:ietf:params:scim:schemas:core:2.0:CustomUser")
result = lexer.read_identifier()
self.assertEqual(result, "urn:ietf:params:scim:schemas:core:2.0:CustomUser")
# Test get_next_token method
def test_get_next_token_dot(self):
"""Test tokenizing dot"""
lexer = SCIMPathLexer(".")
token = lexer.get_next_token()
self.assertEqual(token.type, TokenType.DOT)
self.assertEqual(token.value, ".")
def test_get_next_token_lbracket(self):
"""Test tokenizing left bracket"""
lexer = SCIMPathLexer("[")
token = lexer.get_next_token()
self.assertEqual(token.type, TokenType.LBRACKET)
self.assertEqual(token.value, "[")
def test_get_next_token_rbracket(self):
"""Test tokenizing right bracket"""
lexer = SCIMPathLexer("]")
token = lexer.get_next_token()
self.assertEqual(token.type, TokenType.RBRACKET)
self.assertEqual(token.value, "]")
def test_get_next_token_lparen(self):
"""Test tokenizing left parenthesis"""
lexer = SCIMPathLexer("(")
token = lexer.get_next_token()
self.assertEqual(token.type, TokenType.LPAREN)
self.assertEqual(token.value, "(")
def test_get_next_token_rparen(self):
"""Test tokenizing right parenthesis"""
lexer = SCIMPathLexer(")")
token = lexer.get_next_token()
self.assertEqual(token.type, TokenType.RPAREN)
self.assertEqual(token.value, ")")
def test_get_next_token_string_double_quotes(self):
"""Test tokenizing double-quoted string"""
lexer = SCIMPathLexer('"test string"')
token = lexer.get_next_token()
self.assertEqual(token.type, TokenType.STRING)
self.assertEqual(token.value, "test string")
def test_get_next_token_string_single_quotes(self):
"""Test tokenizing single-quoted string"""
lexer = SCIMPathLexer("'test string'")
token = lexer.get_next_token()
self.assertEqual(token.type, TokenType.STRING)
self.assertEqual(token.value, "test string")
def test_get_next_token_number_integer(self):
"""Test tokenizing integer"""
lexer = SCIMPathLexer("123")
token = lexer.get_next_token()
self.assertEqual(token.type, TokenType.NUMBER)
self.assertEqual(token.value, "123")
def test_get_next_token_number_float(self):
"""Test tokenizing float"""
lexer = SCIMPathLexer("123.45")
token = lexer.get_next_token()
self.assertEqual(token.type, TokenType.NUMBER)
self.assertEqual(token.value, "123.45")
def test_get_next_token_boolean_true(self):
"""Test tokenizing boolean true"""
lexer = SCIMPathLexer("true")
token = lexer.get_next_token()
self.assertEqual(token.type, TokenType.BOOLEAN)
self.assertTrue(token.value)
def test_get_next_token_boolean_false(self):
"""Test tokenizing boolean false"""
lexer = SCIMPathLexer("false")
token = lexer.get_next_token()
self.assertEqual(token.type, TokenType.BOOLEAN)
self.assertFalse(token.value)
def test_get_next_token_boolean_case_insensitive(self):
"""Test tokenizing boolean with different cases"""
for value in ["TRUE", "True", "FALSE", "False"]:
with self.subTest(value=value):
lexer = SCIMPathLexer(value)
token = lexer.get_next_token()
self.assertEqual(token.type, TokenType.BOOLEAN)
def test_get_next_token_null(self):
"""Test tokenizing null"""
lexer = SCIMPathLexer("null")
token = lexer.get_next_token()
self.assertEqual(token.type, TokenType.NULL)
self.assertIsNone(token.value)
def test_get_next_token_null_case_insensitive(self):
"""Test tokenizing null with different cases"""
for value in ["NULL", "Null"]:
with self.subTest(value=value):
lexer = SCIMPathLexer(value)
token = lexer.get_next_token()
self.assertEqual(token.type, TokenType.NULL)
def test_get_next_token_and(self):
"""Test tokenizing AND operator"""
lexer = SCIMPathLexer("and")
token = lexer.get_next_token()
self.assertEqual(token.type, TokenType.AND)
self.assertEqual(token.value, "and")
def test_get_next_token_or(self):
"""Test tokenizing OR operator"""
lexer = SCIMPathLexer("or")
token = lexer.get_next_token()
self.assertEqual(token.type, TokenType.OR)
self.assertEqual(token.value, "or")
def test_get_next_token_not(self):
"""Test tokenizing NOT operator"""
lexer = SCIMPathLexer("not")
token = lexer.get_next_token()
self.assertEqual(token.type, TokenType.NOT)
self.assertEqual(token.value, "not")
def test_get_next_token_operators(self):
"""Test tokenizing all comparison operators"""
operators = ["eq", "ne", "co", "sw", "ew", "gt", "lt", "ge", "le", "pr"]
for op in operators:
with self.subTest(operator=op):
lexer = SCIMPathLexer(op)
token = lexer.get_next_token()
self.assertEqual(token.type, TokenType.OPERATOR)
self.assertEqual(token.value, op)
def test_get_next_token_operators_case_insensitive(self):
"""Test tokenizing operators with different cases"""
for op in ["EQ", "Eq", "NE", "Ne"]:
with self.subTest(operator=op):
lexer = SCIMPathLexer(op)
token = lexer.get_next_token()
self.assertEqual(token.type, TokenType.OPERATOR)
self.assertEqual(token.value, op.lower())
def test_get_next_token_attribute(self):
"""Test tokenizing attribute name"""
lexer = SCIMPathLexer("userName")
token = lexer.get_next_token()
self.assertEqual(token.type, TokenType.ATTRIBUTE)
self.assertEqual(token.value, "userName")
def test_get_next_token_attribute_with_underscore(self):
"""Test tokenizing attribute name with underscore"""
lexer = SCIMPathLexer("_userName")
token = lexer.get_next_token()
self.assertEqual(token.type, TokenType.ATTRIBUTE)
self.assertEqual(token.value, "_userName")
def test_get_next_token_eof(self):
"""Test tokenizing end of file"""
lexer = SCIMPathLexer("")
token = lexer.get_next_token()
self.assertEqual(token.type, TokenType.EOF)
self.assertEqual(token.value, "")
def test_get_next_token_with_whitespace(self):
"""Test tokenizing with leading whitespace"""
lexer = SCIMPathLexer(" userName")
token = lexer.get_next_token()
self.assertEqual(token.type, TokenType.ATTRIBUTE)
self.assertEqual(token.value, "userName")
def test_get_next_token_skip_unknown_characters(self):
"""Test that unknown characters are skipped"""
lexer = SCIMPathLexer("@#$userName")
token = lexer.get_next_token()
self.assertEqual(token.type, TokenType.ATTRIBUTE)
self.assertEqual(token.value, "userName")
def test_get_next_token_multiple_tokens(self):
"""Test tokenizing multiple tokens in sequence"""
lexer = SCIMPathLexer("userName.givenName")
token1 = lexer.get_next_token()
self.assertEqual(token1.type, TokenType.ATTRIBUTE)
self.assertEqual(token1.value, "userName")
token2 = lexer.get_next_token()
self.assertEqual(token2.type, TokenType.DOT)
self.assertEqual(token2.value, ".")
token3 = lexer.get_next_token()
self.assertEqual(token3.type, TokenType.ATTRIBUTE)
self.assertEqual(token3.value, "givenName")
token4 = lexer.get_next_token()
self.assertEqual(token4.type, TokenType.EOF)
def test_get_next_token_complex_filter(self):
"""Test tokenizing complex filter expression"""
lexer = SCIMPathLexer('emails[type eq "work" and primary eq true]')
tokens = []
while True:
token = lexer.get_next_token()
tokens.append(token)
if token.type == TokenType.EOF:
break
expected_types = [
TokenType.ATTRIBUTE, # emails
TokenType.LBRACKET, # [
TokenType.ATTRIBUTE, # type
TokenType.OPERATOR, # eq
TokenType.STRING, # "work"
TokenType.AND, # and
TokenType.ATTRIBUTE, # primary
TokenType.OPERATOR, # eq
TokenType.BOOLEAN, # true
TokenType.RBRACKET, # ]
TokenType.EOF,
]
self.assertEqual(len(tokens), len(expected_types))
for token, expected_type in zip(tokens, expected_types, strict=False):
self.assertEqual(token.type, expected_type)
def test_get_next_token_urn_attribute(self):
"""Test tokenizing URN-based attribute"""
lexer = SCIMPathLexer(f"{SCIM_URN_USER}.userName")
token1 = lexer.get_next_token()
self.assertEqual(token1.type, TokenType.ATTRIBUTE)
self.assertEqual(token1.value, SCIM_URN_USER)
token2 = lexer.get_next_token()
self.assertEqual(token2.type, TokenType.DOT)
token3 = lexer.get_next_token()
self.assertEqual(token3.type, TokenType.ATTRIBUTE)
self.assertEqual(token3.value, "userName")
def test_get_next_token_enterprise_urn(self):
"""Test tokenizing enterprise URN"""
lexer = SCIMPathLexer(f"{SCIM_URN_USER_ENTERPRISE}.manager")
token1 = lexer.get_next_token()
self.assertEqual(token1.type, TokenType.ATTRIBUTE)
self.assertEqual(token1.value, SCIM_URN_USER_ENTERPRISE)
token2 = lexer.get_next_token()
self.assertEqual(token2.type, TokenType.DOT)
def test_lexer_state_after_eof(self):
"""Test lexer state after reaching EOF"""
lexer = SCIMPathLexer("a")
# Get first token
token1 = lexer.get_next_token()
self.assertEqual(token1.type, TokenType.ATTRIBUTE)
# Get EOF token
token2 = lexer.get_next_token()
self.assertEqual(token2.type, TokenType.EOF)
# Should continue returning EOF
token3 = lexer.get_next_token()
self.assertEqual(token3.type, TokenType.EOF)
def test_read_identifier_edge_cases(self):
"""Test read_identifier with edge cases"""
# Test identifier ending with colon
lexer = SCIMPathLexer("test:")
result = lexer.read_identifier()
self.assertEqual(result, "test:")
# Test identifier with numbers
lexer = SCIMPathLexer("test123")
result = lexer.read_identifier()
self.assertEqual(result, "test123")
def test_complex_urn_parsing(self):
"""Test parsing complex URN with version numbers"""
urn = "urn:ietf:params:scim:schemas:extension:enterprise:2.0:User"
lexer = SCIMPathLexer(urn)
result = lexer.read_identifier()
self.assertEqual(result, urn)

File diff suppressed because it is too large Load Diff

View File

@@ -10,7 +10,6 @@ from authentik.core.tests.utils import create_test_user
from authentik.events.models import Event, EventAction
from authentik.lib.generators import generate_id
from authentik.providers.scim.clients.schema import User as SCIMUserSchema
from authentik.sources.scim.constants import SCIM_URN_USER_ENTERPRISE
from authentik.sources.scim.models import SCIMSource, SCIMSourcePropertyMapping, SCIMSourceUser
from authentik.sources.scim.views.v2.base import SCIM_CONTENT_TYPE
@@ -82,9 +81,7 @@ class TestSCIMUsers(APITestCase):
HTTP_AUTHORIZATION=f"Bearer {self.source.token.key}",
)
self.assertEqual(response.status_code, 201)
self.assertTrue(
SCIMSourceUser.objects.filter(source=self.source, external_id=ext_id).exists()
)
self.assertTrue(SCIMSourceUser.objects.filter(source=self.source, id=ext_id).exists())
self.assertTrue(
Event.objects.filter(
action=EventAction.MODEL_CREATED, user__username=self.source.token.user.username
@@ -177,16 +174,14 @@ class TestSCIMUsers(APITestCase):
)
self.assertEqual(response.status_code, 201)
self.assertEqual(
SCIMSourceUser.objects.get(source=self.source, external_id=ext_id).user.attributes[
"phone"
],
SCIMSourceUser.objects.get(source=self.source, id=ext_id).user.attributes["phone"],
"0123456789",
)
def test_user_update(self):
"""Test user update"""
user = create_test_user()
existing = SCIMSourceUser.objects.create(source=self.source, user=user, external_id=uuid4())
existing = SCIMSourceUser.objects.create(source=self.source, user=user, id=uuid4())
ext_id = generate_id()
response = self.client.put(
reverse(
@@ -214,51 +209,10 @@ class TestSCIMUsers(APITestCase):
)
self.assertEqual(response.status_code, 200)
def test_user_update_patch(self):
"""Test user update (patch)"""
user = create_test_user()
existing = SCIMSourceUser.objects.create(
source=self.source,
user=user,
external_id=uuid4(),
attributes={
"userName": generate_id(),
},
)
response = self.client.patch(
reverse(
"authentik_sources_scim:v2-users",
kwargs={
"source_slug": self.source.slug,
"user_id": str(user.uuid),
},
),
data=dumps(
{
"schemas": ["urn:ietf:params:scim:api:messages:2.0:PatchOp"],
"Operations": [
{
"op": "Add",
"path": f"{SCIM_URN_USER_ENTERPRISE}:manager",
"value": "86b2ed3e-30cd-4881-bb58-c4e910821339",
}
],
}
),
content_type=SCIM_CONTENT_TYPE,
HTTP_AUTHORIZATION=f"Bearer {self.source.token.key}",
)
self.assertEqual(response.status_code, 200)
existing.refresh_from_db()
self.assertEqual(
existing.attributes[SCIM_URN_USER_ENTERPRISE],
{"manager": {"value": "86b2ed3e-30cd-4881-bb58-c4e910821339"}},
)
def test_user_delete(self):
"""Test user delete"""
user = create_test_user()
SCIMSourceUser.objects.create(source=self.source, user=user, external_id=uuid4())
SCIMSourceUser.objects.create(source=self.source, user=user, id=uuid4())
response = self.client.delete(
reverse(
"authentik_sources_scim:v2-users",

View File

@@ -1,488 +0,0 @@
from rest_framework.test import APITestCase
from authentik.core.tests.utils import create_test_user
from authentik.lib.generators import generate_id
from authentik.sources.scim.constants import SCIM_URN_USER_ENTERPRISE
from authentik.sources.scim.models import SCIMSource, SCIMSourceUser
from authentik.sources.scim.patch.processor import SCIMPatchProcessor
class TestSCIMUsersPatch(APITestCase):
"""Test SCIM User Patch"""
def test_add(self):
req = {
"schemas": ["urn:ietf:params:scim:api:messages:2.0:PatchOp"],
"Operations": [
{"op": "Add", "path": "name.givenName", "value": "aqwer"},
{"op": "Add", "path": "name.familyName", "value": "qwerqqqq"},
{"op": "Add", "path": "name.formatted", "value": "aqwer qwerqqqq"},
],
}
user = create_test_user()
source = SCIMSource.objects.create(slug=generate_id())
connection = SCIMSourceUser.objects.create(
user=user,
id=generate_id(),
source=source,
attributes={
"meta": {"resourceType": "User"},
"active": True,
"schemas": [
"urn:ietf:params:scim:schemas:core:2.0:User",
SCIM_URN_USER_ENTERPRISE,
],
"userName": "test@t.goauthentik.io",
"externalId": "test",
"displayName": "Test MS",
},
)
updated = SCIMPatchProcessor().apply_patches(connection.attributes, req["Operations"])
self.assertEqual(
updated,
{
"meta": {"resourceType": "User"},
"active": True,
"name": {
"givenName": "aqwer",
"familyName": "qwerqqqq",
"formatted": "aqwer qwerqqqq",
},
"schemas": [
"urn:ietf:params:scim:schemas:core:2.0:User",
SCIM_URN_USER_ENTERPRISE,
],
"userName": "test@t.goauthentik.io",
"externalId": "test",
"displayName": "Test MS",
},
)
def test_add_no_path(self):
"""Test add patch with no path set"""
req = {
"schemas": ["urn:ietf:params:scim:api:messages:2.0:PatchOp"],
"Operations": [
{"op": "Add", "value": {"externalId": "aqwer"}},
],
}
user = create_test_user()
source = SCIMSource.objects.create(slug=generate_id())
connection = SCIMSourceUser.objects.create(
user=user,
id=generate_id(),
source=source,
attributes={
"meta": {"resourceType": "User"},
"active": True,
"schemas": [
"urn:ietf:params:scim:schemas:core:2.0:User",
SCIM_URN_USER_ENTERPRISE,
],
"userName": "test@t.goauthentik.io",
"displayName": "Test MS",
},
)
updated = SCIMPatchProcessor().apply_patches(connection.attributes, req["Operations"])
self.assertEqual(
updated,
{
"meta": {"resourceType": "User"},
"active": True,
"schemas": [
"urn:ietf:params:scim:schemas:core:2.0:User",
SCIM_URN_USER_ENTERPRISE,
],
"userName": "test@t.goauthentik.io",
"externalId": "aqwer",
"displayName": "Test MS",
},
)
def test_replace(self):
req = {
"schemas": ["urn:ietf:params:scim:api:messages:2.0:PatchOp"],
"Operations": [
{"op": "Replace", "path": "name", "value": {"givenName": "aqwer"}},
],
}
user = create_test_user()
source = SCIMSource.objects.create(slug=generate_id())
connection = SCIMSourceUser.objects.create(
user=user,
id=generate_id(),
source=source,
attributes={
"meta": {"resourceType": "User"},
"active": True,
"schemas": [
"urn:ietf:params:scim:schemas:core:2.0:User",
SCIM_URN_USER_ENTERPRISE,
],
"userName": "test@t.goauthentik.io",
"externalId": "test",
"displayName": "Test MS",
},
)
updated = SCIMPatchProcessor().apply_patches(connection.attributes, req["Operations"])
self.assertEqual(
updated,
{
"meta": {"resourceType": "User"},
"active": True,
"name": {
"givenName": "aqwer",
},
"schemas": [
"urn:ietf:params:scim:schemas:core:2.0:User",
SCIM_URN_USER_ENTERPRISE,
],
"userName": "test@t.goauthentik.io",
"externalId": "test",
"displayName": "Test MS",
},
)
def test_replace_no_path(self):
"""Test value replace with no path"""
req = {
"schemas": ["urn:ietf:params:scim:api:messages:2.0:PatchOp"],
"Operations": [
{"op": "Replace", "value": {"externalId": "aqwer"}},
],
}
user = create_test_user()
source = SCIMSource.objects.create(slug=generate_id())
connection = SCIMSourceUser.objects.create(
user=user,
id=generate_id(),
source=source,
attributes={
"meta": {"resourceType": "User"},
"active": True,
"schemas": [
"urn:ietf:params:scim:schemas:core:2.0:User",
SCIM_URN_USER_ENTERPRISE,
],
"userName": "test@t.goauthentik.io",
"externalId": "test",
"displayName": "Test MS",
},
)
updated = SCIMPatchProcessor().apply_patches(connection.attributes, req["Operations"])
self.assertEqual(
updated,
{
"meta": {"resourceType": "User"},
"active": True,
"schemas": [
"urn:ietf:params:scim:schemas:core:2.0:User",
SCIM_URN_USER_ENTERPRISE,
],
"userName": "test@t.goauthentik.io",
"externalId": "aqwer",
"displayName": "Test MS",
},
)
def test_remove(self):
req = {
"schemas": ["urn:ietf:params:scim:api:messages:2.0:PatchOp"],
"Operations": [
{"op": "Remove", "path": "name", "value": {"givenName": "aqwer"}},
],
}
user = create_test_user()
source = SCIMSource.objects.create(slug=generate_id())
connection = SCIMSourceUser.objects.create(
user=user,
id=generate_id(),
source=source,
attributes={
"meta": {"resourceType": "User"},
"active": True,
"name": {
"givenName": "aqwer",
},
"schemas": [
"urn:ietf:params:scim:schemas:core:2.0:User",
SCIM_URN_USER_ENTERPRISE,
],
"userName": "test@t.goauthentik.io",
"externalId": "test",
"displayName": "Test MS",
},
)
updated = SCIMPatchProcessor().apply_patches(connection.attributes, req["Operations"])
self.assertEqual(
updated,
{
"meta": {"resourceType": "User"},
"active": True,
"schemas": [
"urn:ietf:params:scim:schemas:core:2.0:User",
SCIM_URN_USER_ENTERPRISE,
],
"userName": "test@t.goauthentik.io",
"externalId": "test",
"displayName": "Test MS",
},
)
def test_large(self):
"""Large amount of patch operations"""
req = {
"Operations": [
{
"op": "replace",
"path": "emails[primary eq true].value",
"value": "dandre_kling@wintheiser.info",
},
{
"op": "replace",
"path": "phoneNumbers[primary eq true].value",
"value": "72-634-1548",
},
{
"op": "replace",
"path": "phoneNumbers[primary eq true].display",
"value": "72-634-1548",
},
{"op": "replace", "path": "ims[primary eq true].value", "value": "GXSGJKWGHVVS"},
{"op": "replace", "path": "ims[primary eq true].display", "value": "IMCHDKUQIPYB"},
{
"op": "replace",
"path": "photos[primary eq true].display",
"value": "TWAWLHHSUNIV",
},
{
"op": "replace",
"path": "addresses[primary eq true].formatted",
"value": "TMINZQAJQDCL",
},
{
"op": "replace",
"path": "addresses[primary eq true].streetAddress",
"value": "081 Wisoky Key",
},
{
"op": "replace",
"path": "addresses[primary eq true].locality",
"value": "DPFASBZRPMDP",
},
{
"op": "replace",
"path": "addresses[primary eq true].region",
"value": "WHSTJSPIPTCF",
},
{
"op": "replace",
"path": "addresses[primary eq true].postalCode",
"value": "ko28 1qa",
},
{"op": "replace", "path": "addresses[primary eq true].country", "value": "Taiwan"},
{
"op": "replace",
"path": "entitlements[primary eq true].value",
"value": "NGBJMUYZVVBX",
},
{"op": "replace", "path": "roles[primary eq true].value", "value": "XEELVFMMWCVM"},
{
"op": "replace",
"path": "x509Certificates[primary eq true].value",
"value": "UYISMEDOXUZY",
},
{
"op": "replace",
"value": {
"externalId": "7faaefb0-0774-4d8e-8f6d-863c361bc72c",
"name.formatted": "Dell",
"name.familyName": "Gay",
"name.givenName": "Kyler",
"name.middleName": "Hannah",
"name.honorificPrefix": "Cassie",
"name.honorificSuffix": "Yolanda",
"displayName": "DPRLIJSFQMTL",
"nickName": "BKSPMIRMFBTI",
"title": "NBZCOAXVYJUY",
"userType": "ZGJMYZRUORZE",
"preferredLanguage": "as-IN",
"locale": "JLOJHLPWZODG",
"timezone": "America/Argentina/Rio_Gallegos",
"active": True,
f"{SCIM_URN_USER_ENTERPRISE}:employeeNumber": "PDFWRRZBQOHB",
f"{SCIM_URN_USER_ENTERPRISE}:costCenter": "HACMZWSEDOTQ",
f"{SCIM_URN_USER_ENTERPRISE}:organization": "LXVHJUOLNCLS",
f"{SCIM_URN_USER_ENTERPRISE}:division": "JASVTPKPBPMG",
f"{SCIM_URN_USER_ENTERPRISE}:department": "GMSBFLMNPABY",
},
},
],
"schemas": ["urn:ietf:params:scim:api:messages:2.0:PatchOp"],
}
user = create_test_user()
source = SCIMSource.objects.create(slug=generate_id())
connection = SCIMSourceUser.objects.create(
user=user,
id=generate_id(),
source=source,
attributes={
"active": True,
"addresses": [
{
"primary": "true",
"formatted": "BLJMCNXHYLZK",
"streetAddress": "7801 Jacobs Fork",
"locality": "HZJBJWFAKXDD",
"region": "GJXCXPMIIKWK",
"postalCode": "pv82 8ua",
"country": "India",
}
],
"displayName": "KEFXCHKHAFOT",
"emails": [{"primary": "true", "value": "scot@zemlak.uk"}],
"entitlements": [{"primary": "true", "value": "FTTUXWYDAAQC"}],
"externalId": "448d2786-7bf6-4e03-a4ef-64cbaf162fa7",
"ims": [{"primary": "true", "value": "IGWZUUMCMKXS", "display": "PJVGMMKYYHRU"}],
"locale": "PJNYJHWJILTI",
"name": {
"formatted": "Ladarius",
"familyName": "Manley",
"givenName": "Mazie",
"middleName": "Vernon",
"honorificPrefix": "Melyssa",
"honorificSuffix": "Demarcus",
},
"nickName": "HTPKOXMWZKHL",
"phoneNumbers": [
{"primary": "true", "value": "50-608-7660", "display": "50-608-7660"}
],
"photos": [{"primary": "true", "display": "KCONLNLSYTBP"}],
"preferredLanguage": "wae",
"profileUrl": "HPSEOIPXMGOH",
"roles": [{"primary": "true", "value": "TLGYITOIZGKP"}],
"schemas": [
"urn:ietf:params:scim:schemas:core:2.0:User",
SCIM_URN_USER_ENTERPRISE,
],
"timezone": "America/Indiana/Petersburg",
"title": "EJWFXLHNHMCD",
SCIM_URN_USER_ENTERPRISE: {
"employeeNumber": "XHDMEJUURJNR",
"costCenter": "RXUYBXOTRCZH",
"organization": "CEXWXMBRYAHN",
"division": "XMPFMDCLRKCW",
"department": "BKMNJVMCJUYS",
"manager": "PNGSGXLYVWMV",
},
"userName": "imelda.auer@kshlerin.co.uk",
"userType": "PZFXORVSUAPU",
"x509Certificates": [{"primary": "true", "value": "KOVKWGIVVEHH"}],
},
)
updated = SCIMPatchProcessor().apply_patches(connection.attributes, req["Operations"])
self.assertEqual(
updated,
{
"active": True,
"addresses": [
{
"primary": "true",
"formatted": "BLJMCNXHYLZK",
"streetAddress": "7801 Jacobs Fork",
"locality": "HZJBJWFAKXDD",
"region": "GJXCXPMIIKWK",
"postalCode": "pv82 8ua",
"country": "India",
}
],
"displayName": "DPRLIJSFQMTL",
"emails": [{"primary": "true", "value": "scot@zemlak.uk"}],
"entitlements": [{"primary": "true", "value": "FTTUXWYDAAQC"}],
"externalId": "7faaefb0-0774-4d8e-8f6d-863c361bc72c",
"ims": [{"primary": "true", "value": "IGWZUUMCMKXS", "display": "PJVGMMKYYHRU"}],
"locale": "JLOJHLPWZODG",
"name": {
"formatted": "Dell",
"familyName": "Gay",
"givenName": "Kyler",
"middleName": "Hannah",
"honorificPrefix": "Cassie",
"honorificSuffix": "Yolanda",
},
"nickName": "BKSPMIRMFBTI",
"phoneNumbers": [
{"primary": "true", "value": "50-608-7660", "display": "50-608-7660"}
],
"photos": [{"primary": "true", "display": "KCONLNLSYTBP"}],
"preferredLanguage": "as-IN",
"profileUrl": "HPSEOIPXMGOH",
"roles": [{"primary": "true", "value": "TLGYITOIZGKP"}],
"schemas": [
"urn:ietf:params:scim:schemas:core:2.0:User",
SCIM_URN_USER_ENTERPRISE,
],
"timezone": "America/Argentina/Rio_Gallegos",
"title": "NBZCOAXVYJUY",
SCIM_URN_USER_ENTERPRISE: {
"employeeNumber": "PDFWRRZBQOHB",
"costCenter": "HACMZWSEDOTQ",
"organization": "LXVHJUOLNCLS",
"division": "JASVTPKPBPMG",
"department": "GMSBFLMNPABY",
"manager": "PNGSGXLYVWMV",
},
"userName": "imelda.auer@kshlerin.co.uk",
"userType": "ZGJMYZRUORZE",
"x509Certificates": [{"primary": "true", "value": "KOVKWGIVVEHH"}],
},
)
def test_schema_urn_manager(self):
req = {
"schemas": ["urn:ietf:params:scim:api:messages:2.0:PatchOp"],
"Operations": [
{
"op": "Add",
"value": {
"urn:ietf:params:scim:schemas:extension:enterprise:2.0:User:manager": "foo"
},
},
],
}
user = create_test_user()
source = SCIMSource.objects.create(slug=generate_id())
connection = SCIMSourceUser.objects.create(
user=user,
id=generate_id(),
source=source,
attributes={
"meta": {"resourceType": "User"},
"active": True,
"schemas": [
"urn:ietf:params:scim:schemas:core:2.0:User",
SCIM_URN_USER_ENTERPRISE,
],
"userName": "test@t.goauthentik.io",
"externalId": "test",
"displayName": "Test MS",
},
)
updated = SCIMPatchProcessor().apply_patches(connection.attributes, req["Operations"])
self.assertEqual(
updated,
{
"meta": {"resourceType": "User"},
"active": True,
"schemas": [
"urn:ietf:params:scim:schemas:core:2.0:User",
SCIM_URN_USER_ENTERPRISE,
],
"userName": "test@t.goauthentik.io",
"externalId": "test",
"displayName": "Test MS",
"urn:ietf:params:scim:schemas:extension:enterprise:2.0:User": {
"manager": {"value": "foo"}
},
},
)

View File

@@ -1,7 +1,6 @@
"""SCIM Utils"""
from typing import Any
from uuid import UUID
from django.conf import settings
from django.core.paginator import Page, Paginator
@@ -22,7 +21,6 @@ from authentik.core.sources.mapper import SourceMapper
from authentik.lib.sync.mapper import PropertyMappingManager
from authentik.sources.scim.models import SCIMSource
from authentik.sources.scim.views.v2.auth import SCIMTokenAuth
from authentik.sources.scim.views.v2.exceptions import SCIMNotFoundError
SCIM_CONTENT_TYPE = "application/scim+json"
@@ -56,13 +54,6 @@ class SCIMView(APIView):
def get_authenticators(self):
return [SCIMTokenAuth(self)]
def remove_excluded_attributes(self, data: dict):
"""Remove attributes specified in excludedAttributes"""
excluded: str = self.request.query_params.get("excludedAttributes", "")
for key in excluded.split(","):
data.pop(key.strip(), None)
return data
def filter_parse(self, request: Request):
"""Parse the path of a Patch Operation"""
path = request.query_params.get("filter")
@@ -112,12 +103,6 @@ class SCIMObjectView(SCIMView):
# a source attribute before
self.mapper = SourceMapper(self.source)
self.manager = self.mapper.get_manager(self.model, ["data"])
for key, value in kwargs.items():
if key.endswith("_id"):
try:
UUID(value)
except ValueError:
raise SCIMNotFoundError("Invalid ID") from None
def build_object_properties(self, data: dict[str, Any]) -> dict[str, Any | dict[str, Any]]:
return self.mapper.build_object_properties(

View File

@@ -17,7 +17,6 @@ from authentik.core.models import Group, User
from authentik.providers.scim.clients.schema import SCIM_GROUP_SCHEMA, PatchOp, PatchOperation
from authentik.providers.scim.clients.schema import Group as SCIMGroupModel
from authentik.sources.scim.models import SCIMSourceGroup
from authentik.sources.scim.patch.processor import SCIMPatchProcessor
from authentik.sources.scim.views.v2.base import SCIMObjectView
from authentik.sources.scim.views.v2.exceptions import (
SCIMConflictError,
@@ -36,12 +35,11 @@ class GroupsView(SCIMObjectView):
payload = SCIMGroupModel(
schemas=[SCIM_GROUP_SCHEMA],
id=str(scim_group.group.pk),
externalId=scim_group.external_id,
externalId=scim_group.id,
displayName=scim_group.group.name,
members=[],
meta={
"resourceType": "Group",
"lastModified": scim_group.last_update,
"location": self.request.build_absolute_uri(
reverse(
"authentik_sources_scim:v2-groups",
@@ -56,11 +54,7 @@ class GroupsView(SCIMObjectView):
for member in scim_group.group.users.order_by("pk"):
member: User
payload.members.append(GroupMember(value=str(member.uuid)))
final_payload = payload.model_dump(mode="json", exclude_unset=True)
final_payload.update(scim_group.attributes)
return self.remove_excluded_attributes(
SCIMGroupModel.model_validate(final_payload).model_dump(mode="json", exclude_unset=True)
)
return payload.model_dump(mode="json", exclude_unset=True)
def get(self, request: Request, group_id: str | None = None, **kwargs) -> Response:
"""List Group handler"""
@@ -87,7 +81,7 @@ class GroupsView(SCIMObjectView):
)
@atomic
def update_group(self, connection: SCIMSourceGroup | None, data: QueryDict, apply_members=True):
def update_group(self, connection: SCIMSourceGroup | None, data: QueryDict):
"""Partial update a group"""
properties = self.build_object_properties(data)
@@ -100,7 +94,7 @@ class GroupsView(SCIMObjectView):
group.update_attributes(properties)
if "members" in data and apply_members:
if "members" in data:
query = Q()
for _member in data.get("members", []):
try:
@@ -111,18 +105,14 @@ class GroupsView(SCIMObjectView):
query |= Q(uuid=member.value)
if query:
group.users.set(User.objects.filter(query))
data["members"] = self._convert_members(group)
if not connection:
connection, _ = SCIMSourceGroup.objects.update_or_create(
external_id=data.get("externalId") or str(uuid4()),
connection, _ = SCIMSourceGroup.objects.get_or_create(
source=self.source,
group=group,
defaults={
"attributes": data,
},
attributes=data,
id=data.get("externalId") or str(uuid4()),
)
else:
connection.external_id = data.get("externalId", connection.external_id)
connection.attributes = data
connection.save()
return connection
@@ -149,12 +139,6 @@ class GroupsView(SCIMObjectView):
connection = self.update_group(connection, request.data)
return Response(self.group_to_scim(connection), status=200)
def _convert_members(self, group: Group):
users = []
for user in group.users.all().order_by("uuid"):
users.append({"value": str(user.uuid)})
return sorted(users, key=lambda u: u["value"])
@atomic
def patch(self, request: Request, group_id: str, **kwargs) -> Response:
"""Patch group handler"""
@@ -187,13 +171,6 @@ class GroupsView(SCIMObjectView):
query |= Q(uuid=member["value"])
if query:
connection.group.users.remove(*User.objects.filter(query))
patcher = SCIMPatchProcessor()
patched_data = patcher.apply_patches(
connection.attributes, request.data.get("Operations", [])
)
patched_data["members"] = self._convert_members(connection.group)
if patched_data != connection.attributes:
self.update_group(connection, patched_data, apply_members=False)
return Response(self.group_to_scim(connection), status=200)
@atomic

View File

@@ -33,7 +33,9 @@ class ServiceProviderConfigView(SCIMView):
{
"schemas": ["urn:ietf:params:scim:schemas:core:2.0:ServiceProviderConfig"],
"authenticationSchemes": auth_schemas,
"patch": {"supported": True},
# We only support patch for groups currently, so don't broadly advertise it.
# Implementations that require Group patch will use it regardless of this flag.
"patch": {"supported": False},
"bulk": {"supported": False, "maxOperations": 0, "maxPayloadSize": 0},
"filter": {
"supported": True,

View File

@@ -15,7 +15,6 @@ from authentik.core.models import User
from authentik.providers.scim.clients.schema import SCIM_USER_SCHEMA
from authentik.providers.scim.clients.schema import User as SCIMUserModel
from authentik.sources.scim.models import SCIMSourceUser
from authentik.sources.scim.patch.processor import SCIMPatchProcessor
from authentik.sources.scim.views.v2.base import SCIMObjectView
from authentik.sources.scim.views.v2.exceptions import SCIMConflictError, SCIMNotFoundError
@@ -30,7 +29,7 @@ class UsersView(SCIMObjectView):
payload = SCIMUserModel(
schemas=[SCIM_USER_SCHEMA],
id=str(scim_user.user.uuid),
externalId=scim_user.external_id,
externalId=scim_user.id,
userName=scim_user.user.username,
name=Name(
formatted=scim_user.user.name,
@@ -45,7 +44,8 @@ class UsersView(SCIMObjectView):
meta={
"resourceType": "User",
"created": scim_user.user.date_joined,
"lastModified": scim_user.last_update,
# TODO: use events to find last edit?
"lastModified": scim_user.user.date_joined,
"location": self.request.build_absolute_uri(
reverse(
"authentik_sources_scim:v2-users",
@@ -59,9 +59,7 @@ class UsersView(SCIMObjectView):
)
final_payload = payload.model_dump(mode="json", exclude_unset=True)
final_payload.update(scim_user.attributes)
return self.remove_excluded_attributes(
SCIMUserModel.model_validate(final_payload).model_dump(mode="json", exclude_unset=True)
)
return final_payload
def get(self, request: Request, user_id: str | None = None, **kwargs) -> Response:
"""List User handler"""
@@ -103,16 +101,13 @@ class UsersView(SCIMObjectView):
user.update_attributes(properties)
if not connection:
connection, _ = SCIMSourceUser.objects.update_or_create(
external_id=data.get("externalId") or str(uuid4()),
connection, _ = SCIMSourceUser.objects.get_or_create(
source=self.source,
user=user,
defaults={
"attributes": data,
},
attributes=data,
id=data.get("externalId") or str(uuid4()),
)
else:
connection.external_id = data.get("externalId", connection.external_id)
connection.attributes = data
connection.save()
return connection
@@ -132,18 +127,6 @@ class UsersView(SCIMObjectView):
connection = self.update_user(None, request.data)
return Response(self.user_to_scim(connection), status=201)
def patch(self, request: Request, user_id: str, **kwargs):
connection = SCIMSourceUser.objects.filter(source=self.source, user__uuid=user_id).first()
if not connection:
raise SCIMNotFoundError("User not found.")
patcher = SCIMPatchProcessor()
patched_data = patcher.apply_patches(
connection.attributes, request.data.get("Operations", [])
)
if patched_data != connection.attributes:
self.update_user(connection, patched_data)
return Response(self.user_to_scim(connection), status=200)
def put(self, request: Request, user_id: str, **kwargs) -> Response:
"""Update user handler"""
connection = SCIMSourceUser.objects.filter(source=self.source, user__uuid=user_id).first()

View File

@@ -13,6 +13,7 @@ from authentik.flows.exceptions import StageInvalidException
from authentik.flows.models import ConfigurableStage, FriendlyNamedStage, Stage
from authentik.lib.config import CONFIG
from authentik.lib.models import SerializerModel
from authentik.lib.utils.errors import exception_to_string
from authentik.lib.utils.time import timedelta_string_validator
from authentik.stages.authenticator.models import SideChannelDevice
from authentik.stages.email.utils import TemplateEmailMessage
@@ -159,8 +160,9 @@ class EmailDevice(SerializerModel, SideChannelDevice):
Event.new(
EventAction.CONFIGURATION_ERROR,
message=_("Exception occurred while rendering E-mail template"),
error=exception_to_string(exc),
template=stage.template,
).with_exception(exc).from_http(self.request)
).from_http(self.request)
raise StageInvalidException from exc
def __str__(self):

View File

@@ -17,6 +17,7 @@ from authentik.flows.challenge import (
from authentik.flows.exceptions import StageInvalidException
from authentik.flows.stage import ChallengeStageView
from authentik.lib.utils.email import mask_email
from authentik.lib.utils.errors import exception_to_string
from authentik.lib.utils.time import timedelta_from_string
from authentik.stages.authenticator_email.models import (
AuthenticatorEmailStage,
@@ -99,8 +100,9 @@ class AuthenticatorEmailStageView(ChallengeStageView):
Event.new(
EventAction.CONFIGURATION_ERROR,
message=_("Exception occurred while rendering E-mail template"),
error=exception_to_string(exc),
template=stage.template,
).with_exception(exc).from_http(self.request)
).from_http(self.request)
raise StageInvalidException from exc
def _has_email(self) -> str | None:

View File

@@ -4,7 +4,7 @@ from hashlib import sha256
from django.contrib.auth import get_user_model
from django.db import models
from django.http import HttpRequest, HttpResponseBadRequest
from django.http import HttpResponseBadRequest
from django.utils.translation import gettext_lazy as _
from django.views import View
from requests.exceptions import RequestException
@@ -19,6 +19,7 @@ from authentik.events.models import Event, EventAction, NotificationWebhookMappi
from authentik.events.utils import sanitize_item
from authentik.flows.models import ConfigurableStage, FriendlyNamedStage, Stage
from authentik.lib.models import SerializerModel
from authentik.lib.utils.errors import exception_to_string
from authentik.lib.utils.http import get_http_session
from authentik.stages.authenticator.models import SideChannelDevice
@@ -68,44 +69,32 @@ class AuthenticatorSMSStage(ConfigurableStage, FriendlyNamedStage, Stage):
help_text=_("Optionally modify the payload being sent to custom providers."),
)
def send(self, request: HttpRequest, token: str, device: "SMSDevice"):
def send(self, token: str, device: "SMSDevice"):
"""Send message via selected provider"""
if self.provider == SMSProviders.TWILIO:
return self.send_twilio(request, token, device)
return self.send_twilio(token, device)
if self.provider == SMSProviders.GENERIC:
return self.send_generic(request, token, device)
return self.send_generic(token, device)
raise ValueError(f"invalid provider {self.provider}")
def get_message(self, token: str) -> str:
"""Get SMS message"""
return _("Use this code to authenticate in authentik: {token}".format_map({"token": token}))
def send_twilio(self, request: HttpRequest, token: str, device: "SMSDevice"):
def send_twilio(self, token: str, device: "SMSDevice"):
"""send sms via twilio provider"""
client = Client(self.account_sid, self.auth)
message_body = str(self.get_message(token))
if self.mapping:
payload = sanitize_item(
self.mapping.evaluate(
user=device.user,
request=request,
device=device,
token=token,
stage=self,
)
)
message_body = payload.get("message", message_body)
try:
message = client.messages.create(
to=device.phone_number, from_=self.from_number, body=message_body
to=device.phone_number, from_=self.from_number, body=str(self.get_message(token))
)
LOGGER.debug("Sent SMS", to=device, message=message.sid)
except TwilioRestException as exc:
LOGGER.warning("Error sending token by Twilio SMS", exc=exc, msg=exc.msg)
raise ValidationError(exc.msg) from None
def send_generic(self, request: HttpRequest, token: str, device: "SMSDevice"):
def send_generic(self, token: str, device: "SMSDevice"):
"""Send SMS via outside API"""
payload = {
"From": self.from_number,
@@ -118,7 +107,7 @@ class AuthenticatorSMSStage(ConfigurableStage, FriendlyNamedStage, Stage):
payload = sanitize_item(
self.mapping.evaluate(
user=device.user,
request=request,
request=None,
device=device,
token=token,
stage=self,
@@ -153,9 +142,10 @@ class AuthenticatorSMSStage(ConfigurableStage, FriendlyNamedStage, Stage):
Event.new(
EventAction.CONFIGURATION_ERROR,
message="Error sending SMS",
exc=exception_to_string(exc),
status_code=response.status_code,
body=response.text,
).with_exception(exc).set_user(device.user).save()
).set_user(device.user).save()
if response.status_code >= HttpResponseBadRequest.status_code:
raise ValidationError(response.text) from None
raise

View File

@@ -71,7 +71,7 @@ class AuthenticatorSMSStageView(ChallengeStageView):
raise ValidationError(_("Invalid phone number"))
# No code yet, but we have a phone number, so send a verification message
device: SMSDevice = self.request.session[SESSION_KEY_SMS_DEVICE]
stage.send(self.request, device.token, device)
stage.send(device.token, device)
def _has_phone_number(self) -> str | None:
context = self.executor.plan.context

View File

@@ -9,7 +9,7 @@ from django.http.response import Http404
from django.shortcuts import get_object_or_404
from django.utils.translation import gettext as __
from django.utils.translation import gettext_lazy as _
from rest_framework.fields import CharField, ChoiceField, DateTimeField
from rest_framework.fields import CharField, DateTimeField
from rest_framework.serializers import ValidationError
from structlog.stdlib import get_logger
from webauthn import options_to_json
@@ -18,7 +18,7 @@ from webauthn.authentication.verify_authentication_response import verify_authen
from webauthn.helpers import parse_authentication_credential_json
from webauthn.helpers.base64url_to_bytes import base64url_to_bytes
from webauthn.helpers.exceptions import InvalidAuthenticationResponse, InvalidJSONStructure
from webauthn.helpers.structs import PublicKeyCredentialType, UserVerificationRequirement
from webauthn.helpers.structs import UserVerificationRequirement
from authentik.core.api.utils import JSONDictField, PassiveSerializer
from authentik.core.models import Application, User
@@ -48,7 +48,7 @@ if TYPE_CHECKING:
class DeviceChallenge(PassiveSerializer):
"""Single device challenge"""
device_class = ChoiceField(choices=DeviceClasses.choices)
device_class = CharField()
device_uid = CharField()
challenge = JSONDictField()
last_used = DateTimeField(allow_null=True)
@@ -124,7 +124,7 @@ def select_challenge(request: HttpRequest, device: Device):
def select_challenge_sms(request: HttpRequest, device: SMSDevice):
"""Send SMS"""
device.generate_token()
device.stage.send(request, device.token, device)
device.stage.send(device.token, device)
def select_challenge_email(request: HttpRequest, device: EmailDevice):
@@ -157,12 +157,6 @@ def validate_challenge_webauthn(data: dict, stage_view: StageView, user: User) -
request = stage_view.request
challenge = stage_view.executor.plan.context.get(PLAN_CONTEXT_WEBAUTHN_CHALLENGE)
stage: AuthenticatorValidateStage = stage_view.executor.current_stage
if "MinuteMaid" in request.META.get("HTTP_USER_AGENT", ""):
# Workaround for Android sign-in, when signing into Google Workspace on android while
# adding the account to the system (not in Chrome), for some reason `type` is not set
# so in that case we fall back to `public-key`
# since that's the only option we support anyways
data.setdefault("type", PublicKeyCredentialType.PUBLIC_KEY)
try:
credential = parse_authentication_credential_json(data)
except InvalidJSONStructure as exc:

View File

@@ -173,7 +173,6 @@ class AuthenticatorValidateStageDuoTests(FlowTestCase):
{
"auth_method": "auth_mfa",
"auth_method_args": {
"known_device": False,
"mfa_devices": [
{
"app": "authentik_stages_authenticator_duo",
@@ -181,7 +180,7 @@ class AuthenticatorValidateStageDuoTests(FlowTestCase):
"name": "",
"pk": duo_device.pk,
}
],
]
},
"http_request": {
"args": {},

View File

@@ -153,13 +153,13 @@ class AuthenticatorValidateStageTests(FlowTestCase):
plan.append_stage(stage)
plan.context[PLAN_CONTEXT_DEVICE_CHALLENGES] = [
{
"device_class": DeviceClasses.STATIC,
"device_class": "static",
"device_uid": "1",
"challenge": {},
"last_used": now(),
},
{
"device_class": DeviceClasses.TOTP,
"device_class": "totp",
"device_uid": "2",
"challenge": {},
"last_used": now(),
@@ -172,7 +172,7 @@ class AuthenticatorValidateStageTests(FlowTestCase):
reverse("authentik_api:flow-executor", kwargs={"flow_slug": flow.slug}),
data={
"selected_challenge": {
"device_class": DeviceClasses.WEBAUTHN,
"device_class": "baz",
"device_uid": "quox",
"challenge": {},
"last_used": None,

View File

@@ -162,7 +162,7 @@ class AuthenticatorValidateStageWebAuthnTests(FlowTestCase):
session = self.client.session
plan = FlowPlan(flow_pk=flow.pk.hex)
plan.append_stage(stage)
plan.append_stage(UserLoginStage.objects.create(name=generate_id()))
plan.append_stage(UserLoginStage(name=generate_id()))
plan.context[PLAN_CONTEXT_PENDING_USER] = self.user
session[SESSION_KEY_PLAN] = plan
session.save()
@@ -282,7 +282,7 @@ class AuthenticatorValidateStageWebAuthnTests(FlowTestCase):
session = self.client.session
plan = FlowPlan(flow_pk=flow.pk.hex)
plan.append_stage(stage)
plan.append_stage(UserLoginStage.objects.create(name=generate_id()))
plan.append_stage(UserLoginStage(name=generate_id()))
plan.context[PLAN_CONTEXT_PENDING_USER] = self.user
plan.context[PLAN_CONTEXT_DEVICE_CHALLENGES] = [
{
@@ -359,7 +359,7 @@ class AuthenticatorValidateStageWebAuthnTests(FlowTestCase):
session = self.client.session
plan = FlowPlan(flow_pk=flow.pk.hex)
plan.append_stage(stage)
plan.append_stage(UserLoginStage.objects.create(name=generate_id()))
plan.append_stage(UserLoginStage(name=generate_id()))
plan.context[PLAN_CONTEXT_PENDING_USER] = self.user
plan.context[PLAN_CONTEXT_DEVICE_CHALLENGES] = [
{
@@ -441,7 +441,7 @@ class AuthenticatorValidateStageWebAuthnTests(FlowTestCase):
session = self.client.session
plan = FlowPlan(flow_pk=flow.pk.hex)
plan.append_stage(stage)
plan.append_stage(UserLoginStage.objects.create(name=generate_id()))
plan.append_stage(UserLoginStage(name=generate_id()))
plan.context[PLAN_CONTEXT_DEVICE_CHALLENGES] = [
{
"device_class": device.__class__.__name__.lower().replace("device", ""),

File diff suppressed because one or more lines are too long

View File

@@ -21,6 +21,7 @@ from authentik.flows.models import FlowDesignation, FlowToken
from authentik.flows.planner import PLAN_CONTEXT_IS_RESTORED, PLAN_CONTEXT_PENDING_USER
from authentik.flows.stage import ChallengeStageView
from authentik.flows.views.executor import QS_KEY_TOKEN, QS_QUERY
from authentik.lib.utils.errors import exception_to_string
from authentik.lib.utils.time import timedelta_from_string
from authentik.stages.email.flow import pickle_flow_token_for_email
from authentik.stages.email.models import EmailStage
@@ -128,8 +129,9 @@ class EmailStageView(ChallengeStageView):
Event.new(
EventAction.CONFIGURATION_ERROR,
message=_("Exception occurred while rendering E-mail template"),
error=exception_to_string(exc),
template=current_stage.template,
).with_exception(exc).from_http(self.request)
).from_http(self.request)
raise StageInvalidException from exc
def get(self, request: HttpRequest, *args, **kwargs) -> HttpResponse:
@@ -143,7 +145,7 @@ class EmailStageView(ChallengeStageView):
messages.success(request, _("Successfully verified Email."))
if self.executor.current_stage.activate_user_on_success:
user.is_active = True
user.save(update_fields=["is_active"])
user.save()
return self.executor.stage_ok()
if PLAN_CONTEXT_PENDING_USER not in self.executor.plan.context:
self.logger.debug("No pending user")

View File

@@ -1,6 +1,6 @@
"""Prompt Stage Logic"""
from collections.abc import Callable
from collections.abc import Callable, Iterator
from email.policy import Policy
from types import MethodType
from typing import Any
@@ -190,11 +190,10 @@ class ListPolicyEngine(PolicyEngine):
self.__list = policies
self.use_cache = False
def bindings(self):
for idx, policy in enumerate(self.__list):
def iterate_bindings(self) -> Iterator[PolicyBinding]:
for policy in self.__list:
yield PolicyBinding(
policy=policy,
order=idx,
)

View File

@@ -214,7 +214,7 @@ class TestPromptStage(FlowTestCase):
"""Test challenge_response validation"""
plan = FlowPlan(flow_pk=self.flow.pk.hex, bindings=[self.binding], markers=[StageMarker()])
expr = "False"
expr_policy = ExpressionPolicy.objects.create(name=generate_id(), expression=expr)
expr_policy = ExpressionPolicy.objects.create(name="validate-form", expression=expr)
self.stage.validation_policies.set([expr_policy])
self.stage.save()
challenge_response = PromptChallengeResponse(
@@ -222,18 +222,6 @@ class TestPromptStage(FlowTestCase):
)
self.assertEqual(challenge_response.is_valid(), False)
def test_invalid_challenge_multiple(self):
"""Test challenge_response validation (multiple policies)"""
plan = FlowPlan(flow_pk=self.flow.pk.hex, bindings=[self.binding], markers=[StageMarker()])
expr_policy1 = ExpressionPolicy.objects.create(name=generate_id(), expression="False")
expr_policy2 = ExpressionPolicy.objects.create(name=generate_id(), expression="False")
self.stage.validation_policies.set([expr_policy1, expr_policy2])
self.stage.save()
challenge_response = PromptChallengeResponse(
None, stage_instance=self.stage, plan=plan, data=self.prompt_data, stage=self.stage_view
)
self.assertEqual(challenge_response.is_valid(), False)
def test_valid_challenge_request(self):
"""Test a request with valid challenge_response data"""
plan = FlowPlan(flow_pk=self.flow.pk.hex, bindings=[self.binding], markers=[StageMarker()])
@@ -246,7 +234,7 @@ class TestPromptStage(FlowTestCase):
"return request.context['prompt_data']['password_prompt'] "
"== request.context['prompt_data']['password2_prompt']"
)
expr_policy = ExpressionPolicy.objects.create(name=generate_id(), expression=expr)
expr_policy = ExpressionPolicy.objects.create(name="validate-form", expression=expr)
self.stage.validation_policies.set([expr_policy])
self.stage.save()
challenge_response = PromptChallengeResponse(

View File

@@ -18,7 +18,6 @@ class UserLoginStageSerializer(StageSerializer):
"remember_me_offset",
"network_binding",
"geoip_binding",
"remember_device",
]

View File

@@ -6,7 +6,6 @@ from django.contrib.auth.views import redirect_to_login
from django.http.request import HttpRequest
from structlog.stdlib import get_logger
from authentik.core.models import Session
from authentik.events.context_processors.asn import ASN_CONTEXT_PROCESSOR
from authentik.events.context_processors.geoip import GEOIP_CONTEXT_PROCESSOR
from authentik.lib.sentry import SentryIgnoredException
@@ -90,7 +89,7 @@ class BoundSessionMiddleware(SessionMiddleware):
def recheck_session(self, request: HttpRequest):
"""Check if a session is still valid with a changed IP"""
last_ip = request.session.get(Session.Keys.LAST_IP)
last_ip = request.session.get(request.session.model.Keys.LAST_IP)
new_ip = ClientIPMiddleware.get_client_ip(request)
# Check changed IP
if new_ip == last_ip:
@@ -110,7 +109,7 @@ class BoundSessionMiddleware(SessionMiddleware):
if SESSION_KEY_BINDING_NET in request.session or SESSION_KEY_BINDING_GEO in request.session:
# Only set the last IP in the session if there's a binding specified
# (== basically requires the user to be logged in)
request.session[Session.Keys.LAST_IP] = new_ip
request.session[request.session.model.Keys.LAST_IP] = new_ip
@staticmethod
def recheck_session_net(binding: NetworkBinding, last_ip: str, new_ip: str):

View File

@@ -1,23 +0,0 @@
# Generated by Django 5.1.11 on 2025-06-18 16:21
import authentik.lib.utils.time
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
("authentik_stages_user_login", "0006_userloginstage_geoip_binding_and_more"),
]
operations = [
migrations.AddField(
model_name="userloginstage",
name="remember_device",
field=models.TextField(
default="days=30",
help_text="When set to a non-zero value, authentik will save a cookie with a longer expiry,to remember the device the user is logging in from. (Format: hours=-1;minutes=-2;seconds=-3)",
validators=[authentik.lib.utils.time.timedelta_string_validator],
),
),
]

View File

@@ -63,15 +63,6 @@ class UserLoginStage(Stage):
"(Format: hours=-1;minutes=-2;seconds=-3)"
),
)
remember_device = models.TextField(
default="days=30",
validators=[timedelta_string_validator],
help_text=_(
"When set to a non-zero value, authentik will save a cookie with a longer expiry,"
"to remember the device the user is logging in from. "
"(Format: hours=-1;minutes=-2;seconds=-3)"
),
)
@property
def serializer(self) -> type[BaseSerializer]:

View File

@@ -1,17 +1,14 @@
"""Login stage logic"""
from datetime import datetime, timedelta
from hashlib import sha256
from datetime import timedelta
from django.conf import settings
from django.contrib import messages
from django.contrib.auth import login
from django.http import HttpRequest, HttpResponse
from django.utils.translation import gettext as _
from jwt import PyJWTError, decode, encode
from rest_framework.fields import BooleanField, CharField
from authentik.core.models import AuthenticatedSession, Session, User
from authentik.core.models import Session, User
from authentik.events.middleware import audit_ignore
from authentik.flows.challenge import ChallengeResponse, WithUserInfoChallenge
from authentik.flows.planner import PLAN_CONTEXT_PENDING_USER
@@ -19,20 +16,12 @@ from authentik.flows.stage import ChallengeStageView
from authentik.lib.utils.time import timedelta_from_string
from authentik.root.middleware import ClientIPMiddleware
from authentik.stages.password import BACKEND_INBUILT
from authentik.stages.password.stage import (
PLAN_CONTEXT_AUTHENTICATION_BACKEND,
PLAN_CONTEXT_METHOD_ARGS,
)
from authentik.stages.password.stage import PLAN_CONTEXT_AUTHENTICATION_BACKEND
from authentik.stages.user_login.middleware import (
SESSION_KEY_BINDING_GEO,
SESSION_KEY_BINDING_NET,
)
from authentik.stages.user_login.models import UserLoginStage
from authentik.tenants.utils import get_unique_identifier
COOKIE_NAME_KNOWN_DEVICE = "authentik_device"
PLAN_CONTEXT_METHOD_ARGS_KNOWN_DEVICE = "known_device"
class UserLoginChallenge(WithUserInfoChallenge):
@@ -89,63 +78,12 @@ class UserLoginStageView(ChallengeStageView):
self.request.session[SESSION_KEY_BINDING_NET] = stage.network_binding
self.request.session[SESSION_KEY_BINDING_GEO] = stage.geoip_binding
# FIXME: identical function in authenticator_validate
@property
def cookie_jwt_key(self) -> str:
"""Signing key for Known-device Cookie for this stage"""
return sha256(
f"{get_unique_identifier()}:{self.executor.current_stage.pk.hex}".encode("ascii")
).hexdigest()
def set_known_device_cookie(self, user: User):
"""Set a cookie, valid longer than the session, which denotes that this user
has logged in on this device before."""
delta = timedelta_from_string(self.executor.current_stage.remember_device)
response = self.executor.stage_ok()
if delta.total_seconds() < 1:
return response
expiry = datetime.now() + delta
cookie_payload = {
"sub": user.uid,
"exp": expiry.timestamp(),
}
cookie = encode(cookie_payload, self.cookie_jwt_key)
response.set_cookie(
COOKIE_NAME_KNOWN_DEVICE,
cookie,
expires=expiry,
path=settings.SESSION_COOKIE_PATH,
domain=settings.SESSION_COOKIE_DOMAIN,
samesite=settings.SESSION_COOKIE_SAMESITE,
)
return response
def is_known_device(self, user: User):
"""Returns `True` if the login happened on a "known" device, by the same user."""
client_ip = ClientIPMiddleware.get_client_ip(self.request)
if AuthenticatedSession.objects.filter(session__last_ip=client_ip).exists():
return True
if COOKIE_NAME_KNOWN_DEVICE not in self.request.COOKIES:
return False
try:
payload = decode(
self.request.COOKIES[COOKIE_NAME_KNOWN_DEVICE], self.cookie_jwt_key, ["HS256"]
)
if payload["sub"] == user.uid:
return True
return False
except (PyJWTError, ValueError, TypeError) as exc:
self.logger.info("eh", exc=exc)
return False
def do_login(self, request: HttpRequest, remember: bool | None = None) -> HttpResponse:
"""Attach the currently pending user to the current session.
`remember` Argument should be `None` if not configured, otherwise set to `True`/`False`
representative of the user's choice."""
def do_login(self, request: HttpRequest, remember: bool = False) -> HttpResponse:
"""Attach the currently pending user to the current session"""
if PLAN_CONTEXT_PENDING_USER not in self.executor.plan.context:
message = _("No Pending user to login.")
messages.error(request, message)
self.logger.warning(message)
self.logger.debug(message)
return self.executor.stage_invalid()
backend = self.executor.plan.context.get(
PLAN_CONTEXT_AUTHENTICATION_BACKEND, BACKEND_INBUILT
@@ -153,13 +91,8 @@ class UserLoginStageView(ChallengeStageView):
user: User = self.executor.plan.context[PLAN_CONTEXT_PENDING_USER]
if not user.is_active:
self.logger.warning("User is not active, login will not work.")
delta = self.set_session_duration(bool(remember))
delta = self.set_session_duration(remember)
self.set_session_ip()
# Check if the login request is coming from a known device
self.executor.plan.context.setdefault(PLAN_CONTEXT_METHOD_ARGS, {})
self.executor.plan.context[PLAN_CONTEXT_METHOD_ARGS].setdefault(
PLAN_CONTEXT_METHOD_ARGS_KNOWN_DEVICE, self.is_known_device(user)
)
# the `user_logged_in` signal will update the user to write the `last_login` field
# which we don't want to log as we already have a dedicated login event
with audit_ignore():
@@ -179,6 +112,4 @@ class UserLoginStageView(ChallengeStageView):
Session.objects.filter(
authenticatedsession__user=user,
).exclude(session_key=self.request.session.session_key).delete()
if remember is None:
return self.set_known_device_cookie(user)
return self.executor.stage_ok()

View File

@@ -8,18 +8,17 @@ from django.urls import reverse
from django.utils.timezone import now
from authentik.core.models import AuthenticatedSession, Session
from authentik.core.tests.utils import create_test_flow, create_test_user
from authentik.core.tests.utils import create_test_admin_user, create_test_flow
from authentik.flows.markers import StageMarker
from authentik.flows.models import FlowDesignation, FlowStageBinding
from authentik.flows.planner import PLAN_CONTEXT_PENDING_USER, FlowPlan
from authentik.flows.tests import FlowTestCase
from authentik.flows.tests.test_executor import TO_STAGE_RESPONSE_MOCK
from authentik.flows.views.executor import NEXT_ARG_NAME, SESSION_KEY_PLAN
from authentik.flows.views.executor import SESSION_KEY_PLAN
from authentik.lib.generators import generate_id
from authentik.lib.utils.time import timedelta_from_string
from authentik.root.middleware import ClientIPMiddleware
from authentik.stages.user_login.middleware import (
SESSION_KEY_BINDING_NET,
BoundSessionMiddleware,
SessionBindingBroken,
logout_extra,
@@ -32,7 +31,7 @@ class TestUserLoginStage(FlowTestCase):
def setUp(self):
super().setUp()
self.user = create_test_user()
self.user = create_test_admin_user()
self.flow = create_test_flow(FlowDesignation.AUTHENTICATION)
self.stage = UserLoginStage.objects.create(name="login")
@@ -248,21 +247,3 @@ class TestUserLoginStage(FlowTestCase):
request.session = self.client.session
request.user = self.user
logout_extra(request, cm.exception)
def test_session_binding_broken(self):
"""Test session binding"""
self.client.force_login(self.user)
session = self.client.session
session[Session.Keys.LAST_IP] = "192.0.2.1"
session[SESSION_KEY_BINDING_NET] = NetworkBinding.BIND_ASN_NETWORK_IP
session.save()
res = self.client.get(reverse("authentik_api:user-me"))
self.assertEqual(res.status_code, 302)
self.assertEqual(
res.url,
reverse(
"authentik_flows:default-authentication",
)
+ f"?{NEXT_ARG_NAME}={reverse("authentik_api:user-me")}",
)

View File

@@ -4754,7 +4754,6 @@
"add_token",
"change_token",
"delete_token",
"set_token_key",
"view_token",
"view_token_key"
]
@@ -4891,7 +4890,6 @@
"authentik_core.preview_user",
"authentik_core.remove_user_from_group",
"authentik_core.reset_user_password",
"authentik_core.set_token_key",
"authentik_core.unassign_user_permissions",
"authentik_core.view_application",
"authentik_core.view_applicationentitlement",
@@ -9538,7 +9536,6 @@
"authentik_core.preview_user",
"authentik_core.remove_user_from_group",
"authentik_core.reset_user_password",
"authentik_core.set_token_key",
"authentik_core.unassign_user_permissions",
"authentik_core.view_application",
"authentik_core.view_applicationentitlement",
@@ -10961,7 +10958,6 @@
"enum": [
"apple",
"openidconnect",
"entraid",
"azuread",
"discord",
"facebook",
@@ -15550,12 +15546,6 @@
],
"title": "Geoip binding",
"description": "Bind sessions created by this stage to the configured GeoIP location"
},
"remember_device": {
"type": "string",
"minLength": 1,
"title": "Remember device",
"description": "When set to a non-zero value, authentik will save a cookie with a longer expiry,to remember the device the user is logging in from. (Format: hours=-1;minutes=-2;seconds=-3)"
}
},
"required": []

17
go.mod
View File

@@ -6,18 +6,18 @@ require (
beryju.io/ldap v0.1.0
github.com/avast/retry-go/v4 v4.6.1
github.com/coreos/go-oidc/v3 v3.14.1
github.com/getsentry/sentry-go v0.34.1
github.com/getsentry/sentry-go v0.34.0
github.com/go-http-utils/etag v0.0.0-20161124023236-513ea8f21eb1
github.com/go-ldap/ldap/v3 v3.4.11
github.com/go-openapi/runtime v0.28.0
github.com/golang-jwt/jwt/v5 v5.2.3
github.com/golang-jwt/jwt/v5 v5.2.2
github.com/google/uuid v1.6.0
github.com/gorilla/handlers v1.5.2
github.com/gorilla/mux v1.8.1
github.com/gorilla/securecookie v1.1.2
github.com/gorilla/sessions v1.4.0
github.com/gorilla/websocket v1.5.3
github.com/grafana/pyroscope-go v1.2.4
github.com/grafana/pyroscope-go v1.2.2
github.com/jellydator/ttlcache/v3 v3.4.0
github.com/mitchellh/mapstructure v1.5.0
github.com/nmcclain/asn1-ber v0.0.0-20170104154839-2661553a0484
@@ -29,10 +29,10 @@ require (
github.com/spf13/cobra v1.9.1
github.com/stretchr/testify v1.10.0
github.com/wwt/guac v1.3.2
goauthentik.io/api/v3 v3.2025063.6
goauthentik.io/api/v3 v3.2025063.1
golang.org/x/exp v0.0.0-20230210204819-062eb4c674ab
golang.org/x/oauth2 v0.30.0
golang.org/x/sync v0.16.0
golang.org/x/sync v0.15.0
gopkg.in/yaml.v2 v2.4.0
layeh.com/radius v0.0.0-20210819152912-ad72663a72ab
)
@@ -77,10 +77,9 @@ require (
go.opentelemetry.io/otel v1.24.0 // indirect
go.opentelemetry.io/otel/metric v1.24.0 // indirect
go.opentelemetry.io/otel/trace v1.24.0 // indirect
golang.org/x/crypto v0.38.0 // indirect
golang.org/x/net v0.40.0 // indirect
golang.org/x/sys v0.33.0 // indirect
golang.org/x/text v0.25.0 // indirect
golang.org/x/crypto v0.36.0 // indirect
golang.org/x/sys v0.31.0 // indirect
golang.org/x/text v0.24.0 // indirect
google.golang.org/protobuf v1.36.5 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)

36
go.sum
View File

@@ -71,8 +71,8 @@ github.com/envoyproxy/go-control-plane v0.9.4/go.mod h1:6rpuAdCZL397s3pYoYcLgu1m
github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c=
github.com/felixge/httpsnoop v1.0.3 h1:s/nj+GCswXYzN5v2DpNMuMQYe+0DDwt5WVCU6CWBdXk=
github.com/felixge/httpsnoop v1.0.3/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U=
github.com/getsentry/sentry-go v0.34.1 h1:HSjc1C/OsnZttohEPrrqKH42Iud0HuLCXpv8cU1pWcw=
github.com/getsentry/sentry-go v0.34.1/go.mod h1:C55omcY9ChRQIUcVcGcs+Zdy4ZpQGvNJ7JYHIoSWOtE=
github.com/getsentry/sentry-go v0.34.0 h1:1FCHBVp8TfSc8L10zqSwXUZNiOSF+10qw4czjarTiY4=
github.com/getsentry/sentry-go v0.34.0/go.mod h1:C55omcY9ChRQIUcVcGcs+Zdy4ZpQGvNJ7JYHIoSWOtE=
github.com/go-asn1-ber/asn1-ber v1.5.8-0.20250403174932-29230038a667 h1:BP4M0CvQ4S3TGls2FvczZtj5Re/2ZzkV9VwqPHH/3Bo=
github.com/go-asn1-ber/asn1-ber v1.5.8-0.20250403174932-29230038a667/go.mod h1:hEBeB/ic+5LoWskz+yKT7vGhhPYkProFKoKdwZRWMe0=
github.com/go-errors/errors v1.4.2 h1:J6MZopCL4uSllY1OfXM374weqZFFItUbrImctkmUxIA=
@@ -115,8 +115,8 @@ github.com/go-openapi/swag v0.23.0 h1:vsEVJDUo2hPJ2tu0/Xc+4noaxyEffXNIs3cOULZ+Gr
github.com/go-openapi/swag v0.23.0/go.mod h1:esZ8ITTYEsH1V2trKHjAN8Ai7xHb8RV+YSZ577vPjgQ=
github.com/go-openapi/validate v0.24.0 h1:LdfDKwNbpB6Vn40xhTdNZAnfLECL81w+VX3BumrGD58=
github.com/go-openapi/validate v0.24.0/go.mod h1:iyeX1sEufmv3nPbBdX3ieNviWnOZaJ1+zquzJEf2BAQ=
github.com/golang-jwt/jwt/v5 v5.2.3 h1:kkGXqQOBSDDWRhWNXTFpqGSCMyh/PLnqUvMGJPDJDs0=
github.com/golang-jwt/jwt/v5 v5.2.3/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk=
github.com/golang-jwt/jwt/v5 v5.2.2 h1:Rl4B7itRWVtYIHFrSNd7vhTiz9UpLdi6gZhZ3wEeDy8=
github.com/golang-jwt/jwt/v5 v5.2.2/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk=
github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q=
github.com/golang/groupcache v0.0.0-20190702054246-869f871628b6/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc=
github.com/golang/groupcache v0.0.0-20191227052852-215e87163ea7/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc=
@@ -180,8 +180,8 @@ github.com/gorilla/sessions v1.4.0/go.mod h1:FLWm50oby91+hl7p/wRxDth9bWSuk0qVL2e
github.com/gorilla/websocket v1.4.1/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg=
github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
github.com/grafana/pyroscope-go v1.2.4 h1:B22GMXz+O0nWLatxLuaP7o7L9dvP0clLvIpmeEQQM0Q=
github.com/grafana/pyroscope-go v1.2.4/go.mod h1:zzT9QXQAp2Iz2ZdS216UiV8y9uXJYQiGE1q8v1FyhqU=
github.com/grafana/pyroscope-go v1.2.2 h1:uvKCyZMD724RkaCEMrSTC38Yn7AnFe8S2wiAIYdDPCE=
github.com/grafana/pyroscope-go v1.2.2/go.mod h1:zzT9QXQAp2Iz2ZdS216UiV8y9uXJYQiGE1q8v1FyhqU=
github.com/grafana/pyroscope-go/godeltaprof v0.1.8 h1:iwOtYXeeVSAeYefJNaxDytgjKtUuKQbJqgAIjlnicKg=
github.com/grafana/pyroscope-go/godeltaprof v0.1.8/go.mod h1:2+l7K7twW49Ct4wFluZD3tZ6e0SjanjcUUBPVD/UuGU=
github.com/hashicorp/go-uuid v1.0.3 h1:2gKiV6YVmrJ1i2CKKa9obLvRieoRGviZFL26PcT/Co8=
@@ -298,16 +298,16 @@ go.opentelemetry.io/otel/trace v1.24.0 h1:CsKnnL4dUAr/0llH9FKuc698G04IrpWV0MQA/Y
go.opentelemetry.io/otel/trace v1.24.0/go.mod h1:HPc3Xr/cOApsBI154IU0OI0HJexz+aw5uPdbs3UCjNU=
go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto=
go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE=
goauthentik.io/api/v3 v3.2025063.6 h1:TFMnE0bXiWZ5oVYrnxDLpS+pnGNv+KIjLmZHT5qzpcM=
goauthentik.io/api/v3 v3.2025063.6/go.mod h1:zz+mEZg8rY/7eEjkMGWJ2DnGqk+zqxuybGCGrR2O4Kw=
goauthentik.io/api/v3 v3.2025063.1 h1:zvKhZTESgMY/SNiLuTs7G0YleBnev1v7+S9Xd6PZ9bc=
goauthentik.io/api/v3 v3.2025063.1/go.mod h1:zz+mEZg8rY/7eEjkMGWJ2DnGqk+zqxuybGCGrR2O4Kw=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/crypto v0.0.0-20190605123033-f99c8df09eb5/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
golang.org/x/crypto v0.0.0-20200709230013-948cd5f35899/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
golang.org/x/crypto v0.38.0 h1:jt+WWG8IZlBnVbomuhg2Mdq0+BBQaHbtqHEFEigjUV8=
golang.org/x/crypto v0.38.0/go.mod h1:MvrbAqul58NNYPKnOra203SB9vpuZW0e+RRZV+Ggqjw=
golang.org/x/crypto v0.36.0 h1:AnAEvhDddvBdpY+uR+MyHmuZzzNqXSe/GvuDeob5L34=
golang.org/x/crypto v0.36.0/go.mod h1:Y4J0ReaxCR1IMaabaSMugxJES1EpwhBHhv2bDHklZvc=
golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/exp v0.0.0-20190510132918-efd6b22b2522/go.mod h1:ZjyILWgesfNpC6sMxTJOJm9Kp84zZh5NQWvqDGG3Qr8=
@@ -366,8 +366,8 @@ golang.org/x/net v0.0.0-20200520182314-0ba52f642ac2/go.mod h1:qpuaurCH72eLCgpAm/
golang.org/x/net v0.0.0-20200625001655-4c5254603344/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA=
golang.org/x/net v0.0.0-20200707034311-ab3426394381/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA=
golang.org/x/net v0.0.0-20200822124328-c89045814202/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA=
golang.org/x/net v0.40.0 h1:79Xs7wF06Gbdcg4kdCCIQArK11Z1hr5POQ6+fIYHNuY=
golang.org/x/net v0.40.0/go.mod h1:y0hY0exeL2Pku80/zKK7tpntoX23cqL3Oa6njdgRtds=
golang.org/x/net v0.39.0 h1:ZCu7HMWDxpXpaiKdhzIfaltL9Lp31x/3fCP11bc6/fY=
golang.org/x/net v0.39.0/go.mod h1:X7NRbYVEA+ewNkCNyJ513WmMdQ3BineSwVtN2zD/d+E=
golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
@@ -384,8 +384,8 @@ golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJ
golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20200317015054-43a5402ce75a/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20200625203802-6e8e738ad208/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.16.0 h1:ycBJEhp9p4vXvUZNszeOq0kGTPghopOL8q0fq3vstxw=
golang.org/x/sync v0.16.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA=
golang.org/x/sync v0.15.0 h1:KWH3jNZsfyT6xfAfKiz6MRNmd46ByHDYaZ7KSkCtdW8=
golang.org/x/sync v0.15.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA=
golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190312061237-fead79001313/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
@@ -413,15 +413,15 @@ golang.org/x/sys v0.0.0-20200515095857-1151b9dac4a9/go.mod h1:h1NjWce9XRLGQEsW7w
golang.org/x/sys v0.0.0-20200523222454-059865788121/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200803210538-64077c9b5642/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.33.0 h1:q3i8TbbEz+JRD9ywIRlyRAQbM0qF7hu24q3teo2hbuw=
golang.org/x/sys v0.33.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
golang.org/x/sys v0.31.0 h1:ioabZlmFYtWhL+TRYpcnNlLwhyxaM9kWTDEmfnprqik=
golang.org/x/sys v0.31.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.25.0 h1:qVyWApTSYLk/drJRO5mDlNYskwQznZmkpV2c8q9zls4=
golang.org/x/text v0.25.0/go.mod h1:WEdwpYrmk1qmdHvhkSTNPm3app7v4rsT8F2UD6+VHIA=
golang.org/x/text v0.24.0 h1:dd5Bzh4yt5KYA8f9CJHCP4FB4D51c2c6JvN37xJJkJ0=
golang.org/x/text v0.24.0/go.mod h1:L8rBsPeo2pSS+xqN0d5u2ikmjtmoJbDBT1b7nHvFCdU=
golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=

View File

@@ -28,7 +28,7 @@ func (fe *FlowExecutor) solveChallenge_AuthenticatorValidate(challenge *api.Chal
var deviceChallenge *api.DeviceChallenge
inner := api.NewAuthenticatorValidationChallengeResponseRequest()
for _, devCh := range challenge.AuthenticatorValidationChallenge.DeviceChallenges {
if devCh.DeviceClass == api.DEVICECLASSESENUM_DUO {
if devCh.DeviceClass == string(api.DEVICECLASSESENUM_DUO) {
deviceChallenge = &devCh
devId, err := strconv.ParseInt(deviceChallenge.DeviceUid, 10, 32)
if err != nil {
@@ -38,8 +38,8 @@ func (fe *FlowExecutor) solveChallenge_AuthenticatorValidate(challenge *api.Chal
inner.SelectedChallenge = (*api.DeviceChallengeRequest)(deviceChallenge)
inner.Duo = &devId32
}
if devCh.DeviceClass == api.DEVICECLASSESENUM_STATIC ||
devCh.DeviceClass == api.DEVICECLASSESENUM_TOTP {
if devCh.DeviceClass == string(api.DEVICECLASSESENUM_STATIC) ||
devCh.DeviceClass == string(api.DEVICECLASSESENUM_TOTP) {
// Only use code-based devices if we have a code in the entered password,
// and we haven't selected a push device yet
if deviceChallenge == nil && fe.getAnswer(StageAuthenticatorValidate) != "" {

View File

@@ -2,7 +2,6 @@ package radius
import (
"encoding/base64"
"errors"
"github.com/prometheus/client_golang/prometheus"
log "github.com/sirupsen/logrus"
@@ -12,7 +11,9 @@ import (
"layeh.com/radius/rfc2865"
)
func (rs *RadiusServer) Handle_AccessRequest_PAP_Auth(r *RadiusRequest, username, password string) (*radius.Packet, error) {
func (rs *RadiusServer) Handle_AccessRequest(w radius.ResponseWriter, r *RadiusRequest) {
username := rfc2865.UserName_GetString(r.Packet)
fe := flow.NewFlowExecutor(r.Context(), r.pi.flowSlug, r.pi.s.ac.Client.GetConfig(), log.Fields{
"username": username,
"client": r.RemoteAddr(),
@@ -22,64 +23,67 @@ func (rs *RadiusServer) Handle_AccessRequest_PAP_Auth(r *RadiusRequest, username
fe.Params.Add("goauthentik.io/outpost/radius", "true")
fe.Answers[flow.StageIdentification] = username
fe.SetSecrets(password, r.pi.MFASupport)
fe.SetSecrets(rfc2865.UserPassword_GetString(r.Packet), r.pi.MFASupport)
passed, err := fe.Execute()
if err != nil {
r.Log().WithField("username", username).WithError(err).Warning("failed to execute flow")
return nil, errors.New("flow_error")
metrics.RequestsRejected.With(prometheus.Labels{
"outpost_name": rs.ac.Outpost.Name,
"reason": "flow_error",
"app": r.pi.appSlug,
}).Inc()
_ = w.Write(r.Response(radius.CodeAccessReject))
return
}
if !passed {
return nil, errors.New("invalid_credentials")
metrics.RequestsRejected.With(prometheus.Labels{
"outpost_name": rs.ac.Outpost.Name,
"reason": "invalid_credentials",
"app": r.pi.appSlug,
}).Inc()
_ = w.Write(r.Response(radius.CodeAccessReject))
return
}
access, _, err := fe.ApiClient().OutpostsApi.OutpostsRadiusAccessCheck(
r.Context(), r.pi.providerId,
).AppSlug(r.pi.appSlug).Execute()
if err != nil {
r.Log().WithField("username", username).WithError(err).Warning("failed to check access")
return nil, errors.New("access_check_fail")
_ = w.Write(r.Response(radius.CodeAccessReject))
metrics.RequestsRejected.With(prometheus.Labels{
"outpost_name": rs.ac.Outpost.Name,
"reason": "access_check_fail",
"app": r.pi.appSlug,
}).Inc()
return
}
if !access.Access.Passing {
r.Log().WithField("username", username).Info("Access denied for user")
return nil, errors.New("access_denied")
_ = w.Write(r.Response(radius.CodeAccessReject))
metrics.RequestsRejected.With(prometheus.Labels{
"outpost_name": rs.ac.Outpost.Name,
"reason": "access_denied",
"app": r.pi.appSlug,
}).Inc()
return
}
res := r.Response(radius.CodeAccessAccept)
defer func() { _ = w.Write(res) }()
if !access.HasAttributes() {
r.Log().Debug("No attributes")
return res, nil
return
}
rawData, err := base64.StdEncoding.DecodeString(access.GetAttributes())
if err != nil {
r.Log().WithError(err).Warning("failed to decode attributes from core")
return nil, errors.New("attribute_decode_failed")
return
}
p, err := radius.Parse(rawData, r.pi.SharedSecret)
if err != nil {
r.Log().WithError(err).Warning("failed to parse attributes from core")
return nil, errors.New("attribute_parse_failed")
}
for _, attr := range p.Attributes {
res.Add(attr.Type, attr.Attribute)
}
return res, nil
}
func (rs *RadiusServer) Handle_AccessRequest(w radius.ResponseWriter, r *RadiusRequest) {
username := rfc2865.UserName_GetString(r.Packet)
password := rfc2865.UserPassword_GetString(r.Packet)
res, err := rs.Handle_AccessRequest_PAP_Auth(r, username, password)
if err != nil {
metrics.RequestsRejected.With(prometheus.Labels{
"outpost_name": rs.ac.Outpost.Name,
"reason": err.Error(),
"app": r.pi.appSlug,
}).Inc()
_ = w.Write(r.Reject())
return
}
err = r.setMessageAuthenticator(res)
if err != nil {
rs.log.WithError(err).Warning("failed to set message authenticator")
}
_ = w.Write(res)
}

View File

@@ -1,12 +1,8 @@
package radius
import (
"bytes"
"crypto/hmac"
"crypto/md5"
"crypto/sha512"
"encoding/hex"
"errors"
"time"
"github.com/getsentry/sentry-go"
@@ -17,11 +13,6 @@ import (
"goauthentik.io/internal/outpost/radius/metrics"
"goauthentik.io/internal/utils"
"layeh.com/radius"
"layeh.com/radius/rfc2869"
)
var (
ErrInvalidMessageAuthenticator = errors.New("invalid message authenticator")
)
type RadiusRequest struct {
@@ -44,41 +35,6 @@ func (r *RadiusRequest) ID() string {
return r.id
}
func (r *RadiusRequest) validateMessageAuthenticator() error {
mauth := rfc2869.MessageAuthenticator_Get(r.Packet)
hash := hmac.New(md5.New, r.Secret)
encode, err := r.MarshalBinary()
if err != nil {
return err
}
hash.Write(encode)
if bytes.Equal(mauth, hash.Sum(nil)) {
return ErrInvalidMessageAuthenticator
}
return nil
}
func (r *RadiusRequest) setMessageAuthenticator(rp *radius.Packet) error {
_ = rfc2869.MessageAuthenticator_Set(rp, make([]byte, 16))
hash := hmac.New(md5.New, rp.Secret)
encode, err := rp.MarshalBinary()
if err != nil {
return err
}
hash.Write(encode)
_ = rfc2869.MessageAuthenticator_Set(rp, hash.Sum(nil))
return nil
}
func (r *RadiusRequest) Reject() *radius.Packet {
res := r.Response(radius.CodeAccessReject)
err := r.setMessageAuthenticator(res)
if err != nil {
r.log.WithError(err).Warning("failed to set message authenticator")
}
return res
}
func (rs *RadiusServer) ServeRADIUS(w radius.ResponseWriter, r *radius.Request) {
span := sentry.StartSpan(r.Context(), "authentik.providers.radius.connect",
sentry.WithTransactionName("authentik.providers.radius.connect"))
@@ -103,11 +59,6 @@ func (rs *RadiusServer) ServeRADIUS(w radius.ResponseWriter, r *radius.Request)
rl.Info("Radius Request")
if err := nr.validateMessageAuthenticator(); err != nil {
rl.WithError(err).Warning("Invalid message authenticator")
return
}
// Lookup provider by shared secret
var pi *ProviderInstance
for _, p := range rs.providers {
@@ -121,7 +72,7 @@ func (rs *RadiusServer) ServeRADIUS(w radius.ResponseWriter, r *radius.Request)
hs := sha512.Sum512([]byte(r.Secret))
bs := hex.EncodeToString(hs[:])
nr.Log().WithField("hashed_secret", bs).Warning("No provider found")
_ = w.Write(nr.Reject())
_ = w.Write(r.Response(radius.CodeAccessReject))
return
}
nr.pi = pi

View File

@@ -9,7 +9,7 @@
"version": "0.0.0",
"license": "MIT",
"devDependencies": {
"aws-cdk": "^2.1021.0",
"aws-cdk": "^2.1020.1",
"cross-env": "^7.0.3"
},
"engines": {
@@ -17,11 +17,10 @@
}
},
"node_modules/aws-cdk": {
"version": "2.1021.0",
"resolved": "https://registry.npmjs.org/aws-cdk/-/aws-cdk-2.1021.0.tgz",
"integrity": "sha512-kE557b4N9UFWax+7km3R6D56o4tGhpzOks/lRDugaoC8su3mocLCXJhb954b/IRl0ipnbZnY/Sftq+RQ/sxivg==",
"version": "2.1020.1",
"resolved": "https://registry.npmjs.org/aws-cdk/-/aws-cdk-2.1020.1.tgz",
"integrity": "sha512-4UG9qzf6ZSDjINubcukPZChVj6PvDJAHiURAw0jYSkUhObPkX7Zo9uNUIlXzrM+hpB2N2jwRKY9b3sN+KDbtAQ==",
"dev": true,
"license": "Apache-2.0",
"bin": {
"cdk": "bin/cdk"
},

View File

@@ -10,7 +10,7 @@
"node": ">=20"
},
"devDependencies": {
"aws-cdk": "^2.1021.0",
"aws-cdk": "^2.1020.1",
"cross-env": "^7.0.3"
}
}

Binary file not shown.

Some files were not shown because too many files have changed in this diff Show More