mirror of
https://github.com/goauthentik/authentik
synced 2026-05-06 15:12:13 +02:00
Compare commits
2 Commits
remote_deb
...
pr-21647
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
5050a94f16 | ||
|
|
f5dd1b62ef |
2
.github/actions/setup/action.yml
vendored
2
.github/actions/setup/action.yml
vendored
@@ -64,7 +64,7 @@ runs:
|
||||
rustflags: ""
|
||||
- name: Setup rust dependencies
|
||||
if: ${{ contains(inputs.dependencies, 'rust') }}
|
||||
uses: taiki-e/install-action@cf525cb33f51aca27cd6fa02034117ab963ff9f1 # v2
|
||||
uses: taiki-e/install-action@5f57d6cb7cd20b14a8a27f522884c4bc8a187458 # v2
|
||||
with:
|
||||
tool: cargo-deny cargo-machete cargo-llvm-cov nextest
|
||||
- name: Setup node (web)
|
||||
|
||||
2
.github/workflows/_reusable-docker-build.yml
vendored
2
.github/workflows/_reusable-docker-build.yml
vendored
@@ -90,7 +90,7 @@ jobs:
|
||||
registry: ghcr.io
|
||||
username: ${{ github.repository_owner }}
|
||||
password: ${{ secrets.GITHUB_TOKEN }}
|
||||
- uses: int128/docker-manifest-create-action@7df7f9e221d927eaadf87db231ddf728047308a4 # v2
|
||||
- uses: int128/docker-manifest-create-action@3de37de96c4e900bc3eef9055d3c50abdb4f769d # v2
|
||||
id: build
|
||||
with:
|
||||
tags: ${{ matrix.tag }}
|
||||
|
||||
@@ -14,7 +14,6 @@ pyproject.toml @goauthentik/backend
|
||||
uv.lock @goauthentik/backend
|
||||
Cargo.toml @goauthentik/backend
|
||||
Cargo.lock @goauthentik/backend
|
||||
build.rs @goauthentik/backend
|
||||
go.mod @goauthentik/backend
|
||||
go.sum @goauthentik/backend
|
||||
.cargo/ @goauthentik/backend
|
||||
|
||||
10
Cargo.lock
generated
10
Cargo.lock
generated
@@ -198,7 +198,6 @@ dependencies = [
|
||||
"metrics-exporter-prometheus",
|
||||
"nix 0.31.2",
|
||||
"pyo3",
|
||||
"pyo3-build-config",
|
||||
"sqlx",
|
||||
"tokio",
|
||||
"tracing",
|
||||
@@ -217,7 +216,6 @@ dependencies = [
|
||||
"eyre",
|
||||
"forwarded-header-value",
|
||||
"futures",
|
||||
"pin-project-lite",
|
||||
"tokio",
|
||||
"tokio-rustls",
|
||||
"tower",
|
||||
@@ -1507,9 +1505,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "hyper-unix-socket"
|
||||
version = "0.6.1"
|
||||
version = "0.3.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "88978f1d73da0eb87d86555fcc40cbdd87bc86eb6525710b89db8c9278ec6a59"
|
||||
checksum = "c255628da188a9d9ee218bae99da33a4b684ed63abe140a94d0f6e4b5af9a090"
|
||||
dependencies = [
|
||||
"bytes",
|
||||
"hyper",
|
||||
@@ -3000,9 +2998,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "rustls"
|
||||
version = "0.23.40"
|
||||
version = "0.23.39"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ef86cd5876211988985292b91c96a8f2d298df24e75989a43a3c73f2d4d8168b"
|
||||
checksum = "7c2c118cb077cca2822033836dfb1b975355dfb784b5e8da48f7b6c5db74e60e"
|
||||
dependencies = [
|
||||
"aws-lc-rs",
|
||||
"log",
|
||||
|
||||
@@ -39,7 +39,7 @@ eyre = "= 0.6.12"
|
||||
forwarded-header-value = "= 0.1.1"
|
||||
futures = "= 0.3.32"
|
||||
glob = "= 0.3.3"
|
||||
hyper-unix-socket = "= 0.6.1"
|
||||
hyper-unix-socket = "= 0.3.0"
|
||||
hyper-util = "= 0.1.20"
|
||||
ipnet = { version = "= 2.12.0", features = ["serde"] }
|
||||
json-subscriber = "= 0.2.8"
|
||||
@@ -49,7 +49,6 @@ nix = { version = "= 0.31.2", features = ["hostname", "signal"] }
|
||||
notify = "= 8.2.0"
|
||||
pin-project-lite = "= 0.2.17"
|
||||
pyo3 = "= 0.28.3"
|
||||
pyo3-build-config = "= 0.28.3"
|
||||
regex = "= 1.12.3"
|
||||
reqwest = { version = "= 0.13.2", features = [
|
||||
"form",
|
||||
@@ -66,7 +65,7 @@ reqwest-middleware = { version = "= 0.5.1", features = [
|
||||
"query",
|
||||
"rustls",
|
||||
] }
|
||||
rustls = { version = "= 0.23.40", features = ["fips"] }
|
||||
rustls = { version = "= 0.23.39", features = ["fips"] }
|
||||
sentry = { version = "= 0.47.0", default-features = false, features = [
|
||||
"backtrace",
|
||||
"contexts",
|
||||
@@ -261,9 +260,6 @@ default = ["core", "proxy"]
|
||||
core = ["ak-common/core", "dep:pyo3", "dep:sqlx"]
|
||||
proxy = ["ak-common/proxy"]
|
||||
|
||||
[build-dependencies]
|
||||
pyo3-build-config.workspace = true
|
||||
|
||||
[dependencies]
|
||||
ak-axum.workspace = true
|
||||
ak-common.workspace = true
|
||||
|
||||
3
Makefile
3
Makefile
@@ -118,9 +118,6 @@ run-worker: ## Run the main authentik worker process
|
||||
run-worker-watch: ## Run the authentik worker, with auto reloading
|
||||
watchexec --on-busy-update=restart --stop-signal=SIGINT --exts py,rs --no-meta --notify -- $(UV) run ak worker
|
||||
|
||||
debug-attach: ## Attach pdb to a running authentik Python worker (PEP 768). PID=<pid> to pick; SUDO=1 on macOS.
|
||||
$(UV) run python scripts/debug_attach.py
|
||||
|
||||
core-i18n-extract:
|
||||
$(UV) run ak makemessages \
|
||||
--add-location file \
|
||||
|
||||
@@ -5,7 +5,6 @@ from django.utils.translation import gettext_lazy as _
|
||||
|
||||
GRANT_TYPE_AUTHORIZATION_CODE = "authorization_code"
|
||||
GRANT_TYPE_IMPLICIT = "implicit"
|
||||
GRANT_TYPE_HYBRID = "hybrid"
|
||||
GRANT_TYPE_REFRESH_TOKEN = "refresh_token" # nosec
|
||||
GRANT_TYPE_CLIENT_CREDENTIALS = "client_credentials"
|
||||
GRANT_TYPE_PASSWORD = "password" # nosec
|
||||
|
||||
@@ -30,8 +30,6 @@ SAML_BINDING_REDIRECT = "urn:oasis:names:tc:SAML:2.0:bindings:HTTP-Redirect"
|
||||
|
||||
SAML_STATUS_SUCCESS = "urn:oasis:names:tc:SAML:2.0:status:Success"
|
||||
|
||||
DEFAULT_ISSUER = "authentik"
|
||||
|
||||
DSA_SHA1 = "http://www.w3.org/2000/09/xmldsig#dsa-sha1"
|
||||
RSA_SHA1 = "http://www.w3.org/2000/09/xmldsig#rsa-sha1"
|
||||
# https://datatracker.ietf.org/doc/html/rfc4051#section-2.3.2
|
||||
|
||||
@@ -4,7 +4,7 @@ from collections.abc import Iterator
|
||||
from copy import copy
|
||||
|
||||
from django.core.cache import cache
|
||||
from django.db.models import Case, QuerySet
|
||||
from django.db.models import Case, Q, QuerySet
|
||||
from django.db.models.expressions import When
|
||||
from django.shortcuts import get_object_or_404
|
||||
from django.utils.translation import gettext as _
|
||||
@@ -120,7 +120,6 @@ class ApplicationSerializer(ModelSerializer):
|
||||
"meta_publisher",
|
||||
"policy_engine_mode",
|
||||
"group",
|
||||
"meta_hide",
|
||||
]
|
||||
extra_kwargs = {
|
||||
"backchannel_providers": {"required": False},
|
||||
@@ -284,12 +283,14 @@ class ApplicationViewSet(UsedByMixin, ModelViewSet):
|
||||
) == "true"
|
||||
|
||||
queryset = self._filter_queryset_for_list(self.get_queryset())
|
||||
queryset = queryset.exclude(meta_hide=True)
|
||||
if only_with_launch_url:
|
||||
# Pre-filter at DB level to skip expensive per-app policy evaluation
|
||||
# for apps that can never appear in the launcher (no meta_launch_url
|
||||
# and no provider, so no possible launch URL).
|
||||
queryset = queryset.exclude(meta_launch_url="", provider__isnull=True)
|
||||
# for apps that can never appear in the launcher:
|
||||
# - No meta_launch_url AND no provider: no possible launch URL
|
||||
# - meta_launch_url="blank://blank": documented convention to hide from launcher
|
||||
queryset = queryset.exclude(
|
||||
Q(meta_launch_url="", provider__isnull=True) | Q(meta_launch_url="blank://blank")
|
||||
)
|
||||
paginator: Pagination = self.paginator
|
||||
paginated_apps = paginator.paginate_queryset(queryset, request)
|
||||
|
||||
|
||||
@@ -19,7 +19,7 @@ from rest_framework.authentication import SessionAuthentication
|
||||
from rest_framework.decorators import action
|
||||
from rest_framework.fields import CharField, IntegerField, SerializerMethodField
|
||||
from rest_framework.permissions import IsAuthenticated
|
||||
from rest_framework.relations import ManyRelatedField, PrimaryKeyRelatedField
|
||||
from rest_framework.relations import PrimaryKeyRelatedField
|
||||
from rest_framework.request import Request
|
||||
from rest_framework.response import Response
|
||||
from rest_framework.serializers import ListSerializer, ValidationError
|
||||
@@ -37,77 +37,6 @@ from authentik.endpoints.connectors.agent.auth import AgentAuth
|
||||
from authentik.rbac.api.roles import RoleSerializer
|
||||
from authentik.rbac.decorators import permission_required
|
||||
|
||||
|
||||
class BulkManyRelatedField(ManyRelatedField):
|
||||
"""ManyRelatedField that validates all PKs in a single query instead of one per PK."""
|
||||
|
||||
def to_internal_value(self, data):
|
||||
if isinstance(data, str) or not hasattr(data, "__iter__"):
|
||||
self.fail("not_a_list", input_type=type(data).__name__)
|
||||
if not self.allow_empty and len(data) == 0:
|
||||
self.fail("empty")
|
||||
|
||||
child = self.child_relation
|
||||
pk_field = child.pk_field
|
||||
# Coerce PKs through pk_field if defined
|
||||
pk_map = {}
|
||||
for item in data:
|
||||
if isinstance(item, bool):
|
||||
self.fail("incorrect_type", data_type=type(item).__name__)
|
||||
pk = pk_field.to_internal_value(item) if pk_field else item
|
||||
pk_map[pk] = item # map coerced PK -> original value for error reporting
|
||||
|
||||
queryset = child.get_queryset()
|
||||
# Use count to validate all PKs exist in a single query
|
||||
found_count = queryset.filter(pk__in=pk_map.keys()).count()
|
||||
if found_count < len(pk_map):
|
||||
# Some PKs not found — fall back to per-PK checks for error reporting.
|
||||
# This only runs when there's an actual validation error (rare path).
|
||||
for pk, original in pk_map.items():
|
||||
if not queryset.filter(pk=pk).exists():
|
||||
child.fail("does_not_exist", pk_value=original)
|
||||
|
||||
# Return raw PKs — Django's M2M set() accepts both objects and PKs,
|
||||
# using get_prep_value() for type coercion. This avoids loading all
|
||||
# objects into memory and avoids triggering post_init signals.
|
||||
return list(pk_map.keys())
|
||||
|
||||
def to_representation(self, iterable):
|
||||
# For non-prefetched querysets, get PKs directly without loading model instances.
|
||||
# When prefetched, _result_cache is a list (possibly empty); when not, it's None.
|
||||
if hasattr(iterable, "values_list") and getattr(iterable, "_result_cache", None) is None:
|
||||
return list(iterable.values_list("pk", flat=True))
|
||||
return super().to_representation(iterable)
|
||||
|
||||
|
||||
class BulkPrimaryKeyRelatedField(PrimaryKeyRelatedField):
|
||||
"""PrimaryKeyRelatedField that uses bulk validation when many=True."""
|
||||
|
||||
@classmethod
|
||||
def many_init(cls, *args, **kwargs):
|
||||
allow_empty = kwargs.pop("allow_empty", None)
|
||||
max_length = kwargs.pop("max_length", None)
|
||||
min_length = kwargs.pop("min_length", None)
|
||||
child_relation = cls(*args, **kwargs)
|
||||
list_kwargs = {
|
||||
"child_relation": child_relation,
|
||||
}
|
||||
if allow_empty is not None:
|
||||
list_kwargs["allow_empty"] = allow_empty
|
||||
if max_length is not None:
|
||||
list_kwargs["max_length"] = max_length
|
||||
if min_length is not None:
|
||||
list_kwargs["min_length"] = min_length
|
||||
list_kwargs.update(
|
||||
{
|
||||
key: value
|
||||
for key, value in kwargs.items()
|
||||
if key in ("required", "default", "source")
|
||||
}
|
||||
)
|
||||
return BulkManyRelatedField(**list_kwargs)
|
||||
|
||||
|
||||
PARTIAL_USER_SERIALIZER_MODEL_FIELDS = [
|
||||
"pk",
|
||||
"username",
|
||||
@@ -150,7 +79,6 @@ class GroupSerializer(ModelSerializer):
|
||||
"""Group Serializer"""
|
||||
|
||||
attributes = JSONDictField(required=False)
|
||||
users = BulkPrimaryKeyRelatedField(queryset=User.objects.all(), many=True, default=list)
|
||||
parents = PrimaryKeyRelatedField(queryset=Group.objects.all(), many=True, required=False)
|
||||
parents_obj = SerializerMethodField(allow_null=True)
|
||||
children_obj = SerializerMethodField(allow_null=True)
|
||||
@@ -265,6 +193,9 @@ class GroupSerializer(ModelSerializer):
|
||||
"children_obj",
|
||||
]
|
||||
extra_kwargs = {
|
||||
"users": {
|
||||
"default": list,
|
||||
},
|
||||
"children": {
|
||||
"required": False,
|
||||
"default": list,
|
||||
@@ -294,7 +225,6 @@ class GroupFilter(FilterSet):
|
||||
members_by_pk = ModelMultipleChoiceFilter(
|
||||
field_name="users",
|
||||
queryset=User.objects.all(),
|
||||
distinct=False,
|
||||
)
|
||||
|
||||
def filter_attributes(self, queryset, name, value):
|
||||
@@ -346,8 +276,7 @@ class GroupViewSet(UsedByMixin, ModelViewSet):
|
||||
]
|
||||
|
||||
def get_queryset(self):
|
||||
# Always prefetch parents and children since their PKs are always serialized
|
||||
base_qs = Group.objects.all().prefetch_related("roles", "parents", "children")
|
||||
base_qs = Group.objects.all().prefetch_related("roles")
|
||||
|
||||
if self.serializer_class(context={"request": self.request})._should_include_users:
|
||||
# Only fetch fields needed by PartialUserSerializer to reduce DB load and instantiation
|
||||
@@ -358,9 +287,16 @@ class GroupViewSet(UsedByMixin, ModelViewSet):
|
||||
queryset=User.objects.all().only(*PARTIAL_USER_SERIALIZER_MODEL_FIELDS),
|
||||
)
|
||||
)
|
||||
# When include_users=false, skip users prefetch entirely.
|
||||
# BulkManyRelatedField.to_representation will use values_list to get PKs
|
||||
# directly without loading User instances into memory.
|
||||
else:
|
||||
base_qs = base_qs.prefetch_related(
|
||||
Prefetch("users", queryset=User.objects.all().only("id"))
|
||||
)
|
||||
|
||||
if self.serializer_class(context={"request": self.request})._should_include_children:
|
||||
base_qs = base_qs.prefetch_related("children")
|
||||
|
||||
if self.serializer_class(context={"request": self.request})._should_include_parents:
|
||||
base_qs = base_qs.prefetch_related("parents")
|
||||
|
||||
return base_qs
|
||||
|
||||
|
||||
@@ -6,7 +6,6 @@ from typing import Any
|
||||
|
||||
from django.contrib.auth import update_session_auth_hash
|
||||
from django.contrib.auth.models import AnonymousUser, Permission
|
||||
from django.db.models import Exists, OuterRef, Prefetch, Q
|
||||
from django.db.transaction import atomic
|
||||
from django.db.utils import IntegrityError
|
||||
from django.urls import reverse_lazy
|
||||
@@ -14,7 +13,6 @@ from django.utils.http import urlencode
|
||||
from django.utils.text import slugify
|
||||
from django.utils.timezone import now
|
||||
from django.utils.translation import gettext as _
|
||||
from django.utils.translation import gettext_lazy
|
||||
from django_filters.filters import (
|
||||
BooleanFilter,
|
||||
CharFilter,
|
||||
@@ -107,10 +105,6 @@ from authentik.stages.email.utils import TemplateEmailMessage
|
||||
|
||||
LOGGER = get_logger()
|
||||
|
||||
INVALID_PASSWORD_HASH_MESSAGE = gettext_lazy(
|
||||
"Invalid password hash format. Must be a valid Django password hash."
|
||||
)
|
||||
|
||||
|
||||
class ParamUserSerializer(PassiveSerializer):
|
||||
"""Partial serializer for query parameters to select a user"""
|
||||
@@ -137,7 +131,7 @@ class PartialGroupSerializer(ModelSerializer):
|
||||
class UserSerializer(ModelSerializer):
|
||||
"""User Serializer"""
|
||||
|
||||
is_superuser = SerializerMethodField()
|
||||
is_superuser = BooleanField(read_only=True)
|
||||
avatar = SerializerMethodField()
|
||||
attributes = JSONDictField(required=False)
|
||||
groups = PrimaryKeyRelatedField(
|
||||
@@ -174,14 +168,6 @@ class UserSerializer(ModelSerializer):
|
||||
return True
|
||||
return str(request.query_params.get("include_roles", "true")).lower() == "true"
|
||||
|
||||
@extend_schema_field(BooleanField)
|
||||
def get_is_superuser(self, instance: User) -> bool:
|
||||
"""Use annotation if available to avoid N+1 query"""
|
||||
ann = getattr(instance, "_annotated_is_superuser", None)
|
||||
if ann is not None:
|
||||
return ann
|
||||
return instance.is_superuser
|
||||
|
||||
@extend_schema_field(PartialGroupSerializer(many=True))
|
||||
def get_groups_obj(self, instance: User) -> list[PartialGroupSerializer] | None:
|
||||
if not self._should_include_groups:
|
||||
@@ -195,79 +181,47 @@ class UserSerializer(ModelSerializer):
|
||||
return RoleSerializer(instance.roles, many=True).data
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
"""Setting password and permissions directly is allowed only in blueprints."""
|
||||
super().__init__(*args, **kwargs)
|
||||
if SERIALIZER_CONTEXT_BLUEPRINT in self.context:
|
||||
self.fields["password"] = CharField(required=False, allow_null=True)
|
||||
self.fields["password_hash"] = CharField(required=False, allow_null=True)
|
||||
self.fields["permissions"] = ListField(
|
||||
required=False,
|
||||
child=ChoiceField(choices=get_permission_choices()),
|
||||
)
|
||||
|
||||
def create(self, validated_data: dict) -> User:
|
||||
"""Create a user, with blueprint-only password and permission writes."""
|
||||
is_blueprint = SERIALIZER_CONTEXT_BLUEPRINT in self.context
|
||||
if is_blueprint:
|
||||
password = validated_data.pop("password", None)
|
||||
password_hash = validated_data.pop("password_hash", None)
|
||||
permissions = validated_data.pop("permissions", [])
|
||||
self._validate_password_inputs(password, password_hash)
|
||||
|
||||
"""If this serializer is used in the blueprint context, we allow for
|
||||
directly setting a password. However should be done via the `set_password`
|
||||
method instead of directly setting it like rest_framework."""
|
||||
password = validated_data.pop("password", None)
|
||||
perms_qs = Permission.objects.filter(
|
||||
codename__in=[x.split(".")[1] for x in validated_data.pop("permissions", [])]
|
||||
).values_list("content_type__app_label", "codename")
|
||||
perms_list = [f"{ct}.{name}" for ct, name in list(perms_qs)]
|
||||
instance: User = super().create(validated_data)
|
||||
if is_blueprint:
|
||||
self._set_password(instance, password, password_hash)
|
||||
perms_qs = Permission.objects.filter(
|
||||
codename__in=[permission.split(".")[1] for permission in permissions]
|
||||
).values_list("content_type__app_label", "codename")
|
||||
perms_list = [f"{ct}.{name}" for ct, name in perms_qs]
|
||||
instance.assign_perms_to_managed_role(perms_list)
|
||||
self._ensure_password_not_empty(instance)
|
||||
self._set_password(instance, password)
|
||||
instance.assign_perms_to_managed_role(perms_list)
|
||||
return instance
|
||||
|
||||
def update(self, instance: User, validated_data: dict) -> User:
|
||||
"""Update a user, with blueprint-only password and permission writes."""
|
||||
is_blueprint = SERIALIZER_CONTEXT_BLUEPRINT in self.context
|
||||
if is_blueprint:
|
||||
password = validated_data.pop("password", None)
|
||||
password_hash = validated_data.pop("password_hash", None)
|
||||
permissions = validated_data.pop("permissions", [])
|
||||
self._validate_password_inputs(password, password_hash)
|
||||
|
||||
"""Same as `create` above, set the password directly if we're in a blueprint
|
||||
context"""
|
||||
password = validated_data.pop("password", None)
|
||||
perms_qs = Permission.objects.filter(
|
||||
codename__in=[x.split(".")[1] for x in validated_data.pop("permissions", [])]
|
||||
).values_list("content_type__app_label", "codename")
|
||||
perms_list = [f"{ct}.{name}" for ct, name in list(perms_qs)]
|
||||
instance = super().update(instance, validated_data)
|
||||
if is_blueprint:
|
||||
self._set_password(instance, password, password_hash)
|
||||
perms_qs = Permission.objects.filter(
|
||||
codename__in=[permission.split(".")[1] for permission in permissions]
|
||||
).values_list("content_type__app_label", "codename")
|
||||
perms_list = [f"{ct}.{name}" for ct, name in perms_qs]
|
||||
instance.assign_perms_to_managed_role(perms_list)
|
||||
self._ensure_password_not_empty(instance)
|
||||
self._set_password(instance, password)
|
||||
instance.assign_perms_to_managed_role(perms_list)
|
||||
return instance
|
||||
|
||||
def _validate_password_inputs(self, password: str | None, password_hash: str | None):
|
||||
"""Validate mutually-exclusive password inputs before any model mutation."""
|
||||
if password is not None and password_hash is not None:
|
||||
raise ValidationError(_("Cannot set both password and password_hash. Use only one."))
|
||||
if password_hash is None:
|
||||
return
|
||||
try:
|
||||
User.validate_password_hash(password_hash)
|
||||
except ValueError as exc:
|
||||
LOGGER.warning("Failed to identify password hash format", exc_info=exc)
|
||||
raise ValidationError(INVALID_PASSWORD_HASH_MESSAGE) from exc
|
||||
|
||||
def _set_password(self, instance: User, password: str | None, password_hash: str | None = None):
|
||||
"""Set password from plain text or hash."""
|
||||
if password_hash is not None:
|
||||
instance.set_password_from_hash(password_hash)
|
||||
instance.save()
|
||||
elif password:
|
||||
def _set_password(self, instance: User, password: str | None):
|
||||
"""Set password of user if we're in a blueprint context, and if it's an empty
|
||||
string then use an unusable password"""
|
||||
if SERIALIZER_CONTEXT_BLUEPRINT in self.context and password:
|
||||
instance.set_password(password)
|
||||
instance.save()
|
||||
|
||||
def _ensure_password_not_empty(self, instance: User):
|
||||
"""Store an explicit unusable password instead of an empty password field."""
|
||||
if len(instance.password) == 0:
|
||||
instance.set_unusable_password()
|
||||
instance.save()
|
||||
@@ -436,12 +390,6 @@ class UserPasswordSetSerializer(PassiveSerializer):
|
||||
password = CharField(required=True)
|
||||
|
||||
|
||||
class UserPasswordHashSetSerializer(PassiveSerializer):
|
||||
"""Payload to set a users' password hash directly"""
|
||||
|
||||
password = CharField(required=True)
|
||||
|
||||
|
||||
class UserServiceAccountSerializer(PassiveSerializer):
|
||||
"""Payload to create a service account"""
|
||||
|
||||
@@ -593,30 +541,10 @@ class UserViewSet(
|
||||
|
||||
def get_queryset(self):
|
||||
base_qs = User.objects.all().exclude_anonymous()
|
||||
# Always prefetch groups since group PKs are always serialized.
|
||||
# Use full prefetch when include_groups=true (for groups_obj), ID-only otherwise.
|
||||
if self.serializer_class(context={"request": self.request})._should_include_groups:
|
||||
base_qs = base_qs.prefetch_related("groups")
|
||||
else:
|
||||
base_qs = base_qs.prefetch_related(
|
||||
Prefetch("groups", queryset=Group.objects.all().only("group_uuid"))
|
||||
)
|
||||
if self.serializer_class(context={"request": self.request})._should_include_roles:
|
||||
base_qs = base_qs.prefetch_related("roles")
|
||||
else:
|
||||
base_qs = base_qs.prefetch_related(
|
||||
Prefetch("roles", queryset=Role.objects.all().only("uuid"))
|
||||
)
|
||||
# Annotate is_superuser to avoid N+1 query per user
|
||||
base_qs = base_qs.annotate(
|
||||
_annotated_is_superuser=Exists(
|
||||
Group.objects.filter(
|
||||
is_superuser=True,
|
||||
).filter(
|
||||
Q(users=OuterRef("pk")) | Q(descendant_nodes__descendant__users=OuterRef("pk"))
|
||||
)
|
||||
)
|
||||
)
|
||||
return base_qs
|
||||
|
||||
@extend_schema(
|
||||
@@ -785,11 +713,6 @@ class UserViewSet(
|
||||
self.request.session.modified = True
|
||||
return Response(serializer.initial_data)
|
||||
|
||||
def _update_session_hash_after_password_change(self, request: Request, user: User):
|
||||
if user.pk == request.user.pk and SESSION_KEY_IMPERSONATE_USER not in self.request.session:
|
||||
LOGGER.debug("Updating session hash after password change")
|
||||
update_session_auth_hash(self.request, user)
|
||||
|
||||
@permission_required("authentik_core.reset_user_password")
|
||||
@extend_schema(
|
||||
request=UserPasswordSetSerializer,
|
||||
@@ -813,45 +736,9 @@ class UserViewSet(
|
||||
except (ValidationError, IntegrityError) as exc:
|
||||
LOGGER.debug("Failed to set password", exc=exc)
|
||||
return Response(status=400)
|
||||
self._update_session_hash_after_password_change(request, user)
|
||||
return Response(status=204)
|
||||
|
||||
@permission_required("authentik_core.reset_user_password")
|
||||
@extend_schema(
|
||||
request=UserPasswordHashSetSerializer,
|
||||
responses={
|
||||
204: OpenApiResponse(description="Successfully changed password"),
|
||||
400: OpenApiResponse(description="Bad request"),
|
||||
},
|
||||
)
|
||||
@action(
|
||||
detail=True,
|
||||
methods=["POST"],
|
||||
permission_classes=[IsAuthenticated],
|
||||
)
|
||||
@validate(UserPasswordHashSetSerializer)
|
||||
def set_password_hash(
|
||||
self, request: Request, pk: int, body: UserPasswordHashSetSerializer
|
||||
) -> Response:
|
||||
"""Set a user's password from a pre-hashed Django password value.
|
||||
|
||||
Submit the Django password hash in the shared ``password`` request field.
|
||||
|
||||
This updates authentik's local password verifier only. It does not attempt
|
||||
to propagate the password change to LDAP or Kerberos because no raw password
|
||||
is available from the request payload.
|
||||
"""
|
||||
user: User = self.get_object()
|
||||
try:
|
||||
user.set_password_from_hash(body.validated_data["password"], request=request)
|
||||
user.save()
|
||||
except ValueError as exc:
|
||||
LOGGER.debug("Failed to set password hash", exc=exc)
|
||||
return Response(data={"password": [INVALID_PASSWORD_HASH_MESSAGE]}, status=400)
|
||||
except (ValidationError, IntegrityError) as exc:
|
||||
LOGGER.debug("Failed to set password hash", exc=exc)
|
||||
return Response(status=400)
|
||||
self._update_session_hash_after_password_change(request, user)
|
||||
if user.pk == request.user.pk and SESSION_KEY_IMPERSONATE_USER not in self.request.session:
|
||||
LOGGER.debug("Updating session hash after password change")
|
||||
update_session_auth_hash(self.request, user)
|
||||
return Response(status=204)
|
||||
|
||||
@permission_required("authentik_core.reset_user_password")
|
||||
|
||||
@@ -1,28 +0,0 @@
|
||||
"""Hash password using Django's password hashers"""
|
||||
|
||||
from django.contrib.auth.hashers import make_password
|
||||
from django.core.management.base import BaseCommand, CommandError
|
||||
|
||||
|
||||
class Command(BaseCommand):
|
||||
"""Hash a password using Django's password hashers"""
|
||||
|
||||
help = "Hash a password for use with AUTHENTIK_BOOTSTRAP_PASSWORD_HASH"
|
||||
|
||||
def add_arguments(self, parser):
|
||||
parser.add_argument(
|
||||
"password",
|
||||
type=str,
|
||||
help="Password to hash",
|
||||
)
|
||||
|
||||
def handle(self, *args, **options):
|
||||
password = options["password"]
|
||||
|
||||
if not password:
|
||||
raise CommandError("Password cannot be empty")
|
||||
try:
|
||||
hashed = make_password(password)
|
||||
self.stdout.write(hashed)
|
||||
except ValueError as exc:
|
||||
raise CommandError(f"Error hashing password: {exc}") from exc
|
||||
@@ -1,33 +0,0 @@
|
||||
# Generated by Django 5.2.12 on 2026-04-09 18:04
|
||||
|
||||
from django.apps.registry import Apps
|
||||
from django.db import migrations, models
|
||||
from django.db.backends.base.schema import BaseDatabaseSchemaEditor
|
||||
|
||||
|
||||
def migrate_blank_launch_url(apps: Apps, schema_editor: BaseDatabaseSchemaEditor):
|
||||
db_alias = schema_editor.connection.alias
|
||||
Application = apps.get_model("authentik_core", "Application")
|
||||
|
||||
Application.objects.using(db_alias).filter(meta_launch_url="blank://blank").update(
|
||||
meta_hide=True, meta_launch_url=""
|
||||
)
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
|
||||
dependencies = [
|
||||
("authentik_core", "0058_setup"),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.AddField(
|
||||
model_name="application",
|
||||
name="meta_hide",
|
||||
field=models.BooleanField(
|
||||
default=False,
|
||||
help_text="Hide this application from the user's My applications page.",
|
||||
),
|
||||
),
|
||||
migrations.RunPython(migrate_blank_launch_url, migrations.RunPython.noop),
|
||||
]
|
||||
@@ -10,7 +10,7 @@ from uuid import uuid4
|
||||
|
||||
import pgtrigger
|
||||
from deepmerge import always_merger
|
||||
from django.contrib.auth.hashers import check_password, identify_hasher
|
||||
from django.contrib.auth.hashers import check_password
|
||||
from django.contrib.auth.models import AbstractUser, Permission
|
||||
from django.contrib.auth.models import UserManager as DjangoUserManager
|
||||
from django.contrib.sessions.base_session import AbstractBaseSession
|
||||
@@ -560,33 +560,6 @@ class User(SerializerModel, AttributesMixin, AbstractUser):
|
||||
self.password_change_date = now()
|
||||
return super().set_password(raw_password)
|
||||
|
||||
@staticmethod
|
||||
def validate_password_hash(password_hash: str):
|
||||
"""Validate that the value is a recognized Django password hash."""
|
||||
identify_hasher(password_hash) # Raises ValueError if invalid
|
||||
|
||||
def set_password_from_hash(self, password_hash: str, signal=True, sender=None, request=None):
|
||||
"""Set password directly from a pre-hashed value.
|
||||
|
||||
Unlike set_password(), this does not hash the input again. The provided value
|
||||
must already be a valid Django password hash, and it is stored directly on the
|
||||
user after validation.
|
||||
|
||||
Because no raw password is available, downstream password sync integrations
|
||||
such as LDAP and Kerberos cannot be updated from this code path.
|
||||
|
||||
Raises ValueError if the hash format is not recognized.
|
||||
"""
|
||||
self.validate_password_hash(password_hash)
|
||||
if self.pk and signal:
|
||||
from authentik.core.signals import password_hash_changed
|
||||
|
||||
if not sender:
|
||||
sender = self
|
||||
password_hash_changed.send(sender=sender, user=self, request=request)
|
||||
self.password = password_hash
|
||||
self.password_change_date = now()
|
||||
|
||||
def check_password(self, raw_password: str) -> bool:
|
||||
"""
|
||||
Return a boolean of whether the raw_password was correct. Handles
|
||||
@@ -762,9 +735,6 @@ class Application(SerializerModel, PolicyBindingModel):
|
||||
meta_icon = FileField(default="", blank=True)
|
||||
meta_description = models.TextField(default="", blank=True)
|
||||
meta_publisher = models.TextField(default="", blank=True)
|
||||
meta_hide = models.BooleanField(
|
||||
default=False, help_text=_("Hide this application from the user's My applications page.")
|
||||
)
|
||||
|
||||
objects = ApplicationQuerySet.as_manager()
|
||||
|
||||
|
||||
@@ -16,11 +16,7 @@ LOGGER = get_logger()
|
||||
|
||||
@receiver(post_startup)
|
||||
def post_startup_setup_bootstrap(sender, **_):
|
||||
if (
|
||||
not getenv("AUTHENTIK_BOOTSTRAP_PASSWORD")
|
||||
and not getenv("AUTHENTIK_BOOTSTRAP_PASSWORD_HASH")
|
||||
and not getenv("AUTHENTIK_BOOTSTRAP_TOKEN")
|
||||
):
|
||||
if not getenv("AUTHENTIK_BOOTSTRAP_PASSWORD") and not getenv("AUTHENTIK_BOOTSTRAP_TOKEN"):
|
||||
return
|
||||
LOGGER.info("Configuring authentik through bootstrap environment variables")
|
||||
content = BlueprintInstance(path=BOOTSTRAP_BLUEPRINT).retrieve()
|
||||
|
||||
@@ -24,8 +24,6 @@ from authentik.root.ws.consumer import build_device_group
|
||||
|
||||
# Arguments: user: User, password: str
|
||||
password_changed = Signal()
|
||||
# Arguments: user: User, request: HttpRequest | None
|
||||
password_hash_changed = Signal()
|
||||
# Arguments: credentials: dict[str, any], request: HttpRequest,
|
||||
# stage: Stage, context: dict[str, any]
|
||||
login_failed = Signal()
|
||||
|
||||
@@ -129,7 +129,6 @@ class TestApplicationsAPI(APITestCase):
|
||||
"meta_icon_url": None,
|
||||
"meta_icon_themed_urls": None,
|
||||
"meta_description": "",
|
||||
"meta_hide": False,
|
||||
"meta_publisher": "",
|
||||
"policy_engine_mode": "any",
|
||||
},
|
||||
@@ -188,14 +187,12 @@ class TestApplicationsAPI(APITestCase):
|
||||
"meta_icon_url": None,
|
||||
"meta_icon_themed_urls": None,
|
||||
"meta_description": "",
|
||||
"meta_hide": False,
|
||||
"meta_publisher": "",
|
||||
"policy_engine_mode": "any",
|
||||
},
|
||||
{
|
||||
"launch_url": None,
|
||||
"meta_description": "",
|
||||
"meta_hide": False,
|
||||
"meta_icon": "",
|
||||
"meta_icon_url": None,
|
||||
"meta_icon_themed_urls": None,
|
||||
|
||||
@@ -1,28 +0,0 @@
|
||||
"""Tests for hash_password management command."""
|
||||
|
||||
from io import StringIO
|
||||
|
||||
from django.contrib.auth.hashers import check_password
|
||||
from django.core.management import call_command
|
||||
from django.core.management.base import CommandError
|
||||
from django.test import TestCase
|
||||
|
||||
|
||||
class TestHashPasswordCommand(TestCase):
|
||||
"""Test hash_password management command."""
|
||||
|
||||
def test_hash_password(self):
|
||||
"""Test hashing a password."""
|
||||
out = StringIO()
|
||||
call_command("hash_password", "test123", stdout=out)
|
||||
hashed = out.getvalue().strip()
|
||||
|
||||
self.assertTrue(hashed.startswith("pbkdf2_sha256$"))
|
||||
self.assertTrue(check_password("test123", hashed))
|
||||
|
||||
def test_hash_password_empty_fails(self):
|
||||
"""Test that empty password raises error."""
|
||||
with self.assertRaises(CommandError) as ctx:
|
||||
call_command("hash_password", "")
|
||||
|
||||
self.assertIn("Password cannot be empty", str(ctx.exception))
|
||||
@@ -1,7 +1,6 @@
|
||||
from http import HTTPStatus
|
||||
from os import environ
|
||||
|
||||
from django.contrib.auth.hashers import make_password
|
||||
from django.urls import reverse
|
||||
|
||||
from authentik.blueprints.tests import apply_blueprint
|
||||
@@ -17,7 +16,6 @@ from authentik.tenants.flags import patch_flag
|
||||
class TestSetup(FlowTestCase):
|
||||
def tearDown(self):
|
||||
environ.pop("AUTHENTIK_BOOTSTRAP_PASSWORD", None)
|
||||
environ.pop("AUTHENTIK_BOOTSTRAP_PASSWORD_HASH", None)
|
||||
environ.pop("AUTHENTIK_BOOTSTRAP_TOKEN", None)
|
||||
|
||||
@patch_flag(Setup, True)
|
||||
@@ -156,19 +154,3 @@ class TestSetup(FlowTestCase):
|
||||
token = Token.objects.filter(identifier="authentik-bootstrap-token").first()
|
||||
self.assertEqual(token.intent, TokenIntents.INTENT_API)
|
||||
self.assertEqual(token.key, environ["AUTHENTIK_BOOTSTRAP_TOKEN"])
|
||||
|
||||
def test_setup_bootstrap_env_password_hash(self):
|
||||
"""Test setup with password hash env var"""
|
||||
User.objects.filter(username="akadmin").delete()
|
||||
Setup.set(False)
|
||||
|
||||
password = generate_id()
|
||||
password_hash = make_password(password)
|
||||
environ["AUTHENTIK_BOOTSTRAP_PASSWORD_HASH"] = password_hash
|
||||
pre_startup.send(sender=self)
|
||||
post_startup.send(sender=self)
|
||||
|
||||
self.assertTrue(Setup.get())
|
||||
user = User.objects.get(username="akadmin")
|
||||
self.assertEqual(user.password, password_hash)
|
||||
self.assertTrue(user.check_password(password))
|
||||
|
||||
@@ -1,15 +1,8 @@
|
||||
"""user tests"""
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
from django.contrib.auth.hashers import make_password
|
||||
from django.test.testcases import TestCase
|
||||
from rest_framework.exceptions import ValidationError
|
||||
|
||||
from authentik.blueprints.v1.importer import SERIALIZER_CONTEXT_BLUEPRINT
|
||||
from authentik.core.api.users import UserSerializer
|
||||
from authentik.core.models import User
|
||||
from authentik.core.signals import password_changed, password_hash_changed
|
||||
from authentik.events.models import Event
|
||||
from authentik.lib.generators import generate_id
|
||||
|
||||
@@ -40,99 +33,3 @@ class TestUsers(TestCase):
|
||||
self.assertEqual(Event.objects.count(), 1)
|
||||
user.ak_groups.all()
|
||||
self.assertEqual(Event.objects.count(), 1)
|
||||
|
||||
def test_set_password_from_hash_signal_skips_source_sync_receivers(self):
|
||||
"""Test hash password updates do not expose a raw password to sync receivers."""
|
||||
user = User.objects.create(
|
||||
username=generate_id(),
|
||||
attributes={"distinguishedName": "cn=test,ou=users,dc=example,dc=com"},
|
||||
)
|
||||
password_changed_captured = []
|
||||
password_hash_changed_captured = []
|
||||
dispatch_uid = generate_id()
|
||||
hash_dispatch_uid = generate_id()
|
||||
|
||||
def password_changed_receiver(sender, **kwargs):
|
||||
password_changed_captured.append(kwargs)
|
||||
|
||||
def password_hash_changed_receiver(sender, **kwargs):
|
||||
password_hash_changed_captured.append(kwargs)
|
||||
|
||||
password_changed.connect(password_changed_receiver, dispatch_uid=dispatch_uid)
|
||||
password_hash_changed.connect(
|
||||
password_hash_changed_receiver, dispatch_uid=hash_dispatch_uid
|
||||
)
|
||||
try:
|
||||
with (
|
||||
patch(
|
||||
"authentik.sources.ldap.signals.LDAPSource.objects.filter"
|
||||
) as ldap_sources_filter,
|
||||
patch(
|
||||
"authentik.sources.kerberos.signals."
|
||||
"UserKerberosSourceConnection.objects.select_related"
|
||||
) as kerberos_connections_select,
|
||||
):
|
||||
user.set_password_from_hash(make_password("new-password")) # nosec
|
||||
user.save()
|
||||
finally:
|
||||
password_changed.disconnect(dispatch_uid=dispatch_uid)
|
||||
password_hash_changed.disconnect(dispatch_uid=hash_dispatch_uid)
|
||||
|
||||
self.assertEqual(password_changed_captured, [])
|
||||
self.assertEqual(len(password_hash_changed_captured), 1)
|
||||
ldap_sources_filter.assert_not_called()
|
||||
kerberos_connections_select.assert_not_called()
|
||||
|
||||
|
||||
class TestUserSerializerPasswordHash(TestCase):
|
||||
"""Test UserSerializer password_hash support in blueprint context."""
|
||||
|
||||
def test_password_hash_sets_password_directly(self):
|
||||
"""Test a valid password hash is stored without re-hashing."""
|
||||
password = "test-password-123" # nosec
|
||||
password_hash = make_password(password)
|
||||
serializer = UserSerializer(
|
||||
data={
|
||||
"username": generate_id(),
|
||||
"name": "Test User",
|
||||
"password_hash": password_hash,
|
||||
},
|
||||
context={SERIALIZER_CONTEXT_BLUEPRINT: True},
|
||||
)
|
||||
|
||||
self.assertTrue(serializer.is_valid(), serializer.errors)
|
||||
user = serializer.save()
|
||||
|
||||
self.assertEqual(user.password, password_hash)
|
||||
self.assertTrue(user.check_password(password))
|
||||
self.assertIsNotNone(user.password_change_date)
|
||||
|
||||
def test_password_hash_rejects_invalid_format(self):
|
||||
"""Test invalid password hash values are rejected."""
|
||||
serializer = UserSerializer(
|
||||
data={
|
||||
"username": generate_id(),
|
||||
"name": "Test User",
|
||||
"password_hash": "not-a-valid-hash",
|
||||
},
|
||||
context={SERIALIZER_CONTEXT_BLUEPRINT: True},
|
||||
)
|
||||
|
||||
self.assertTrue(serializer.is_valid(), serializer.errors)
|
||||
with self.assertRaises(ValidationError) as ctx:
|
||||
serializer.save()
|
||||
|
||||
self.assertIn("Invalid password hash format", str(ctx.exception))
|
||||
|
||||
def test_password_hash_ignored_outside_blueprint_context(self):
|
||||
"""Test password_hash is not accepted by the regular serializer."""
|
||||
serializer = UserSerializer(
|
||||
data={
|
||||
"username": generate_id(),
|
||||
"name": "Test User",
|
||||
"password_hash": make_password("test"), # nosec
|
||||
}
|
||||
)
|
||||
|
||||
self.assertTrue(serializer.is_valid(), serializer.errors)
|
||||
self.assertNotIn("password_hash", serializer.validated_data)
|
||||
|
||||
@@ -3,7 +3,6 @@
|
||||
from datetime import datetime, timedelta
|
||||
from json import loads
|
||||
|
||||
from django.contrib.auth.hashers import make_password
|
||||
from django.urls.base import reverse
|
||||
from django.utils.timezone import now
|
||||
from rest_framework.test import APITestCase
|
||||
@@ -27,9 +26,6 @@ from authentik.flows.models import FlowAuthenticationRequirement, FlowDesignatio
|
||||
from authentik.lib.generators import generate_id, generate_key
|
||||
from authentik.stages.email.models import EmailStage
|
||||
|
||||
INVALID_PASSWORD_HASH = "not-a-valid-hash"
|
||||
INVALID_PASSWORD_HASH_ERROR = "Invalid password hash format. Must be a valid Django password hash."
|
||||
|
||||
|
||||
class TestUsersAPI(APITestCase):
|
||||
"""Test Users API"""
|
||||
@@ -38,20 +34,6 @@ class TestUsersAPI(APITestCase):
|
||||
self.admin = create_test_admin_user()
|
||||
self.user = create_test_user()
|
||||
|
||||
def _set_password_hash(self, user: User, password_hash: str, client=None):
|
||||
return (client or self.client).post(
|
||||
reverse("authentik_api:user-set-password-hash", kwargs={"pk": user.pk}),
|
||||
data={"password": password_hash},
|
||||
)
|
||||
|
||||
def _assert_password_hash_set(
|
||||
self, user: User, password: str, password_hash: str, response
|
||||
) -> None:
|
||||
self.assertEqual(response.status_code, 204, response.data)
|
||||
user.refresh_from_db()
|
||||
self.assertEqual(user.password, password_hash)
|
||||
self.assertTrue(user.check_password(password))
|
||||
|
||||
def test_filter_type(self):
|
||||
"""Test API filtering by type"""
|
||||
self.client.force_login(self.admin)
|
||||
@@ -131,26 +113,6 @@ class TestUsersAPI(APITestCase):
|
||||
self.assertEqual(response.status_code, 400)
|
||||
self.assertJSONEqual(response.content, {"password": ["This field may not be blank."]})
|
||||
|
||||
def test_set_password_hash(self):
|
||||
"""Test setting a user's password from a hash."""
|
||||
self.client.force_login(self.admin)
|
||||
password = generate_key()
|
||||
password_hash = make_password(password)
|
||||
response = self._set_password_hash(self.user, password_hash)
|
||||
|
||||
self._assert_password_hash_set(self.user, password, password_hash, response)
|
||||
|
||||
def test_set_password_hash_invalid(self):
|
||||
"""Test invalid password hashes are rejected."""
|
||||
self.client.force_login(self.admin)
|
||||
response = self._set_password_hash(self.user, INVALID_PASSWORD_HASH)
|
||||
|
||||
self.assertEqual(response.status_code, 400)
|
||||
self.assertJSONEqual(
|
||||
response.content,
|
||||
{"password": [INVALID_PASSWORD_HASH_ERROR]},
|
||||
)
|
||||
|
||||
def test_recovery(self):
|
||||
"""Test user recovery link"""
|
||||
flow = create_test_flow(
|
||||
@@ -299,29 +261,6 @@ class TestUsersAPI(APITestCase):
|
||||
self.assertTrue(token_filter.exists())
|
||||
self.assertTrue(token_filter.first().expiring)
|
||||
|
||||
def test_service_account_set_password_hash(self):
|
||||
"""Service account password hash can be set through the API."""
|
||||
self.client.force_login(self.admin)
|
||||
response = self.client.post(
|
||||
reverse("authentik_api:user-service-account"),
|
||||
data={
|
||||
"name": "test-sa",
|
||||
"create_group": False,
|
||||
},
|
||||
)
|
||||
self.assertEqual(response.status_code, 200, response.data)
|
||||
body = loads(response.content)
|
||||
|
||||
user = User.objects.get(pk=body["user_pk"])
|
||||
self.assertEqual(user.type, UserTypes.SERVICE_ACCOUNT)
|
||||
self.assertFalse(user.has_usable_password())
|
||||
|
||||
password = generate_key()
|
||||
password_hash = make_password(password)
|
||||
response = self._set_password_hash(user, password_hash)
|
||||
|
||||
self._assert_password_hash_set(user, password, password_hash, response)
|
||||
|
||||
def test_service_account_no_expire(self):
|
||||
"""Service account creation without token expiration"""
|
||||
self.client.force_login(self.admin)
|
||||
|
||||
@@ -12,7 +12,7 @@ from authentik.core.models import (
|
||||
User,
|
||||
UserTypes,
|
||||
)
|
||||
from authentik.core.signals import password_changed, password_hash_changed
|
||||
from authentik.core.signals import password_changed
|
||||
from authentik.enterprise.providers.ssf.models import (
|
||||
EventTypes,
|
||||
SSFProvider,
|
||||
@@ -84,13 +84,14 @@ def ssf_user_session_delete_session_revoked(sender, instance: AuthenticatedSessi
|
||||
)
|
||||
|
||||
|
||||
def _send_password_credential_change(user: User, change_type: str):
|
||||
@receiver(password_changed)
|
||||
def ssf_password_changed_cred_change(sender, user: User, password: str | None, **_):
|
||||
"""Credential change trigger (password changed)"""
|
||||
send_ssf_events(
|
||||
EventTypes.CAEP_CREDENTIAL_CHANGE,
|
||||
{
|
||||
"credential_type": "password",
|
||||
"change_type": change_type,
|
||||
"change_type": "revoke" if password is None else "update",
|
||||
},
|
||||
sub_id={
|
||||
"format": "complex",
|
||||
@@ -102,16 +103,6 @@ def _send_password_credential_change(user: User, change_type: str):
|
||||
)
|
||||
|
||||
|
||||
@receiver(password_hash_changed)
|
||||
@receiver(password_changed)
|
||||
def ssf_password_changed_cred_change(signal, sender, user: User, password: str | None = None, **_):
|
||||
"""Credential change trigger (password changed)"""
|
||||
if signal is password_hash_changed:
|
||||
_send_password_credential_change(user, "update")
|
||||
return
|
||||
_send_password_credential_change(user, "revoke" if password is None else "update")
|
||||
|
||||
|
||||
device_type_map = {
|
||||
StaticDevice: "pin",
|
||||
TOTPDevice: "pin",
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
from uuid import uuid4
|
||||
|
||||
from django.contrib.auth.hashers import make_password
|
||||
from django.urls import reverse
|
||||
from rest_framework.test import APITestCase
|
||||
|
||||
@@ -53,21 +52,6 @@ class TestSignals(APITestCase):
|
||||
)
|
||||
self.assertEqual(res.status_code, 201, res.content)
|
||||
|
||||
def _assert_password_credential_change(self, user, change_type: str):
|
||||
stream = Stream.objects.filter(provider=self.provider).first()
|
||||
self.assertIsNotNone(stream)
|
||||
event = StreamEvent.objects.filter(stream=stream).first()
|
||||
self.assertIsNotNone(event)
|
||||
self.assertEqual(event.status, SSFEventStatus.PENDING_FAILED)
|
||||
event_payload = event.payload["events"][
|
||||
"https://schemas.openid.net/secevent/caep/event-type/credential-change"
|
||||
]
|
||||
self.assertEqual(event_payload["change_type"], change_type)
|
||||
self.assertEqual(event_payload["credential_type"], "password")
|
||||
self.assertEqual(event.payload["sub_id"]["format"], "complex")
|
||||
self.assertEqual(event.payload["sub_id"]["user"]["format"], "email")
|
||||
self.assertEqual(event.payload["sub_id"]["user"]["email"], user.email)
|
||||
|
||||
def test_signal_logout(self):
|
||||
"""Test user logout"""
|
||||
user = create_test_user()
|
||||
@@ -95,25 +79,19 @@ class TestSignals(APITestCase):
|
||||
user.set_password(generate_id())
|
||||
user.save()
|
||||
|
||||
self._assert_password_credential_change(user, "update")
|
||||
|
||||
def test_signal_password_change_from_hash(self):
|
||||
"""Test user password change from a pre-hashed password."""
|
||||
user = create_test_user()
|
||||
self.client.force_login(user)
|
||||
user.set_password_from_hash(make_password(generate_id()))
|
||||
user.save()
|
||||
|
||||
self._assert_password_credential_change(user, "update")
|
||||
|
||||
def test_signal_password_revoke(self):
|
||||
"""Test explicit password revoke."""
|
||||
user = create_test_user()
|
||||
self.client.force_login(user)
|
||||
user.set_password(None)
|
||||
user.save()
|
||||
|
||||
self._assert_password_credential_change(user, "revoke")
|
||||
stream = Stream.objects.filter(provider=self.provider).first()
|
||||
self.assertIsNotNone(stream)
|
||||
event = StreamEvent.objects.filter(stream=stream).first()
|
||||
self.assertIsNotNone(event)
|
||||
self.assertEqual(event.status, SSFEventStatus.PENDING_FAILED)
|
||||
event_payload = event.payload["events"][
|
||||
"https://schemas.openid.net/secevent/caep/event-type/credential-change"
|
||||
]
|
||||
self.assertEqual(event_payload["change_type"], "update")
|
||||
self.assertEqual(event_payload["credential_type"], "password")
|
||||
self.assertEqual(event.payload["sub_id"]["format"], "complex")
|
||||
self.assertEqual(event.payload["sub_id"]["user"]["format"], "email")
|
||||
self.assertEqual(event.payload["sub_id"]["user"]["email"], user.email)
|
||||
|
||||
def test_signal_authenticator_added(self):
|
||||
"""Test authenticator creation signal"""
|
||||
|
||||
@@ -11,7 +11,7 @@ from django.http import HttpRequest
|
||||
from rest_framework.request import Request
|
||||
|
||||
from authentik.core.models import AuthenticatedSession, User
|
||||
from authentik.core.signals import login_failed, password_changed, password_hash_changed
|
||||
from authentik.core.signals import login_failed, password_changed
|
||||
from authentik.events.models import Event, EventAction
|
||||
from authentik.flows.models import Stage
|
||||
from authentik.flows.planner import (
|
||||
@@ -112,15 +112,8 @@ def on_invitation_used(sender, request: HttpRequest, invitation: Invitation, **_
|
||||
)
|
||||
|
||||
|
||||
@receiver(password_hash_changed)
|
||||
@receiver(password_changed)
|
||||
def on_password_changed(
|
||||
sender,
|
||||
user: User,
|
||||
password: str | None = None,
|
||||
request: HttpRequest | None = None,
|
||||
**_,
|
||||
):
|
||||
def on_password_changed(sender, user: User, password: str, request: HttpRequest | None, **_):
|
||||
"""Log password change"""
|
||||
Event.new(EventAction.PASSWORD_SET).from_http(request, user=user)
|
||||
|
||||
|
||||
@@ -2,7 +2,6 @@
|
||||
|
||||
from urllib.parse import urlencode
|
||||
|
||||
from django.contrib.auth.hashers import make_password
|
||||
from django.contrib.contenttypes.models import ContentType
|
||||
from django.test import RequestFactory, TestCase
|
||||
from django.views.debug import SafeExceptionReporterFilter
|
||||
@@ -11,7 +10,7 @@ from guardian.shortcuts import get_anonymous_user
|
||||
from authentik.brands.models import Brand
|
||||
from authentik.core.models import Group, User
|
||||
from authentik.core.tests.utils import create_test_user
|
||||
from authentik.events.models import Event, EventAction
|
||||
from authentik.events.models import Event
|
||||
from authentik.flows.planner import PLAN_CONTEXT_PENDING_USER, FlowPlan
|
||||
from authentik.flows.views.executor import QS_QUERY, SESSION_KEY_PLAN
|
||||
from authentik.lib.generators import generate_id
|
||||
@@ -214,14 +213,3 @@ class TestEvents(TestCase):
|
||||
event = Event.new("unittest", foo="foo bar \u0000 baz")
|
||||
event.save()
|
||||
self.assertEqual(event.context["foo"], "foo bar baz")
|
||||
|
||||
def test_password_set_signal_on_set_password_from_hash(self):
|
||||
"""Changing password from hash should still emit an audit event."""
|
||||
user = create_test_user()
|
||||
old_count = Event.objects.filter(action=EventAction.PASSWORD_SET, user__pk=user.pk).count()
|
||||
|
||||
user.set_password_from_hash(make_password(generate_id()))
|
||||
user.save()
|
||||
|
||||
new_count = Event.objects.filter(action=EventAction.PASSWORD_SET, user__pk=user.pk).count()
|
||||
self.assertEqual(new_count, old_count + 1)
|
||||
|
||||
@@ -65,7 +65,6 @@ class OAuth2ProviderSerializer(ProviderSerializer):
|
||||
fields = ProviderSerializer.Meta.fields + [
|
||||
"authorization_flow",
|
||||
"client_type",
|
||||
"grant_types",
|
||||
"client_id",
|
||||
"client_secret",
|
||||
"access_code_validity",
|
||||
|
||||
@@ -7,7 +7,7 @@ from django.http import HttpRequest, HttpResponse, HttpResponseRedirect
|
||||
from authentik.events.models import Event, EventAction
|
||||
from authentik.lib.sentry import SentryIgnoredException
|
||||
from authentik.lib.views import bad_request_message
|
||||
from authentik.providers.oauth2.models import GrantType, RedirectURI
|
||||
from authentik.providers.oauth2.models import GrantTypes, RedirectURI
|
||||
|
||||
|
||||
class OAuth2Error(SentryIgnoredException):
|
||||
@@ -182,7 +182,7 @@ class AuthorizeError(OAuth2Error):
|
||||
# See:
|
||||
# http://openid.net/specs/openid-connect-core-1_0.html#ImplicitAuthError
|
||||
fragment_or_query = (
|
||||
"#" if self.grant_type in [GrantType.IMPLICIT, GrantType.HYBRID] else "?"
|
||||
"#" if self.grant_type in [GrantTypes.IMPLICIT, GrantTypes.HYBRID] else "?"
|
||||
)
|
||||
|
||||
uri = (
|
||||
@@ -225,7 +225,7 @@ class TokenError(OAuth2Error):
|
||||
),
|
||||
}
|
||||
|
||||
def __init__(self, error: str):
|
||||
def __init__(self, error):
|
||||
super().__init__()
|
||||
self.error = error
|
||||
self.description = self.errors[error]
|
||||
|
||||
@@ -1,69 +0,0 @@
|
||||
# Generated by Django 5.2.11 on 2026-02-17 11:04
|
||||
|
||||
import django.contrib.postgres.fields
|
||||
from django.db import migrations, models
|
||||
|
||||
|
||||
def migrate_default_grant_types():
|
||||
from authentik.providers.oauth2.models import GrantType
|
||||
|
||||
return [
|
||||
GrantType.AUTHORIZATION_CODE,
|
||||
GrantType.HYBRID,
|
||||
GrantType.IMPLICIT,
|
||||
GrantType.CLIENT_CREDENTIALS,
|
||||
GrantType.PASSWORD,
|
||||
GrantType.DEVICE_CODE,
|
||||
GrantType.REFRESH_TOKEN,
|
||||
]
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
|
||||
dependencies = [
|
||||
(
|
||||
"authentik_providers_oauth2",
|
||||
"0031_remove_oauth2provider_backchannel_logout_uri_and_more",
|
||||
),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.AddField(
|
||||
model_name="oauth2provider",
|
||||
name="grant_types",
|
||||
field=django.contrib.postgres.fields.ArrayField(
|
||||
base_field=models.TextField(
|
||||
choices=[
|
||||
("authorization_code", "Authorization Code"),
|
||||
("implicit", "Implicit"),
|
||||
("hybrid", "Hybrid"),
|
||||
("refresh_token", "Refresh Token"),
|
||||
("client_credentials", "Client Credentials"),
|
||||
("password", "Password"),
|
||||
("urn:ietf:params:oauth:grant-type:device_code", "Device Code"),
|
||||
]
|
||||
),
|
||||
default=migrate_default_grant_types,
|
||||
size=None,
|
||||
),
|
||||
),
|
||||
migrations.AlterField(
|
||||
model_name="oauth2provider",
|
||||
name="grant_types",
|
||||
field=django.contrib.postgres.fields.ArrayField(
|
||||
base_field=models.TextField(
|
||||
choices=[
|
||||
("authorization_code", "Authorization Code"),
|
||||
("implicit", "Implicit"),
|
||||
("hybrid", "Hybrid"),
|
||||
("refresh_token", "Refresh Token"),
|
||||
("client_credentials", "Client Credentials"),
|
||||
("password", "Password"),
|
||||
("urn:ietf:params:oauth:grant-type:device_code", "Device Code"),
|
||||
]
|
||||
),
|
||||
default=list,
|
||||
size=None,
|
||||
),
|
||||
),
|
||||
]
|
||||
@@ -19,7 +19,6 @@ from cryptography.hazmat.primitives.asymmetric.rsa import RSAPrivateKey
|
||||
from cryptography.hazmat.primitives.asymmetric.types import PrivateKeyTypes
|
||||
from dacite import Config
|
||||
from dacite.core import from_dict
|
||||
from django.contrib.postgres.fields import ArrayField
|
||||
from django.contrib.postgres.indexes import HashIndex
|
||||
from django.db import models
|
||||
from django.http import HttpRequest
|
||||
@@ -34,16 +33,7 @@ from rest_framework.serializers import Serializer
|
||||
from structlog.stdlib import get_logger
|
||||
|
||||
from authentik.brands.models import WebfingerProvider
|
||||
from authentik.common.oauth.constants import (
|
||||
GRANT_TYPE_AUTHORIZATION_CODE,
|
||||
GRANT_TYPE_CLIENT_CREDENTIALS,
|
||||
GRANT_TYPE_DEVICE_CODE,
|
||||
GRANT_TYPE_HYBRID,
|
||||
GRANT_TYPE_IMPLICIT,
|
||||
GRANT_TYPE_PASSWORD,
|
||||
GRANT_TYPE_REFRESH_TOKEN,
|
||||
SubModes,
|
||||
)
|
||||
from authentik.common.oauth.constants import SubModes
|
||||
from authentik.core.models import (
|
||||
AuthenticatedSession,
|
||||
ExpiringModel,
|
||||
@@ -68,7 +58,7 @@ def generate_client_secret() -> str:
|
||||
return generate_id(128)
|
||||
|
||||
|
||||
class ClientType(models.TextChoices):
|
||||
class ClientTypes(models.TextChoices):
|
||||
"""Confidential clients are capable of maintaining the confidentiality
|
||||
of their credentials. Public clients are incapable."""
|
||||
|
||||
@@ -76,16 +66,12 @@ class ClientType(models.TextChoices):
|
||||
PUBLIC = "public", _("Public")
|
||||
|
||||
|
||||
class GrantType(models.TextChoices):
|
||||
class GrantTypes(models.TextChoices):
|
||||
"""OAuth2 Grant types we support"""
|
||||
|
||||
AUTHORIZATION_CODE = GRANT_TYPE_AUTHORIZATION_CODE
|
||||
IMPLICIT = GRANT_TYPE_IMPLICIT
|
||||
HYBRID = GRANT_TYPE_HYBRID
|
||||
REFRESH_TOKEN = GRANT_TYPE_REFRESH_TOKEN
|
||||
CLIENT_CREDENTIALS = GRANT_TYPE_CLIENT_CREDENTIALS
|
||||
PASSWORD = GRANT_TYPE_PASSWORD
|
||||
DEVICE_CODE = GRANT_TYPE_DEVICE_CODE
|
||||
AUTHORIZATION_CODE = "authorization_code"
|
||||
IMPLICIT = "implicit"
|
||||
HYBRID = "hybrid"
|
||||
|
||||
|
||||
class ResponseMode(models.TextChoices):
|
||||
@@ -202,15 +188,14 @@ class OAuth2Provider(WebfingerProvider, Provider):
|
||||
|
||||
client_type = models.CharField(
|
||||
max_length=30,
|
||||
choices=ClientType.choices,
|
||||
default=ClientType.CONFIDENTIAL,
|
||||
choices=ClientTypes.choices,
|
||||
default=ClientTypes.CONFIDENTIAL,
|
||||
verbose_name=_("Client Type"),
|
||||
help_text=_(
|
||||
"Confidential clients are capable of maintaining the confidentiality "
|
||||
"of their credentials. Public clients are incapable"
|
||||
),
|
||||
)
|
||||
grant_types = ArrayField(models.TextField(choices=GrantType.choices), default=list)
|
||||
client_id = models.CharField(
|
||||
max_length=255,
|
||||
unique=True,
|
||||
|
||||
@@ -22,7 +22,7 @@ from authentik.providers.oauth2.errors import AuthorizeError, ClientIdError, Red
|
||||
from authentik.providers.oauth2.models import (
|
||||
AccessToken,
|
||||
AuthorizationCode,
|
||||
GrantType,
|
||||
GrantTypes,
|
||||
OAuth2Provider,
|
||||
RedirectURI,
|
||||
RedirectURIMatchingMode,
|
||||
@@ -41,34 +41,12 @@ class TestAuthorize(OAuthTestCase):
|
||||
super().setUp()
|
||||
self.factory = RequestFactory()
|
||||
|
||||
def test_disallowed_grant_type(self):
|
||||
"""Test with disallowed grant type"""
|
||||
OAuth2Provider.objects.create(
|
||||
name=generate_id(),
|
||||
client_id="test",
|
||||
grant_types=[],
|
||||
authorization_flow=create_test_flow(),
|
||||
redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://local.invalid/Foo")],
|
||||
)
|
||||
with self.assertRaises(AuthorizeError) as cm:
|
||||
request = self.factory.get(
|
||||
"/",
|
||||
data={
|
||||
"response_type": "code",
|
||||
"client_id": "test",
|
||||
"redirect_uri": "http://local.invalid/Foo",
|
||||
},
|
||||
)
|
||||
OAuthAuthorizationParams.from_request(request)
|
||||
self.assertEqual(cm.exception.error, "invalid_request")
|
||||
|
||||
def test_invalid_grant_type(self):
|
||||
"""Test with invalid grant type"""
|
||||
OAuth2Provider.objects.create(
|
||||
name=generate_id(),
|
||||
client_id="test",
|
||||
authorization_flow=create_test_flow(),
|
||||
grant_types=[GrantType.AUTHORIZATION_CODE],
|
||||
redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://local.invalid/Foo")],
|
||||
)
|
||||
with self.assertRaises(AuthorizeError) as cm:
|
||||
@@ -96,7 +74,6 @@ class TestAuthorize(OAuthTestCase):
|
||||
client_id="test",
|
||||
authorization_flow=create_test_flow(),
|
||||
redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://local.invalid/Foo")],
|
||||
grant_types=[GrantType.AUTHORIZATION_CODE],
|
||||
)
|
||||
with self.assertRaises(AuthorizeError) as cm:
|
||||
request = self.factory.get(
|
||||
@@ -211,7 +188,6 @@ class TestAuthorize(OAuthTestCase):
|
||||
client_id="test",
|
||||
authorization_flow=create_test_flow(),
|
||||
redirect_uris=[RedirectURI(RedirectURIMatchingMode.REGEX, ".+")],
|
||||
grant_types=[GrantType.AUTHORIZATION_CODE],
|
||||
)
|
||||
request = self.factory.get(
|
||||
"/",
|
||||
@@ -230,7 +206,6 @@ class TestAuthorize(OAuthTestCase):
|
||||
name=generate_id(),
|
||||
client_id="test",
|
||||
authorization_flow=create_test_flow(),
|
||||
grant_types=[GrantType.AUTHORIZATION_CODE],
|
||||
redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://local.invalid/Foo")],
|
||||
)
|
||||
provider.property_mappings.set(
|
||||
@@ -252,14 +227,12 @@ class TestAuthorize(OAuthTestCase):
|
||||
)
|
||||
self.assertEqual(
|
||||
OAuthAuthorizationParams.from_request(request).grant_type,
|
||||
GrantType.AUTHORIZATION_CODE,
|
||||
GrantTypes.AUTHORIZATION_CODE,
|
||||
)
|
||||
self.assertEqual(
|
||||
OAuthAuthorizationParams.from_request(request).redirect_uri,
|
||||
"http://local.invalid/Foo",
|
||||
)
|
||||
provider.grant_types = [GrantType.IMPLICIT]
|
||||
provider.save()
|
||||
request = self.factory.get(
|
||||
"/",
|
||||
data={
|
||||
@@ -273,7 +246,7 @@ class TestAuthorize(OAuthTestCase):
|
||||
)
|
||||
self.assertEqual(
|
||||
OAuthAuthorizationParams.from_request(request).grant_type,
|
||||
GrantType.IMPLICIT,
|
||||
GrantTypes.IMPLICIT,
|
||||
)
|
||||
# Implicit without openid scope
|
||||
with self.assertRaises(AuthorizeError) as cm:
|
||||
@@ -288,10 +261,8 @@ class TestAuthorize(OAuthTestCase):
|
||||
)
|
||||
self.assertEqual(
|
||||
OAuthAuthorizationParams.from_request(request).grant_type,
|
||||
GrantType.IMPLICIT,
|
||||
GrantTypes.IMPLICIT,
|
||||
)
|
||||
provider.grant_types = [GrantType.HYBRID]
|
||||
provider.save()
|
||||
request = self.factory.get(
|
||||
"/",
|
||||
data={
|
||||
@@ -303,7 +274,7 @@ class TestAuthorize(OAuthTestCase):
|
||||
},
|
||||
)
|
||||
self.assertEqual(
|
||||
OAuthAuthorizationParams.from_request(request).grant_type, GrantType.HYBRID
|
||||
OAuthAuthorizationParams.from_request(request).grant_type, GrantTypes.HYBRID
|
||||
)
|
||||
with self.assertRaises(AuthorizeError) as cm:
|
||||
request = self.factory.get(
|
||||
@@ -326,7 +297,6 @@ class TestAuthorize(OAuthTestCase):
|
||||
authorization_flow=flow,
|
||||
redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "foo://localhost")],
|
||||
access_code_validity="seconds=100",
|
||||
grant_types=[GrantType.AUTHORIZATION_CODE],
|
||||
)
|
||||
Application.objects.create(name="app", slug="app", provider=provider)
|
||||
state = generate_id()
|
||||
@@ -363,7 +333,6 @@ class TestAuthorize(OAuthTestCase):
|
||||
authorization_flow=flow,
|
||||
redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://localhost")],
|
||||
signing_key=self.keypair,
|
||||
grant_types=[GrantType.IMPLICIT],
|
||||
)
|
||||
provider.property_mappings.set(
|
||||
ScopeMapping.objects.filter(
|
||||
@@ -435,7 +404,6 @@ class TestAuthorize(OAuthTestCase):
|
||||
redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://localhost")],
|
||||
signing_key=self.keypair,
|
||||
encryption_key=self.keypair,
|
||||
grant_types=[GrantType.IMPLICIT],
|
||||
)
|
||||
provider.property_mappings.set(
|
||||
ScopeMapping.objects.filter(
|
||||
@@ -498,7 +466,6 @@ class TestAuthorize(OAuthTestCase):
|
||||
authorization_flow=flow,
|
||||
redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://localhost")],
|
||||
signing_key=self.keypair,
|
||||
grant_types=[GrantType.AUTHORIZATION_CODE],
|
||||
)
|
||||
Application.objects.create(name="app", slug="app", provider=provider)
|
||||
state = generate_id()
|
||||
@@ -548,7 +515,6 @@ class TestAuthorize(OAuthTestCase):
|
||||
authorization_flow=flow,
|
||||
redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://localhost")],
|
||||
signing_key=self.keypair,
|
||||
grant_types=[GrantType.IMPLICIT],
|
||||
)
|
||||
provider.property_mappings.set(
|
||||
ScopeMapping.objects.filter(
|
||||
@@ -606,7 +572,6 @@ class TestAuthorize(OAuthTestCase):
|
||||
authorization_flow=flow,
|
||||
redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://localhost")],
|
||||
signing_key=self.keypair,
|
||||
grant_types=[GrantType.AUTHORIZATION_CODE],
|
||||
)
|
||||
app = Application.objects.create(name=generate_id(), slug=generate_id(), provider=provider)
|
||||
state = generate_id()
|
||||
@@ -647,7 +612,6 @@ class TestAuthorize(OAuthTestCase):
|
||||
name=generate_id(),
|
||||
client_id="test",
|
||||
authorization_flow=create_test_flow(),
|
||||
grant_types=[GrantType.IMPLICIT],
|
||||
redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://localhost")],
|
||||
)
|
||||
request = self.factory.get(
|
||||
@@ -671,7 +635,6 @@ class TestAuthorize(OAuthTestCase):
|
||||
client_id="test",
|
||||
authorization_flow=create_test_flow(),
|
||||
redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://localhost")],
|
||||
grant_types=[GrantType.IMPLICIT],
|
||||
)
|
||||
provider.property_mappings.set(
|
||||
ScopeMapping.objects.filter(
|
||||
@@ -704,7 +667,6 @@ class TestAuthorize(OAuthTestCase):
|
||||
authorization_flow=flow,
|
||||
redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "foo://localhost")],
|
||||
access_code_validity="seconds=100",
|
||||
grant_types=[GrantType.AUTHORIZATION_CODE],
|
||||
)
|
||||
Application.objects.create(name="app", slug="app", provider=provider)
|
||||
state = generate_id()
|
||||
@@ -735,7 +697,6 @@ class TestAuthorize(OAuthTestCase):
|
||||
authorization_flow=flow,
|
||||
redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "foo://localhost")],
|
||||
access_code_validity="seconds=100",
|
||||
grant_types=[GrantType.AUTHORIZATION_CODE],
|
||||
)
|
||||
Application.objects.create(name="app", slug="app", provider=provider)
|
||||
state = generate_id()
|
||||
@@ -775,7 +736,6 @@ class TestAuthorize(OAuthTestCase):
|
||||
authentication_flow=auth_flow,
|
||||
redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "foo://localhost")],
|
||||
access_code_validity="seconds=100",
|
||||
grant_types=[GrantType.AUTHORIZATION_CODE],
|
||||
)
|
||||
Application.objects.create(name="app", slug="app", provider=provider)
|
||||
state = generate_id()
|
||||
@@ -802,7 +762,6 @@ class TestAuthorize(OAuthTestCase):
|
||||
authorization_flow=flow,
|
||||
redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "foo://localhost")],
|
||||
access_code_validity="seconds=100",
|
||||
grant_types=[GrantType.AUTHORIZATION_CODE],
|
||||
)
|
||||
Application.objects.create(name="app", slug="app", provider=provider)
|
||||
state = generate_id()
|
||||
|
||||
@@ -10,7 +10,7 @@ from authentik.blueprints.tests import apply_blueprint
|
||||
from authentik.core.models import Application
|
||||
from authentik.core.tests.utils import create_test_flow
|
||||
from authentik.lib.generators import generate_id
|
||||
from authentik.providers.oauth2.models import DeviceToken, GrantType, OAuth2Provider, ScopeMapping
|
||||
from authentik.providers.oauth2.models import DeviceToken, OAuth2Provider, ScopeMapping
|
||||
from authentik.providers.oauth2.tests.utils import OAuthTestCase
|
||||
|
||||
|
||||
@@ -22,7 +22,6 @@ class TesOAuth2DeviceBackchannel(OAuthTestCase):
|
||||
name=generate_id(),
|
||||
client_id="test",
|
||||
authorization_flow=create_test_flow(),
|
||||
grant_types=[GrantType.DEVICE_CODE],
|
||||
)
|
||||
self.application = Application.objects.create(
|
||||
name=generate_id(),
|
||||
@@ -43,21 +42,6 @@ class TesOAuth2DeviceBackchannel(OAuthTestCase):
|
||||
reverse("authentik_providers_oauth2:device"),
|
||||
)
|
||||
self.assertEqual(res.status_code, 400)
|
||||
|
||||
def test_backchannel_invalid_no_grant(self):
|
||||
"""Test backchannel"""
|
||||
self.provider.grant_types = []
|
||||
self.provider.save()
|
||||
res = self.client.post(
|
||||
reverse("authentik_providers_oauth2:device"),
|
||||
data={
|
||||
"client_id": "test",
|
||||
},
|
||||
)
|
||||
self.assertEqual(res.status_code, 400)
|
||||
|
||||
def test_backchannel_invalid_no_app(self):
|
||||
"""Test backchannel"""
|
||||
# test without application
|
||||
self.application.provider = None
|
||||
self.application.save()
|
||||
|
||||
@@ -9,7 +9,7 @@ from authentik.core.models import Application, Group
|
||||
from authentik.core.tests.utils import create_test_admin_user, create_test_brand, create_test_flow
|
||||
from authentik.lib.generators import generate_id
|
||||
from authentik.policies.models import PolicyBinding
|
||||
from authentik.providers.oauth2.models import DeviceToken, GrantType, OAuth2Provider
|
||||
from authentik.providers.oauth2.models import DeviceToken, OAuth2Provider
|
||||
from authentik.providers.oauth2.tests.utils import OAuthTestCase
|
||||
from authentik.providers.oauth2.views.device_init import QS_KEY_CODE
|
||||
|
||||
@@ -22,7 +22,6 @@ class TesOAuth2DeviceInit(OAuthTestCase):
|
||||
name=generate_id(),
|
||||
client_id="test",
|
||||
authorization_flow=create_test_flow(),
|
||||
grant_types=[GrantType.DEVICE_CODE],
|
||||
)
|
||||
self.application = Application.objects.create(
|
||||
name=generate_id(),
|
||||
|
||||
@@ -14,7 +14,7 @@ from authentik.lib.generators import generate_id
|
||||
from authentik.providers.oauth2.id_token import IDToken
|
||||
from authentik.providers.oauth2.models import (
|
||||
AccessToken,
|
||||
ClientType,
|
||||
ClientTypes,
|
||||
OAuth2Provider,
|
||||
RedirectURI,
|
||||
RedirectURIMatchingMode,
|
||||
@@ -173,7 +173,7 @@ class TesOAuth2Introspection(OAuthTestCase):
|
||||
|
||||
def test_introspect_provider_public(self):
|
||||
"""Test introspect"""
|
||||
self.provider.client_type = ClientType.PUBLIC
|
||||
self.provider.client_type = ClientTypes.PUBLIC
|
||||
self.provider.save()
|
||||
token = AccessToken.objects.create(
|
||||
provider=self.provider,
|
||||
@@ -208,7 +208,7 @@ class TesOAuth2Introspection(OAuthTestCase):
|
||||
authorization_flow=create_test_flow(),
|
||||
redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "")],
|
||||
signing_key=create_test_cert(),
|
||||
client_type=ClientType.PUBLIC,
|
||||
client_type=ClientTypes.PUBLIC,
|
||||
)
|
||||
Application.objects.create(name=generate_id(), slug=generate_id(), provider=other_provider)
|
||||
|
||||
|
||||
@@ -13,7 +13,7 @@ from authentik.lib.generators import generate_id
|
||||
from authentik.providers.oauth2.id_token import IDToken
|
||||
from authentik.providers.oauth2.models import (
|
||||
AccessToken,
|
||||
ClientType,
|
||||
ClientTypes,
|
||||
DeviceToken,
|
||||
OAuth2Provider,
|
||||
RedirectURI,
|
||||
@@ -126,7 +126,7 @@ class TesOAuth2Revoke(OAuthTestCase):
|
||||
|
||||
def test_revoke_public(self):
|
||||
"""Test revoke public client"""
|
||||
self.provider.client_type = ClientType.PUBLIC
|
||||
self.provider.client_type = ClientTypes.PUBLIC
|
||||
self.provider.save()
|
||||
token = AccessToken.objects.create(
|
||||
provider=self.provider,
|
||||
@@ -241,7 +241,7 @@ class TesOAuth2Revoke(OAuthTestCase):
|
||||
authorization_flow=create_test_flow(),
|
||||
redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "")],
|
||||
signing_key=create_test_cert(),
|
||||
client_type=ClientType.PUBLIC,
|
||||
client_type=ClientTypes.PUBLIC,
|
||||
)
|
||||
Application.objects.create(name=generate_id(), slug=generate_id(), provider=other_provider)
|
||||
|
||||
@@ -270,14 +270,14 @@ class TesOAuth2Revoke(OAuthTestCase):
|
||||
def test_revoke_provider_fed_public(self):
|
||||
"""Test revoke with federation. self.provider is a public
|
||||
client and other_provider is a public client."""
|
||||
self.provider.client_type = ClientType.PUBLIC
|
||||
self.provider.client_type = ClientTypes.PUBLIC
|
||||
self.provider.save()
|
||||
other_provider = OAuth2Provider.objects.create(
|
||||
name=generate_id(),
|
||||
authorization_flow=create_test_flow(),
|
||||
redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "")],
|
||||
signing_key=create_test_cert(),
|
||||
client_type=ClientType.PUBLIC,
|
||||
client_type=ClientTypes.PUBLIC,
|
||||
)
|
||||
Application.objects.create(name=generate_id(), slug=generate_id(), provider=other_provider)
|
||||
|
||||
|
||||
@@ -25,7 +25,6 @@ from authentik.providers.oauth2.errors import TokenError
|
||||
from authentik.providers.oauth2.models import (
|
||||
AccessToken,
|
||||
AuthorizationCode,
|
||||
GrantType,
|
||||
OAuth2Provider,
|
||||
RedirectURI,
|
||||
RedirectURIMatchingMode,
|
||||
@@ -45,39 +44,11 @@ class TestToken(OAuthTestCase):
|
||||
self.factory = RequestFactory()
|
||||
self.app = Application.objects.create(name=generate_id(), slug="test")
|
||||
|
||||
def test_invalid_grant_type(self):
|
||||
"""test request param"""
|
||||
provider = OAuth2Provider.objects.create(
|
||||
name=generate_id(),
|
||||
authorization_flow=create_test_flow(),
|
||||
grant_types=[],
|
||||
redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://TestServer")],
|
||||
signing_key=self.keypair,
|
||||
)
|
||||
header = b64encode(f"{provider.client_id}:{provider.client_secret}".encode()).decode()
|
||||
user = create_test_admin_user()
|
||||
code = AuthorizationCode.objects.create(
|
||||
code="foobar", provider=provider, user=user, auth_time=timezone.now()
|
||||
)
|
||||
request = self.factory.post(
|
||||
"/",
|
||||
data={
|
||||
"grant_type": GRANT_TYPE_AUTHORIZATION_CODE,
|
||||
"code": code.code,
|
||||
"redirect_uri": "http://TestServer",
|
||||
},
|
||||
HTTP_AUTHORIZATION=f"Basic {header}",
|
||||
)
|
||||
with self.assertRaises(TokenError) as cm:
|
||||
TokenParams.parse(request, provider, provider.client_id, provider.client_secret)
|
||||
self.assertEqual(cm.exception.cause, "grant_type_not_configured")
|
||||
|
||||
def test_request_auth_code(self):
|
||||
"""test request param"""
|
||||
provider = OAuth2Provider.objects.create(
|
||||
name=generate_id(),
|
||||
authorization_flow=create_test_flow(),
|
||||
grant_types=[GrantType.AUTHORIZATION_CODE],
|
||||
redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://TestServer")],
|
||||
signing_key=self.keypair,
|
||||
)
|
||||
@@ -105,7 +76,6 @@ class TestToken(OAuthTestCase):
|
||||
provider = OAuth2Provider.objects.create(
|
||||
name=generate_id(),
|
||||
authorization_flow=create_test_flow(),
|
||||
grant_types=[GrantType.REFRESH_TOKEN],
|
||||
redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://testserver")],
|
||||
signing_key=self.keypair,
|
||||
)
|
||||
@@ -127,7 +97,6 @@ class TestToken(OAuthTestCase):
|
||||
provider = OAuth2Provider.objects.create(
|
||||
name=generate_id(),
|
||||
authorization_flow=create_test_flow(),
|
||||
grant_types=[GrantType.REFRESH_TOKEN],
|
||||
redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://local.invalid")],
|
||||
signing_key=self.keypair,
|
||||
)
|
||||
@@ -170,7 +139,6 @@ class TestToken(OAuthTestCase):
|
||||
provider = OAuth2Provider.objects.create(
|
||||
name=generate_id(),
|
||||
authorization_flow=create_test_flow(),
|
||||
grant_types=[GrantType.AUTHORIZATION_CODE],
|
||||
redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://local.invalid")],
|
||||
signing_key=self.keypair,
|
||||
)
|
||||
@@ -211,7 +179,6 @@ class TestToken(OAuthTestCase):
|
||||
provider = OAuth2Provider.objects.create(
|
||||
name=generate_id(),
|
||||
authorization_flow=create_test_flow(),
|
||||
grant_types=[GrantType.AUTHORIZATION_CODE],
|
||||
redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://local.invalid")],
|
||||
signing_key=self.keypair,
|
||||
encryption_key=self.keypair,
|
||||
@@ -243,7 +210,6 @@ class TestToken(OAuthTestCase):
|
||||
provider = OAuth2Provider.objects.create(
|
||||
name=generate_id(),
|
||||
authorization_flow=create_test_flow(),
|
||||
grant_types=[GrantType.REFRESH_TOKEN],
|
||||
redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://local.invalid")],
|
||||
signing_key=self.keypair,
|
||||
)
|
||||
@@ -305,7 +271,6 @@ class TestToken(OAuthTestCase):
|
||||
provider = OAuth2Provider.objects.create(
|
||||
name=generate_id(),
|
||||
authorization_flow=create_test_flow(),
|
||||
grant_types=[GrantType.REFRESH_TOKEN],
|
||||
redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://local.invalid")],
|
||||
signing_key=self.keypair,
|
||||
)
|
||||
@@ -363,7 +328,6 @@ class TestToken(OAuthTestCase):
|
||||
provider = OAuth2Provider.objects.create(
|
||||
name=generate_id(),
|
||||
authorization_flow=create_test_flow(),
|
||||
grant_types=[GrantType.REFRESH_TOKEN],
|
||||
redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://testserver")],
|
||||
signing_key=self.keypair,
|
||||
)
|
||||
@@ -436,7 +400,6 @@ class TestToken(OAuthTestCase):
|
||||
provider = OAuth2Provider.objects.create(
|
||||
name=generate_id(),
|
||||
authorization_flow=create_test_flow(),
|
||||
grant_types=[GrantType.REFRESH_TOKEN],
|
||||
redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://local.invalid")],
|
||||
signing_key=self.keypair,
|
||||
refresh_token_threshold="hours=1", # nosec
|
||||
@@ -534,7 +497,6 @@ class TestToken(OAuthTestCase):
|
||||
provider = OAuth2Provider.objects.create(
|
||||
name=generate_id(),
|
||||
authorization_flow=create_test_flow(),
|
||||
grant_types=[GrantType.AUTHORIZATION_CODE],
|
||||
redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://local.invalid")],
|
||||
signing_key=self.keypair,
|
||||
include_claims_in_id_token=True,
|
||||
|
||||
@@ -22,7 +22,6 @@ from authentik.lib.generators import generate_id
|
||||
from authentik.policies.models import PolicyBinding
|
||||
from authentik.providers.oauth2.models import (
|
||||
AccessToken,
|
||||
GrantType,
|
||||
OAuth2Provider,
|
||||
RedirectURI,
|
||||
RedirectURIMatchingMode,
|
||||
@@ -56,7 +55,6 @@ class TestTokenClientCredentialsJWTProvider(OAuthTestCase):
|
||||
authorization_flow=create_test_flow(),
|
||||
redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://testserver")],
|
||||
signing_key=self.cert,
|
||||
grant_types=[GrantType.CLIENT_CREDENTIALS],
|
||||
)
|
||||
self.provider.jwt_federation_providers.add(self.other_provider)
|
||||
self.provider.property_mappings.set(ScopeMapping.objects.all())
|
||||
|
||||
@@ -20,7 +20,6 @@ from authentik.core.tests.utils import create_test_cert, create_test_flow
|
||||
from authentik.lib.generators import generate_id
|
||||
from authentik.policies.models import PolicyBinding
|
||||
from authentik.providers.oauth2.models import (
|
||||
GrantType,
|
||||
OAuth2Provider,
|
||||
RedirectURI,
|
||||
RedirectURIMatchingMode,
|
||||
@@ -69,7 +68,6 @@ class TestTokenClientCredentialsJWTSource(OAuthTestCase):
|
||||
authorization_flow=create_test_flow(),
|
||||
redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://testserver")],
|
||||
signing_key=self.cert,
|
||||
grant_types=[GrantType.CLIENT_CREDENTIALS],
|
||||
)
|
||||
self.provider.jwt_federation_sources.add(self.source)
|
||||
self.provider.property_mappings.set(ScopeMapping.objects.all())
|
||||
|
||||
@@ -21,7 +21,6 @@ from authentik.policies.models import PolicyBinding
|
||||
from authentik.providers.oauth2.errors import TokenError
|
||||
from authentik.providers.oauth2.models import (
|
||||
AccessToken,
|
||||
GrantType,
|
||||
OAuth2Provider,
|
||||
RedirectURI,
|
||||
RedirectURIMatchingMode,
|
||||
@@ -42,7 +41,6 @@ class TestTokenClientCredentialsStandard(OAuthTestCase):
|
||||
authorization_flow=create_test_flow(),
|
||||
redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://testserver")],
|
||||
signing_key=create_test_cert(),
|
||||
grant_types=[GrantType.CLIENT_CREDENTIALS, GrantType.PASSWORD],
|
||||
)
|
||||
self.provider.property_mappings.set(ScopeMapping.objects.all())
|
||||
self.app = Application.objects.create(name="test", slug="test", provider=self.provider)
|
||||
|
||||
@@ -22,7 +22,6 @@ from authentik.core.tests.utils import create_test_admin_user, create_test_cert,
|
||||
from authentik.policies.models import PolicyBinding
|
||||
from authentik.providers.oauth2.errors import TokenError
|
||||
from authentik.providers.oauth2.models import (
|
||||
GrantType,
|
||||
OAuth2Provider,
|
||||
RedirectURI,
|
||||
RedirectURIMatchingMode,
|
||||
@@ -43,7 +42,6 @@ class TestTokenClientCredentialsStandardCompat(OAuthTestCase):
|
||||
authorization_flow=create_test_flow(),
|
||||
redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://testserver")],
|
||||
signing_key=create_test_cert(),
|
||||
grant_types=[GrantType.CLIENT_CREDENTIALS, GrantType.PASSWORD],
|
||||
)
|
||||
self.provider.property_mappings.set(ScopeMapping.objects.all())
|
||||
self.app = Application.objects.create(name="test", slug="test", provider=self.provider)
|
||||
|
||||
@@ -25,7 +25,6 @@ from authentik.core.tests.utils import (
|
||||
from authentik.policies.models import PolicyBinding
|
||||
from authentik.providers.oauth2.errors import TokenError
|
||||
from authentik.providers.oauth2.models import (
|
||||
GrantType,
|
||||
OAuth2Provider,
|
||||
RedirectURI,
|
||||
RedirectURIMatchingMode,
|
||||
@@ -46,7 +45,6 @@ class TestTokenClientCredentialsUserNamePassword(OAuthTestCase):
|
||||
authorization_flow=create_test_flow(),
|
||||
redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://testserver")],
|
||||
signing_key=create_test_cert(),
|
||||
grant_types=[GrantType.CLIENT_CREDENTIALS, GrantType.PASSWORD],
|
||||
)
|
||||
self.provider.property_mappings.set(ScopeMapping.objects.all())
|
||||
self.app = Application.objects.create(name="test", slug="test", provider=self.provider)
|
||||
|
||||
@@ -17,7 +17,6 @@ from authentik.lib.generators import generate_code_fixed_length, generate_id
|
||||
from authentik.providers.oauth2.models import (
|
||||
AccessToken,
|
||||
DeviceToken,
|
||||
GrantType,
|
||||
OAuth2Provider,
|
||||
RedirectURI,
|
||||
RedirectURIMatchingMode,
|
||||
@@ -38,7 +37,6 @@ class TestTokenDeviceCode(OAuthTestCase):
|
||||
authorization_flow=create_test_flow(),
|
||||
redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://testserver")],
|
||||
signing_key=create_test_cert(),
|
||||
grant_types=[GrantType.DEVICE_CODE],
|
||||
)
|
||||
self.provider.property_mappings.set(ScopeMapping.objects.all())
|
||||
self.app = Application.objects.create(name="test", slug="test", provider=self.provider)
|
||||
|
||||
@@ -11,7 +11,6 @@ from authentik.core.tests.utils import create_test_admin_user, create_test_flow
|
||||
from authentik.lib.generators import generate_id
|
||||
from authentik.providers.oauth2.models import (
|
||||
AuthorizationCode,
|
||||
GrantType,
|
||||
OAuth2Provider,
|
||||
RedirectURI,
|
||||
RedirectURIMatchingMode,
|
||||
@@ -38,7 +37,6 @@ class TestTokenPKCE(OAuthTestCase):
|
||||
authorization_flow=flow,
|
||||
redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "foo://localhost")],
|
||||
access_code_validity="seconds=100",
|
||||
grant_types=[GrantType.AUTHORIZATION_CODE],
|
||||
)
|
||||
Application.objects.create(name="app", slug="app", provider=provider)
|
||||
state = generate_id()
|
||||
@@ -97,7 +95,6 @@ class TestTokenPKCE(OAuthTestCase):
|
||||
authorization_flow=flow,
|
||||
redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "foo://localhost")],
|
||||
access_code_validity="seconds=100",
|
||||
grant_types=[GrantType.AUTHORIZATION_CODE],
|
||||
)
|
||||
Application.objects.create(name="app", slug="app", provider=provider)
|
||||
state = generate_id()
|
||||
@@ -154,7 +151,6 @@ class TestTokenPKCE(OAuthTestCase):
|
||||
authorization_flow=flow,
|
||||
redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "foo://localhost")],
|
||||
access_code_validity="seconds=100",
|
||||
grant_types=[GrantType.AUTHORIZATION_CODE],
|
||||
)
|
||||
Application.objects.create(name="app", slug="app", provider=provider)
|
||||
state = generate_id()
|
||||
@@ -200,7 +196,6 @@ class TestTokenPKCE(OAuthTestCase):
|
||||
authorization_flow=flow,
|
||||
redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "foo://localhost")],
|
||||
access_code_validity="seconds=100",
|
||||
grant_types=[GrantType.AUTHORIZATION_CODE],
|
||||
)
|
||||
Application.objects.create(name="app", slug="app", provider=provider)
|
||||
state = generate_id()
|
||||
|
||||
@@ -57,7 +57,7 @@ from authentik.providers.oauth2.id_token import IDToken
|
||||
from authentik.providers.oauth2.models import (
|
||||
AccessToken,
|
||||
AuthorizationCode,
|
||||
GrantType,
|
||||
GrantTypes,
|
||||
OAuth2Provider,
|
||||
RedirectURIMatchingMode,
|
||||
ResponseMode,
|
||||
@@ -164,31 +164,28 @@ class OAuthAuthorizationParams:
|
||||
"""Check grant"""
|
||||
# Determine which flow to use.
|
||||
if self.response_type in [ResponseTypes.CODE]:
|
||||
self.grant_type = GrantType.AUTHORIZATION_CODE
|
||||
self.grant_type = GrantTypes.AUTHORIZATION_CODE
|
||||
elif self.response_type in [
|
||||
ResponseTypes.ID_TOKEN,
|
||||
ResponseTypes.ID_TOKEN_TOKEN,
|
||||
]:
|
||||
self.grant_type = GrantType.IMPLICIT
|
||||
self.grant_type = GrantTypes.IMPLICIT
|
||||
elif self.response_type in [
|
||||
ResponseTypes.CODE_TOKEN,
|
||||
ResponseTypes.CODE_ID_TOKEN,
|
||||
ResponseTypes.CODE_ID_TOKEN_TOKEN,
|
||||
]:
|
||||
self.grant_type = GrantType.HYBRID
|
||||
self.grant_type = GrantTypes.HYBRID
|
||||
|
||||
# Grant type validation.
|
||||
if not self.grant_type:
|
||||
LOGGER.warning("Invalid response type", type=self.response_type)
|
||||
raise AuthorizeError(self.redirect_uri, "unsupported_response_type", "", self.state)
|
||||
|
||||
if self.grant_type not in self.provider.grant_types:
|
||||
LOGGER.warning("Invalid grant_type for provider", grant_type=self.grant_type)
|
||||
raise AuthorizeError(self.redirect_uri, "invalid_request", self.grant_type, self.state)
|
||||
|
||||
if self.response_mode not in ResponseMode.values:
|
||||
self.response_mode = ResponseMode.QUERY
|
||||
|
||||
if self.grant_type in [GrantType.IMPLICIT, GrantType.HYBRID]:
|
||||
if self.grant_type in [GrantTypes.IMPLICIT, GrantTypes.HYBRID]:
|
||||
self.response_mode = ResponseMode.FRAGMENT
|
||||
|
||||
def check_redirect_uri(self):
|
||||
@@ -249,7 +246,7 @@ class OAuthAuthorizationParams:
|
||||
)
|
||||
self.scope = self.scope.intersection(default_scope_names)
|
||||
if SCOPE_OPENID not in self.scope and (
|
||||
self.grant_type == GrantType.HYBRID
|
||||
self.grant_type == GrantTypes.HYBRID
|
||||
or self.response_type in [ResponseTypes.ID_TOKEN, ResponseTypes.ID_TOKEN_TOKEN]
|
||||
):
|
||||
LOGGER.warning("Missing 'openid' scope.")
|
||||
@@ -600,8 +597,8 @@ class OAuthFulfillmentStage(StageView):
|
||||
code = None
|
||||
|
||||
if self.params.grant_type in [
|
||||
GrantType.AUTHORIZATION_CODE,
|
||||
GrantType.HYBRID,
|
||||
GrantTypes.AUTHORIZATION_CODE,
|
||||
GrantTypes.HYBRID,
|
||||
]:
|
||||
code = self.params.create_code(self.request)
|
||||
code.save()
|
||||
@@ -616,7 +613,7 @@ class OAuthFulfillmentStage(StageView):
|
||||
|
||||
if self.params.response_mode == ResponseMode.FRAGMENT:
|
||||
query_fragment = {}
|
||||
if self.params.grant_type in [GrantType.AUTHORIZATION_CODE]:
|
||||
if self.params.grant_type in [GrantTypes.AUTHORIZATION_CODE]:
|
||||
query_fragment["code"] = code.code
|
||||
query_fragment["state"] = [str(self.params.state) if self.params.state else ""]
|
||||
else:
|
||||
@@ -630,7 +627,7 @@ class OAuthFulfillmentStage(StageView):
|
||||
|
||||
if self.params.response_mode == ResponseMode.FORM_POST:
|
||||
post_params = {}
|
||||
if self.params.grant_type in [GrantType.AUTHORIZATION_CODE]:
|
||||
if self.params.grant_type in [GrantTypes.AUTHORIZATION_CODE]:
|
||||
post_params["code"] = code.code
|
||||
post_params["state"] = [str(self.params.state) if self.params.state else ""]
|
||||
else:
|
||||
@@ -699,7 +696,7 @@ class OAuthFulfillmentStage(StageView):
|
||||
token.save()
|
||||
|
||||
# Code parameter must be present if it's Hybrid Flow.
|
||||
if self.params.grant_type == GrantType.HYBRID:
|
||||
if self.params.grant_type == GrantTypes.HYBRID:
|
||||
query_fragment["code"] = code.code
|
||||
|
||||
query_fragment["token_type"] = TOKEN_TYPE
|
||||
|
||||
@@ -15,7 +15,7 @@ from authentik.core.models import Application
|
||||
from authentik.lib.config import CONFIG
|
||||
from authentik.lib.utils.time import timedelta_from_string
|
||||
from authentik.providers.oauth2.errors import DeviceCodeError
|
||||
from authentik.providers.oauth2.models import DeviceToken, GrantType, OAuth2Provider, ScopeMapping
|
||||
from authentik.providers.oauth2.models import DeviceToken, OAuth2Provider, ScopeMapping
|
||||
from authentik.providers.oauth2.utils import TokenResponse, extract_client_auth
|
||||
from authentik.providers.oauth2.views.device_init import QS_KEY_CODE
|
||||
|
||||
@@ -42,8 +42,6 @@ class DeviceView(View):
|
||||
_ = provider.application
|
||||
except Application.DoesNotExist:
|
||||
raise DeviceCodeError("invalid_client") from None
|
||||
if GrantType.DEVICE_CODE not in provider.grant_types:
|
||||
raise DeviceCodeError("invalid_client")
|
||||
self.provider = provider
|
||||
self.client_id = client_id
|
||||
|
||||
|
||||
@@ -11,7 +11,7 @@ from structlog.stdlib import get_logger
|
||||
|
||||
from authentik.providers.oauth2.errors import TokenIntrospectionError
|
||||
from authentik.providers.oauth2.id_token import IDToken
|
||||
from authentik.providers.oauth2.models import AccessToken, ClientType, OAuth2Provider, RefreshToken
|
||||
from authentik.providers.oauth2.models import AccessToken, ClientTypes, OAuth2Provider, RefreshToken
|
||||
from authentik.providers.oauth2.utils import TokenResponse, authenticate_provider
|
||||
|
||||
LOGGER = get_logger()
|
||||
@@ -45,7 +45,7 @@ class TokenIntrospectionParams:
|
||||
if not provider:
|
||||
LOGGER.info("Failed to authenticate introspection request")
|
||||
raise TokenIntrospectionError
|
||||
if provider.client_type != ClientType.CONFIDENTIAL:
|
||||
if provider.client_type != ClientTypes.CONFIDENTIAL:
|
||||
LOGGER.info("Introspection request from public provider, denying.")
|
||||
raise TokenIntrospectionError
|
||||
|
||||
|
||||
@@ -58,7 +58,7 @@ from authentik.providers.oauth2.id_token import IDToken
|
||||
from authentik.providers.oauth2.models import (
|
||||
AccessToken,
|
||||
AuthorizationCode,
|
||||
ClientType,
|
||||
ClientTypes,
|
||||
DeviceToken,
|
||||
OAuth2Provider,
|
||||
RedirectURIMatchingMode,
|
||||
@@ -165,10 +165,6 @@ class TokenParams:
|
||||
raise TokenError("invalid_grant")
|
||||
|
||||
def __post_init__(self, raw_code: str, raw_token: str, request: HttpRequest):
|
||||
if self.grant_type not in self.provider.grant_types:
|
||||
LOGGER.warning("Invalid grant_type for provider", grant_type=self.grant_type)
|
||||
raise TokenError("invalid_grant").with_cause("grant_type_not_configured")
|
||||
|
||||
# Confidential clients MUST authenticate to the token endpoint per
|
||||
# RFC 6749 §2.3.1. The device code grant (RFC 8628 §3.4) inherits
|
||||
# that requirement - the device_code alone is not a substitute for
|
||||
@@ -178,7 +174,7 @@ class TokenParams:
|
||||
GRANT_TYPE_REFRESH_TOKEN,
|
||||
GRANT_TYPE_DEVICE_CODE,
|
||||
]:
|
||||
if self.provider.client_type == ClientType.CONFIDENTIAL and not compare_digest(
|
||||
if self.provider.client_type == ClientTypes.CONFIDENTIAL and not compare_digest(
|
||||
self.provider.client_secret, self.client_secret
|
||||
):
|
||||
LOGGER.warning(
|
||||
@@ -610,10 +606,10 @@ class TokenView(View):
|
||||
if not self.provider:
|
||||
LOGGER.warning("OAuth2Provider does not exist", client_id=client_id)
|
||||
raise TokenError("invalid_client")
|
||||
CTX_AUTH_VIA.set("oauth_client_secret")
|
||||
self.params = self.params_class.parse(
|
||||
request, self.provider, client_id, client_secret
|
||||
)
|
||||
CTX_AUTH_VIA.set("oauth_client_secret")
|
||||
|
||||
with start_span(
|
||||
op="authentik.providers.oauth2.post.response",
|
||||
|
||||
@@ -10,7 +10,7 @@ from django.views.decorators.csrf import csrf_exempt
|
||||
from structlog.stdlib import get_logger
|
||||
|
||||
from authentik.providers.oauth2.errors import TokenRevocationError
|
||||
from authentik.providers.oauth2.models import AccessToken, ClientType, OAuth2Provider, RefreshToken
|
||||
from authentik.providers.oauth2.models import AccessToken, ClientTypes, OAuth2Provider, RefreshToken
|
||||
from authentik.providers.oauth2.utils import (
|
||||
TokenResponse,
|
||||
authenticate_provider,
|
||||
@@ -33,13 +33,11 @@ class TokenRevocationParams:
|
||||
raw_token = request.POST.get("token")
|
||||
|
||||
provider, _, _ = provider_from_request(request)
|
||||
if provider and provider.client_type == ClientType.CONFIDENTIAL:
|
||||
provider = authenticate_provider(request)
|
||||
if not provider:
|
||||
raise TokenRevocationError("invalid_client")
|
||||
# By default clients can only revoke their own tokens
|
||||
query = Q(provider=provider, token=raw_token)
|
||||
if provider.client_type == ClientType.CONFIDENTIAL:
|
||||
if provider.client_type == ClientTypes.CONFIDENTIAL:
|
||||
provider = authenticate_provider(request)
|
||||
if not provider:
|
||||
raise TokenRevocationError("invalid_client")
|
||||
|
||||
@@ -16,8 +16,7 @@ from authentik.crypto.models import CertificateKeyPair
|
||||
from authentik.lib.models import DomainlessURLValidator, InternallyManagedMixin
|
||||
from authentik.outposts.models import OutpostModel
|
||||
from authentik.providers.oauth2.models import (
|
||||
ClientType,
|
||||
GrantType,
|
||||
ClientTypes,
|
||||
OAuth2Provider,
|
||||
RedirectURI,
|
||||
RedirectURIMatchingMode,
|
||||
@@ -162,12 +161,7 @@ class ProxyProvider(OutpostModel, OAuth2Provider):
|
||||
|
||||
def set_oauth_defaults(self):
|
||||
"""Ensure all OAuth2-related settings are correct"""
|
||||
self.grant_types = [
|
||||
GrantType.AUTHORIZATION_CODE,
|
||||
GrantType.CLIENT_CREDENTIALS,
|
||||
GrantType.PASSWORD,
|
||||
]
|
||||
self.client_type = ClientType.CONFIDENTIAL
|
||||
self.client_type = ClientTypes.CONFIDENTIAL
|
||||
self.signing_key = None
|
||||
self.include_claims_in_id_token = True
|
||||
scopes = ScopeMapping.objects.filter(
|
||||
|
||||
@@ -9,7 +9,7 @@ from authentik.core.models import Application
|
||||
from authentik.core.tests.utils import create_test_admin_user, create_test_cert, create_test_flow
|
||||
from authentik.lib.generators import generate_id
|
||||
from authentik.outposts.models import Outpost, OutpostType
|
||||
from authentik.providers.oauth2.models import ClientType
|
||||
from authentik.providers.oauth2.models import ClientTypes
|
||||
from authentik.providers.proxy.models import ProxyMode, ProxyProvider
|
||||
|
||||
|
||||
@@ -96,7 +96,7 @@ class ProxyProviderTests(APITestCase):
|
||||
)
|
||||
self.assertEqual(response.status_code, 201)
|
||||
provider: ProxyProvider = ProxyProvider.objects.get(name=name)
|
||||
self.assertEqual(provider.client_type, ClientType.CONFIDENTIAL)
|
||||
self.assertEqual(provider.client_type, ClientTypes.CONFIDENTIAL)
|
||||
|
||||
def test_update_defaults(self):
|
||||
"""Test create"""
|
||||
@@ -114,8 +114,8 @@ class ProxyProviderTests(APITestCase):
|
||||
)
|
||||
self.assertEqual(response.status_code, 201)
|
||||
provider: ProxyProvider = ProxyProvider.objects.get(name=name)
|
||||
self.assertEqual(provider.client_type, ClientType.CONFIDENTIAL)
|
||||
provider.client_type = ClientType.PUBLIC
|
||||
self.assertEqual(provider.client_type, ClientTypes.CONFIDENTIAL)
|
||||
provider.client_type = ClientTypes.PUBLIC
|
||||
provider.save()
|
||||
response = self.client.put(
|
||||
reverse("authentik_api:proxyprovider-detail", kwargs={"pk": provider.pk}),
|
||||
@@ -130,7 +130,7 @@ class ProxyProviderTests(APITestCase):
|
||||
)
|
||||
self.assertEqual(response.status_code, 200)
|
||||
provider: ProxyProvider = ProxyProvider.objects.get(name=name)
|
||||
self.assertEqual(provider.client_type, ClientType.CONFIDENTIAL)
|
||||
self.assertEqual(provider.client_type, ClientTypes.CONFIDENTIAL)
|
||||
|
||||
def test_sa_fetch(self):
|
||||
"""Test fetching the outpost config as the service account"""
|
||||
|
||||
@@ -24,11 +24,7 @@ from rest_framework.viewsets import ModelViewSet
|
||||
from structlog.stdlib import get_logger
|
||||
|
||||
from authentik.api.validation import validate
|
||||
from authentik.common.saml.constants import (
|
||||
DEFAULT_ISSUER,
|
||||
SAML_BINDING_POST,
|
||||
SAML_BINDING_REDIRECT,
|
||||
)
|
||||
from authentik.common.saml.constants import SAML_BINDING_POST, SAML_BINDING_REDIRECT
|
||||
from authentik.core.api.providers import ProviderSerializer
|
||||
from authentik.core.api.used_by import UsedByMixin
|
||||
from authentik.core.api.utils import PassiveSerializer, PropertyMappingPreviewSerializer
|
||||
@@ -59,7 +55,6 @@ class SAMLProviderSerializer(ProviderSerializer):
|
||||
"""SAMLProvider Serializer"""
|
||||
|
||||
url_download_metadata = SerializerMethodField()
|
||||
url_issuer = SerializerMethodField()
|
||||
|
||||
url_sso_post = SerializerMethodField()
|
||||
url_sso_redirect = SerializerMethodField()
|
||||
@@ -90,23 +85,6 @@ class SAMLProviderSerializer(ProviderSerializer):
|
||||
+ "?download"
|
||||
)
|
||||
|
||||
def get_url_issuer(self, instance: SAMLProvider) -> str:
|
||||
"""Get Issuer/EntityID URL"""
|
||||
if instance.issuer_override:
|
||||
return instance.issuer_override
|
||||
if "request" not in self._context:
|
||||
return DEFAULT_ISSUER
|
||||
request: HttpRequest = self._context["request"]._request
|
||||
try:
|
||||
return request.build_absolute_uri(
|
||||
reverse(
|
||||
"authentik_providers_saml:base",
|
||||
kwargs={"application_slug": instance.application.slug},
|
||||
)
|
||||
)
|
||||
except Provider.application.RelatedObjectDoesNotExist:
|
||||
return DEFAULT_ISSUER
|
||||
|
||||
def get_url_sso_post(self, instance: SAMLProvider) -> str:
|
||||
"""Get SSO Post URL"""
|
||||
if "request" not in self._context:
|
||||
@@ -220,7 +198,7 @@ class SAMLProviderSerializer(ProviderSerializer):
|
||||
"acs_url",
|
||||
"sls_url",
|
||||
"audience",
|
||||
"issuer_override",
|
||||
"issuer",
|
||||
"assertion_valid_not_before",
|
||||
"assertion_valid_not_on_or_after",
|
||||
"session_valid_not_on_or_after",
|
||||
@@ -242,7 +220,6 @@ class SAMLProviderSerializer(ProviderSerializer):
|
||||
"default_relay_state",
|
||||
"default_name_id_policy",
|
||||
"url_download_metadata",
|
||||
"url_issuer",
|
||||
"url_sso_post",
|
||||
"url_sso_redirect",
|
||||
"url_sso_init",
|
||||
|
||||
@@ -1,34 +0,0 @@
|
||||
# Generated by Django 5.2.11 on 2026-02-24 06:03
|
||||
|
||||
from django.db import migrations, models
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
|
||||
dependencies = [
|
||||
("authentik_providers_saml", "0021_samlprovider_sign_logout_response"),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.RenameField(
|
||||
model_name="samlprovider",
|
||||
old_name="issuer",
|
||||
new_name="issuer_override",
|
||||
),
|
||||
migrations.AlterField(
|
||||
model_name="samlprovider",
|
||||
name="issuer_override",
|
||||
field=models.TextField(
|
||||
blank=True,
|
||||
default="",
|
||||
help_text="Also known as EntityID. Providing a value overrides the default issuer generated by authentik.",
|
||||
),
|
||||
),
|
||||
migrations.AddField(
|
||||
model_name="samlsession",
|
||||
name="issuer",
|
||||
field=models.TextField(
|
||||
default=None, help_text="SAML Issuer used for this session", null=True
|
||||
),
|
||||
),
|
||||
]
|
||||
@@ -77,14 +77,7 @@ class SAMLProvider(Provider):
|
||||
"no audience restriction will be added."
|
||||
),
|
||||
)
|
||||
issuer_override = models.TextField(
|
||||
blank=True,
|
||||
default="",
|
||||
help_text=_(
|
||||
"Also known as EntityID. Providing a value overrides the default issuer "
|
||||
"generated by authentik."
|
||||
),
|
||||
)
|
||||
issuer = models.TextField(help_text=_("Also known as EntityID"), default="authentik")
|
||||
sls_url = models.TextField(
|
||||
blank=True,
|
||||
validators=[DomainlessURLValidator(schemes=("http", "https"))],
|
||||
@@ -325,9 +318,6 @@ class SAMLSession(InternallyManagedMixin, SerializerModel, ExpiringModel):
|
||||
session_index = models.TextField(help_text=_("SAML SessionIndex for this session"))
|
||||
name_id = models.TextField(help_text=_("SAML NameID value for this session"))
|
||||
name_id_format = models.TextField(default="", blank=True, help_text=_("SAML NameID format"))
|
||||
issuer = models.TextField(
|
||||
default=None, null=True, help_text=_("SAML Issuer used for this session")
|
||||
)
|
||||
created = models.DateTimeField(auto_now_add=True)
|
||||
|
||||
@property
|
||||
|
||||
@@ -6,7 +6,6 @@ from types import GeneratorType
|
||||
|
||||
import xmlsec
|
||||
from django.http import HttpRequest
|
||||
from django.urls import reverse
|
||||
from django.utils.timezone import now
|
||||
from lxml import etree # nosec
|
||||
from lxml.etree import Element, SubElement, _Element # nosec
|
||||
@@ -64,7 +63,6 @@ class AssertionProcessor:
|
||||
session_index: str
|
||||
name_id: str
|
||||
name_id_format: str
|
||||
issuer: str
|
||||
session_not_on_or_after_datetime: datetime
|
||||
|
||||
def __init__(self, provider: SAMLProvider, request: HttpRequest, auth_n_request: AuthNRequest):
|
||||
@@ -139,24 +137,10 @@ class AssertionProcessor:
|
||||
continue
|
||||
return attribute_statement
|
||||
|
||||
def _get_issuer_value(self) -> str:
|
||||
"""Get issuer value, with fallback to generated URL if empty"""
|
||||
# If user has set an override issuer, use it
|
||||
if self.provider.issuer_override:
|
||||
return self.provider.issuer_override
|
||||
|
||||
return self.http_request.build_absolute_uri(
|
||||
reverse(
|
||||
"authentik_providers_saml:base",
|
||||
kwargs={"application_slug": self.provider.application.slug},
|
||||
)
|
||||
)
|
||||
|
||||
def get_issuer(self) -> Element:
|
||||
"""Get Issuer Element"""
|
||||
issuer = Element(f"{{{NS_SAML_ASSERTION}}}Issuer", nsmap=NS_MAP)
|
||||
self.issuer = self._get_issuer_value()
|
||||
issuer.text = self.issuer
|
||||
issuer.text = self.provider.issuer
|
||||
return issuer
|
||||
|
||||
def get_assertion_auth_n_statement(self) -> Element:
|
||||
|
||||
@@ -8,7 +8,6 @@ from lxml import etree # nosec
|
||||
from lxml.etree import Element, _Element
|
||||
|
||||
from authentik.common.saml.constants import (
|
||||
DEFAULT_ISSUER,
|
||||
DIGEST_ALGORITHM_TRANSLATION_MAP,
|
||||
NS_MAP,
|
||||
NS_SAML_ASSERTION,
|
||||
@@ -34,12 +33,11 @@ class LogoutRequestProcessor:
|
||||
name_id_format: str
|
||||
session_index: str | None
|
||||
relay_state: str | None
|
||||
issuer: str | None
|
||||
|
||||
_issue_instant: str
|
||||
_request_id: str
|
||||
|
||||
def __init__( # noqa: PLR0913
|
||||
def __init__(
|
||||
self,
|
||||
provider: SAMLProvider,
|
||||
user: User | None,
|
||||
@@ -48,7 +46,6 @@ class LogoutRequestProcessor:
|
||||
name_id_format: str = SAML_NAME_ID_FORMAT_EMAIL,
|
||||
session_index: str | None = None,
|
||||
relay_state: str | None = None,
|
||||
issuer: str | None = None,
|
||||
):
|
||||
self.provider = provider
|
||||
self.user = user
|
||||
@@ -57,23 +54,14 @@ class LogoutRequestProcessor:
|
||||
self.name_id_format = name_id_format
|
||||
self.session_index = session_index
|
||||
self.relay_state = relay_state
|
||||
self.issuer = issuer
|
||||
|
||||
self._issue_instant = get_time_string()
|
||||
self._request_id = get_random_id()
|
||||
|
||||
def _get_issuer_value(self) -> str:
|
||||
"""Get issuer value from session, with fallback to provider"""
|
||||
if self.issuer:
|
||||
return self.issuer
|
||||
if self.provider.issuer_override:
|
||||
return self.provider.issuer_override
|
||||
return DEFAULT_ISSUER
|
||||
|
||||
def get_issuer(self) -> Element:
|
||||
"""Get Issuer element"""
|
||||
issuer = Element(f"{{{NS_SAML_ASSERTION}}}Issuer")
|
||||
issuer.text = self._get_issuer_value()
|
||||
issuer.text = self.provider.issuer
|
||||
return issuer
|
||||
|
||||
def get_name_id(self) -> Element:
|
||||
|
||||
@@ -8,7 +8,6 @@ from lxml import etree
|
||||
from lxml.etree import Element, SubElement
|
||||
|
||||
from authentik.common.saml.constants import (
|
||||
DEFAULT_ISSUER,
|
||||
DIGEST_ALGORITHM_TRANSLATION_MAP,
|
||||
NS_MAP,
|
||||
NS_SAML_ASSERTION,
|
||||
@@ -29,38 +28,27 @@ class LogoutResponseProcessor:
|
||||
logout_request: LogoutRequest
|
||||
destination: str | None
|
||||
relay_state: str | None
|
||||
issuer: str | None
|
||||
_issue_instant: str
|
||||
_response_id: str
|
||||
|
||||
def __init__( # noqa: PLR0913
|
||||
def __init__(
|
||||
self,
|
||||
provider: SAMLProvider,
|
||||
logout_request: LogoutRequest,
|
||||
destination: str | None = None,
|
||||
relay_state: str | None = None,
|
||||
issuer: str | None = None,
|
||||
):
|
||||
self.provider = provider
|
||||
self.logout_request = logout_request
|
||||
self.destination = destination
|
||||
self.relay_state = relay_state or (logout_request.relay_state if logout_request else None)
|
||||
self.issuer = issuer
|
||||
self._issue_instant = get_time_string()
|
||||
self._response_id = get_random_id()
|
||||
|
||||
def _get_issuer_value(self) -> str:
|
||||
"""Get issuer value from session, with fallback to provider"""
|
||||
if self.issuer:
|
||||
return self.issuer
|
||||
if self.provider.issuer_override:
|
||||
return self.provider.issuer_override
|
||||
return DEFAULT_ISSUER
|
||||
|
||||
def get_issuer(self) -> Element:
|
||||
"""Get Issuer element"""
|
||||
issuer = Element(f"{{{NS_SAML_ASSERTION}}}Issuer")
|
||||
issuer.text = self._get_issuer_value()
|
||||
issuer.text = self.provider.issuer
|
||||
return issuer
|
||||
|
||||
def build(self, status: str = "Success") -> Element:
|
||||
|
||||
@@ -40,19 +40,6 @@ class MetadataProcessor:
|
||||
self.force_binding = None
|
||||
self.xml_id = "_" + sha256(f"{provider.name}-{provider.pk}".encode("ascii")).hexdigest()
|
||||
|
||||
def _get_issuer_value(self) -> str:
|
||||
"""Get issuer value, with fallback to generated URL if empty"""
|
||||
# If user has set an override issuer, use it
|
||||
if self.provider.issuer_override:
|
||||
return self.provider.issuer_override
|
||||
|
||||
return self.http_request.build_absolute_uri(
|
||||
reverse(
|
||||
"authentik_providers_saml:base",
|
||||
kwargs={"application_slug": self.provider.application.slug},
|
||||
)
|
||||
)
|
||||
|
||||
# Using type unions doesn't work with cython types (which is what lxml is)
|
||||
def get_signing_key_descriptor(self) -> Element | None:
|
||||
"""Get Signing KeyDescriptor, if enabled for the provider"""
|
||||
@@ -202,7 +189,7 @@ class MetadataProcessor:
|
||||
"""Build full EntityDescriptor"""
|
||||
entity_descriptor = Element(f"{{{NS_SAML_METADATA}}}EntityDescriptor", nsmap=NS_MAP)
|
||||
entity_descriptor.attrib["ID"] = self.xml_id
|
||||
entity_descriptor.attrib["entityID"] = self._get_issuer_value()
|
||||
entity_descriptor.attrib["entityID"] = self.provider.issuer
|
||||
|
||||
if self.provider.signing_kp:
|
||||
self._prepare_signature(entity_descriptor)
|
||||
|
||||
@@ -51,6 +51,7 @@ class ServiceProviderMetadata:
|
||||
provider = SAMLProvider.objects.create(
|
||||
name=name, authorization_flow=authorization_flow, invalidation_flow=invalidation_flow
|
||||
)
|
||||
provider.issuer = self.entity_id
|
||||
provider.sp_binding = self.acs_binding
|
||||
provider.acs_url = self.acs_location
|
||||
provider.default_name_id_policy = self.name_id_policy
|
||||
|
||||
@@ -75,7 +75,6 @@ def handle_saml_iframe_pre_user_logout(
|
||||
name_id_format=session.name_id_format,
|
||||
session_index=session.session_index,
|
||||
relay_state=relay_state,
|
||||
issuer=session.issuer,
|
||||
)
|
||||
|
||||
if session.provider.sls_binding == SAMLBindings.POST:
|
||||
@@ -164,7 +163,6 @@ def handle_flow_pre_user_logout(
|
||||
name_id_format=session.name_id_format,
|
||||
session_index=session.session_index,
|
||||
relay_state=relay_state,
|
||||
issuer=session.issuer,
|
||||
)
|
||||
|
||||
if session.provider.sls_binding == SAMLBindings.POST:
|
||||
@@ -226,7 +224,6 @@ def user_session_deleted_saml_logout(sender, instance: AuthenticatedSession, **_
|
||||
name_id=saml_session.name_id,
|
||||
name_id_format=saml_session.name_id_format,
|
||||
session_index=saml_session.session_index,
|
||||
issuer=saml_session.issuer,
|
||||
)
|
||||
|
||||
|
||||
@@ -260,5 +257,4 @@ def user_deactivated_saml_logout(sender, instance: User, **kwargs):
|
||||
name_id=saml_session.name_id,
|
||||
name_id_format=saml_session.name_id_format,
|
||||
session_index=saml_session.session_index,
|
||||
issuer=saml_session.issuer,
|
||||
)
|
||||
|
||||
@@ -22,7 +22,6 @@ def send_saml_logout_request(
|
||||
name_id: str,
|
||||
name_id_format: str,
|
||||
session_index: str,
|
||||
issuer: str,
|
||||
):
|
||||
"""Send SAML LogoutRequest to a Service Provider using session data"""
|
||||
provider = SAMLProvider.objects.filter(pk=provider_pk).first()
|
||||
@@ -48,7 +47,6 @@ def send_saml_logout_request(
|
||||
name_id=name_id,
|
||||
name_id_format=name_id_format,
|
||||
session_index=session_index,
|
||||
issuer=issuer,
|
||||
)
|
||||
|
||||
return send_post_logout_request(provider, processor)
|
||||
@@ -91,7 +89,6 @@ def send_saml_logout_response(
|
||||
sls_url: str,
|
||||
logout_request_id: str | None = None,
|
||||
relay_state: str | None = None,
|
||||
issuer: str | None = None,
|
||||
):
|
||||
"""Send SAML LogoutResponse to a Service Provider using backchannel (server-to-server)"""
|
||||
provider = SAMLProvider.objects.filter(pk=provider_pk).first()
|
||||
@@ -122,7 +119,6 @@ def send_saml_logout_response(
|
||||
logout_request=logout_request,
|
||||
destination=sls_url,
|
||||
relay_state=relay_state,
|
||||
issuer=issuer,
|
||||
)
|
||||
|
||||
encoded_response = processor.encode_post()
|
||||
|
||||
@@ -15,7 +15,6 @@ from authentik.common.saml.constants import (
|
||||
SAML_NAME_ID_FORMAT_EMAIL,
|
||||
SAML_NAME_ID_FORMAT_UNSPECIFIED,
|
||||
)
|
||||
from authentik.core.models import Application
|
||||
from authentik.core.tests.utils import (
|
||||
RequestFactory,
|
||||
create_test_admin_user,
|
||||
@@ -98,11 +97,6 @@ class TestAuthNRequest(TestCase):
|
||||
)
|
||||
self.provider.property_mappings.set(SAMLPropertyMapping.objects.all())
|
||||
self.provider.save()
|
||||
Application.objects.create(
|
||||
name="test-app",
|
||||
slug="test-app",
|
||||
provider=self.provider,
|
||||
)
|
||||
self.source = SAMLSource.objects.create(
|
||||
slug="provider",
|
||||
issuer="authentik",
|
||||
@@ -532,7 +526,7 @@ class TestAuthNRequest(TestCase):
|
||||
authorization_flow=create_test_flow(),
|
||||
acs_url="https://10.120.20.200/saml-sp/SAML2/POST",
|
||||
audience="https://10.120.20.200/saml-sp/SAML2/POST",
|
||||
issuer_override="https://10.120.20.200/saml-sp/SAML2/POST",
|
||||
issuer="https://10.120.20.200/saml-sp/SAML2/POST",
|
||||
signing_kp=static_keypair,
|
||||
verification_kp=static_keypair,
|
||||
)
|
||||
@@ -553,7 +547,7 @@ class TestAuthNRequest(TestCase):
|
||||
"saml/acs/2d737f96-55fb-4035-953e-5e24134eb778"
|
||||
),
|
||||
audience="https://10.120.20.200/saml-sp/SAML2/POST",
|
||||
issuer_override="https://10.120.20.200/saml-sp/SAML2/POST",
|
||||
issuer="https://10.120.20.200/saml-sp/SAML2/POST",
|
||||
signing_kp=create_test_cert(),
|
||||
)
|
||||
parsed_request = AuthNRequestParser(provider).parse(POST_REQUEST)
|
||||
|
||||
@@ -47,7 +47,7 @@ class TestNativeLogoutStageView(TestCase):
|
||||
authorization_flow=self.flow,
|
||||
acs_url="https://sp1.example.com/acs",
|
||||
sls_url="https://sp1.example.com/sls",
|
||||
issuer_override="https://idp.example.com",
|
||||
issuer="https://idp.example.com",
|
||||
sp_binding="redirect",
|
||||
sls_binding="redirect",
|
||||
logout_method=SAMLLogoutMethods.FRONTCHANNEL_NATIVE,
|
||||
@@ -58,7 +58,7 @@ class TestNativeLogoutStageView(TestCase):
|
||||
authorization_flow=self.flow,
|
||||
acs_url="https://sp2.example.com/acs",
|
||||
sls_url="https://sp2.example.com/sls",
|
||||
issuer_override="https://idp.example.com",
|
||||
issuer="https://idp.example.com",
|
||||
sp_binding="post",
|
||||
sls_binding="post",
|
||||
logout_method=SAMLLogoutMethods.FRONTCHANNEL_NATIVE,
|
||||
@@ -218,7 +218,7 @@ class TestIframeLogoutStageView(TestCase):
|
||||
authorization_flow=self.flow,
|
||||
acs_url="https://sp1.example.com/acs",
|
||||
sls_url="https://sp1.example.com/sls",
|
||||
issuer_override="https://idp.example.com",
|
||||
issuer="https://idp.example.com",
|
||||
sp_binding="redirect",
|
||||
sls_binding="redirect",
|
||||
logout_method="frontchannel_iframe",
|
||||
@@ -229,7 +229,7 @@ class TestIframeLogoutStageView(TestCase):
|
||||
authorization_flow=self.flow,
|
||||
acs_url="https://sp2.example.com/acs",
|
||||
sls_url="https://sp2.example.com/sls",
|
||||
issuer_override="https://idp.example.com",
|
||||
issuer="https://idp.example.com",
|
||||
sp_binding="post",
|
||||
sls_binding="post",
|
||||
logout_method="frontchannel_iframe",
|
||||
@@ -372,7 +372,7 @@ class TestIdPLogoutIntegration(FlowTestCase):
|
||||
authorization_flow=self.flow,
|
||||
acs_url="https://sp.example.com/acs",
|
||||
sls_url="https://sp.example.com/sls",
|
||||
issuer_override="https://idp.example.com",
|
||||
issuer="https://idp.example.com",
|
||||
sp_binding="redirect",
|
||||
sls_binding="redirect",
|
||||
signing_kp=self.keypair,
|
||||
|
||||
@@ -28,7 +28,7 @@ class TestLogoutIntegration(TestCase):
|
||||
authorization_flow=self.flow,
|
||||
acs_url="https://sp.example.com/acs",
|
||||
sls_url="https://sp.example.com/sls",
|
||||
issuer_override="https://idp.example.com",
|
||||
issuer="https://idp.example.com",
|
||||
sp_binding="redirect",
|
||||
sls_binding="redirect",
|
||||
signature_algorithm=RSA_SHA256,
|
||||
@@ -57,7 +57,7 @@ class TestLogoutIntegration(TestCase):
|
||||
parsed = self.parser.parse(encoded)
|
||||
|
||||
# Verify all fields match
|
||||
self.assertEqual(parsed.issuer, self.provider.issuer_override)
|
||||
self.assertEqual(parsed.issuer, self.provider.issuer)
|
||||
self.assertEqual(parsed.name_id, "test@example.com")
|
||||
self.assertEqual(parsed.name_id_format, SAML_NAME_ID_FORMAT_EMAIL)
|
||||
self.assertEqual(parsed.session_index, "test-session-123")
|
||||
@@ -72,7 +72,7 @@ class TestLogoutIntegration(TestCase):
|
||||
parsed = self.parser.parse_detached(encoded)
|
||||
|
||||
# Verify all fields match
|
||||
self.assertEqual(parsed.issuer, self.provider.issuer_override)
|
||||
self.assertEqual(parsed.issuer, self.provider.issuer)
|
||||
self.assertEqual(parsed.name_id, "test@example.com")
|
||||
self.assertEqual(parsed.name_id_format, SAML_NAME_ID_FORMAT_EMAIL)
|
||||
self.assertEqual(parsed.session_index, "test-session-123")
|
||||
@@ -106,7 +106,7 @@ class TestLogoutIntegration(TestCase):
|
||||
parsed = parser.parse(encoded)
|
||||
|
||||
# Verify all fields match
|
||||
self.assertEqual(parsed.issuer, self.provider.issuer_override)
|
||||
self.assertEqual(parsed.issuer, self.provider.issuer)
|
||||
self.assertEqual(parsed.name_id, "signed@example.com")
|
||||
self.assertEqual(parsed.name_id_format, SAML_NAME_ID_FORMAT_EMAIL)
|
||||
self.assertEqual(parsed.session_index, "signed-session-456")
|
||||
@@ -125,7 +125,7 @@ class TestLogoutIntegration(TestCase):
|
||||
parsed = self.parser.parse_detached(saml_request)
|
||||
|
||||
# Verify parsing succeeded
|
||||
self.assertEqual(parsed.issuer, self.provider.issuer_override)
|
||||
self.assertEqual(parsed.issuer, self.provider.issuer)
|
||||
self.assertEqual(parsed.name_id, "test@example.com")
|
||||
self.assertEqual(parsed.name_id_format, SAML_NAME_ID_FORMAT_EMAIL)
|
||||
|
||||
@@ -164,7 +164,7 @@ class TestLogoutIntegration(TestCase):
|
||||
|
||||
# Parse the SAMLRequest (unsigned XML)
|
||||
parsed = self.parser.parse_detached(params["SAMLRequest"][0])
|
||||
self.assertEqual(parsed.issuer, self.provider.issuer_override)
|
||||
self.assertEqual(parsed.issuer, self.provider.issuer)
|
||||
|
||||
def test_form_data_can_be_parsed(self):
|
||||
"""Test that form data generates parseable POST request"""
|
||||
@@ -175,7 +175,7 @@ class TestLogoutIntegration(TestCase):
|
||||
parsed = self.parser.parse(form_data["SAMLRequest"])
|
||||
|
||||
# Verify parsing succeeded
|
||||
self.assertEqual(parsed.issuer, self.provider.issuer_override)
|
||||
self.assertEqual(parsed.issuer, self.provider.issuer)
|
||||
self.assertEqual(parsed.name_id, "test@example.com")
|
||||
self.assertEqual(parsed.name_id_format, SAML_NAME_ID_FORMAT_EMAIL)
|
||||
self.assertEqual(parsed.session_index, "test-session-123")
|
||||
@@ -244,4 +244,4 @@ class TestLogoutIntegration(TestCase):
|
||||
|
||||
# But same issuer
|
||||
self.assertEqual(parsed1.issuer, parsed2.issuer)
|
||||
self.assertEqual(parsed1.issuer, self.provider.issuer_override)
|
||||
self.assertEqual(parsed1.issuer, self.provider.issuer)
|
||||
|
||||
@@ -35,7 +35,7 @@ class TestLogoutRequestProcessor(TestCase):
|
||||
authorization_flow=self.flow,
|
||||
acs_url="https://sp.example.com/acs",
|
||||
sls_url="https://sp.example.com/sls",
|
||||
issuer_override="https://idp.example.com",
|
||||
issuer="https://idp.example.com",
|
||||
sp_binding="redirect",
|
||||
sls_binding="redirect",
|
||||
signature_algorithm=RSA_SHA256,
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
"""logout response tests"""
|
||||
|
||||
from defusedxml import ElementTree
|
||||
from django.test import RequestFactory, TestCase
|
||||
from django.test import TestCase
|
||||
|
||||
from authentik.blueprints.tests import apply_blueprint
|
||||
from authentik.common.saml.constants import (
|
||||
@@ -9,13 +9,10 @@ from authentik.common.saml.constants import (
|
||||
NS_SAML_PROTOCOL,
|
||||
NS_SIGNATURE,
|
||||
)
|
||||
from authentik.core.models import Application
|
||||
from authentik.core.tests.utils import create_test_cert, create_test_flow
|
||||
from authentik.lib.generators import generate_id
|
||||
from authentik.providers.saml.models import SAMLPropertyMapping, SAMLProvider
|
||||
from authentik.providers.saml.processors.logout_request_parser import LogoutRequest
|
||||
from authentik.providers.saml.processors.logout_response_processor import LogoutResponseProcessor
|
||||
from authentik.providers.saml.processors.metadata import MetadataProcessor
|
||||
|
||||
|
||||
class TestLogoutResponse(TestCase):
|
||||
@@ -24,7 +21,6 @@ class TestLogoutResponse(TestCase):
|
||||
@apply_blueprint("system/providers-saml.yaml")
|
||||
def setUp(self):
|
||||
cert = create_test_cert()
|
||||
self.factory = RequestFactory()
|
||||
self.provider: SAMLProvider = SAMLProvider.objects.create(
|
||||
authorization_flow=create_test_flow(),
|
||||
acs_url="http://testserver/source/saml/provider/acs/",
|
||||
@@ -34,31 +30,17 @@ class TestLogoutResponse(TestCase):
|
||||
)
|
||||
self.provider.property_mappings.set(SAMLPropertyMapping.objects.all())
|
||||
self.provider.save()
|
||||
self.application = Application.objects.create(
|
||||
name=generate_id(),
|
||||
slug=generate_id(),
|
||||
provider=self.provider,
|
||||
)
|
||||
|
||||
def test_build_response(self):
|
||||
"""Test building a LogoutResponse uses the generated issuer from the assertion"""
|
||||
# Generate the issuer the same way the assertion/metadata processors would
|
||||
request = self.factory.get("/")
|
||||
metadata_processor = MetadataProcessor(self.provider, request)
|
||||
generated_issuer = metadata_processor._get_issuer_value()
|
||||
|
||||
"""Test building a LogoutResponse"""
|
||||
logout_request = LogoutRequest(
|
||||
id="test-request-id",
|
||||
issuer="test-sp",
|
||||
relay_state="test-relay-state",
|
||||
)
|
||||
|
||||
# Pass the generated issuer as if it came from SAMLSession.issuer
|
||||
processor = LogoutResponseProcessor(
|
||||
self.provider,
|
||||
logout_request,
|
||||
destination=self.provider.sls_url,
|
||||
issuer=generated_issuer,
|
||||
self.provider, logout_request, destination=self.provider.sls_url
|
||||
)
|
||||
response_xml = processor.build_response(status="Success")
|
||||
|
||||
@@ -69,9 +51,9 @@ class TestLogoutResponse(TestCase):
|
||||
self.assertEqual(root.attrib["Destination"], self.provider.sls_url)
|
||||
self.assertEqual(root.attrib["InResponseTo"], "test-request-id")
|
||||
|
||||
# Check Issuer matches the generated issuer from the assertion processor
|
||||
# Check Issuer
|
||||
issuer = root.find(f"{{{NS_SAML_ASSERTION}}}Issuer")
|
||||
self.assertEqual(issuer.text, generated_issuer)
|
||||
self.assertEqual(issuer.text, self.provider.issuer)
|
||||
|
||||
# Check Status
|
||||
status = root.find(f".//{{{NS_SAML_PROTOCOL}}}StatusCode")
|
||||
|
||||
@@ -85,6 +85,7 @@ class TestServiceProviderMetadataParser(TestCase):
|
||||
metadata = ServiceProviderMetadataParser().parse(load_fixture("fixtures/simple.xml"))
|
||||
provider = metadata.to_provider("test", self.flow, self.flow)
|
||||
self.assertEqual(provider.acs_url, "http://localhost:8080/saml/acs")
|
||||
self.assertEqual(provider.issuer, "http://localhost:8080/saml/metadata")
|
||||
self.assertEqual(provider.sp_binding, SAMLBindings.POST)
|
||||
self.assertEqual(provider.default_name_id_policy, SAMLNameIDPolicy.EMAIL)
|
||||
self.assertEqual(
|
||||
@@ -98,6 +99,7 @@ class TestServiceProviderMetadataParser(TestCase):
|
||||
metadata = ServiceProviderMetadataParser().parse(load_fixture("fixtures/cert.xml"))
|
||||
provider = metadata.to_provider("test", self.flow, self.flow)
|
||||
self.assertEqual(provider.acs_url, "http://localhost:8080/apps/user_saml/saml/acs")
|
||||
self.assertEqual(provider.issuer, "http://localhost:8080/apps/user_saml/saml/metadata")
|
||||
self.assertEqual(provider.sp_binding, SAMLBindings.POST)
|
||||
self.assertEqual(
|
||||
provider.verification_kp.certificate_data, load_fixture("fixtures/cert.pem")
|
||||
|
||||
@@ -32,7 +32,7 @@ class TestSAMLSessionModel(TestCase):
|
||||
name="test-provider",
|
||||
authorization_flow=self.flow,
|
||||
acs_url="https://sp.example.com/acs",
|
||||
issuer_override="https://idp.example.com",
|
||||
issuer="https://idp.example.com",
|
||||
)
|
||||
|
||||
# Create another provider for testing
|
||||
@@ -40,7 +40,7 @@ class TestSAMLSessionModel(TestCase):
|
||||
name="test-provider-2",
|
||||
authorization_flow=self.flow,
|
||||
acs_url="https://sp2.example.com/acs",
|
||||
issuer_override="https://idp2.example.com",
|
||||
issuer="https://idp2.example.com",
|
||||
)
|
||||
|
||||
# Create a session first (using authentik's custom Session model)
|
||||
@@ -72,7 +72,6 @@ class TestSAMLSessionModel(TestCase):
|
||||
name_id_format=self.name_id_format,
|
||||
expires=self.expires,
|
||||
expiring=True,
|
||||
issuer="authentik",
|
||||
)
|
||||
|
||||
# Verify the session was created
|
||||
@@ -101,7 +100,6 @@ class TestSAMLSessionModel(TestCase):
|
||||
name_id_format=self.name_id_format,
|
||||
expires=self.expires,
|
||||
expiring=True,
|
||||
issuer="authentik",
|
||||
)
|
||||
|
||||
# Try to create another session with same session_index and provider
|
||||
@@ -115,7 +113,6 @@ class TestSAMLSessionModel(TestCase):
|
||||
name_id_format=self.name_id_format,
|
||||
expires=self.expires,
|
||||
expiring=True,
|
||||
issuer="authentik",
|
||||
)
|
||||
|
||||
def test_cascade_deletion_user(self):
|
||||
@@ -130,7 +127,6 @@ class TestSAMLSessionModel(TestCase):
|
||||
name_id_format=self.name_id_format,
|
||||
expires=self.expires,
|
||||
expiring=True,
|
||||
issuer="authentik",
|
||||
)
|
||||
|
||||
# Verify session exists
|
||||
@@ -154,7 +150,6 @@ class TestSAMLSessionModel(TestCase):
|
||||
name_id_format=self.name_id_format,
|
||||
expires=self.expires,
|
||||
expiring=True,
|
||||
issuer="authentik",
|
||||
)
|
||||
|
||||
# Verify session exists
|
||||
@@ -178,7 +173,6 @@ class TestSAMLSessionModel(TestCase):
|
||||
name_id_format=self.name_id_format,
|
||||
expires=self.expires,
|
||||
expiring=True,
|
||||
issuer="authentik",
|
||||
)
|
||||
|
||||
# Verify session exists
|
||||
@@ -202,7 +196,6 @@ class TestSAMLSessionModel(TestCase):
|
||||
name_id_format=self.name_id_format,
|
||||
expires=self.expires,
|
||||
expiring=True,
|
||||
issuer="authentik",
|
||||
)
|
||||
|
||||
# Create second session with different provider
|
||||
@@ -215,7 +208,6 @@ class TestSAMLSessionModel(TestCase):
|
||||
name_id_format=self.name_id_format,
|
||||
expires=self.expires,
|
||||
expiring=True,
|
||||
issuer="authentik",
|
||||
)
|
||||
|
||||
# Verify both sessions exist
|
||||
@@ -237,7 +229,6 @@ class TestSAMLSessionModel(TestCase):
|
||||
name_id_format=self.name_id_format,
|
||||
expires=future_time,
|
||||
expiring=True,
|
||||
issuer="authentik",
|
||||
)
|
||||
|
||||
# Verify expiry time
|
||||
@@ -257,7 +248,6 @@ class TestSAMLSessionModel(TestCase):
|
||||
name_id_format=self.name_id_format,
|
||||
expires=past_time,
|
||||
expiring=True,
|
||||
issuer="authentik",
|
||||
)
|
||||
|
||||
# Check if marked as expired
|
||||
@@ -275,7 +265,6 @@ class TestSAMLSessionModel(TestCase):
|
||||
name_id_format="", # Blank format
|
||||
expires=self.expires,
|
||||
expiring=True,
|
||||
issuer="authentik",
|
||||
)
|
||||
|
||||
# Verify it was created successfully
|
||||
@@ -294,7 +283,6 @@ class TestSAMLSessionModel(TestCase):
|
||||
name_id_format=self.name_id_format,
|
||||
expires=self.expires,
|
||||
expiring=True,
|
||||
issuer="authentik",
|
||||
)
|
||||
|
||||
session2 = SAMLSession.objects.create(
|
||||
@@ -306,7 +294,6 @@ class TestSAMLSessionModel(TestCase):
|
||||
name_id_format=self.name_id_format,
|
||||
expires=self.expires,
|
||||
expiring=True,
|
||||
issuer="authentik",
|
||||
)
|
||||
|
||||
# Query by provider
|
||||
@@ -329,7 +316,6 @@ class TestSAMLSessionModel(TestCase):
|
||||
name_id_format=self.name_id_format,
|
||||
expires=self.expires,
|
||||
expiring=True,
|
||||
issuer="authentik",
|
||||
)
|
||||
|
||||
# Check serializer property
|
||||
@@ -348,7 +334,6 @@ class TestSAMLSessionModel(TestCase):
|
||||
name_id_format=self.name_id_format,
|
||||
expires=self.expires,
|
||||
expiring=True,
|
||||
issuer="authentik",
|
||||
)
|
||||
|
||||
# Verify sessions exist
|
||||
|
||||
@@ -7,7 +7,6 @@ from guardian.shortcuts import get_anonymous_user
|
||||
from lxml import etree # nosec
|
||||
|
||||
from authentik.blueprints.tests import apply_blueprint
|
||||
from authentik.core.models import Application
|
||||
from authentik.core.tests.utils import RequestFactory, create_test_cert, create_test_flow
|
||||
from authentik.lib.xml import lxml_from_string
|
||||
from authentik.providers.saml.models import SAMLPropertyMapping, SAMLProvider
|
||||
@@ -31,11 +30,6 @@ class TestSchema(TestCase):
|
||||
)
|
||||
self.provider.property_mappings.set(SAMLPropertyMapping.objects.all())
|
||||
self.provider.save()
|
||||
Application.objects.create(
|
||||
name="test-app",
|
||||
slug="test-app",
|
||||
provider=self.provider,
|
||||
)
|
||||
self.source = SAMLSource.objects.create(
|
||||
slug="provider",
|
||||
issuer="authentik",
|
||||
|
||||
@@ -28,7 +28,7 @@ class TestSendSamlLogoutResponse(TestCase):
|
||||
authorization_flow=self.flow,
|
||||
acs_url="https://sp.example.com/acs",
|
||||
sls_url="https://sp.example.com/sls",
|
||||
issuer_override="https://idp.example.com",
|
||||
issuer="https://idp.example.com",
|
||||
signing_kp=self.cert,
|
||||
)
|
||||
|
||||
@@ -137,7 +137,7 @@ class TestSendSamlLogoutRequest(TestCase):
|
||||
authorization_flow=self.flow,
|
||||
acs_url="https://sp.example.com/acs",
|
||||
sls_url="https://sp.example.com/sls",
|
||||
issuer_override="https://idp.example.com",
|
||||
issuer="https://idp.example.com",
|
||||
signing_kp=self.cert,
|
||||
)
|
||||
|
||||
@@ -155,7 +155,6 @@ class TestSendSamlLogoutRequest(TestCase):
|
||||
name_id="test@example.com",
|
||||
name_id_format=SAML_NAME_ID_FORMAT_EMAIL,
|
||||
session_index="test-session-123",
|
||||
issuer="https://idp.example.com",
|
||||
)
|
||||
|
||||
self.assertTrue(result)
|
||||
@@ -180,7 +179,6 @@ class TestSendSamlLogoutRequest(TestCase):
|
||||
name_id="test@example.com",
|
||||
name_id_format=SAML_NAME_ID_FORMAT_EMAIL,
|
||||
session_index="test-session-123",
|
||||
issuer="https://idp.example.com",
|
||||
)
|
||||
|
||||
self.assertFalse(result)
|
||||
@@ -200,7 +198,6 @@ class TestSendSamlLogoutRequest(TestCase):
|
||||
name_id="test@example.com",
|
||||
name_id_format=SAML_NAME_ID_FORMAT_EMAIL,
|
||||
session_index="test-session-123",
|
||||
issuer="https://idp.example.com",
|
||||
)
|
||||
|
||||
|
||||
@@ -217,7 +214,7 @@ class TestSendPostLogoutRequest(TestCase):
|
||||
authorization_flow=self.flow,
|
||||
acs_url="https://sp.example.com/acs",
|
||||
sls_url="https://sp.example.com/sls",
|
||||
issuer_override="https://idp.example.com",
|
||||
issuer="https://idp.example.com",
|
||||
signing_kp=self.cert,
|
||||
)
|
||||
|
||||
|
||||
@@ -40,7 +40,7 @@ class TestSPInitiatedSLOViews(TestCase):
|
||||
invalidation_flow=self.invalidation_flow,
|
||||
acs_url="https://sp.example.com/acs",
|
||||
sls_url="https://sp.example.com/sls",
|
||||
issuer_override="https://idp.example.com",
|
||||
issuer="https://idp.example.com",
|
||||
sp_binding="redirect",
|
||||
sls_binding="redirect",
|
||||
)
|
||||
@@ -90,7 +90,7 @@ class TestSPInitiatedSLOViews(TestCase):
|
||||
# Verify logout request was stored in plan context
|
||||
self.assertIn("authentik/providers/saml/logout_request", view.plan_context)
|
||||
logout_request = view.plan_context["authentik/providers/saml/logout_request"]
|
||||
self.assertEqual(logout_request.issuer, self.provider.issuer_override)
|
||||
self.assertEqual(logout_request.issuer, self.provider.issuer)
|
||||
self.assertEqual(logout_request.session_index, "test-session-123")
|
||||
|
||||
def test_redirect_view_handles_logout_response_with_plan_context(self):
|
||||
@@ -228,7 +228,7 @@ class TestSPInitiatedSLOViews(TestCase):
|
||||
# Verify logout request was stored in plan context
|
||||
self.assertIn("authentik/providers/saml/logout_request", view.plan_context)
|
||||
logout_request = view.plan_context["authentik/providers/saml/logout_request"]
|
||||
self.assertEqual(logout_request.issuer, self.provider.issuer_override)
|
||||
self.assertEqual(logout_request.issuer, self.provider.issuer)
|
||||
self.assertEqual(logout_request.session_index, "test-session-123")
|
||||
|
||||
def test_post_view_handles_logout_response_with_plan_context(self):
|
||||
@@ -396,7 +396,7 @@ class TestSPInitiatedSLOViews(TestCase):
|
||||
authorization_flow=self.flow,
|
||||
acs_url="https://sp2.example.com/acs",
|
||||
sls_url="https://sp2.example.com/sls",
|
||||
issuer_override="https://idp2.example.com",
|
||||
issuer="https://idp2.example.com",
|
||||
invalidation_flow=None, # No invalidation flow
|
||||
)
|
||||
|
||||
@@ -524,7 +524,7 @@ class TestSPInitiatedSLOLogoutMethods(TestCase):
|
||||
invalidation_flow=self.invalidation_flow,
|
||||
acs_url="https://sp.example.com/acs",
|
||||
sls_url="https://sp.example.com/sls",
|
||||
issuer_override="https://idp.example.com",
|
||||
issuer="https://idp.example.com",
|
||||
sp_binding="redirect",
|
||||
sls_binding="redirect",
|
||||
signing_kp=self.cert,
|
||||
@@ -714,7 +714,7 @@ class TestSPInitiatedSLOLogoutMethods(TestCase):
|
||||
invalidation_flow=self.invalidation_flow,
|
||||
acs_url="https://sp.example.com/acs",
|
||||
sls_url="", # No SLS URL
|
||||
issuer_override="https://idp.example.com",
|
||||
issuer="https://idp.example.com",
|
||||
)
|
||||
|
||||
app_no_sls = Application.objects.create(
|
||||
|
||||
@@ -11,12 +11,6 @@ from authentik.providers.saml.views.sp_slo import (
|
||||
)
|
||||
|
||||
urlpatterns = [
|
||||
# Base path for Issuer/Entity ID
|
||||
path(
|
||||
"<slug:application_slug>/",
|
||||
sso.SAMLSSOBindingRedirectView.as_view(),
|
||||
name="base",
|
||||
),
|
||||
# SSO Bindings
|
||||
path(
|
||||
"<slug:application_slug>/sso/binding/redirect/",
|
||||
|
||||
@@ -81,7 +81,6 @@ class SAMLFlowFinalView(ChallengeStageView):
|
||||
"session": auth_session,
|
||||
"name_id": processor.name_id,
|
||||
"name_id_format": processor.name_id_format,
|
||||
"issuer": processor.issuer,
|
||||
"expires": processor.session_not_on_or_after_datetime,
|
||||
"expiring": True,
|
||||
},
|
||||
|
||||
@@ -107,25 +107,12 @@ class SPInitiatedSLOView(PolicyAccessView):
|
||||
# Store relay state for the logout response
|
||||
plan.context[PLAN_CONTEXT_SAML_RELAY_STATE] = relay_state
|
||||
|
||||
# Look up the session issuer to use in the logout response
|
||||
auth_session = AuthenticatedSession.from_request(request, request.user)
|
||||
session_issuer = None
|
||||
if auth_session:
|
||||
saml_session = SAMLSession.objects.filter(
|
||||
session=auth_session,
|
||||
user=request.user,
|
||||
provider=self.provider,
|
||||
).first()
|
||||
if saml_session:
|
||||
session_issuer = saml_session.issuer
|
||||
|
||||
if self.provider.logout_method == SAMLLogoutMethods.FRONTCHANNEL_NATIVE:
|
||||
# Native mode - user will be redirected/posted away from authentik
|
||||
processor = LogoutResponseProcessor(
|
||||
self.provider,
|
||||
logout_request,
|
||||
destination=self.provider.sls_url,
|
||||
issuer=session_issuer,
|
||||
)
|
||||
|
||||
if self.provider.sls_binding == SAMLBindings.POST:
|
||||
@@ -165,7 +152,6 @@ class SPInitiatedSLOView(PolicyAccessView):
|
||||
sls_url=self.provider.sls_url,
|
||||
logout_request_id=logout_request.id if logout_request else None,
|
||||
relay_state=relay_state,
|
||||
issuer=session_issuer,
|
||||
)
|
||||
|
||||
LOGGER.debug(
|
||||
@@ -182,7 +168,6 @@ class SPInitiatedSLOView(PolicyAccessView):
|
||||
self.provider,
|
||||
logout_request,
|
||||
destination=self.provider.sls_url,
|
||||
issuer=session_issuer,
|
||||
)
|
||||
|
||||
logout_response = processor.build_response()
|
||||
|
||||
@@ -97,9 +97,6 @@ class SCIMClient[TModel: "Model", TConnection: "Model", TSchema: "BaseModel"](
|
||||
if cached_config is not None:
|
||||
return cached_config
|
||||
|
||||
if self.provider.compatibility_mode == SCIMCompatibilityMode.VCENTER:
|
||||
return default_config
|
||||
|
||||
# Attempt to fetch from remote
|
||||
path = "/ServiceProviderConfig"
|
||||
if self.provider.compatibility_mode == SCIMCompatibilityMode.SALESFORCE:
|
||||
|
||||
@@ -94,7 +94,6 @@ class Migration(migrations.Migration):
|
||||
("slack", "Slack"),
|
||||
("sfdc", "Salesforce"),
|
||||
("webex", "Webex"),
|
||||
("vcenter", "vCenter"),
|
||||
],
|
||||
default="default",
|
||||
help_text="Alter authentik behavior for vendor-specific SCIM implementations.",
|
||||
|
||||
@@ -83,7 +83,6 @@ class SCIMCompatibilityMode(models.TextChoices):
|
||||
SLACK = "slack", _("Slack")
|
||||
SALESFORCE = "sfdc", _("Salesforce")
|
||||
WEBEX = "webex", _("Webex")
|
||||
VCENTER = "vcenter", _("vCenter")
|
||||
|
||||
|
||||
class SCIMProvider(OutgoingSyncProvider, BackchannelProvider):
|
||||
|
||||
@@ -94,7 +94,7 @@ class OAuthCallback(OAuthClientMixin, View):
|
||||
def get_user_id(self, info: dict[str, Any]) -> str | None:
|
||||
"""Return unique identifier from the profile info."""
|
||||
if "id" in info:
|
||||
return str(info["id"])
|
||||
return info["id"]
|
||||
return None
|
||||
|
||||
def handle_login_failure(self, reason: str) -> HttpResponse:
|
||||
|
||||
@@ -353,7 +353,7 @@ class IdentificationStageView(ChallengeStageView):
|
||||
PLAN_CONTEXT_APPLICATION, Application()
|
||||
)
|
||||
challenge.initial_data["application_pre"] = app.name
|
||||
if not app.meta_hide and (launch_url := app.get_launch_url()):
|
||||
if launch_url := app.get_launch_url():
|
||||
challenge.initial_data["application_pre_launch"] = launch_url
|
||||
if (
|
||||
PLAN_CONTEXT_DEVICE in self.executor.plan.context
|
||||
|
||||
@@ -5215,11 +5215,6 @@
|
||||
"type": "string",
|
||||
"title": "Group"
|
||||
},
|
||||
"meta_hide": {
|
||||
"type": "boolean",
|
||||
"title": "Meta hide",
|
||||
"description": "Hide this application from the user's My applications page."
|
||||
},
|
||||
"icon": {
|
||||
"type": "string",
|
||||
"minLength": 1,
|
||||
@@ -5537,14 +5532,6 @@
|
||||
"minLength": 1,
|
||||
"title": "Password"
|
||||
},
|
||||
"password_hash": {
|
||||
"type": [
|
||||
"string",
|
||||
"null"
|
||||
],
|
||||
"minLength": 1,
|
||||
"title": "Password hash"
|
||||
},
|
||||
"permissions": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
@@ -9975,23 +9962,6 @@
|
||||
"title": "Client Type",
|
||||
"description": "Confidential clients are capable of maintaining the confidentiality of their credentials. Public clients are incapable"
|
||||
},
|
||||
"grant_types": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "string",
|
||||
"enum": [
|
||||
"authorization_code",
|
||||
"implicit",
|
||||
"hybrid",
|
||||
"refresh_token",
|
||||
"client_credentials",
|
||||
"password",
|
||||
"urn:ietf:params:oauth:grant-type:device_code"
|
||||
],
|
||||
"title": "Grant types"
|
||||
},
|
||||
"title": "Grant types"
|
||||
},
|
||||
"client_id": {
|
||||
"type": "string",
|
||||
"maxLength": 255,
|
||||
@@ -10825,10 +10795,11 @@
|
||||
"title": "Audience",
|
||||
"description": "Value of the audience restriction field of the assertion. When left empty, no audience restriction will be added."
|
||||
},
|
||||
"issuer_override": {
|
||||
"issuer": {
|
||||
"type": "string",
|
||||
"title": "Issuer override",
|
||||
"description": "Also known as EntityID. Providing a value overrides the default issuer generated by authentik."
|
||||
"minLength": 1,
|
||||
"title": "Issuer",
|
||||
"description": "Also known as EntityID"
|
||||
},
|
||||
"assertion_valid_not_before": {
|
||||
"type": "string",
|
||||
@@ -11111,8 +11082,7 @@
|
||||
"aws",
|
||||
"slack",
|
||||
"sfdc",
|
||||
"webex",
|
||||
"vcenter"
|
||||
"webex"
|
||||
],
|
||||
"title": "SCIM Compatibility Mode",
|
||||
"description": "Alter authentik behavior for vendor-specific SCIM implementations."
|
||||
|
||||
@@ -11,7 +11,6 @@ context:
|
||||
group_name: authentik Admins
|
||||
email: !Env [AUTHENTIK_BOOTSTRAP_EMAIL, "root@example.com"]
|
||||
password: !Env [AUTHENTIK_BOOTSTRAP_PASSWORD, null]
|
||||
password_hash: !Env [AUTHENTIK_BOOTSTRAP_PASSWORD_HASH, null]
|
||||
token: !Env [AUTHENTIK_BOOTSTRAP_TOKEN, null]
|
||||
entries:
|
||||
- model: authentik_core.group
|
||||
@@ -32,7 +31,6 @@ entries:
|
||||
groups:
|
||||
- !KeyOf admin-group
|
||||
password: !Context password
|
||||
password_hash: !Context password_hash
|
||||
- model: authentik_core.token
|
||||
state: created
|
||||
conditions:
|
||||
|
||||
@@ -75,10 +75,6 @@ entries:
|
||||
url: https://localhost:8443/test/a/authentik/callback
|
||||
- matching_mode: strict
|
||||
url: https://host.docker.internal:8443/test/a/authentik/callback
|
||||
grant_types:
|
||||
- authorization_code
|
||||
- implicit
|
||||
- refresh_token
|
||||
property_mappings:
|
||||
- !Find [authentik_providers_oauth2.scopemapping, [managed, goauthentik.io/providers/oauth2/scope-openid]]
|
||||
- !Find [authentik_providers_oauth2.scopemapping, [managed, goauthentik.io/providers/oauth2/scope-email]]
|
||||
@@ -110,10 +106,6 @@ entries:
|
||||
url: https://localhost:8443/test/a/authentik/callback
|
||||
- matching_mode: strict
|
||||
url: https://host.docker.internal:8443/test/a/authentik/callback
|
||||
grant_types:
|
||||
- authorization_code
|
||||
- implicit
|
||||
- refresh_token
|
||||
property_mappings:
|
||||
- !Find [authentik_providers_oauth2.scopemapping, [managed, goauthentik.io/providers/oauth2/scope-openid]]
|
||||
- !Find [authentik_providers_oauth2.scopemapping, [managed, goauthentik.io/providers/oauth2/scope-email]]
|
||||
|
||||
6
build.rs
6
build.rs
@@ -1,6 +0,0 @@
|
||||
fn main() {
|
||||
#[cfg(feature = "core")]
|
||||
{
|
||||
pyo3_build_config::add_libpython_rpath_link_args();
|
||||
}
|
||||
}
|
||||
2
go.mod
2
go.mod
@@ -7,7 +7,7 @@ require (
|
||||
beryju.io/radius-eap v0.1.0
|
||||
github.com/avast/retry-go/v4 v4.7.0
|
||||
github.com/coreos/go-oidc/v3 v3.18.0
|
||||
github.com/getsentry/sentry-go v0.46.0
|
||||
github.com/getsentry/sentry-go v0.45.1
|
||||
github.com/go-http-utils/etag v0.0.0-20161124023236-513ea8f21eb1
|
||||
github.com/go-ldap/ldap/v3 v3.4.13
|
||||
github.com/go-openapi/runtime v0.29.4
|
||||
|
||||
4
go.sum
4
go.sum
@@ -20,8 +20,8 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/felixge/httpsnoop v1.0.3 h1:s/nj+GCswXYzN5v2DpNMuMQYe+0DDwt5WVCU6CWBdXk=
|
||||
github.com/felixge/httpsnoop v1.0.3/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U=
|
||||
github.com/getsentry/sentry-go v0.46.0 h1:mbdDaarbUdOt9X+dx6kDdntkShLEX3/+KyOsVDTPDj0=
|
||||
github.com/getsentry/sentry-go v0.46.0/go.mod h1:evVbw2qotNUdYG8KxXbAdjOQWWvWIwKxpjdZZIvcIPw=
|
||||
github.com/getsentry/sentry-go v0.45.1 h1:9rfzJtGiJG+MGIaWZXidDGHcH5GU1Z5y0WVJGf9nysw=
|
||||
github.com/getsentry/sentry-go v0.45.1/go.mod h1:XDotiNZbgf5U8bPDUAfvcFmOnMQQceESxyKaObSssW0=
|
||||
github.com/go-asn1-ber/asn1-ber v1.5.8-0.20250403174932-29230038a667 h1:BP4M0CvQ4S3TGls2FvczZtj5Re/2ZzkV9VwqPHH/3Bo=
|
||||
github.com/go-asn1-ber/asn1-ber v1.5.8-0.20250403174932-29230038a667/go.mod h1:hEBeB/ic+5LoWskz+yKT7vGhhPYkProFKoKdwZRWMe0=
|
||||
github.com/go-errors/errors v1.4.2 h1:J6MZopCL4uSllY1OfXM374weqZFFItUbrImctkmUxIA=
|
||||
|
||||
@@ -29,7 +29,7 @@ RUN npm run build && \
|
||||
npm run build:sfe
|
||||
|
||||
# Stage: Build go proxy
|
||||
FROM --platform=${BUILDPLATFORM} docker.io/library/golang:1.26.2-trixie@sha256:4a7137ea573f79c86ae451ff05817ed762ef5597fcf732259e97abeb3108d873 AS go-builder
|
||||
FROM --platform=${BUILDPLATFORM} docker.io/library/golang:1.26.2-trixie@sha256:982ae929f9a74083a242c6e25d19d7d9ed78c6e97fab639a119e90707ba819e2 AS go-builder
|
||||
|
||||
ARG TARGETOS
|
||||
ARG TARGETARCH
|
||||
@@ -147,11 +147,8 @@ RUN --mount=type=bind,target=rust-toolchain.toml,src=rust-toolchain.toml \
|
||||
--mount=type=bind,target=Cargo.toml,src=Cargo.toml \
|
||||
--mount=type=bind,target=Cargo.lock,src=Cargo.lock \
|
||||
--mount=type=bind,target=.cargo/,src=.cargo/ \
|
||||
--mount=type=bind,target=build.rs,src=build.rs \
|
||||
--mount=type=bind,target=src/,src=src/ \
|
||||
--mount=type=bind,target=packages/ak-axum,src=packages/ak-axum \
|
||||
--mount=type=bind,target=packages/ak-common,src=packages/ak-common \
|
||||
--mount=type=bind,target=packages/client-rust,src=packages/client-rust \
|
||||
--mount=type=bind,target=packages/,src=packages/ \
|
||||
--mount=type=bind,target=authentik/lib/default.yml,src=authentik/lib/default.yml \
|
||||
# Required otherwise workspace discovery fails
|
||||
--mount=type=bind,target=website/scripts/docsmg/,src=website/scripts/docsmg/ \
|
||||
@@ -194,10 +191,7 @@ COPY --from=rust-toolchain /root/.cargo /root/.cargo
|
||||
ENV UV_NO_BINARY_PACKAGE="cryptography lxml python-kadmin-rs xmlsec"
|
||||
RUN --mount=type=bind,target=pyproject.toml,src=pyproject.toml \
|
||||
--mount=type=bind,target=uv.lock,src=uv.lock \
|
||||
--mount=type=bind,target=packages/ak-guardian,src=packages/ak-guardian \
|
||||
--mount=type=bind,target=packages/django-channels-postgres,src=packages/django-channels-postgres \
|
||||
--mount=type=bind,target=packages/django-dramatiq-postgres,src=packages/django-dramatiq-postgres \
|
||||
--mount=type=bind,target=packages/django-postgres-cache,src=packages/django-postgres-cache \
|
||||
--mount=type=bind,target=packages,src=packages \
|
||||
--mount=type=bind,target=rust-toolchain.toml,src=rust-toolchain.toml \
|
||||
--mount=type=cache,id=uv-python-deps-$TARGETARCH$TARGETVARIANT,target=/root/.cache/uv \
|
||||
uv sync --frozen --no-install-project --no-dev
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
# syntax=docker/dockerfile:1
|
||||
|
||||
# Stage 1: Build
|
||||
FROM --platform=${BUILDPLATFORM} docker.io/library/golang:1.26.2-trixie@sha256:4a7137ea573f79c86ae451ff05817ed762ef5597fcf732259e97abeb3108d873 AS builder
|
||||
FROM --platform=${BUILDPLATFORM} docker.io/library/golang:1.26.2-trixie@sha256:982ae929f9a74083a242c6e25d19d7d9ed78c6e97fab639a119e90707ba819e2 AS builder
|
||||
|
||||
ARG TARGETOS
|
||||
ARG TARGETARCH
|
||||
|
||||
@@ -21,7 +21,7 @@ COPY web .
|
||||
RUN npm run build-proxy
|
||||
|
||||
# Stage 2: Build
|
||||
FROM --platform=${BUILDPLATFORM} docker.io/library/golang:1.26.2-trixie@sha256:4a7137ea573f79c86ae451ff05817ed762ef5597fcf732259e97abeb3108d873 AS builder
|
||||
FROM --platform=${BUILDPLATFORM} docker.io/library/golang:1.26.2-trixie@sha256:982ae929f9a74083a242c6e25d19d7d9ed78c6e97fab639a119e90707ba819e2 AS builder
|
||||
|
||||
ARG TARGETOS
|
||||
ARG TARGETARCH
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
# syntax=docker/dockerfile:1
|
||||
|
||||
# Stage 1: Build
|
||||
FROM --platform=${BUILDPLATFORM} docker.io/library/golang:1.26.2-trixie@sha256:4a7137ea573f79c86ae451ff05817ed762ef5597fcf732259e97abeb3108d873 AS builder
|
||||
FROM --platform=${BUILDPLATFORM} docker.io/library/golang:1.26.2-trixie@sha256:982ae929f9a74083a242c6e25d19d7d9ed78c6e97fab639a119e90707ba819e2 AS builder
|
||||
|
||||
ARG TARGETOS
|
||||
ARG TARGETARCH
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
# syntax=docker/dockerfile:1
|
||||
|
||||
# Stage 1: Build
|
||||
FROM --platform=${BUILDPLATFORM} docker.io/library/golang:1.26.2-trixie@sha256:4a7137ea573f79c86ae451ff05817ed762ef5597fcf732259e97abeb3108d873 AS builder
|
||||
FROM --platform=${BUILDPLATFORM} docker.io/library/golang:1.26.2-trixie@sha256:982ae929f9a74083a242c6e25d19d7d9ed78c6e97fab639a119e90707ba819e2 AS builder
|
||||
|
||||
ARG TARGETOS
|
||||
ARG TARGETARCH
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import faulthandler
|
||||
import os
|
||||
import random
|
||||
import signal
|
||||
@@ -77,12 +76,6 @@ def main(worker_id: int, socket_path: str):
|
||||
signal.signal(signal.SIGINT, immediate_shutdown)
|
||||
signal.signal(signal.SIGQUIT, immediate_shutdown)
|
||||
signal.signal(signal.SIGTERM, graceful_shutdown)
|
||||
# SIGUSR1 dumps every thread's traceback to stderr. Without this, the default
|
||||
# action is "terminate", which kills the worker (and trips the Rust supervisor).
|
||||
# Side-benefit: signal delivery wakes the eval loop, so `pdb -p` can attach to
|
||||
# an otherwise-idle worker parked in a C-level syscall.
|
||||
faulthandler.enable()
|
||||
faulthandler.register(signal.SIGUSR1)
|
||||
|
||||
random.seed()
|
||||
|
||||
@@ -104,11 +97,7 @@ def main(worker_id: int, socket_path: str):
|
||||
# Notify rust process that we are ready
|
||||
os.kill(os.getppid(), signal.SIGUSR2)
|
||||
|
||||
# Poll instead of waiting indefinitely so the main thread's eval loop ticks
|
||||
# periodically — PEP 768's debugger pending hook is serviced on the main
|
||||
# thread, and a permanent Event.wait() never returns to bytecode execution.
|
||||
while not shutdown.wait(timeout=1.0):
|
||||
pass
|
||||
shutdown.wait()
|
||||
|
||||
logger.info("Shutting down worker...")
|
||||
|
||||
|
||||
Binary file not shown.
@@ -8,7 +8,7 @@ msgid ""
|
||||
msgstr ""
|
||||
"Project-Id-Version: PACKAGE VERSION\n"
|
||||
"Report-Msgid-Bugs-To: \n"
|
||||
"POT-Creation-Date: 2026-04-29 00:28+0000\n"
|
||||
"POT-Creation-Date: 2026-04-23 00:25+0000\n"
|
||||
"PO-Revision-Date: YEAR-MO-DA HO:MI+ZONE\n"
|
||||
"Last-Translator: FULL NAME <EMAIL@ADDRESS>\n"
|
||||
"Language-Team: LANGUAGE <LL@li.org>\n"
|
||||
@@ -392,10 +392,6 @@ msgstr ""
|
||||
msgid "Open launch URL in a new browser tab or window."
|
||||
msgstr ""
|
||||
|
||||
#: authentik/core/models.py
|
||||
msgid "Hide this application from the user's My applications page."
|
||||
msgstr ""
|
||||
|
||||
#: authentik/core/models.py
|
||||
msgid "Application"
|
||||
msgstr ""
|
||||
@@ -2495,9 +2491,7 @@ msgid ""
|
||||
msgstr ""
|
||||
|
||||
#: authentik/providers/saml/models.py
|
||||
msgid ""
|
||||
"Also known as EntityID. Providing a value overrides the default issuer "
|
||||
"generated by authentik."
|
||||
msgid "Also known as EntityID"
|
||||
msgstr ""
|
||||
|
||||
#: authentik/providers/saml/models.py
|
||||
@@ -2691,10 +2685,6 @@ msgstr ""
|
||||
msgid "SAML NameID format"
|
||||
msgstr ""
|
||||
|
||||
#: authentik/providers/saml/models.py
|
||||
msgid "SAML Issuer used for this session"
|
||||
msgstr ""
|
||||
|
||||
#: authentik/providers/saml/models.py
|
||||
msgid "SAML Session"
|
||||
msgstr ""
|
||||
@@ -2727,10 +2717,6 @@ msgstr ""
|
||||
msgid "Webex"
|
||||
msgstr ""
|
||||
|
||||
#: authentik/providers/scim/models.py
|
||||
msgid "vCenter"
|
||||
msgstr ""
|
||||
|
||||
#: authentik/providers/scim/models.py
|
||||
msgid "Group filters used to define sync-scope for groups."
|
||||
msgstr ""
|
||||
|
||||
Binary file not shown.
File diff suppressed because it is too large
Load Diff
@@ -18,7 +18,6 @@ durstr.workspace = true
|
||||
eyre.workspace = true
|
||||
forwarded-header-value.workspace = true
|
||||
futures.workspace = true
|
||||
pin-project-lite.workspace = true
|
||||
tokio-rustls.workspace = true
|
||||
tokio.workspace = true
|
||||
tower-http.workspace = true
|
||||
|
||||
@@ -1,737 +0,0 @@
|
||||
//! axum-server acceptor that catches panics and shuts down the application.
|
||||
|
||||
use std::{
|
||||
any::Any,
|
||||
io::{self, IoSlice},
|
||||
panic::{AssertUnwindSafe, catch_unwind, resume_unwind},
|
||||
task::{Context, Poll},
|
||||
};
|
||||
|
||||
use ak_common::Arbiter;
|
||||
use axum_server::accept::Accept;
|
||||
use futures::{FutureExt as _, future::BoxFuture};
|
||||
use pin_project_lite::pin_project;
|
||||
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
|
||||
use tower::Service;
|
||||
use tracing::{error, instrument};
|
||||
|
||||
fn extract_panic_msg<'a>(panic: &'a Box<dyn Any + Send + 'static>) -> &'a str {
|
||||
panic
|
||||
.downcast_ref::<&str>()
|
||||
.copied()
|
||||
.or_else(|| panic.downcast_ref::<String>().map(String::as_str))
|
||||
.unwrap_or("unknown panic message")
|
||||
}
|
||||
|
||||
/// Acceptor catching panics from the underlying acceptor.
|
||||
///
|
||||
/// Also wraps the stream and service to catch panics.
|
||||
#[derive(Clone)]
|
||||
pub(crate) struct CatchPanicAcceptor<A> {
|
||||
inner: A,
|
||||
arbiter: Arbiter,
|
||||
}
|
||||
|
||||
impl<A> CatchPanicAcceptor<A> {
|
||||
pub(crate) fn new(inner: A, arbiter: Arbiter) -> Self {
|
||||
Self { inner, arbiter }
|
||||
}
|
||||
}
|
||||
|
||||
impl<A, I, S> Accept<I, S> for CatchPanicAcceptor<A>
|
||||
where
|
||||
A: Accept<I, S> + Clone + Send + 'static,
|
||||
A::Stream: AsyncRead + AsyncWrite + Send,
|
||||
A::Service: Send,
|
||||
A::Future: Send,
|
||||
I: AsyncRead + AsyncWrite + Unpin + Send + 'static,
|
||||
S: Send + 'static,
|
||||
{
|
||||
type Future = BoxFuture<'static, io::Result<(Self::Stream, Self::Service)>>;
|
||||
type Service = CatchPanicService<A::Service>;
|
||||
type Stream = CatchPanicStream<A::Stream>;
|
||||
|
||||
#[instrument(skip_all)]
|
||||
fn accept(&self, stream: I, service: S) -> Self::Future {
|
||||
let acceptor = self.inner.clone();
|
||||
let arbiter = self.arbiter.clone();
|
||||
|
||||
Box::pin(async move {
|
||||
match AssertUnwindSafe(acceptor.accept(stream, service))
|
||||
.catch_unwind()
|
||||
.await
|
||||
{
|
||||
Ok(result) => {
|
||||
let (stream, service) = result?;
|
||||
Ok((
|
||||
CatchPanicStream::new(stream, arbiter.clone()),
|
||||
CatchPanicService::new(service, arbiter),
|
||||
))
|
||||
}
|
||||
Err(panic) => {
|
||||
error!(
|
||||
panic = extract_panic_msg(&panic),
|
||||
"acceptor panicked, shutting down immediately"
|
||||
);
|
||||
arbiter.do_fast_shutdown().await;
|
||||
resume_unwind(panic);
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
pin_project! {
|
||||
/// A stream wrapper that catches panics from the underlying stream.
|
||||
pub(crate) struct CatchPanicStream<S> {
|
||||
#[pin]
|
||||
inner: S,
|
||||
arbiter: Arbiter,
|
||||
}
|
||||
}
|
||||
|
||||
impl<S> CatchPanicStream<S> {
|
||||
pub(crate) fn new(inner: S, arbiter: Arbiter) -> Self {
|
||||
Self { inner, arbiter }
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: AsyncRead> AsyncRead for CatchPanicStream<S> {
|
||||
fn poll_read(
|
||||
self: std::pin::Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
buf: &mut ReadBuf<'_>,
|
||||
) -> Poll<io::Result<()>> {
|
||||
let this = self.project();
|
||||
|
||||
match catch_unwind(AssertUnwindSafe(|| this.inner.poll_read(cx, buf))) {
|
||||
Ok(result) => result,
|
||||
Err(panic) => {
|
||||
error!(
|
||||
panic = extract_panic_msg(&panic),
|
||||
"stream poll_read panicked, shutting down immediately"
|
||||
);
|
||||
let arbiter = this.arbiter.clone();
|
||||
tokio::spawn(async move { arbiter.do_fast_shutdown().await });
|
||||
resume_unwind(panic);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: AsyncWrite> AsyncWrite for CatchPanicStream<S> {
|
||||
fn poll_write(
|
||||
self: std::pin::Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
buf: &[u8],
|
||||
) -> Poll<io::Result<usize>> {
|
||||
let this = self.project();
|
||||
|
||||
match catch_unwind(AssertUnwindSafe(|| this.inner.poll_write(cx, buf))) {
|
||||
Ok(result) => result,
|
||||
Err(panic) => {
|
||||
error!(
|
||||
panic = extract_panic_msg(&panic),
|
||||
"stream poll_write panicked, shutting down immediately"
|
||||
);
|
||||
let arbiter = this.arbiter.clone();
|
||||
tokio::spawn(async move { arbiter.do_fast_shutdown().await });
|
||||
resume_unwind(panic);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn poll_write_vectored(
|
||||
self: std::pin::Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
bufs: &[IoSlice<'_>],
|
||||
) -> Poll<io::Result<usize>> {
|
||||
let this = self.project();
|
||||
|
||||
match catch_unwind(AssertUnwindSafe(|| {
|
||||
this.inner.poll_write_vectored(cx, bufs)
|
||||
})) {
|
||||
Ok(result) => result,
|
||||
Err(panic) => {
|
||||
error!(
|
||||
panic = extract_panic_msg(&panic),
|
||||
"stream poll_write_vectored panicked, shutting down immediately"
|
||||
);
|
||||
let arbiter = this.arbiter.clone();
|
||||
tokio::spawn(async move { arbiter.do_fast_shutdown().await });
|
||||
resume_unwind(panic)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn is_write_vectored(&self) -> bool {
|
||||
match catch_unwind(AssertUnwindSafe(|| self.inner.is_write_vectored())) {
|
||||
Ok(result) => result,
|
||||
Err(panic) => {
|
||||
error!(
|
||||
panic = extract_panic_msg(&panic),
|
||||
"stream is_write_vectored panicked, shutting down immediately"
|
||||
);
|
||||
let arbiter = self.arbiter.clone();
|
||||
tokio::spawn(async move { arbiter.do_fast_shutdown().await });
|
||||
resume_unwind(panic);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn poll_flush(self: std::pin::Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
|
||||
let this = self.project();
|
||||
|
||||
match catch_unwind(AssertUnwindSafe(|| this.inner.poll_flush(cx))) {
|
||||
Ok(result) => result,
|
||||
Err(panic) => {
|
||||
error!(
|
||||
panic = extract_panic_msg(&panic),
|
||||
"stream poll_flush panicked, shutting down immediately"
|
||||
);
|
||||
let arbiter = this.arbiter.clone();
|
||||
tokio::spawn(async move { arbiter.do_fast_shutdown().await });
|
||||
resume_unwind(panic);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn poll_shutdown(self: std::pin::Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
|
||||
let this = self.project();
|
||||
|
||||
match catch_unwind(AssertUnwindSafe(|| this.inner.poll_shutdown(cx))) {
|
||||
Ok(result) => result,
|
||||
Err(panic) => {
|
||||
error!(
|
||||
panic = extract_panic_msg(&panic),
|
||||
"stream poll_shutdown panicked, shutting down immediately"
|
||||
);
|
||||
let arbiter = this.arbiter.clone();
|
||||
tokio::spawn(async move { arbiter.do_fast_shutdown().await });
|
||||
resume_unwind(panic);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// A panic wrapper that catches panics from the underlying service.
|
||||
#[derive(Clone)]
|
||||
pub(crate) struct CatchPanicService<S> {
|
||||
inner: S,
|
||||
arbiter: Arbiter,
|
||||
}
|
||||
|
||||
impl<S> CatchPanicService<S> {
|
||||
pub(crate) fn new(inner: S, arbiter: Arbiter) -> Self {
|
||||
Self { inner, arbiter }
|
||||
}
|
||||
}
|
||||
|
||||
impl<S, R> Service<R> for CatchPanicService<S>
|
||||
where
|
||||
S: Service<R>,
|
||||
{
|
||||
type Error = S::Error;
|
||||
type Future = CatchPanicFuture<S::Future>;
|
||||
type Response = S::Response;
|
||||
|
||||
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
|
||||
let inner = &mut self.inner;
|
||||
|
||||
match catch_unwind(AssertUnwindSafe(|| inner.poll_ready(cx))) {
|
||||
Ok(result) => result,
|
||||
Err(panic) => {
|
||||
error!(
|
||||
panic = extract_panic_msg(&panic),
|
||||
"service poll_ready panicked, shutting down immediately"
|
||||
);
|
||||
let arbiter = self.arbiter.clone();
|
||||
tokio::spawn(async move { arbiter.do_fast_shutdown().await });
|
||||
resume_unwind(panic);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn call(&mut self, req: R) -> Self::Future {
|
||||
let inner = &mut self.inner;
|
||||
|
||||
match catch_unwind(AssertUnwindSafe(|| inner.call(req))) {
|
||||
Ok(future) => CatchPanicFuture::new(future, self.arbiter.clone()),
|
||||
Err(panic) => {
|
||||
error!(
|
||||
panic = extract_panic_msg(&panic),
|
||||
"service call panicked, shutting down immediately"
|
||||
);
|
||||
let arbiter = self.arbiter.clone();
|
||||
tokio::spawn(async move { arbiter.do_fast_shutdown().await });
|
||||
resume_unwind(panic);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pin_project! {
|
||||
/// A Future wrapper that catches panics from the inner future.
|
||||
pub(crate) struct CatchPanicFuture<F> {
|
||||
#[pin]
|
||||
inner: F,
|
||||
arbiter: Arbiter,
|
||||
}
|
||||
}
|
||||
|
||||
impl<F> CatchPanicFuture<F> {
|
||||
fn new(inner: F, arbiter: Arbiter) -> Self {
|
||||
Self { inner, arbiter }
|
||||
}
|
||||
}
|
||||
|
||||
impl<F: Future> Future for CatchPanicFuture<F> {
|
||||
type Output = F::Output;
|
||||
|
||||
fn poll(self: std::pin::Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
||||
let this = self.project();
|
||||
|
||||
match catch_unwind(AssertUnwindSafe(|| this.inner.poll(cx))) {
|
||||
Ok(result) => result,
|
||||
Err(panic) => {
|
||||
error!(
|
||||
panic = extract_panic_msg(&panic),
|
||||
"service future panicked, shutting down immediately"
|
||||
);
|
||||
let arbiter = this.arbiter.clone();
|
||||
tokio::spawn(async move { arbiter.do_fast_shutdown().await });
|
||||
resume_unwind(panic);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::{
|
||||
convert::Infallible,
|
||||
io,
|
||||
panic::{AssertUnwindSafe, panic_any},
|
||||
task::{Context, Poll},
|
||||
};
|
||||
|
||||
use ak_common::{Arbiter, Tasks};
|
||||
use axum_server::accept::Accept;
|
||||
use futures::{
|
||||
FutureExt as _,
|
||||
future::{BoxFuture, poll_fn},
|
||||
};
|
||||
use tokio::{
|
||||
io::{AsyncReadExt as _, AsyncWriteExt as _, DuplexStream, ReadBuf, duplex},
|
||||
time::{Duration, timeout},
|
||||
};
|
||||
use tower::Service;
|
||||
|
||||
use super::{CatchPanicAcceptor, CatchPanicService, CatchPanicStream};
|
||||
|
||||
fn duplex_stream() -> DuplexStream {
|
||||
let (stream, _peer) = duplex(1024);
|
||||
stream
|
||||
}
|
||||
|
||||
/// Returns `true` if the arbiter's fast-shutdown has already been triggered.
|
||||
async fn fast_shutdown_triggered(arbiter: &Arbiter) -> bool {
|
||||
timeout(Duration::from_millis(50), arbiter.fast_shutdown())
|
||||
.await
|
||||
.is_ok()
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct OkAcceptor;
|
||||
|
||||
impl<I: Send + 'static, S: Send + 'static> Accept<I, S> for OkAcceptor {
|
||||
type Future = BoxFuture<'static, io::Result<(I, S)>>;
|
||||
type Service = S;
|
||||
type Stream = I;
|
||||
|
||||
fn accept(&self, stream: I, service: S) -> Self::Future {
|
||||
Box::pin(async move { Ok((stream, service)) })
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct ErrorAcceptor;
|
||||
|
||||
impl<I: Send + 'static, S: Send + 'static> Accept<I, S> for ErrorAcceptor {
|
||||
type Future = BoxFuture<'static, io::Result<(I, S)>>;
|
||||
type Service = S;
|
||||
type Stream = I;
|
||||
|
||||
fn accept(&self, _stream: I, _service: S) -> Self::Future {
|
||||
Box::pin(async move { Err(io::Error::other("inner error")) })
|
||||
}
|
||||
}
|
||||
|
||||
/// Panics with a `&'static str` payload.
|
||||
#[derive(Clone)]
|
||||
struct PanicStrAcceptor;
|
||||
|
||||
impl<I: Send + 'static, S: Send + 'static> Accept<I, S> for PanicStrAcceptor {
|
||||
type Future = BoxFuture<'static, io::Result<(I, S)>>;
|
||||
type Service = S;
|
||||
type Stream = I;
|
||||
|
||||
fn accept(&self, _stream: I, _service: S) -> Self::Future {
|
||||
Box::pin(async move { panic!("str panic message") })
|
||||
}
|
||||
}
|
||||
|
||||
/// Panics with a `String` payload.
|
||||
#[derive(Clone)]
|
||||
struct PanicStringAcceptor;
|
||||
|
||||
impl<I: Send + 'static, S: Send + 'static> Accept<I, S> for PanicStringAcceptor {
|
||||
type Future = BoxFuture<'static, io::Result<(I, S)>>;
|
||||
type Service = S;
|
||||
type Stream = I;
|
||||
|
||||
fn accept(&self, _stream: I, _service: S) -> Self::Future {
|
||||
Box::pin(async move {
|
||||
let msg = "string panic message".to_owned();
|
||||
panic_any(msg)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// Panics with a payload that is neither `&str` nor `String`.
|
||||
#[derive(Clone)]
|
||||
struct PanicUnknownAcceptor;
|
||||
|
||||
impl<I: Send + 'static, S: Send + 'static> Accept<I, S> for PanicUnknownAcceptor {
|
||||
type Future = BoxFuture<'static, io::Result<(I, S)>>;
|
||||
type Service = S;
|
||||
type Stream = I;
|
||||
|
||||
fn accept(&self, _stream: I, _service: S) -> Self::Future {
|
||||
Box::pin(async move { panic_any(42u32) })
|
||||
}
|
||||
}
|
||||
|
||||
struct PanicStream;
|
||||
|
||||
impl tokio::io::AsyncRead for PanicStream {
|
||||
fn poll_read(
|
||||
self: std::pin::Pin<&mut Self>,
|
||||
_cx: &mut Context<'_>,
|
||||
_buf: &mut ReadBuf<'_>,
|
||||
) -> Poll<io::Result<()>> {
|
||||
panic!("poll_read panic")
|
||||
}
|
||||
}
|
||||
|
||||
impl tokio::io::AsyncWrite for PanicStream {
|
||||
fn poll_write(
|
||||
self: std::pin::Pin<&mut Self>,
|
||||
_cx: &mut Context<'_>,
|
||||
_buf: &[u8],
|
||||
) -> Poll<io::Result<usize>> {
|
||||
panic!("poll_write panic")
|
||||
}
|
||||
|
||||
fn poll_flush(
|
||||
self: std::pin::Pin<&mut Self>,
|
||||
_cx: &mut Context<'_>,
|
||||
) -> Poll<io::Result<()>> {
|
||||
panic!("poll_flush panic")
|
||||
}
|
||||
|
||||
fn poll_shutdown(
|
||||
self: std::pin::Pin<&mut Self>,
|
||||
_cx: &mut Context<'_>,
|
||||
) -> Poll<io::Result<()>> {
|
||||
panic!("poll_shutdown panic")
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct OkService;
|
||||
|
||||
impl Service<()> for OkService {
|
||||
type Error = Infallible;
|
||||
type Future = futures::future::Ready<Result<(), Infallible>>;
|
||||
type Response = ();
|
||||
|
||||
fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
|
||||
fn call(&mut self, _req: ()) -> Self::Future {
|
||||
futures::future::ready(Ok(()))
|
||||
}
|
||||
}
|
||||
|
||||
struct PanicPollReadyService;
|
||||
|
||||
impl Service<()> for PanicPollReadyService {
|
||||
type Error = Infallible;
|
||||
type Future = futures::future::Ready<Result<(), Infallible>>;
|
||||
type Response = ();
|
||||
|
||||
fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
|
||||
panic!("poll_ready panic")
|
||||
}
|
||||
|
||||
fn call(&mut self, _req: ()) -> Self::Future {
|
||||
unreachable!()
|
||||
}
|
||||
}
|
||||
|
||||
struct PanicCallBodyService;
|
||||
|
||||
impl Service<()> for PanicCallBodyService {
|
||||
type Error = Infallible;
|
||||
type Future = futures::future::Ready<Result<(), Infallible>>;
|
||||
type Response = ();
|
||||
|
||||
fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
|
||||
fn call(&mut self, _req: ()) -> Self::Future {
|
||||
panic!("call body panic")
|
||||
}
|
||||
}
|
||||
|
||||
struct PanicFuture;
|
||||
|
||||
impl Future for PanicFuture {
|
||||
type Output = Result<(), Infallible>;
|
||||
|
||||
fn poll(self: std::pin::Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Self::Output> {
|
||||
panic!("future panic")
|
||||
}
|
||||
}
|
||||
|
||||
struct PanicCallFutureService;
|
||||
|
||||
impl Service<()> for PanicCallFutureService {
|
||||
type Error = Infallible;
|
||||
type Future = PanicFuture;
|
||||
type Response = ();
|
||||
|
||||
fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
|
||||
fn call(&mut self, _req: ()) -> Self::Future {
|
||||
PanicFuture
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn acceptor_passes_through_success() {
|
||||
let tasks = Tasks::new().expect("failed to create tasks");
|
||||
let arbiter = tasks.arbiter();
|
||||
let acceptor = CatchPanicAcceptor::new(OkAcceptor, arbiter.clone());
|
||||
|
||||
let result = acceptor.accept(duplex_stream(), OkService).await;
|
||||
|
||||
assert!(result.is_ok());
|
||||
assert!(!fast_shutdown_triggered(&arbiter).await);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn acceptor_passes_through_error() {
|
||||
let tasks = Tasks::new().expect("failed to create tasks");
|
||||
let arbiter = tasks.arbiter();
|
||||
let acceptor = CatchPanicAcceptor::new(ErrorAcceptor, arbiter.clone());
|
||||
|
||||
let result = acceptor.accept(duplex_stream(), OkService).await;
|
||||
|
||||
assert!(result.is_err());
|
||||
assert_eq!(result.err().unwrap().to_string(), "inner error");
|
||||
assert!(!fast_shutdown_triggered(&arbiter).await);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn acceptor_catches_str_panic_and_shuts_down() {
|
||||
let tasks = Tasks::new().expect("failed to create tasks");
|
||||
let arbiter = tasks.arbiter();
|
||||
let acceptor = CatchPanicAcceptor::new(PanicStrAcceptor, arbiter.clone());
|
||||
|
||||
let result = AssertUnwindSafe(acceptor.accept(duplex_stream(), OkService))
|
||||
.catch_unwind()
|
||||
.await;
|
||||
|
||||
assert!(result.is_err());
|
||||
assert!(fast_shutdown_triggered(&arbiter).await);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn acceptor_catches_string_panic_and_shuts_down() {
|
||||
let tasks = Tasks::new().expect("failed to create tasks");
|
||||
let arbiter = tasks.arbiter();
|
||||
let acceptor = CatchPanicAcceptor::new(PanicStringAcceptor, arbiter.clone());
|
||||
|
||||
let result = AssertUnwindSafe(acceptor.accept(duplex_stream(), OkService))
|
||||
.catch_unwind()
|
||||
.await;
|
||||
|
||||
assert!(result.is_err());
|
||||
assert!(fast_shutdown_triggered(&arbiter).await);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn acceptor_catches_unknown_panic_and_shuts_down() {
|
||||
let tasks = Tasks::new().expect("failed to create tasks");
|
||||
let arbiter = tasks.arbiter();
|
||||
let acceptor = CatchPanicAcceptor::new(PanicUnknownAcceptor, arbiter.clone());
|
||||
|
||||
let result = AssertUnwindSafe(acceptor.accept(duplex_stream(), OkService))
|
||||
.catch_unwind()
|
||||
.await;
|
||||
|
||||
assert!(result.is_err());
|
||||
assert!(fast_shutdown_triggered(&arbiter).await);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn stream_poll_read_passes_through() {
|
||||
let tasks = Tasks::new().expect("failed to create tasks");
|
||||
let arbiter = tasks.arbiter();
|
||||
let (mut a, mut b) = duplex(1024);
|
||||
b.write_all(b"hello").await.unwrap();
|
||||
|
||||
let mut stream = CatchPanicStream::new(&mut a, arbiter.clone());
|
||||
let mut buf = [0u8; 5];
|
||||
let result = stream.read(&mut buf).await;
|
||||
|
||||
assert!(result.is_ok());
|
||||
assert_eq!(&buf, b"hello");
|
||||
assert!(!fast_shutdown_triggered(&arbiter).await);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn stream_poll_read_panic_returns_error_and_shuts_down() {
|
||||
let tasks = Tasks::new().expect("failed to create tasks");
|
||||
let arbiter = tasks.arbiter();
|
||||
let mut stream = CatchPanicStream::new(PanicStream, arbiter.clone());
|
||||
|
||||
let result = AssertUnwindSafe(stream.read(&mut [0u8; 10]))
|
||||
.catch_unwind()
|
||||
.await;
|
||||
|
||||
assert!(result.is_err());
|
||||
assert!(fast_shutdown_triggered(&arbiter).await);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn stream_poll_write_passes_through() {
|
||||
let tasks = Tasks::new().expect("failed to create tasks");
|
||||
let arbiter = tasks.arbiter();
|
||||
let (mut a, _b) = duplex(1024);
|
||||
|
||||
let mut stream = CatchPanicStream::new(&mut a, arbiter.clone());
|
||||
let result = stream.write_all(b"hello").await;
|
||||
|
||||
assert!(result.is_ok());
|
||||
assert!(!fast_shutdown_triggered(&arbiter).await);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn stream_poll_write_panic_returns_error_and_shuts_down() {
|
||||
let tasks = Tasks::new().expect("failed to create tasks");
|
||||
let arbiter = tasks.arbiter();
|
||||
let mut stream = CatchPanicStream::new(PanicStream, arbiter.clone());
|
||||
|
||||
let result = AssertUnwindSafe(stream.write(b"hello"))
|
||||
.catch_unwind()
|
||||
.await;
|
||||
|
||||
assert!(result.is_err());
|
||||
assert!(fast_shutdown_triggered(&arbiter).await);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn stream_poll_flush_panic_returns_error_and_shuts_down() {
|
||||
let tasks = Tasks::new().expect("failed to create tasks");
|
||||
let arbiter = tasks.arbiter();
|
||||
let mut stream = CatchPanicStream::new(PanicStream, arbiter.clone());
|
||||
|
||||
let result = AssertUnwindSafe(stream.flush()).catch_unwind().await;
|
||||
|
||||
assert!(result.is_err());
|
||||
assert!(fast_shutdown_triggered(&arbiter).await);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn stream_poll_shutdown_panic_returns_error_and_shuts_down() {
|
||||
let tasks = Tasks::new().expect("failed to create tasks");
|
||||
let arbiter = tasks.arbiter();
|
||||
let mut stream = CatchPanicStream::new(PanicStream, arbiter.clone());
|
||||
|
||||
let result = AssertUnwindSafe(stream.shutdown()).catch_unwind().await;
|
||||
|
||||
assert!(result.is_err());
|
||||
assert!(fast_shutdown_triggered(&arbiter).await);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn service_poll_ready_passes_through() {
|
||||
let tasks = Tasks::new().expect("failed to create tasks");
|
||||
let arbiter = tasks.arbiter();
|
||||
let mut service = CatchPanicService::new(OkService, arbiter.clone());
|
||||
|
||||
let result = poll_fn(|cx| service.poll_ready(cx)).await;
|
||||
|
||||
assert!(result.is_ok());
|
||||
assert!(!fast_shutdown_triggered(&arbiter).await);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn service_poll_ready_panic_re_panics_and_shuts_down() {
|
||||
let tasks = Tasks::new().expect("failed to create tasks");
|
||||
let arbiter = tasks.arbiter();
|
||||
let mut service = CatchPanicService::new(PanicPollReadyService, arbiter.clone());
|
||||
|
||||
let result = AssertUnwindSafe(poll_fn(|cx| service.poll_ready(cx)))
|
||||
.catch_unwind()
|
||||
.await;
|
||||
|
||||
assert!(result.is_err());
|
||||
assert!(fast_shutdown_triggered(&arbiter).await);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn service_call_passes_through() {
|
||||
let tasks = Tasks::new().expect("failed to create tasks");
|
||||
let arbiter = tasks.arbiter();
|
||||
let mut service = CatchPanicService::new(OkService, arbiter.clone());
|
||||
|
||||
let result = service.call(()).await;
|
||||
|
||||
assert!(result.is_ok());
|
||||
assert!(!fast_shutdown_triggered(&arbiter).await);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn service_call_body_panic_re_panics_and_shuts_down() {
|
||||
let tasks = Tasks::new().expect("failed to create tasks");
|
||||
let arbiter = tasks.arbiter();
|
||||
let mut service = CatchPanicService::new(PanicCallBodyService, arbiter.clone());
|
||||
|
||||
let result = AssertUnwindSafe(async { service.call(()).await })
|
||||
.catch_unwind()
|
||||
.await;
|
||||
|
||||
assert!(result.is_err());
|
||||
assert!(fast_shutdown_triggered(&arbiter).await);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn service_call_future_panic_re_panics_and_shuts_down() {
|
||||
let tasks = Tasks::new().expect("failed to create tasks");
|
||||
let arbiter = tasks.arbiter();
|
||||
let mut service = CatchPanicService::new(PanicCallFutureService, arbiter.clone());
|
||||
|
||||
let result = AssertUnwindSafe(service.call(())).catch_unwind().await;
|
||||
|
||||
assert!(result.is_err());
|
||||
assert!(fast_shutdown_triggered(&arbiter).await);
|
||||
}
|
||||
}
|
||||
@@ -1,3 +1,2 @@
|
||||
pub mod catch_panic;
|
||||
pub mod proxy_protocol;
|
||||
pub mod tls;
|
||||
|
||||
@@ -1,13 +1,7 @@
|
||||
//! Utilities for working with [`Router`].
|
||||
|
||||
use ak_common::config;
|
||||
use axum::{
|
||||
Router,
|
||||
extract::Request,
|
||||
http::{HeaderName, HeaderValue, StatusCode},
|
||||
middleware::{Next, from_fn},
|
||||
response::Response,
|
||||
};
|
||||
use axum::{Router, http::StatusCode, middleware::from_fn};
|
||||
use tower::ServiceBuilder;
|
||||
use tower_http::timeout::TimeoutLayer;
|
||||
|
||||
@@ -19,16 +13,6 @@ use crate::{
|
||||
tracing::{span_middleware, tracing_middleware},
|
||||
};
|
||||
|
||||
const X_POWERED_BY: HeaderName = HeaderName::from_static("x-powered-by");
|
||||
|
||||
async fn powered_by_authentik_middleware(request: Request, next: Next) -> Response {
|
||||
let mut response = next.run(request).await;
|
||||
response
|
||||
.headers_mut()
|
||||
.insert(X_POWERED_BY, HeaderValue::from_static("authentik"));
|
||||
response
|
||||
}
|
||||
|
||||
/// Wrap a [`Router`] with common middlewares.
|
||||
///
|
||||
/// Set `with_tracing` to [`true`] to log requests.
|
||||
@@ -46,7 +30,6 @@ pub fn wrap_router(router: Router, with_tracing: bool) -> Router {
|
||||
timeout,
|
||||
))
|
||||
.layer(from_fn(span_middleware))
|
||||
.layer(from_fn(powered_by_authentik_middleware))
|
||||
.layer(from_fn(trusted_proxy_middleware))
|
||||
.layer(from_fn(client_ip_middleware))
|
||||
.layer(from_fn(scheme_middleware))
|
||||
|
||||
@@ -12,9 +12,7 @@ use axum_server::{
|
||||
use eyre::Result;
|
||||
use tracing::{info, trace};
|
||||
|
||||
use crate::accept::{
|
||||
catch_panic::CatchPanicAcceptor, proxy_protocol::ProxyProtocolAcceptor, tls::TlsAcceptor,
|
||||
};
|
||||
use crate::accept::{proxy_protocol::ProxyProtocolAcceptor, tls::TlsAcceptor};
|
||||
|
||||
async fn run_plain(
|
||||
arbiter: Arbiter,
|
||||
@@ -29,10 +27,7 @@ async fn run_plain(
|
||||
arbiter.add_net_handle(handle.clone()).await;
|
||||
|
||||
let res = axum_server::Server::bind(addr)
|
||||
.acceptor(CatchPanicAcceptor::new(
|
||||
ProxyProtocolAcceptor::new().acceptor(DefaultAcceptor::new()),
|
||||
arbiter.clone(),
|
||||
))
|
||||
.acceptor(ProxyProtocolAcceptor::new().acceptor(DefaultAcceptor::new()))
|
||||
.handle(handle)
|
||||
.serve(router.into_make_service_with_connect_info::<net::SocketAddr>())
|
||||
.await;
|
||||
@@ -85,10 +80,7 @@ pub(crate) async fn run_unix(
|
||||
}
|
||||
}
|
||||
let res = axum_server::Server::bind(addr.clone())
|
||||
.acceptor(CatchPanicAcceptor::new(
|
||||
DefaultAcceptor::new(),
|
||||
arbiter.clone(),
|
||||
))
|
||||
.acceptor(DefaultAcceptor::new())
|
||||
.handle(handle)
|
||||
.serve(router.into_make_service())
|
||||
.await;
|
||||
@@ -141,12 +133,9 @@ async fn run_tls(
|
||||
arbiter.add_net_handle(handle.clone()).await;
|
||||
|
||||
axum_server::Server::bind(addr)
|
||||
.acceptor(CatchPanicAcceptor::new(
|
||||
ProxyProtocolAcceptor::new().acceptor(TlsAcceptor::new(
|
||||
RustlsAcceptor::new(config).acceptor(DefaultAcceptor::new()),
|
||||
)),
|
||||
arbiter.clone(),
|
||||
))
|
||||
.acceptor(ProxyProtocolAcceptor::new().acceptor(TlsAcceptor::new(
|
||||
RustlsAcceptor::new(config).acceptor(DefaultAcceptor::new()),
|
||||
)))
|
||||
.handle(handle)
|
||||
.serve(router.into_make_service_with_connect_info::<net::SocketAddr>())
|
||||
.await?;
|
||||
|
||||
@@ -235,7 +235,7 @@ impl Arbiter {
|
||||
}
|
||||
|
||||
/// Shutdown the application immediately.
|
||||
pub async fn do_fast_shutdown(&self) {
|
||||
async fn do_fast_shutdown(&self) {
|
||||
info!("arbiter has been told to shutdown immediately");
|
||||
self.unix_handles
|
||||
.lock()
|
||||
@@ -253,7 +253,7 @@ impl Arbiter {
|
||||
}
|
||||
|
||||
/// Shutdown the application gracefully.
|
||||
pub async fn do_graceful_shutdown(&self) {
|
||||
async fn do_graceful_shutdown(&self) {
|
||||
info!("arbiter has been told to shutdown gracefully");
|
||||
// Match the value in lifecycle/gunicorn.conf.py for graceful shutdown
|
||||
let timeout = Some(Duration::from_secs(30 + 5));
|
||||
|
||||
@@ -16,10 +16,7 @@ use url::Url;
|
||||
pub mod schema;
|
||||
pub use schema::Config;
|
||||
|
||||
use crate::{
|
||||
arbiter::{Arbiter, Event, Tasks},
|
||||
config::schema::KEYS_TO_PARSE_AS_LIST,
|
||||
};
|
||||
use crate::arbiter::{Arbiter, Event, Tasks};
|
||||
|
||||
static DEFAULT_CONFIG: &str = include_str!("../../../../authentik/lib/default.yml");
|
||||
static CONFIG_MANAGER: OnceLock<ConfigManager> = OnceLock::new();
|
||||
@@ -78,15 +75,11 @@ impl Config {
|
||||
config_rs::File::from(path.as_path()).format(config_rs::FileFormat::Yaml),
|
||||
);
|
||||
}
|
||||
let mut env_source = config_rs::Environment::with_prefix("AUTHENTIK")
|
||||
.prefix_separator("_")
|
||||
.separator("__")
|
||||
.try_parsing(true)
|
||||
.list_separator(",");
|
||||
for key in KEYS_TO_PARSE_AS_LIST {
|
||||
env_source = env_source.with_list_parse_key(key);
|
||||
}
|
||||
builder = builder.add_source(env_source);
|
||||
builder = builder.add_source(
|
||||
config_rs::Environment::with_prefix("AUTHENTIK")
|
||||
.prefix_separator("_")
|
||||
.separator("__"),
|
||||
);
|
||||
if let Some(overrides) = overrides {
|
||||
builder = builder.add_source(config_rs::File::from_str(
|
||||
&overrides.to_string(),
|
||||
@@ -462,92 +455,4 @@ mod tests {
|
||||
super::set(json!({"secret_key": "my_new_secret_key"})).expect("failed to set config");
|
||||
assert_eq!(super::get().secret_key, "my_new_secret_key");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn env_bool_true() {
|
||||
#[expect(unsafe_code, reason = "testing")]
|
||||
// SAFETY: testing
|
||||
unsafe {
|
||||
env::set_var("AUTHENTIK_DEBUG", "true");
|
||||
}
|
||||
|
||||
let (config, _) = super::Config::load(&[], None).expect("failed to load config");
|
||||
|
||||
assert!(config.debug);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn env_bool_false() {
|
||||
#[expect(unsafe_code, reason = "testing")]
|
||||
// SAFETY: testing
|
||||
unsafe {
|
||||
env::set_var("AUTHENTIK_DEBUG", "false");
|
||||
}
|
||||
|
||||
let (config, _) = super::Config::load(&[], None).expect("failed to load config");
|
||||
|
||||
assert!(!config.debug);
|
||||
}
|
||||
|
||||
// See https://github.com/rust-cli/config-rs/issues/443
|
||||
// #[test]
|
||||
// fn env_list_empty() {
|
||||
// #[expect(unsafe_code, reason = "testing")]
|
||||
// // SAFETY: testing
|
||||
// unsafe {
|
||||
// env::set_var("AUTHENTIK_LISTEN__HTTP", "");
|
||||
// }
|
||||
//
|
||||
// let (config, _) = super::Config::load(&[], None).expect("failed to load config");
|
||||
//
|
||||
// assert_eq!(config.listen.http, []);
|
||||
// }
|
||||
|
||||
#[test]
|
||||
fn env_list_one_element() {
|
||||
#[expect(unsafe_code, reason = "testing")]
|
||||
// SAFETY: testing
|
||||
unsafe {
|
||||
env::set_var("AUTHENTIK_LISTEN__HTTP", "[::1]:9000");
|
||||
}
|
||||
|
||||
let (config, _) = super::Config::load(&[], None).expect("failed to load config");
|
||||
|
||||
assert_eq!(
|
||||
config.listen.http,
|
||||
["[::1]:9000".parse().expect("infallible")]
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn env_list_many_elements() {
|
||||
#[expect(unsafe_code, reason = "testing")]
|
||||
// SAFETY: testing
|
||||
unsafe {
|
||||
env::set_var("AUTHENTIK_LISTEN__HTTP", "[::1]:9000,[::1]:9001");
|
||||
}
|
||||
|
||||
let (config, _) = super::Config::load(&[], None).expect("failed to load config");
|
||||
|
||||
assert_eq!(
|
||||
config.listen.http,
|
||||
[
|
||||
"[::1]:9000".parse().expect("infallible"),
|
||||
"[::1]:9001".parse().expect("infallible")
|
||||
]
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn env_string() {
|
||||
#[expect(unsafe_code, reason = "testing")]
|
||||
// SAFETY: testing
|
||||
unsafe {
|
||||
env::set_var("AUTHENTIK_SECRET_KEY", "my_secret_key");
|
||||
}
|
||||
|
||||
let (config, _) = super::Config::load(&[], None).expect("failed to load config");
|
||||
|
||||
assert_eq!(config.secret_key, "my_secret_key",);
|
||||
}
|
||||
}
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user