Compare commits

..

2 Commits

Author SHA1 Message Date
Teffen Ellis
5050a94f16 web: fix stray > in "Not you?" link and add Playwright regression for #21571
Move the closing > of the opening <a> tag so the rendered link text no longer
carries a leading > glyph. Add a browser test that seeds the identification
stage with enable_remember_me, walks the identify -> password -> "Not you?"
path, and asserts the link text, the cleared username field, and the cleared
remember-me localStorage key.
Co-Authored-By: Agent <agent@authentik-i21647-current-instant-chili.girlbossru.sh>
2026-04-27 14:10:01 +00:00
Teffen Ellis
f5dd1b62ef web: Clear remember me before navigation. 2026-04-27 11:46:38 +02:00
183 changed files with 1550 additions and 22254 deletions

View File

@@ -64,7 +64,7 @@ runs:
rustflags: ""
- name: Setup rust dependencies
if: ${{ contains(inputs.dependencies, 'rust') }}
uses: taiki-e/install-action@cf525cb33f51aca27cd6fa02034117ab963ff9f1 # v2
uses: taiki-e/install-action@5f57d6cb7cd20b14a8a27f522884c4bc8a187458 # v2
with:
tool: cargo-deny cargo-machete cargo-llvm-cov nextest
- name: Setup node (web)

View File

@@ -90,7 +90,7 @@ jobs:
registry: ghcr.io
username: ${{ github.repository_owner }}
password: ${{ secrets.GITHUB_TOKEN }}
- uses: int128/docker-manifest-create-action@7df7f9e221d927eaadf87db231ddf728047308a4 # v2
- uses: int128/docker-manifest-create-action@3de37de96c4e900bc3eef9055d3c50abdb4f769d # v2
id: build
with:
tags: ${{ matrix.tag }}

View File

@@ -14,7 +14,6 @@ pyproject.toml @goauthentik/backend
uv.lock @goauthentik/backend
Cargo.toml @goauthentik/backend
Cargo.lock @goauthentik/backend
build.rs @goauthentik/backend
go.mod @goauthentik/backend
go.sum @goauthentik/backend
.cargo/ @goauthentik/backend

10
Cargo.lock generated
View File

@@ -198,7 +198,6 @@ dependencies = [
"metrics-exporter-prometheus",
"nix 0.31.2",
"pyo3",
"pyo3-build-config",
"sqlx",
"tokio",
"tracing",
@@ -217,7 +216,6 @@ dependencies = [
"eyre",
"forwarded-header-value",
"futures",
"pin-project-lite",
"tokio",
"tokio-rustls",
"tower",
@@ -1507,9 +1505,9 @@ dependencies = [
[[package]]
name = "hyper-unix-socket"
version = "0.6.1"
version = "0.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "88978f1d73da0eb87d86555fcc40cbdd87bc86eb6525710b89db8c9278ec6a59"
checksum = "c255628da188a9d9ee218bae99da33a4b684ed63abe140a94d0f6e4b5af9a090"
dependencies = [
"bytes",
"hyper",
@@ -3000,9 +2998,9 @@ dependencies = [
[[package]]
name = "rustls"
version = "0.23.40"
version = "0.23.39"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ef86cd5876211988985292b91c96a8f2d298df24e75989a43a3c73f2d4d8168b"
checksum = "7c2c118cb077cca2822033836dfb1b975355dfb784b5e8da48f7b6c5db74e60e"
dependencies = [
"aws-lc-rs",
"log",

View File

@@ -39,7 +39,7 @@ eyre = "= 0.6.12"
forwarded-header-value = "= 0.1.1"
futures = "= 0.3.32"
glob = "= 0.3.3"
hyper-unix-socket = "= 0.6.1"
hyper-unix-socket = "= 0.3.0"
hyper-util = "= 0.1.20"
ipnet = { version = "= 2.12.0", features = ["serde"] }
json-subscriber = "= 0.2.8"
@@ -49,7 +49,6 @@ nix = { version = "= 0.31.2", features = ["hostname", "signal"] }
notify = "= 8.2.0"
pin-project-lite = "= 0.2.17"
pyo3 = "= 0.28.3"
pyo3-build-config = "= 0.28.3"
regex = "= 1.12.3"
reqwest = { version = "= 0.13.2", features = [
"form",
@@ -66,7 +65,7 @@ reqwest-middleware = { version = "= 0.5.1", features = [
"query",
"rustls",
] }
rustls = { version = "= 0.23.40", features = ["fips"] }
rustls = { version = "= 0.23.39", features = ["fips"] }
sentry = { version = "= 0.47.0", default-features = false, features = [
"backtrace",
"contexts",
@@ -261,9 +260,6 @@ default = ["core", "proxy"]
core = ["ak-common/core", "dep:pyo3", "dep:sqlx"]
proxy = ["ak-common/proxy"]
[build-dependencies]
pyo3-build-config.workspace = true
[dependencies]
ak-axum.workspace = true
ak-common.workspace = true

View File

@@ -118,9 +118,6 @@ run-worker: ## Run the main authentik worker process
run-worker-watch: ## Run the authentik worker, with auto reloading
watchexec --on-busy-update=restart --stop-signal=SIGINT --exts py,rs --no-meta --notify -- $(UV) run ak worker
debug-attach: ## Attach pdb to a running authentik Python worker (PEP 768). PID=<pid> to pick; SUDO=1 on macOS.
$(UV) run python scripts/debug_attach.py
core-i18n-extract:
$(UV) run ak makemessages \
--add-location file \

View File

@@ -5,7 +5,6 @@ from django.utils.translation import gettext_lazy as _
GRANT_TYPE_AUTHORIZATION_CODE = "authorization_code"
GRANT_TYPE_IMPLICIT = "implicit"
GRANT_TYPE_HYBRID = "hybrid"
GRANT_TYPE_REFRESH_TOKEN = "refresh_token" # nosec
GRANT_TYPE_CLIENT_CREDENTIALS = "client_credentials"
GRANT_TYPE_PASSWORD = "password" # nosec

View File

@@ -30,8 +30,6 @@ SAML_BINDING_REDIRECT = "urn:oasis:names:tc:SAML:2.0:bindings:HTTP-Redirect"
SAML_STATUS_SUCCESS = "urn:oasis:names:tc:SAML:2.0:status:Success"
DEFAULT_ISSUER = "authentik"
DSA_SHA1 = "http://www.w3.org/2000/09/xmldsig#dsa-sha1"
RSA_SHA1 = "http://www.w3.org/2000/09/xmldsig#rsa-sha1"
# https://datatracker.ietf.org/doc/html/rfc4051#section-2.3.2

View File

@@ -4,7 +4,7 @@ from collections.abc import Iterator
from copy import copy
from django.core.cache import cache
from django.db.models import Case, QuerySet
from django.db.models import Case, Q, QuerySet
from django.db.models.expressions import When
from django.shortcuts import get_object_or_404
from django.utils.translation import gettext as _
@@ -120,7 +120,6 @@ class ApplicationSerializer(ModelSerializer):
"meta_publisher",
"policy_engine_mode",
"group",
"meta_hide",
]
extra_kwargs = {
"backchannel_providers": {"required": False},
@@ -284,12 +283,14 @@ class ApplicationViewSet(UsedByMixin, ModelViewSet):
) == "true"
queryset = self._filter_queryset_for_list(self.get_queryset())
queryset = queryset.exclude(meta_hide=True)
if only_with_launch_url:
# Pre-filter at DB level to skip expensive per-app policy evaluation
# for apps that can never appear in the launcher (no meta_launch_url
# and no provider, so no possible launch URL).
queryset = queryset.exclude(meta_launch_url="", provider__isnull=True)
# for apps that can never appear in the launcher:
# - No meta_launch_url AND no provider: no possible launch URL
# - meta_launch_url="blank://blank": documented convention to hide from launcher
queryset = queryset.exclude(
Q(meta_launch_url="", provider__isnull=True) | Q(meta_launch_url="blank://blank")
)
paginator: Pagination = self.paginator
paginated_apps = paginator.paginate_queryset(queryset, request)

View File

@@ -19,7 +19,7 @@ from rest_framework.authentication import SessionAuthentication
from rest_framework.decorators import action
from rest_framework.fields import CharField, IntegerField, SerializerMethodField
from rest_framework.permissions import IsAuthenticated
from rest_framework.relations import ManyRelatedField, PrimaryKeyRelatedField
from rest_framework.relations import PrimaryKeyRelatedField
from rest_framework.request import Request
from rest_framework.response import Response
from rest_framework.serializers import ListSerializer, ValidationError
@@ -37,77 +37,6 @@ from authentik.endpoints.connectors.agent.auth import AgentAuth
from authentik.rbac.api.roles import RoleSerializer
from authentik.rbac.decorators import permission_required
class BulkManyRelatedField(ManyRelatedField):
"""ManyRelatedField that validates all PKs in a single query instead of one per PK."""
def to_internal_value(self, data):
if isinstance(data, str) or not hasattr(data, "__iter__"):
self.fail("not_a_list", input_type=type(data).__name__)
if not self.allow_empty and len(data) == 0:
self.fail("empty")
child = self.child_relation
pk_field = child.pk_field
# Coerce PKs through pk_field if defined
pk_map = {}
for item in data:
if isinstance(item, bool):
self.fail("incorrect_type", data_type=type(item).__name__)
pk = pk_field.to_internal_value(item) if pk_field else item
pk_map[pk] = item # map coerced PK -> original value for error reporting
queryset = child.get_queryset()
# Use count to validate all PKs exist in a single query
found_count = queryset.filter(pk__in=pk_map.keys()).count()
if found_count < len(pk_map):
# Some PKs not found — fall back to per-PK checks for error reporting.
# This only runs when there's an actual validation error (rare path).
for pk, original in pk_map.items():
if not queryset.filter(pk=pk).exists():
child.fail("does_not_exist", pk_value=original)
# Return raw PKs — Django's M2M set() accepts both objects and PKs,
# using get_prep_value() for type coercion. This avoids loading all
# objects into memory and avoids triggering post_init signals.
return list(pk_map.keys())
def to_representation(self, iterable):
# For non-prefetched querysets, get PKs directly without loading model instances.
# When prefetched, _result_cache is a list (possibly empty); when not, it's None.
if hasattr(iterable, "values_list") and getattr(iterable, "_result_cache", None) is None:
return list(iterable.values_list("pk", flat=True))
return super().to_representation(iterable)
class BulkPrimaryKeyRelatedField(PrimaryKeyRelatedField):
"""PrimaryKeyRelatedField that uses bulk validation when many=True."""
@classmethod
def many_init(cls, *args, **kwargs):
allow_empty = kwargs.pop("allow_empty", None)
max_length = kwargs.pop("max_length", None)
min_length = kwargs.pop("min_length", None)
child_relation = cls(*args, **kwargs)
list_kwargs = {
"child_relation": child_relation,
}
if allow_empty is not None:
list_kwargs["allow_empty"] = allow_empty
if max_length is not None:
list_kwargs["max_length"] = max_length
if min_length is not None:
list_kwargs["min_length"] = min_length
list_kwargs.update(
{
key: value
for key, value in kwargs.items()
if key in ("required", "default", "source")
}
)
return BulkManyRelatedField(**list_kwargs)
PARTIAL_USER_SERIALIZER_MODEL_FIELDS = [
"pk",
"username",
@@ -150,7 +79,6 @@ class GroupSerializer(ModelSerializer):
"""Group Serializer"""
attributes = JSONDictField(required=False)
users = BulkPrimaryKeyRelatedField(queryset=User.objects.all(), many=True, default=list)
parents = PrimaryKeyRelatedField(queryset=Group.objects.all(), many=True, required=False)
parents_obj = SerializerMethodField(allow_null=True)
children_obj = SerializerMethodField(allow_null=True)
@@ -265,6 +193,9 @@ class GroupSerializer(ModelSerializer):
"children_obj",
]
extra_kwargs = {
"users": {
"default": list,
},
"children": {
"required": False,
"default": list,
@@ -294,7 +225,6 @@ class GroupFilter(FilterSet):
members_by_pk = ModelMultipleChoiceFilter(
field_name="users",
queryset=User.objects.all(),
distinct=False,
)
def filter_attributes(self, queryset, name, value):
@@ -346,8 +276,7 @@ class GroupViewSet(UsedByMixin, ModelViewSet):
]
def get_queryset(self):
# Always prefetch parents and children since their PKs are always serialized
base_qs = Group.objects.all().prefetch_related("roles", "parents", "children")
base_qs = Group.objects.all().prefetch_related("roles")
if self.serializer_class(context={"request": self.request})._should_include_users:
# Only fetch fields needed by PartialUserSerializer to reduce DB load and instantiation
@@ -358,9 +287,16 @@ class GroupViewSet(UsedByMixin, ModelViewSet):
queryset=User.objects.all().only(*PARTIAL_USER_SERIALIZER_MODEL_FIELDS),
)
)
# When include_users=false, skip users prefetch entirely.
# BulkManyRelatedField.to_representation will use values_list to get PKs
# directly without loading User instances into memory.
else:
base_qs = base_qs.prefetch_related(
Prefetch("users", queryset=User.objects.all().only("id"))
)
if self.serializer_class(context={"request": self.request})._should_include_children:
base_qs = base_qs.prefetch_related("children")
if self.serializer_class(context={"request": self.request})._should_include_parents:
base_qs = base_qs.prefetch_related("parents")
return base_qs

View File

@@ -6,7 +6,6 @@ from typing import Any
from django.contrib.auth import update_session_auth_hash
from django.contrib.auth.models import AnonymousUser, Permission
from django.db.models import Exists, OuterRef, Prefetch, Q
from django.db.transaction import atomic
from django.db.utils import IntegrityError
from django.urls import reverse_lazy
@@ -14,7 +13,6 @@ from django.utils.http import urlencode
from django.utils.text import slugify
from django.utils.timezone import now
from django.utils.translation import gettext as _
from django.utils.translation import gettext_lazy
from django_filters.filters import (
BooleanFilter,
CharFilter,
@@ -107,10 +105,6 @@ from authentik.stages.email.utils import TemplateEmailMessage
LOGGER = get_logger()
INVALID_PASSWORD_HASH_MESSAGE = gettext_lazy(
"Invalid password hash format. Must be a valid Django password hash."
)
class ParamUserSerializer(PassiveSerializer):
"""Partial serializer for query parameters to select a user"""
@@ -137,7 +131,7 @@ class PartialGroupSerializer(ModelSerializer):
class UserSerializer(ModelSerializer):
"""User Serializer"""
is_superuser = SerializerMethodField()
is_superuser = BooleanField(read_only=True)
avatar = SerializerMethodField()
attributes = JSONDictField(required=False)
groups = PrimaryKeyRelatedField(
@@ -174,14 +168,6 @@ class UserSerializer(ModelSerializer):
return True
return str(request.query_params.get("include_roles", "true")).lower() == "true"
@extend_schema_field(BooleanField)
def get_is_superuser(self, instance: User) -> bool:
"""Use annotation if available to avoid N+1 query"""
ann = getattr(instance, "_annotated_is_superuser", None)
if ann is not None:
return ann
return instance.is_superuser
@extend_schema_field(PartialGroupSerializer(many=True))
def get_groups_obj(self, instance: User) -> list[PartialGroupSerializer] | None:
if not self._should_include_groups:
@@ -195,79 +181,47 @@ class UserSerializer(ModelSerializer):
return RoleSerializer(instance.roles, many=True).data
def __init__(self, *args, **kwargs):
"""Setting password and permissions directly is allowed only in blueprints."""
super().__init__(*args, **kwargs)
if SERIALIZER_CONTEXT_BLUEPRINT in self.context:
self.fields["password"] = CharField(required=False, allow_null=True)
self.fields["password_hash"] = CharField(required=False, allow_null=True)
self.fields["permissions"] = ListField(
required=False,
child=ChoiceField(choices=get_permission_choices()),
)
def create(self, validated_data: dict) -> User:
"""Create a user, with blueprint-only password and permission writes."""
is_blueprint = SERIALIZER_CONTEXT_BLUEPRINT in self.context
if is_blueprint:
password = validated_data.pop("password", None)
password_hash = validated_data.pop("password_hash", None)
permissions = validated_data.pop("permissions", [])
self._validate_password_inputs(password, password_hash)
"""If this serializer is used in the blueprint context, we allow for
directly setting a password. However should be done via the `set_password`
method instead of directly setting it like rest_framework."""
password = validated_data.pop("password", None)
perms_qs = Permission.objects.filter(
codename__in=[x.split(".")[1] for x in validated_data.pop("permissions", [])]
).values_list("content_type__app_label", "codename")
perms_list = [f"{ct}.{name}" for ct, name in list(perms_qs)]
instance: User = super().create(validated_data)
if is_blueprint:
self._set_password(instance, password, password_hash)
perms_qs = Permission.objects.filter(
codename__in=[permission.split(".")[1] for permission in permissions]
).values_list("content_type__app_label", "codename")
perms_list = [f"{ct}.{name}" for ct, name in perms_qs]
instance.assign_perms_to_managed_role(perms_list)
self._ensure_password_not_empty(instance)
self._set_password(instance, password)
instance.assign_perms_to_managed_role(perms_list)
return instance
def update(self, instance: User, validated_data: dict) -> User:
"""Update a user, with blueprint-only password and permission writes."""
is_blueprint = SERIALIZER_CONTEXT_BLUEPRINT in self.context
if is_blueprint:
password = validated_data.pop("password", None)
password_hash = validated_data.pop("password_hash", None)
permissions = validated_data.pop("permissions", [])
self._validate_password_inputs(password, password_hash)
"""Same as `create` above, set the password directly if we're in a blueprint
context"""
password = validated_data.pop("password", None)
perms_qs = Permission.objects.filter(
codename__in=[x.split(".")[1] for x in validated_data.pop("permissions", [])]
).values_list("content_type__app_label", "codename")
perms_list = [f"{ct}.{name}" for ct, name in list(perms_qs)]
instance = super().update(instance, validated_data)
if is_blueprint:
self._set_password(instance, password, password_hash)
perms_qs = Permission.objects.filter(
codename__in=[permission.split(".")[1] for permission in permissions]
).values_list("content_type__app_label", "codename")
perms_list = [f"{ct}.{name}" for ct, name in perms_qs]
instance.assign_perms_to_managed_role(perms_list)
self._ensure_password_not_empty(instance)
self._set_password(instance, password)
instance.assign_perms_to_managed_role(perms_list)
return instance
def _validate_password_inputs(self, password: str | None, password_hash: str | None):
"""Validate mutually-exclusive password inputs before any model mutation."""
if password is not None and password_hash is not None:
raise ValidationError(_("Cannot set both password and password_hash. Use only one."))
if password_hash is None:
return
try:
User.validate_password_hash(password_hash)
except ValueError as exc:
LOGGER.warning("Failed to identify password hash format", exc_info=exc)
raise ValidationError(INVALID_PASSWORD_HASH_MESSAGE) from exc
def _set_password(self, instance: User, password: str | None, password_hash: str | None = None):
"""Set password from plain text or hash."""
if password_hash is not None:
instance.set_password_from_hash(password_hash)
instance.save()
elif password:
def _set_password(self, instance: User, password: str | None):
"""Set password of user if we're in a blueprint context, and if it's an empty
string then use an unusable password"""
if SERIALIZER_CONTEXT_BLUEPRINT in self.context and password:
instance.set_password(password)
instance.save()
def _ensure_password_not_empty(self, instance: User):
"""Store an explicit unusable password instead of an empty password field."""
if len(instance.password) == 0:
instance.set_unusable_password()
instance.save()
@@ -436,12 +390,6 @@ class UserPasswordSetSerializer(PassiveSerializer):
password = CharField(required=True)
class UserPasswordHashSetSerializer(PassiveSerializer):
"""Payload to set a users' password hash directly"""
password = CharField(required=True)
class UserServiceAccountSerializer(PassiveSerializer):
"""Payload to create a service account"""
@@ -593,30 +541,10 @@ class UserViewSet(
def get_queryset(self):
base_qs = User.objects.all().exclude_anonymous()
# Always prefetch groups since group PKs are always serialized.
# Use full prefetch when include_groups=true (for groups_obj), ID-only otherwise.
if self.serializer_class(context={"request": self.request})._should_include_groups:
base_qs = base_qs.prefetch_related("groups")
else:
base_qs = base_qs.prefetch_related(
Prefetch("groups", queryset=Group.objects.all().only("group_uuid"))
)
if self.serializer_class(context={"request": self.request})._should_include_roles:
base_qs = base_qs.prefetch_related("roles")
else:
base_qs = base_qs.prefetch_related(
Prefetch("roles", queryset=Role.objects.all().only("uuid"))
)
# Annotate is_superuser to avoid N+1 query per user
base_qs = base_qs.annotate(
_annotated_is_superuser=Exists(
Group.objects.filter(
is_superuser=True,
).filter(
Q(users=OuterRef("pk")) | Q(descendant_nodes__descendant__users=OuterRef("pk"))
)
)
)
return base_qs
@extend_schema(
@@ -785,11 +713,6 @@ class UserViewSet(
self.request.session.modified = True
return Response(serializer.initial_data)
def _update_session_hash_after_password_change(self, request: Request, user: User):
if user.pk == request.user.pk and SESSION_KEY_IMPERSONATE_USER not in self.request.session:
LOGGER.debug("Updating session hash after password change")
update_session_auth_hash(self.request, user)
@permission_required("authentik_core.reset_user_password")
@extend_schema(
request=UserPasswordSetSerializer,
@@ -813,45 +736,9 @@ class UserViewSet(
except (ValidationError, IntegrityError) as exc:
LOGGER.debug("Failed to set password", exc=exc)
return Response(status=400)
self._update_session_hash_after_password_change(request, user)
return Response(status=204)
@permission_required("authentik_core.reset_user_password")
@extend_schema(
request=UserPasswordHashSetSerializer,
responses={
204: OpenApiResponse(description="Successfully changed password"),
400: OpenApiResponse(description="Bad request"),
},
)
@action(
detail=True,
methods=["POST"],
permission_classes=[IsAuthenticated],
)
@validate(UserPasswordHashSetSerializer)
def set_password_hash(
self, request: Request, pk: int, body: UserPasswordHashSetSerializer
) -> Response:
"""Set a user's password from a pre-hashed Django password value.
Submit the Django password hash in the shared ``password`` request field.
This updates authentik's local password verifier only. It does not attempt
to propagate the password change to LDAP or Kerberos because no raw password
is available from the request payload.
"""
user: User = self.get_object()
try:
user.set_password_from_hash(body.validated_data["password"], request=request)
user.save()
except ValueError as exc:
LOGGER.debug("Failed to set password hash", exc=exc)
return Response(data={"password": [INVALID_PASSWORD_HASH_MESSAGE]}, status=400)
except (ValidationError, IntegrityError) as exc:
LOGGER.debug("Failed to set password hash", exc=exc)
return Response(status=400)
self._update_session_hash_after_password_change(request, user)
if user.pk == request.user.pk and SESSION_KEY_IMPERSONATE_USER not in self.request.session:
LOGGER.debug("Updating session hash after password change")
update_session_auth_hash(self.request, user)
return Response(status=204)
@permission_required("authentik_core.reset_user_password")

View File

@@ -1,28 +0,0 @@
"""Hash password using Django's password hashers"""
from django.contrib.auth.hashers import make_password
from django.core.management.base import BaseCommand, CommandError
class Command(BaseCommand):
"""Hash a password using Django's password hashers"""
help = "Hash a password for use with AUTHENTIK_BOOTSTRAP_PASSWORD_HASH"
def add_arguments(self, parser):
parser.add_argument(
"password",
type=str,
help="Password to hash",
)
def handle(self, *args, **options):
password = options["password"]
if not password:
raise CommandError("Password cannot be empty")
try:
hashed = make_password(password)
self.stdout.write(hashed)
except ValueError as exc:
raise CommandError(f"Error hashing password: {exc}") from exc

View File

@@ -1,33 +0,0 @@
# Generated by Django 5.2.12 on 2026-04-09 18:04
from django.apps.registry import Apps
from django.db import migrations, models
from django.db.backends.base.schema import BaseDatabaseSchemaEditor
def migrate_blank_launch_url(apps: Apps, schema_editor: BaseDatabaseSchemaEditor):
db_alias = schema_editor.connection.alias
Application = apps.get_model("authentik_core", "Application")
Application.objects.using(db_alias).filter(meta_launch_url="blank://blank").update(
meta_hide=True, meta_launch_url=""
)
class Migration(migrations.Migration):
dependencies = [
("authentik_core", "0058_setup"),
]
operations = [
migrations.AddField(
model_name="application",
name="meta_hide",
field=models.BooleanField(
default=False,
help_text="Hide this application from the user's My applications page.",
),
),
migrations.RunPython(migrate_blank_launch_url, migrations.RunPython.noop),
]

View File

@@ -10,7 +10,7 @@ from uuid import uuid4
import pgtrigger
from deepmerge import always_merger
from django.contrib.auth.hashers import check_password, identify_hasher
from django.contrib.auth.hashers import check_password
from django.contrib.auth.models import AbstractUser, Permission
from django.contrib.auth.models import UserManager as DjangoUserManager
from django.contrib.sessions.base_session import AbstractBaseSession
@@ -560,33 +560,6 @@ class User(SerializerModel, AttributesMixin, AbstractUser):
self.password_change_date = now()
return super().set_password(raw_password)
@staticmethod
def validate_password_hash(password_hash: str):
"""Validate that the value is a recognized Django password hash."""
identify_hasher(password_hash) # Raises ValueError if invalid
def set_password_from_hash(self, password_hash: str, signal=True, sender=None, request=None):
"""Set password directly from a pre-hashed value.
Unlike set_password(), this does not hash the input again. The provided value
must already be a valid Django password hash, and it is stored directly on the
user after validation.
Because no raw password is available, downstream password sync integrations
such as LDAP and Kerberos cannot be updated from this code path.
Raises ValueError if the hash format is not recognized.
"""
self.validate_password_hash(password_hash)
if self.pk and signal:
from authentik.core.signals import password_hash_changed
if not sender:
sender = self
password_hash_changed.send(sender=sender, user=self, request=request)
self.password = password_hash
self.password_change_date = now()
def check_password(self, raw_password: str) -> bool:
"""
Return a boolean of whether the raw_password was correct. Handles
@@ -762,9 +735,6 @@ class Application(SerializerModel, PolicyBindingModel):
meta_icon = FileField(default="", blank=True)
meta_description = models.TextField(default="", blank=True)
meta_publisher = models.TextField(default="", blank=True)
meta_hide = models.BooleanField(
default=False, help_text=_("Hide this application from the user's My applications page.")
)
objects = ApplicationQuerySet.as_manager()

View File

@@ -16,11 +16,7 @@ LOGGER = get_logger()
@receiver(post_startup)
def post_startup_setup_bootstrap(sender, **_):
if (
not getenv("AUTHENTIK_BOOTSTRAP_PASSWORD")
and not getenv("AUTHENTIK_BOOTSTRAP_PASSWORD_HASH")
and not getenv("AUTHENTIK_BOOTSTRAP_TOKEN")
):
if not getenv("AUTHENTIK_BOOTSTRAP_PASSWORD") and not getenv("AUTHENTIK_BOOTSTRAP_TOKEN"):
return
LOGGER.info("Configuring authentik through bootstrap environment variables")
content = BlueprintInstance(path=BOOTSTRAP_BLUEPRINT).retrieve()

View File

@@ -24,8 +24,6 @@ from authentik.root.ws.consumer import build_device_group
# Arguments: user: User, password: str
password_changed = Signal()
# Arguments: user: User, request: HttpRequest | None
password_hash_changed = Signal()
# Arguments: credentials: dict[str, any], request: HttpRequest,
# stage: Stage, context: dict[str, any]
login_failed = Signal()

View File

@@ -129,7 +129,6 @@ class TestApplicationsAPI(APITestCase):
"meta_icon_url": None,
"meta_icon_themed_urls": None,
"meta_description": "",
"meta_hide": False,
"meta_publisher": "",
"policy_engine_mode": "any",
},
@@ -188,14 +187,12 @@ class TestApplicationsAPI(APITestCase):
"meta_icon_url": None,
"meta_icon_themed_urls": None,
"meta_description": "",
"meta_hide": False,
"meta_publisher": "",
"policy_engine_mode": "any",
},
{
"launch_url": None,
"meta_description": "",
"meta_hide": False,
"meta_icon": "",
"meta_icon_url": None,
"meta_icon_themed_urls": None,

View File

@@ -1,28 +0,0 @@
"""Tests for hash_password management command."""
from io import StringIO
from django.contrib.auth.hashers import check_password
from django.core.management import call_command
from django.core.management.base import CommandError
from django.test import TestCase
class TestHashPasswordCommand(TestCase):
"""Test hash_password management command."""
def test_hash_password(self):
"""Test hashing a password."""
out = StringIO()
call_command("hash_password", "test123", stdout=out)
hashed = out.getvalue().strip()
self.assertTrue(hashed.startswith("pbkdf2_sha256$"))
self.assertTrue(check_password("test123", hashed))
def test_hash_password_empty_fails(self):
"""Test that empty password raises error."""
with self.assertRaises(CommandError) as ctx:
call_command("hash_password", "")
self.assertIn("Password cannot be empty", str(ctx.exception))

View File

@@ -1,7 +1,6 @@
from http import HTTPStatus
from os import environ
from django.contrib.auth.hashers import make_password
from django.urls import reverse
from authentik.blueprints.tests import apply_blueprint
@@ -17,7 +16,6 @@ from authentik.tenants.flags import patch_flag
class TestSetup(FlowTestCase):
def tearDown(self):
environ.pop("AUTHENTIK_BOOTSTRAP_PASSWORD", None)
environ.pop("AUTHENTIK_BOOTSTRAP_PASSWORD_HASH", None)
environ.pop("AUTHENTIK_BOOTSTRAP_TOKEN", None)
@patch_flag(Setup, True)
@@ -156,19 +154,3 @@ class TestSetup(FlowTestCase):
token = Token.objects.filter(identifier="authentik-bootstrap-token").first()
self.assertEqual(token.intent, TokenIntents.INTENT_API)
self.assertEqual(token.key, environ["AUTHENTIK_BOOTSTRAP_TOKEN"])
def test_setup_bootstrap_env_password_hash(self):
"""Test setup with password hash env var"""
User.objects.filter(username="akadmin").delete()
Setup.set(False)
password = generate_id()
password_hash = make_password(password)
environ["AUTHENTIK_BOOTSTRAP_PASSWORD_HASH"] = password_hash
pre_startup.send(sender=self)
post_startup.send(sender=self)
self.assertTrue(Setup.get())
user = User.objects.get(username="akadmin")
self.assertEqual(user.password, password_hash)
self.assertTrue(user.check_password(password))

View File

@@ -1,15 +1,8 @@
"""user tests"""
from unittest.mock import patch
from django.contrib.auth.hashers import make_password
from django.test.testcases import TestCase
from rest_framework.exceptions import ValidationError
from authentik.blueprints.v1.importer import SERIALIZER_CONTEXT_BLUEPRINT
from authentik.core.api.users import UserSerializer
from authentik.core.models import User
from authentik.core.signals import password_changed, password_hash_changed
from authentik.events.models import Event
from authentik.lib.generators import generate_id
@@ -40,99 +33,3 @@ class TestUsers(TestCase):
self.assertEqual(Event.objects.count(), 1)
user.ak_groups.all()
self.assertEqual(Event.objects.count(), 1)
def test_set_password_from_hash_signal_skips_source_sync_receivers(self):
"""Test hash password updates do not expose a raw password to sync receivers."""
user = User.objects.create(
username=generate_id(),
attributes={"distinguishedName": "cn=test,ou=users,dc=example,dc=com"},
)
password_changed_captured = []
password_hash_changed_captured = []
dispatch_uid = generate_id()
hash_dispatch_uid = generate_id()
def password_changed_receiver(sender, **kwargs):
password_changed_captured.append(kwargs)
def password_hash_changed_receiver(sender, **kwargs):
password_hash_changed_captured.append(kwargs)
password_changed.connect(password_changed_receiver, dispatch_uid=dispatch_uid)
password_hash_changed.connect(
password_hash_changed_receiver, dispatch_uid=hash_dispatch_uid
)
try:
with (
patch(
"authentik.sources.ldap.signals.LDAPSource.objects.filter"
) as ldap_sources_filter,
patch(
"authentik.sources.kerberos.signals."
"UserKerberosSourceConnection.objects.select_related"
) as kerberos_connections_select,
):
user.set_password_from_hash(make_password("new-password")) # nosec
user.save()
finally:
password_changed.disconnect(dispatch_uid=dispatch_uid)
password_hash_changed.disconnect(dispatch_uid=hash_dispatch_uid)
self.assertEqual(password_changed_captured, [])
self.assertEqual(len(password_hash_changed_captured), 1)
ldap_sources_filter.assert_not_called()
kerberos_connections_select.assert_not_called()
class TestUserSerializerPasswordHash(TestCase):
"""Test UserSerializer password_hash support in blueprint context."""
def test_password_hash_sets_password_directly(self):
"""Test a valid password hash is stored without re-hashing."""
password = "test-password-123" # nosec
password_hash = make_password(password)
serializer = UserSerializer(
data={
"username": generate_id(),
"name": "Test User",
"password_hash": password_hash,
},
context={SERIALIZER_CONTEXT_BLUEPRINT: True},
)
self.assertTrue(serializer.is_valid(), serializer.errors)
user = serializer.save()
self.assertEqual(user.password, password_hash)
self.assertTrue(user.check_password(password))
self.assertIsNotNone(user.password_change_date)
def test_password_hash_rejects_invalid_format(self):
"""Test invalid password hash values are rejected."""
serializer = UserSerializer(
data={
"username": generate_id(),
"name": "Test User",
"password_hash": "not-a-valid-hash",
},
context={SERIALIZER_CONTEXT_BLUEPRINT: True},
)
self.assertTrue(serializer.is_valid(), serializer.errors)
with self.assertRaises(ValidationError) as ctx:
serializer.save()
self.assertIn("Invalid password hash format", str(ctx.exception))
def test_password_hash_ignored_outside_blueprint_context(self):
"""Test password_hash is not accepted by the regular serializer."""
serializer = UserSerializer(
data={
"username": generate_id(),
"name": "Test User",
"password_hash": make_password("test"), # nosec
}
)
self.assertTrue(serializer.is_valid(), serializer.errors)
self.assertNotIn("password_hash", serializer.validated_data)

View File

@@ -3,7 +3,6 @@
from datetime import datetime, timedelta
from json import loads
from django.contrib.auth.hashers import make_password
from django.urls.base import reverse
from django.utils.timezone import now
from rest_framework.test import APITestCase
@@ -27,9 +26,6 @@ from authentik.flows.models import FlowAuthenticationRequirement, FlowDesignatio
from authentik.lib.generators import generate_id, generate_key
from authentik.stages.email.models import EmailStage
INVALID_PASSWORD_HASH = "not-a-valid-hash"
INVALID_PASSWORD_HASH_ERROR = "Invalid password hash format. Must be a valid Django password hash."
class TestUsersAPI(APITestCase):
"""Test Users API"""
@@ -38,20 +34,6 @@ class TestUsersAPI(APITestCase):
self.admin = create_test_admin_user()
self.user = create_test_user()
def _set_password_hash(self, user: User, password_hash: str, client=None):
return (client or self.client).post(
reverse("authentik_api:user-set-password-hash", kwargs={"pk": user.pk}),
data={"password": password_hash},
)
def _assert_password_hash_set(
self, user: User, password: str, password_hash: str, response
) -> None:
self.assertEqual(response.status_code, 204, response.data)
user.refresh_from_db()
self.assertEqual(user.password, password_hash)
self.assertTrue(user.check_password(password))
def test_filter_type(self):
"""Test API filtering by type"""
self.client.force_login(self.admin)
@@ -131,26 +113,6 @@ class TestUsersAPI(APITestCase):
self.assertEqual(response.status_code, 400)
self.assertJSONEqual(response.content, {"password": ["This field may not be blank."]})
def test_set_password_hash(self):
"""Test setting a user's password from a hash."""
self.client.force_login(self.admin)
password = generate_key()
password_hash = make_password(password)
response = self._set_password_hash(self.user, password_hash)
self._assert_password_hash_set(self.user, password, password_hash, response)
def test_set_password_hash_invalid(self):
"""Test invalid password hashes are rejected."""
self.client.force_login(self.admin)
response = self._set_password_hash(self.user, INVALID_PASSWORD_HASH)
self.assertEqual(response.status_code, 400)
self.assertJSONEqual(
response.content,
{"password": [INVALID_PASSWORD_HASH_ERROR]},
)
def test_recovery(self):
"""Test user recovery link"""
flow = create_test_flow(
@@ -299,29 +261,6 @@ class TestUsersAPI(APITestCase):
self.assertTrue(token_filter.exists())
self.assertTrue(token_filter.first().expiring)
def test_service_account_set_password_hash(self):
"""Service account password hash can be set through the API."""
self.client.force_login(self.admin)
response = self.client.post(
reverse("authentik_api:user-service-account"),
data={
"name": "test-sa",
"create_group": False,
},
)
self.assertEqual(response.status_code, 200, response.data)
body = loads(response.content)
user = User.objects.get(pk=body["user_pk"])
self.assertEqual(user.type, UserTypes.SERVICE_ACCOUNT)
self.assertFalse(user.has_usable_password())
password = generate_key()
password_hash = make_password(password)
response = self._set_password_hash(user, password_hash)
self._assert_password_hash_set(user, password, password_hash, response)
def test_service_account_no_expire(self):
"""Service account creation without token expiration"""
self.client.force_login(self.admin)

View File

@@ -12,7 +12,7 @@ from authentik.core.models import (
User,
UserTypes,
)
from authentik.core.signals import password_changed, password_hash_changed
from authentik.core.signals import password_changed
from authentik.enterprise.providers.ssf.models import (
EventTypes,
SSFProvider,
@@ -84,13 +84,14 @@ def ssf_user_session_delete_session_revoked(sender, instance: AuthenticatedSessi
)
def _send_password_credential_change(user: User, change_type: str):
@receiver(password_changed)
def ssf_password_changed_cred_change(sender, user: User, password: str | None, **_):
"""Credential change trigger (password changed)"""
send_ssf_events(
EventTypes.CAEP_CREDENTIAL_CHANGE,
{
"credential_type": "password",
"change_type": change_type,
"change_type": "revoke" if password is None else "update",
},
sub_id={
"format": "complex",
@@ -102,16 +103,6 @@ def _send_password_credential_change(user: User, change_type: str):
)
@receiver(password_hash_changed)
@receiver(password_changed)
def ssf_password_changed_cred_change(signal, sender, user: User, password: str | None = None, **_):
"""Credential change trigger (password changed)"""
if signal is password_hash_changed:
_send_password_credential_change(user, "update")
return
_send_password_credential_change(user, "revoke" if password is None else "update")
device_type_map = {
StaticDevice: "pin",
TOTPDevice: "pin",

View File

@@ -1,6 +1,5 @@
from uuid import uuid4
from django.contrib.auth.hashers import make_password
from django.urls import reverse
from rest_framework.test import APITestCase
@@ -53,21 +52,6 @@ class TestSignals(APITestCase):
)
self.assertEqual(res.status_code, 201, res.content)
def _assert_password_credential_change(self, user, change_type: str):
stream = Stream.objects.filter(provider=self.provider).first()
self.assertIsNotNone(stream)
event = StreamEvent.objects.filter(stream=stream).first()
self.assertIsNotNone(event)
self.assertEqual(event.status, SSFEventStatus.PENDING_FAILED)
event_payload = event.payload["events"][
"https://schemas.openid.net/secevent/caep/event-type/credential-change"
]
self.assertEqual(event_payload["change_type"], change_type)
self.assertEqual(event_payload["credential_type"], "password")
self.assertEqual(event.payload["sub_id"]["format"], "complex")
self.assertEqual(event.payload["sub_id"]["user"]["format"], "email")
self.assertEqual(event.payload["sub_id"]["user"]["email"], user.email)
def test_signal_logout(self):
"""Test user logout"""
user = create_test_user()
@@ -95,25 +79,19 @@ class TestSignals(APITestCase):
user.set_password(generate_id())
user.save()
self._assert_password_credential_change(user, "update")
def test_signal_password_change_from_hash(self):
"""Test user password change from a pre-hashed password."""
user = create_test_user()
self.client.force_login(user)
user.set_password_from_hash(make_password(generate_id()))
user.save()
self._assert_password_credential_change(user, "update")
def test_signal_password_revoke(self):
"""Test explicit password revoke."""
user = create_test_user()
self.client.force_login(user)
user.set_password(None)
user.save()
self._assert_password_credential_change(user, "revoke")
stream = Stream.objects.filter(provider=self.provider).first()
self.assertIsNotNone(stream)
event = StreamEvent.objects.filter(stream=stream).first()
self.assertIsNotNone(event)
self.assertEqual(event.status, SSFEventStatus.PENDING_FAILED)
event_payload = event.payload["events"][
"https://schemas.openid.net/secevent/caep/event-type/credential-change"
]
self.assertEqual(event_payload["change_type"], "update")
self.assertEqual(event_payload["credential_type"], "password")
self.assertEqual(event.payload["sub_id"]["format"], "complex")
self.assertEqual(event.payload["sub_id"]["user"]["format"], "email")
self.assertEqual(event.payload["sub_id"]["user"]["email"], user.email)
def test_signal_authenticator_added(self):
"""Test authenticator creation signal"""

View File

@@ -11,7 +11,7 @@ from django.http import HttpRequest
from rest_framework.request import Request
from authentik.core.models import AuthenticatedSession, User
from authentik.core.signals import login_failed, password_changed, password_hash_changed
from authentik.core.signals import login_failed, password_changed
from authentik.events.models import Event, EventAction
from authentik.flows.models import Stage
from authentik.flows.planner import (
@@ -112,15 +112,8 @@ def on_invitation_used(sender, request: HttpRequest, invitation: Invitation, **_
)
@receiver(password_hash_changed)
@receiver(password_changed)
def on_password_changed(
sender,
user: User,
password: str | None = None,
request: HttpRequest | None = None,
**_,
):
def on_password_changed(sender, user: User, password: str, request: HttpRequest | None, **_):
"""Log password change"""
Event.new(EventAction.PASSWORD_SET).from_http(request, user=user)

View File

@@ -2,7 +2,6 @@
from urllib.parse import urlencode
from django.contrib.auth.hashers import make_password
from django.contrib.contenttypes.models import ContentType
from django.test import RequestFactory, TestCase
from django.views.debug import SafeExceptionReporterFilter
@@ -11,7 +10,7 @@ from guardian.shortcuts import get_anonymous_user
from authentik.brands.models import Brand
from authentik.core.models import Group, User
from authentik.core.tests.utils import create_test_user
from authentik.events.models import Event, EventAction
from authentik.events.models import Event
from authentik.flows.planner import PLAN_CONTEXT_PENDING_USER, FlowPlan
from authentik.flows.views.executor import QS_QUERY, SESSION_KEY_PLAN
from authentik.lib.generators import generate_id
@@ -214,14 +213,3 @@ class TestEvents(TestCase):
event = Event.new("unittest", foo="foo bar \u0000 baz")
event.save()
self.assertEqual(event.context["foo"], "foo bar baz")
def test_password_set_signal_on_set_password_from_hash(self):
"""Changing password from hash should still emit an audit event."""
user = create_test_user()
old_count = Event.objects.filter(action=EventAction.PASSWORD_SET, user__pk=user.pk).count()
user.set_password_from_hash(make_password(generate_id()))
user.save()
new_count = Event.objects.filter(action=EventAction.PASSWORD_SET, user__pk=user.pk).count()
self.assertEqual(new_count, old_count + 1)

View File

@@ -65,7 +65,6 @@ class OAuth2ProviderSerializer(ProviderSerializer):
fields = ProviderSerializer.Meta.fields + [
"authorization_flow",
"client_type",
"grant_types",
"client_id",
"client_secret",
"access_code_validity",

View File

@@ -7,7 +7,7 @@ from django.http import HttpRequest, HttpResponse, HttpResponseRedirect
from authentik.events.models import Event, EventAction
from authentik.lib.sentry import SentryIgnoredException
from authentik.lib.views import bad_request_message
from authentik.providers.oauth2.models import GrantType, RedirectURI
from authentik.providers.oauth2.models import GrantTypes, RedirectURI
class OAuth2Error(SentryIgnoredException):
@@ -182,7 +182,7 @@ class AuthorizeError(OAuth2Error):
# See:
# http://openid.net/specs/openid-connect-core-1_0.html#ImplicitAuthError
fragment_or_query = (
"#" if self.grant_type in [GrantType.IMPLICIT, GrantType.HYBRID] else "?"
"#" if self.grant_type in [GrantTypes.IMPLICIT, GrantTypes.HYBRID] else "?"
)
uri = (
@@ -225,7 +225,7 @@ class TokenError(OAuth2Error):
),
}
def __init__(self, error: str):
def __init__(self, error):
super().__init__()
self.error = error
self.description = self.errors[error]

View File

@@ -1,69 +0,0 @@
# Generated by Django 5.2.11 on 2026-02-17 11:04
import django.contrib.postgres.fields
from django.db import migrations, models
def migrate_default_grant_types():
from authentik.providers.oauth2.models import GrantType
return [
GrantType.AUTHORIZATION_CODE,
GrantType.HYBRID,
GrantType.IMPLICIT,
GrantType.CLIENT_CREDENTIALS,
GrantType.PASSWORD,
GrantType.DEVICE_CODE,
GrantType.REFRESH_TOKEN,
]
class Migration(migrations.Migration):
dependencies = [
(
"authentik_providers_oauth2",
"0031_remove_oauth2provider_backchannel_logout_uri_and_more",
),
]
operations = [
migrations.AddField(
model_name="oauth2provider",
name="grant_types",
field=django.contrib.postgres.fields.ArrayField(
base_field=models.TextField(
choices=[
("authorization_code", "Authorization Code"),
("implicit", "Implicit"),
("hybrid", "Hybrid"),
("refresh_token", "Refresh Token"),
("client_credentials", "Client Credentials"),
("password", "Password"),
("urn:ietf:params:oauth:grant-type:device_code", "Device Code"),
]
),
default=migrate_default_grant_types,
size=None,
),
),
migrations.AlterField(
model_name="oauth2provider",
name="grant_types",
field=django.contrib.postgres.fields.ArrayField(
base_field=models.TextField(
choices=[
("authorization_code", "Authorization Code"),
("implicit", "Implicit"),
("hybrid", "Hybrid"),
("refresh_token", "Refresh Token"),
("client_credentials", "Client Credentials"),
("password", "Password"),
("urn:ietf:params:oauth:grant-type:device_code", "Device Code"),
]
),
default=list,
size=None,
),
),
]

View File

@@ -19,7 +19,6 @@ from cryptography.hazmat.primitives.asymmetric.rsa import RSAPrivateKey
from cryptography.hazmat.primitives.asymmetric.types import PrivateKeyTypes
from dacite import Config
from dacite.core import from_dict
from django.contrib.postgres.fields import ArrayField
from django.contrib.postgres.indexes import HashIndex
from django.db import models
from django.http import HttpRequest
@@ -34,16 +33,7 @@ from rest_framework.serializers import Serializer
from structlog.stdlib import get_logger
from authentik.brands.models import WebfingerProvider
from authentik.common.oauth.constants import (
GRANT_TYPE_AUTHORIZATION_CODE,
GRANT_TYPE_CLIENT_CREDENTIALS,
GRANT_TYPE_DEVICE_CODE,
GRANT_TYPE_HYBRID,
GRANT_TYPE_IMPLICIT,
GRANT_TYPE_PASSWORD,
GRANT_TYPE_REFRESH_TOKEN,
SubModes,
)
from authentik.common.oauth.constants import SubModes
from authentik.core.models import (
AuthenticatedSession,
ExpiringModel,
@@ -68,7 +58,7 @@ def generate_client_secret() -> str:
return generate_id(128)
class ClientType(models.TextChoices):
class ClientTypes(models.TextChoices):
"""Confidential clients are capable of maintaining the confidentiality
of their credentials. Public clients are incapable."""
@@ -76,16 +66,12 @@ class ClientType(models.TextChoices):
PUBLIC = "public", _("Public")
class GrantType(models.TextChoices):
class GrantTypes(models.TextChoices):
"""OAuth2 Grant types we support"""
AUTHORIZATION_CODE = GRANT_TYPE_AUTHORIZATION_CODE
IMPLICIT = GRANT_TYPE_IMPLICIT
HYBRID = GRANT_TYPE_HYBRID
REFRESH_TOKEN = GRANT_TYPE_REFRESH_TOKEN
CLIENT_CREDENTIALS = GRANT_TYPE_CLIENT_CREDENTIALS
PASSWORD = GRANT_TYPE_PASSWORD
DEVICE_CODE = GRANT_TYPE_DEVICE_CODE
AUTHORIZATION_CODE = "authorization_code"
IMPLICIT = "implicit"
HYBRID = "hybrid"
class ResponseMode(models.TextChoices):
@@ -202,15 +188,14 @@ class OAuth2Provider(WebfingerProvider, Provider):
client_type = models.CharField(
max_length=30,
choices=ClientType.choices,
default=ClientType.CONFIDENTIAL,
choices=ClientTypes.choices,
default=ClientTypes.CONFIDENTIAL,
verbose_name=_("Client Type"),
help_text=_(
"Confidential clients are capable of maintaining the confidentiality "
"of their credentials. Public clients are incapable"
),
)
grant_types = ArrayField(models.TextField(choices=GrantType.choices), default=list)
client_id = models.CharField(
max_length=255,
unique=True,

View File

@@ -22,7 +22,7 @@ from authentik.providers.oauth2.errors import AuthorizeError, ClientIdError, Red
from authentik.providers.oauth2.models import (
AccessToken,
AuthorizationCode,
GrantType,
GrantTypes,
OAuth2Provider,
RedirectURI,
RedirectURIMatchingMode,
@@ -41,34 +41,12 @@ class TestAuthorize(OAuthTestCase):
super().setUp()
self.factory = RequestFactory()
def test_disallowed_grant_type(self):
"""Test with disallowed grant type"""
OAuth2Provider.objects.create(
name=generate_id(),
client_id="test",
grant_types=[],
authorization_flow=create_test_flow(),
redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://local.invalid/Foo")],
)
with self.assertRaises(AuthorizeError) as cm:
request = self.factory.get(
"/",
data={
"response_type": "code",
"client_id": "test",
"redirect_uri": "http://local.invalid/Foo",
},
)
OAuthAuthorizationParams.from_request(request)
self.assertEqual(cm.exception.error, "invalid_request")
def test_invalid_grant_type(self):
"""Test with invalid grant type"""
OAuth2Provider.objects.create(
name=generate_id(),
client_id="test",
authorization_flow=create_test_flow(),
grant_types=[GrantType.AUTHORIZATION_CODE],
redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://local.invalid/Foo")],
)
with self.assertRaises(AuthorizeError) as cm:
@@ -96,7 +74,6 @@ class TestAuthorize(OAuthTestCase):
client_id="test",
authorization_flow=create_test_flow(),
redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://local.invalid/Foo")],
grant_types=[GrantType.AUTHORIZATION_CODE],
)
with self.assertRaises(AuthorizeError) as cm:
request = self.factory.get(
@@ -211,7 +188,6 @@ class TestAuthorize(OAuthTestCase):
client_id="test",
authorization_flow=create_test_flow(),
redirect_uris=[RedirectURI(RedirectURIMatchingMode.REGEX, ".+")],
grant_types=[GrantType.AUTHORIZATION_CODE],
)
request = self.factory.get(
"/",
@@ -230,7 +206,6 @@ class TestAuthorize(OAuthTestCase):
name=generate_id(),
client_id="test",
authorization_flow=create_test_flow(),
grant_types=[GrantType.AUTHORIZATION_CODE],
redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://local.invalid/Foo")],
)
provider.property_mappings.set(
@@ -252,14 +227,12 @@ class TestAuthorize(OAuthTestCase):
)
self.assertEqual(
OAuthAuthorizationParams.from_request(request).grant_type,
GrantType.AUTHORIZATION_CODE,
GrantTypes.AUTHORIZATION_CODE,
)
self.assertEqual(
OAuthAuthorizationParams.from_request(request).redirect_uri,
"http://local.invalid/Foo",
)
provider.grant_types = [GrantType.IMPLICIT]
provider.save()
request = self.factory.get(
"/",
data={
@@ -273,7 +246,7 @@ class TestAuthorize(OAuthTestCase):
)
self.assertEqual(
OAuthAuthorizationParams.from_request(request).grant_type,
GrantType.IMPLICIT,
GrantTypes.IMPLICIT,
)
# Implicit without openid scope
with self.assertRaises(AuthorizeError) as cm:
@@ -288,10 +261,8 @@ class TestAuthorize(OAuthTestCase):
)
self.assertEqual(
OAuthAuthorizationParams.from_request(request).grant_type,
GrantType.IMPLICIT,
GrantTypes.IMPLICIT,
)
provider.grant_types = [GrantType.HYBRID]
provider.save()
request = self.factory.get(
"/",
data={
@@ -303,7 +274,7 @@ class TestAuthorize(OAuthTestCase):
},
)
self.assertEqual(
OAuthAuthorizationParams.from_request(request).grant_type, GrantType.HYBRID
OAuthAuthorizationParams.from_request(request).grant_type, GrantTypes.HYBRID
)
with self.assertRaises(AuthorizeError) as cm:
request = self.factory.get(
@@ -326,7 +297,6 @@ class TestAuthorize(OAuthTestCase):
authorization_flow=flow,
redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "foo://localhost")],
access_code_validity="seconds=100",
grant_types=[GrantType.AUTHORIZATION_CODE],
)
Application.objects.create(name="app", slug="app", provider=provider)
state = generate_id()
@@ -363,7 +333,6 @@ class TestAuthorize(OAuthTestCase):
authorization_flow=flow,
redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://localhost")],
signing_key=self.keypair,
grant_types=[GrantType.IMPLICIT],
)
provider.property_mappings.set(
ScopeMapping.objects.filter(
@@ -435,7 +404,6 @@ class TestAuthorize(OAuthTestCase):
redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://localhost")],
signing_key=self.keypair,
encryption_key=self.keypair,
grant_types=[GrantType.IMPLICIT],
)
provider.property_mappings.set(
ScopeMapping.objects.filter(
@@ -498,7 +466,6 @@ class TestAuthorize(OAuthTestCase):
authorization_flow=flow,
redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://localhost")],
signing_key=self.keypair,
grant_types=[GrantType.AUTHORIZATION_CODE],
)
Application.objects.create(name="app", slug="app", provider=provider)
state = generate_id()
@@ -548,7 +515,6 @@ class TestAuthorize(OAuthTestCase):
authorization_flow=flow,
redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://localhost")],
signing_key=self.keypair,
grant_types=[GrantType.IMPLICIT],
)
provider.property_mappings.set(
ScopeMapping.objects.filter(
@@ -606,7 +572,6 @@ class TestAuthorize(OAuthTestCase):
authorization_flow=flow,
redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://localhost")],
signing_key=self.keypair,
grant_types=[GrantType.AUTHORIZATION_CODE],
)
app = Application.objects.create(name=generate_id(), slug=generate_id(), provider=provider)
state = generate_id()
@@ -647,7 +612,6 @@ class TestAuthorize(OAuthTestCase):
name=generate_id(),
client_id="test",
authorization_flow=create_test_flow(),
grant_types=[GrantType.IMPLICIT],
redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://localhost")],
)
request = self.factory.get(
@@ -671,7 +635,6 @@ class TestAuthorize(OAuthTestCase):
client_id="test",
authorization_flow=create_test_flow(),
redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://localhost")],
grant_types=[GrantType.IMPLICIT],
)
provider.property_mappings.set(
ScopeMapping.objects.filter(
@@ -704,7 +667,6 @@ class TestAuthorize(OAuthTestCase):
authorization_flow=flow,
redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "foo://localhost")],
access_code_validity="seconds=100",
grant_types=[GrantType.AUTHORIZATION_CODE],
)
Application.objects.create(name="app", slug="app", provider=provider)
state = generate_id()
@@ -735,7 +697,6 @@ class TestAuthorize(OAuthTestCase):
authorization_flow=flow,
redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "foo://localhost")],
access_code_validity="seconds=100",
grant_types=[GrantType.AUTHORIZATION_CODE],
)
Application.objects.create(name="app", slug="app", provider=provider)
state = generate_id()
@@ -775,7 +736,6 @@ class TestAuthorize(OAuthTestCase):
authentication_flow=auth_flow,
redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "foo://localhost")],
access_code_validity="seconds=100",
grant_types=[GrantType.AUTHORIZATION_CODE],
)
Application.objects.create(name="app", slug="app", provider=provider)
state = generate_id()
@@ -802,7 +762,6 @@ class TestAuthorize(OAuthTestCase):
authorization_flow=flow,
redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "foo://localhost")],
access_code_validity="seconds=100",
grant_types=[GrantType.AUTHORIZATION_CODE],
)
Application.objects.create(name="app", slug="app", provider=provider)
state = generate_id()

View File

@@ -10,7 +10,7 @@ from authentik.blueprints.tests import apply_blueprint
from authentik.core.models import Application
from authentik.core.tests.utils import create_test_flow
from authentik.lib.generators import generate_id
from authentik.providers.oauth2.models import DeviceToken, GrantType, OAuth2Provider, ScopeMapping
from authentik.providers.oauth2.models import DeviceToken, OAuth2Provider, ScopeMapping
from authentik.providers.oauth2.tests.utils import OAuthTestCase
@@ -22,7 +22,6 @@ class TesOAuth2DeviceBackchannel(OAuthTestCase):
name=generate_id(),
client_id="test",
authorization_flow=create_test_flow(),
grant_types=[GrantType.DEVICE_CODE],
)
self.application = Application.objects.create(
name=generate_id(),
@@ -43,21 +42,6 @@ class TesOAuth2DeviceBackchannel(OAuthTestCase):
reverse("authentik_providers_oauth2:device"),
)
self.assertEqual(res.status_code, 400)
def test_backchannel_invalid_no_grant(self):
"""Test backchannel"""
self.provider.grant_types = []
self.provider.save()
res = self.client.post(
reverse("authentik_providers_oauth2:device"),
data={
"client_id": "test",
},
)
self.assertEqual(res.status_code, 400)
def test_backchannel_invalid_no_app(self):
"""Test backchannel"""
# test without application
self.application.provider = None
self.application.save()

View File

@@ -9,7 +9,7 @@ from authentik.core.models import Application, Group
from authentik.core.tests.utils import create_test_admin_user, create_test_brand, create_test_flow
from authentik.lib.generators import generate_id
from authentik.policies.models import PolicyBinding
from authentik.providers.oauth2.models import DeviceToken, GrantType, OAuth2Provider
from authentik.providers.oauth2.models import DeviceToken, OAuth2Provider
from authentik.providers.oauth2.tests.utils import OAuthTestCase
from authentik.providers.oauth2.views.device_init import QS_KEY_CODE
@@ -22,7 +22,6 @@ class TesOAuth2DeviceInit(OAuthTestCase):
name=generate_id(),
client_id="test",
authorization_flow=create_test_flow(),
grant_types=[GrantType.DEVICE_CODE],
)
self.application = Application.objects.create(
name=generate_id(),

View File

@@ -14,7 +14,7 @@ from authentik.lib.generators import generate_id
from authentik.providers.oauth2.id_token import IDToken
from authentik.providers.oauth2.models import (
AccessToken,
ClientType,
ClientTypes,
OAuth2Provider,
RedirectURI,
RedirectURIMatchingMode,
@@ -173,7 +173,7 @@ class TesOAuth2Introspection(OAuthTestCase):
def test_introspect_provider_public(self):
"""Test introspect"""
self.provider.client_type = ClientType.PUBLIC
self.provider.client_type = ClientTypes.PUBLIC
self.provider.save()
token = AccessToken.objects.create(
provider=self.provider,
@@ -208,7 +208,7 @@ class TesOAuth2Introspection(OAuthTestCase):
authorization_flow=create_test_flow(),
redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "")],
signing_key=create_test_cert(),
client_type=ClientType.PUBLIC,
client_type=ClientTypes.PUBLIC,
)
Application.objects.create(name=generate_id(), slug=generate_id(), provider=other_provider)

View File

@@ -13,7 +13,7 @@ from authentik.lib.generators import generate_id
from authentik.providers.oauth2.id_token import IDToken
from authentik.providers.oauth2.models import (
AccessToken,
ClientType,
ClientTypes,
DeviceToken,
OAuth2Provider,
RedirectURI,
@@ -126,7 +126,7 @@ class TesOAuth2Revoke(OAuthTestCase):
def test_revoke_public(self):
"""Test revoke public client"""
self.provider.client_type = ClientType.PUBLIC
self.provider.client_type = ClientTypes.PUBLIC
self.provider.save()
token = AccessToken.objects.create(
provider=self.provider,
@@ -241,7 +241,7 @@ class TesOAuth2Revoke(OAuthTestCase):
authorization_flow=create_test_flow(),
redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "")],
signing_key=create_test_cert(),
client_type=ClientType.PUBLIC,
client_type=ClientTypes.PUBLIC,
)
Application.objects.create(name=generate_id(), slug=generate_id(), provider=other_provider)
@@ -270,14 +270,14 @@ class TesOAuth2Revoke(OAuthTestCase):
def test_revoke_provider_fed_public(self):
"""Test revoke with federation. self.provider is a public
client and other_provider is a public client."""
self.provider.client_type = ClientType.PUBLIC
self.provider.client_type = ClientTypes.PUBLIC
self.provider.save()
other_provider = OAuth2Provider.objects.create(
name=generate_id(),
authorization_flow=create_test_flow(),
redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "")],
signing_key=create_test_cert(),
client_type=ClientType.PUBLIC,
client_type=ClientTypes.PUBLIC,
)
Application.objects.create(name=generate_id(), slug=generate_id(), provider=other_provider)

View File

@@ -25,7 +25,6 @@ from authentik.providers.oauth2.errors import TokenError
from authentik.providers.oauth2.models import (
AccessToken,
AuthorizationCode,
GrantType,
OAuth2Provider,
RedirectURI,
RedirectURIMatchingMode,
@@ -45,39 +44,11 @@ class TestToken(OAuthTestCase):
self.factory = RequestFactory()
self.app = Application.objects.create(name=generate_id(), slug="test")
def test_invalid_grant_type(self):
"""test request param"""
provider = OAuth2Provider.objects.create(
name=generate_id(),
authorization_flow=create_test_flow(),
grant_types=[],
redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://TestServer")],
signing_key=self.keypair,
)
header = b64encode(f"{provider.client_id}:{provider.client_secret}".encode()).decode()
user = create_test_admin_user()
code = AuthorizationCode.objects.create(
code="foobar", provider=provider, user=user, auth_time=timezone.now()
)
request = self.factory.post(
"/",
data={
"grant_type": GRANT_TYPE_AUTHORIZATION_CODE,
"code": code.code,
"redirect_uri": "http://TestServer",
},
HTTP_AUTHORIZATION=f"Basic {header}",
)
with self.assertRaises(TokenError) as cm:
TokenParams.parse(request, provider, provider.client_id, provider.client_secret)
self.assertEqual(cm.exception.cause, "grant_type_not_configured")
def test_request_auth_code(self):
"""test request param"""
provider = OAuth2Provider.objects.create(
name=generate_id(),
authorization_flow=create_test_flow(),
grant_types=[GrantType.AUTHORIZATION_CODE],
redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://TestServer")],
signing_key=self.keypair,
)
@@ -105,7 +76,6 @@ class TestToken(OAuthTestCase):
provider = OAuth2Provider.objects.create(
name=generate_id(),
authorization_flow=create_test_flow(),
grant_types=[GrantType.REFRESH_TOKEN],
redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://testserver")],
signing_key=self.keypair,
)
@@ -127,7 +97,6 @@ class TestToken(OAuthTestCase):
provider = OAuth2Provider.objects.create(
name=generate_id(),
authorization_flow=create_test_flow(),
grant_types=[GrantType.REFRESH_TOKEN],
redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://local.invalid")],
signing_key=self.keypair,
)
@@ -170,7 +139,6 @@ class TestToken(OAuthTestCase):
provider = OAuth2Provider.objects.create(
name=generate_id(),
authorization_flow=create_test_flow(),
grant_types=[GrantType.AUTHORIZATION_CODE],
redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://local.invalid")],
signing_key=self.keypair,
)
@@ -211,7 +179,6 @@ class TestToken(OAuthTestCase):
provider = OAuth2Provider.objects.create(
name=generate_id(),
authorization_flow=create_test_flow(),
grant_types=[GrantType.AUTHORIZATION_CODE],
redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://local.invalid")],
signing_key=self.keypair,
encryption_key=self.keypair,
@@ -243,7 +210,6 @@ class TestToken(OAuthTestCase):
provider = OAuth2Provider.objects.create(
name=generate_id(),
authorization_flow=create_test_flow(),
grant_types=[GrantType.REFRESH_TOKEN],
redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://local.invalid")],
signing_key=self.keypair,
)
@@ -305,7 +271,6 @@ class TestToken(OAuthTestCase):
provider = OAuth2Provider.objects.create(
name=generate_id(),
authorization_flow=create_test_flow(),
grant_types=[GrantType.REFRESH_TOKEN],
redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://local.invalid")],
signing_key=self.keypair,
)
@@ -363,7 +328,6 @@ class TestToken(OAuthTestCase):
provider = OAuth2Provider.objects.create(
name=generate_id(),
authorization_flow=create_test_flow(),
grant_types=[GrantType.REFRESH_TOKEN],
redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://testserver")],
signing_key=self.keypair,
)
@@ -436,7 +400,6 @@ class TestToken(OAuthTestCase):
provider = OAuth2Provider.objects.create(
name=generate_id(),
authorization_flow=create_test_flow(),
grant_types=[GrantType.REFRESH_TOKEN],
redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://local.invalid")],
signing_key=self.keypair,
refresh_token_threshold="hours=1", # nosec
@@ -534,7 +497,6 @@ class TestToken(OAuthTestCase):
provider = OAuth2Provider.objects.create(
name=generate_id(),
authorization_flow=create_test_flow(),
grant_types=[GrantType.AUTHORIZATION_CODE],
redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://local.invalid")],
signing_key=self.keypair,
include_claims_in_id_token=True,

View File

@@ -22,7 +22,6 @@ from authentik.lib.generators import generate_id
from authentik.policies.models import PolicyBinding
from authentik.providers.oauth2.models import (
AccessToken,
GrantType,
OAuth2Provider,
RedirectURI,
RedirectURIMatchingMode,
@@ -56,7 +55,6 @@ class TestTokenClientCredentialsJWTProvider(OAuthTestCase):
authorization_flow=create_test_flow(),
redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://testserver")],
signing_key=self.cert,
grant_types=[GrantType.CLIENT_CREDENTIALS],
)
self.provider.jwt_federation_providers.add(self.other_provider)
self.provider.property_mappings.set(ScopeMapping.objects.all())

View File

@@ -20,7 +20,6 @@ from authentik.core.tests.utils import create_test_cert, create_test_flow
from authentik.lib.generators import generate_id
from authentik.policies.models import PolicyBinding
from authentik.providers.oauth2.models import (
GrantType,
OAuth2Provider,
RedirectURI,
RedirectURIMatchingMode,
@@ -69,7 +68,6 @@ class TestTokenClientCredentialsJWTSource(OAuthTestCase):
authorization_flow=create_test_flow(),
redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://testserver")],
signing_key=self.cert,
grant_types=[GrantType.CLIENT_CREDENTIALS],
)
self.provider.jwt_federation_sources.add(self.source)
self.provider.property_mappings.set(ScopeMapping.objects.all())

View File

@@ -21,7 +21,6 @@ from authentik.policies.models import PolicyBinding
from authentik.providers.oauth2.errors import TokenError
from authentik.providers.oauth2.models import (
AccessToken,
GrantType,
OAuth2Provider,
RedirectURI,
RedirectURIMatchingMode,
@@ -42,7 +41,6 @@ class TestTokenClientCredentialsStandard(OAuthTestCase):
authorization_flow=create_test_flow(),
redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://testserver")],
signing_key=create_test_cert(),
grant_types=[GrantType.CLIENT_CREDENTIALS, GrantType.PASSWORD],
)
self.provider.property_mappings.set(ScopeMapping.objects.all())
self.app = Application.objects.create(name="test", slug="test", provider=self.provider)

View File

@@ -22,7 +22,6 @@ from authentik.core.tests.utils import create_test_admin_user, create_test_cert,
from authentik.policies.models import PolicyBinding
from authentik.providers.oauth2.errors import TokenError
from authentik.providers.oauth2.models import (
GrantType,
OAuth2Provider,
RedirectURI,
RedirectURIMatchingMode,
@@ -43,7 +42,6 @@ class TestTokenClientCredentialsStandardCompat(OAuthTestCase):
authorization_flow=create_test_flow(),
redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://testserver")],
signing_key=create_test_cert(),
grant_types=[GrantType.CLIENT_CREDENTIALS, GrantType.PASSWORD],
)
self.provider.property_mappings.set(ScopeMapping.objects.all())
self.app = Application.objects.create(name="test", slug="test", provider=self.provider)

View File

@@ -25,7 +25,6 @@ from authentik.core.tests.utils import (
from authentik.policies.models import PolicyBinding
from authentik.providers.oauth2.errors import TokenError
from authentik.providers.oauth2.models import (
GrantType,
OAuth2Provider,
RedirectURI,
RedirectURIMatchingMode,
@@ -46,7 +45,6 @@ class TestTokenClientCredentialsUserNamePassword(OAuthTestCase):
authorization_flow=create_test_flow(),
redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://testserver")],
signing_key=create_test_cert(),
grant_types=[GrantType.CLIENT_CREDENTIALS, GrantType.PASSWORD],
)
self.provider.property_mappings.set(ScopeMapping.objects.all())
self.app = Application.objects.create(name="test", slug="test", provider=self.provider)

View File

@@ -17,7 +17,6 @@ from authentik.lib.generators import generate_code_fixed_length, generate_id
from authentik.providers.oauth2.models import (
AccessToken,
DeviceToken,
GrantType,
OAuth2Provider,
RedirectURI,
RedirectURIMatchingMode,
@@ -38,7 +37,6 @@ class TestTokenDeviceCode(OAuthTestCase):
authorization_flow=create_test_flow(),
redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://testserver")],
signing_key=create_test_cert(),
grant_types=[GrantType.DEVICE_CODE],
)
self.provider.property_mappings.set(ScopeMapping.objects.all())
self.app = Application.objects.create(name="test", slug="test", provider=self.provider)

View File

@@ -11,7 +11,6 @@ from authentik.core.tests.utils import create_test_admin_user, create_test_flow
from authentik.lib.generators import generate_id
from authentik.providers.oauth2.models import (
AuthorizationCode,
GrantType,
OAuth2Provider,
RedirectURI,
RedirectURIMatchingMode,
@@ -38,7 +37,6 @@ class TestTokenPKCE(OAuthTestCase):
authorization_flow=flow,
redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "foo://localhost")],
access_code_validity="seconds=100",
grant_types=[GrantType.AUTHORIZATION_CODE],
)
Application.objects.create(name="app", slug="app", provider=provider)
state = generate_id()
@@ -97,7 +95,6 @@ class TestTokenPKCE(OAuthTestCase):
authorization_flow=flow,
redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "foo://localhost")],
access_code_validity="seconds=100",
grant_types=[GrantType.AUTHORIZATION_CODE],
)
Application.objects.create(name="app", slug="app", provider=provider)
state = generate_id()
@@ -154,7 +151,6 @@ class TestTokenPKCE(OAuthTestCase):
authorization_flow=flow,
redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "foo://localhost")],
access_code_validity="seconds=100",
grant_types=[GrantType.AUTHORIZATION_CODE],
)
Application.objects.create(name="app", slug="app", provider=provider)
state = generate_id()
@@ -200,7 +196,6 @@ class TestTokenPKCE(OAuthTestCase):
authorization_flow=flow,
redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "foo://localhost")],
access_code_validity="seconds=100",
grant_types=[GrantType.AUTHORIZATION_CODE],
)
Application.objects.create(name="app", slug="app", provider=provider)
state = generate_id()

View File

@@ -57,7 +57,7 @@ from authentik.providers.oauth2.id_token import IDToken
from authentik.providers.oauth2.models import (
AccessToken,
AuthorizationCode,
GrantType,
GrantTypes,
OAuth2Provider,
RedirectURIMatchingMode,
ResponseMode,
@@ -164,31 +164,28 @@ class OAuthAuthorizationParams:
"""Check grant"""
# Determine which flow to use.
if self.response_type in [ResponseTypes.CODE]:
self.grant_type = GrantType.AUTHORIZATION_CODE
self.grant_type = GrantTypes.AUTHORIZATION_CODE
elif self.response_type in [
ResponseTypes.ID_TOKEN,
ResponseTypes.ID_TOKEN_TOKEN,
]:
self.grant_type = GrantType.IMPLICIT
self.grant_type = GrantTypes.IMPLICIT
elif self.response_type in [
ResponseTypes.CODE_TOKEN,
ResponseTypes.CODE_ID_TOKEN,
ResponseTypes.CODE_ID_TOKEN_TOKEN,
]:
self.grant_type = GrantType.HYBRID
self.grant_type = GrantTypes.HYBRID
# Grant type validation.
if not self.grant_type:
LOGGER.warning("Invalid response type", type=self.response_type)
raise AuthorizeError(self.redirect_uri, "unsupported_response_type", "", self.state)
if self.grant_type not in self.provider.grant_types:
LOGGER.warning("Invalid grant_type for provider", grant_type=self.grant_type)
raise AuthorizeError(self.redirect_uri, "invalid_request", self.grant_type, self.state)
if self.response_mode not in ResponseMode.values:
self.response_mode = ResponseMode.QUERY
if self.grant_type in [GrantType.IMPLICIT, GrantType.HYBRID]:
if self.grant_type in [GrantTypes.IMPLICIT, GrantTypes.HYBRID]:
self.response_mode = ResponseMode.FRAGMENT
def check_redirect_uri(self):
@@ -249,7 +246,7 @@ class OAuthAuthorizationParams:
)
self.scope = self.scope.intersection(default_scope_names)
if SCOPE_OPENID not in self.scope and (
self.grant_type == GrantType.HYBRID
self.grant_type == GrantTypes.HYBRID
or self.response_type in [ResponseTypes.ID_TOKEN, ResponseTypes.ID_TOKEN_TOKEN]
):
LOGGER.warning("Missing 'openid' scope.")
@@ -600,8 +597,8 @@ class OAuthFulfillmentStage(StageView):
code = None
if self.params.grant_type in [
GrantType.AUTHORIZATION_CODE,
GrantType.HYBRID,
GrantTypes.AUTHORIZATION_CODE,
GrantTypes.HYBRID,
]:
code = self.params.create_code(self.request)
code.save()
@@ -616,7 +613,7 @@ class OAuthFulfillmentStage(StageView):
if self.params.response_mode == ResponseMode.FRAGMENT:
query_fragment = {}
if self.params.grant_type in [GrantType.AUTHORIZATION_CODE]:
if self.params.grant_type in [GrantTypes.AUTHORIZATION_CODE]:
query_fragment["code"] = code.code
query_fragment["state"] = [str(self.params.state) if self.params.state else ""]
else:
@@ -630,7 +627,7 @@ class OAuthFulfillmentStage(StageView):
if self.params.response_mode == ResponseMode.FORM_POST:
post_params = {}
if self.params.grant_type in [GrantType.AUTHORIZATION_CODE]:
if self.params.grant_type in [GrantTypes.AUTHORIZATION_CODE]:
post_params["code"] = code.code
post_params["state"] = [str(self.params.state) if self.params.state else ""]
else:
@@ -699,7 +696,7 @@ class OAuthFulfillmentStage(StageView):
token.save()
# Code parameter must be present if it's Hybrid Flow.
if self.params.grant_type == GrantType.HYBRID:
if self.params.grant_type == GrantTypes.HYBRID:
query_fragment["code"] = code.code
query_fragment["token_type"] = TOKEN_TYPE

View File

@@ -15,7 +15,7 @@ from authentik.core.models import Application
from authentik.lib.config import CONFIG
from authentik.lib.utils.time import timedelta_from_string
from authentik.providers.oauth2.errors import DeviceCodeError
from authentik.providers.oauth2.models import DeviceToken, GrantType, OAuth2Provider, ScopeMapping
from authentik.providers.oauth2.models import DeviceToken, OAuth2Provider, ScopeMapping
from authentik.providers.oauth2.utils import TokenResponse, extract_client_auth
from authentik.providers.oauth2.views.device_init import QS_KEY_CODE
@@ -42,8 +42,6 @@ class DeviceView(View):
_ = provider.application
except Application.DoesNotExist:
raise DeviceCodeError("invalid_client") from None
if GrantType.DEVICE_CODE not in provider.grant_types:
raise DeviceCodeError("invalid_client")
self.provider = provider
self.client_id = client_id

View File

@@ -11,7 +11,7 @@ from structlog.stdlib import get_logger
from authentik.providers.oauth2.errors import TokenIntrospectionError
from authentik.providers.oauth2.id_token import IDToken
from authentik.providers.oauth2.models import AccessToken, ClientType, OAuth2Provider, RefreshToken
from authentik.providers.oauth2.models import AccessToken, ClientTypes, OAuth2Provider, RefreshToken
from authentik.providers.oauth2.utils import TokenResponse, authenticate_provider
LOGGER = get_logger()
@@ -45,7 +45,7 @@ class TokenIntrospectionParams:
if not provider:
LOGGER.info("Failed to authenticate introspection request")
raise TokenIntrospectionError
if provider.client_type != ClientType.CONFIDENTIAL:
if provider.client_type != ClientTypes.CONFIDENTIAL:
LOGGER.info("Introspection request from public provider, denying.")
raise TokenIntrospectionError

View File

@@ -58,7 +58,7 @@ from authentik.providers.oauth2.id_token import IDToken
from authentik.providers.oauth2.models import (
AccessToken,
AuthorizationCode,
ClientType,
ClientTypes,
DeviceToken,
OAuth2Provider,
RedirectURIMatchingMode,
@@ -165,10 +165,6 @@ class TokenParams:
raise TokenError("invalid_grant")
def __post_init__(self, raw_code: str, raw_token: str, request: HttpRequest):
if self.grant_type not in self.provider.grant_types:
LOGGER.warning("Invalid grant_type for provider", grant_type=self.grant_type)
raise TokenError("invalid_grant").with_cause("grant_type_not_configured")
# Confidential clients MUST authenticate to the token endpoint per
# RFC 6749 §2.3.1. The device code grant (RFC 8628 §3.4) inherits
# that requirement - the device_code alone is not a substitute for
@@ -178,7 +174,7 @@ class TokenParams:
GRANT_TYPE_REFRESH_TOKEN,
GRANT_TYPE_DEVICE_CODE,
]:
if self.provider.client_type == ClientType.CONFIDENTIAL and not compare_digest(
if self.provider.client_type == ClientTypes.CONFIDENTIAL and not compare_digest(
self.provider.client_secret, self.client_secret
):
LOGGER.warning(
@@ -610,10 +606,10 @@ class TokenView(View):
if not self.provider:
LOGGER.warning("OAuth2Provider does not exist", client_id=client_id)
raise TokenError("invalid_client")
CTX_AUTH_VIA.set("oauth_client_secret")
self.params = self.params_class.parse(
request, self.provider, client_id, client_secret
)
CTX_AUTH_VIA.set("oauth_client_secret")
with start_span(
op="authentik.providers.oauth2.post.response",

View File

@@ -10,7 +10,7 @@ from django.views.decorators.csrf import csrf_exempt
from structlog.stdlib import get_logger
from authentik.providers.oauth2.errors import TokenRevocationError
from authentik.providers.oauth2.models import AccessToken, ClientType, OAuth2Provider, RefreshToken
from authentik.providers.oauth2.models import AccessToken, ClientTypes, OAuth2Provider, RefreshToken
from authentik.providers.oauth2.utils import (
TokenResponse,
authenticate_provider,
@@ -33,13 +33,11 @@ class TokenRevocationParams:
raw_token = request.POST.get("token")
provider, _, _ = provider_from_request(request)
if provider and provider.client_type == ClientType.CONFIDENTIAL:
provider = authenticate_provider(request)
if not provider:
raise TokenRevocationError("invalid_client")
# By default clients can only revoke their own tokens
query = Q(provider=provider, token=raw_token)
if provider.client_type == ClientType.CONFIDENTIAL:
if provider.client_type == ClientTypes.CONFIDENTIAL:
provider = authenticate_provider(request)
if not provider:
raise TokenRevocationError("invalid_client")

View File

@@ -16,8 +16,7 @@ from authentik.crypto.models import CertificateKeyPair
from authentik.lib.models import DomainlessURLValidator, InternallyManagedMixin
from authentik.outposts.models import OutpostModel
from authentik.providers.oauth2.models import (
ClientType,
GrantType,
ClientTypes,
OAuth2Provider,
RedirectURI,
RedirectURIMatchingMode,
@@ -162,12 +161,7 @@ class ProxyProvider(OutpostModel, OAuth2Provider):
def set_oauth_defaults(self):
"""Ensure all OAuth2-related settings are correct"""
self.grant_types = [
GrantType.AUTHORIZATION_CODE,
GrantType.CLIENT_CREDENTIALS,
GrantType.PASSWORD,
]
self.client_type = ClientType.CONFIDENTIAL
self.client_type = ClientTypes.CONFIDENTIAL
self.signing_key = None
self.include_claims_in_id_token = True
scopes = ScopeMapping.objects.filter(

View File

@@ -9,7 +9,7 @@ from authentik.core.models import Application
from authentik.core.tests.utils import create_test_admin_user, create_test_cert, create_test_flow
from authentik.lib.generators import generate_id
from authentik.outposts.models import Outpost, OutpostType
from authentik.providers.oauth2.models import ClientType
from authentik.providers.oauth2.models import ClientTypes
from authentik.providers.proxy.models import ProxyMode, ProxyProvider
@@ -96,7 +96,7 @@ class ProxyProviderTests(APITestCase):
)
self.assertEqual(response.status_code, 201)
provider: ProxyProvider = ProxyProvider.objects.get(name=name)
self.assertEqual(provider.client_type, ClientType.CONFIDENTIAL)
self.assertEqual(provider.client_type, ClientTypes.CONFIDENTIAL)
def test_update_defaults(self):
"""Test create"""
@@ -114,8 +114,8 @@ class ProxyProviderTests(APITestCase):
)
self.assertEqual(response.status_code, 201)
provider: ProxyProvider = ProxyProvider.objects.get(name=name)
self.assertEqual(provider.client_type, ClientType.CONFIDENTIAL)
provider.client_type = ClientType.PUBLIC
self.assertEqual(provider.client_type, ClientTypes.CONFIDENTIAL)
provider.client_type = ClientTypes.PUBLIC
provider.save()
response = self.client.put(
reverse("authentik_api:proxyprovider-detail", kwargs={"pk": provider.pk}),
@@ -130,7 +130,7 @@ class ProxyProviderTests(APITestCase):
)
self.assertEqual(response.status_code, 200)
provider: ProxyProvider = ProxyProvider.objects.get(name=name)
self.assertEqual(provider.client_type, ClientType.CONFIDENTIAL)
self.assertEqual(provider.client_type, ClientTypes.CONFIDENTIAL)
def test_sa_fetch(self):
"""Test fetching the outpost config as the service account"""

View File

@@ -24,11 +24,7 @@ from rest_framework.viewsets import ModelViewSet
from structlog.stdlib import get_logger
from authentik.api.validation import validate
from authentik.common.saml.constants import (
DEFAULT_ISSUER,
SAML_BINDING_POST,
SAML_BINDING_REDIRECT,
)
from authentik.common.saml.constants import SAML_BINDING_POST, SAML_BINDING_REDIRECT
from authentik.core.api.providers import ProviderSerializer
from authentik.core.api.used_by import UsedByMixin
from authentik.core.api.utils import PassiveSerializer, PropertyMappingPreviewSerializer
@@ -59,7 +55,6 @@ class SAMLProviderSerializer(ProviderSerializer):
"""SAMLProvider Serializer"""
url_download_metadata = SerializerMethodField()
url_issuer = SerializerMethodField()
url_sso_post = SerializerMethodField()
url_sso_redirect = SerializerMethodField()
@@ -90,23 +85,6 @@ class SAMLProviderSerializer(ProviderSerializer):
+ "?download"
)
def get_url_issuer(self, instance: SAMLProvider) -> str:
"""Get Issuer/EntityID URL"""
if instance.issuer_override:
return instance.issuer_override
if "request" not in self._context:
return DEFAULT_ISSUER
request: HttpRequest = self._context["request"]._request
try:
return request.build_absolute_uri(
reverse(
"authentik_providers_saml:base",
kwargs={"application_slug": instance.application.slug},
)
)
except Provider.application.RelatedObjectDoesNotExist:
return DEFAULT_ISSUER
def get_url_sso_post(self, instance: SAMLProvider) -> str:
"""Get SSO Post URL"""
if "request" not in self._context:
@@ -220,7 +198,7 @@ class SAMLProviderSerializer(ProviderSerializer):
"acs_url",
"sls_url",
"audience",
"issuer_override",
"issuer",
"assertion_valid_not_before",
"assertion_valid_not_on_or_after",
"session_valid_not_on_or_after",
@@ -242,7 +220,6 @@ class SAMLProviderSerializer(ProviderSerializer):
"default_relay_state",
"default_name_id_policy",
"url_download_metadata",
"url_issuer",
"url_sso_post",
"url_sso_redirect",
"url_sso_init",

View File

@@ -1,34 +0,0 @@
# Generated by Django 5.2.11 on 2026-02-24 06:03
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
("authentik_providers_saml", "0021_samlprovider_sign_logout_response"),
]
operations = [
migrations.RenameField(
model_name="samlprovider",
old_name="issuer",
new_name="issuer_override",
),
migrations.AlterField(
model_name="samlprovider",
name="issuer_override",
field=models.TextField(
blank=True,
default="",
help_text="Also known as EntityID. Providing a value overrides the default issuer generated by authentik.",
),
),
migrations.AddField(
model_name="samlsession",
name="issuer",
field=models.TextField(
default=None, help_text="SAML Issuer used for this session", null=True
),
),
]

View File

@@ -77,14 +77,7 @@ class SAMLProvider(Provider):
"no audience restriction will be added."
),
)
issuer_override = models.TextField(
blank=True,
default="",
help_text=_(
"Also known as EntityID. Providing a value overrides the default issuer "
"generated by authentik."
),
)
issuer = models.TextField(help_text=_("Also known as EntityID"), default="authentik")
sls_url = models.TextField(
blank=True,
validators=[DomainlessURLValidator(schemes=("http", "https"))],
@@ -325,9 +318,6 @@ class SAMLSession(InternallyManagedMixin, SerializerModel, ExpiringModel):
session_index = models.TextField(help_text=_("SAML SessionIndex for this session"))
name_id = models.TextField(help_text=_("SAML NameID value for this session"))
name_id_format = models.TextField(default="", blank=True, help_text=_("SAML NameID format"))
issuer = models.TextField(
default=None, null=True, help_text=_("SAML Issuer used for this session")
)
created = models.DateTimeField(auto_now_add=True)
@property

View File

@@ -6,7 +6,6 @@ from types import GeneratorType
import xmlsec
from django.http import HttpRequest
from django.urls import reverse
from django.utils.timezone import now
from lxml import etree # nosec
from lxml.etree import Element, SubElement, _Element # nosec
@@ -64,7 +63,6 @@ class AssertionProcessor:
session_index: str
name_id: str
name_id_format: str
issuer: str
session_not_on_or_after_datetime: datetime
def __init__(self, provider: SAMLProvider, request: HttpRequest, auth_n_request: AuthNRequest):
@@ -139,24 +137,10 @@ class AssertionProcessor:
continue
return attribute_statement
def _get_issuer_value(self) -> str:
"""Get issuer value, with fallback to generated URL if empty"""
# If user has set an override issuer, use it
if self.provider.issuer_override:
return self.provider.issuer_override
return self.http_request.build_absolute_uri(
reverse(
"authentik_providers_saml:base",
kwargs={"application_slug": self.provider.application.slug},
)
)
def get_issuer(self) -> Element:
"""Get Issuer Element"""
issuer = Element(f"{{{NS_SAML_ASSERTION}}}Issuer", nsmap=NS_MAP)
self.issuer = self._get_issuer_value()
issuer.text = self.issuer
issuer.text = self.provider.issuer
return issuer
def get_assertion_auth_n_statement(self) -> Element:

View File

@@ -8,7 +8,6 @@ from lxml import etree # nosec
from lxml.etree import Element, _Element
from authentik.common.saml.constants import (
DEFAULT_ISSUER,
DIGEST_ALGORITHM_TRANSLATION_MAP,
NS_MAP,
NS_SAML_ASSERTION,
@@ -34,12 +33,11 @@ class LogoutRequestProcessor:
name_id_format: str
session_index: str | None
relay_state: str | None
issuer: str | None
_issue_instant: str
_request_id: str
def __init__( # noqa: PLR0913
def __init__(
self,
provider: SAMLProvider,
user: User | None,
@@ -48,7 +46,6 @@ class LogoutRequestProcessor:
name_id_format: str = SAML_NAME_ID_FORMAT_EMAIL,
session_index: str | None = None,
relay_state: str | None = None,
issuer: str | None = None,
):
self.provider = provider
self.user = user
@@ -57,23 +54,14 @@ class LogoutRequestProcessor:
self.name_id_format = name_id_format
self.session_index = session_index
self.relay_state = relay_state
self.issuer = issuer
self._issue_instant = get_time_string()
self._request_id = get_random_id()
def _get_issuer_value(self) -> str:
"""Get issuer value from session, with fallback to provider"""
if self.issuer:
return self.issuer
if self.provider.issuer_override:
return self.provider.issuer_override
return DEFAULT_ISSUER
def get_issuer(self) -> Element:
"""Get Issuer element"""
issuer = Element(f"{{{NS_SAML_ASSERTION}}}Issuer")
issuer.text = self._get_issuer_value()
issuer.text = self.provider.issuer
return issuer
def get_name_id(self) -> Element:

View File

@@ -8,7 +8,6 @@ from lxml import etree
from lxml.etree import Element, SubElement
from authentik.common.saml.constants import (
DEFAULT_ISSUER,
DIGEST_ALGORITHM_TRANSLATION_MAP,
NS_MAP,
NS_SAML_ASSERTION,
@@ -29,38 +28,27 @@ class LogoutResponseProcessor:
logout_request: LogoutRequest
destination: str | None
relay_state: str | None
issuer: str | None
_issue_instant: str
_response_id: str
def __init__( # noqa: PLR0913
def __init__(
self,
provider: SAMLProvider,
logout_request: LogoutRequest,
destination: str | None = None,
relay_state: str | None = None,
issuer: str | None = None,
):
self.provider = provider
self.logout_request = logout_request
self.destination = destination
self.relay_state = relay_state or (logout_request.relay_state if logout_request else None)
self.issuer = issuer
self._issue_instant = get_time_string()
self._response_id = get_random_id()
def _get_issuer_value(self) -> str:
"""Get issuer value from session, with fallback to provider"""
if self.issuer:
return self.issuer
if self.provider.issuer_override:
return self.provider.issuer_override
return DEFAULT_ISSUER
def get_issuer(self) -> Element:
"""Get Issuer element"""
issuer = Element(f"{{{NS_SAML_ASSERTION}}}Issuer")
issuer.text = self._get_issuer_value()
issuer.text = self.provider.issuer
return issuer
def build(self, status: str = "Success") -> Element:

View File

@@ -40,19 +40,6 @@ class MetadataProcessor:
self.force_binding = None
self.xml_id = "_" + sha256(f"{provider.name}-{provider.pk}".encode("ascii")).hexdigest()
def _get_issuer_value(self) -> str:
"""Get issuer value, with fallback to generated URL if empty"""
# If user has set an override issuer, use it
if self.provider.issuer_override:
return self.provider.issuer_override
return self.http_request.build_absolute_uri(
reverse(
"authentik_providers_saml:base",
kwargs={"application_slug": self.provider.application.slug},
)
)
# Using type unions doesn't work with cython types (which is what lxml is)
def get_signing_key_descriptor(self) -> Element | None:
"""Get Signing KeyDescriptor, if enabled for the provider"""
@@ -202,7 +189,7 @@ class MetadataProcessor:
"""Build full EntityDescriptor"""
entity_descriptor = Element(f"{{{NS_SAML_METADATA}}}EntityDescriptor", nsmap=NS_MAP)
entity_descriptor.attrib["ID"] = self.xml_id
entity_descriptor.attrib["entityID"] = self._get_issuer_value()
entity_descriptor.attrib["entityID"] = self.provider.issuer
if self.provider.signing_kp:
self._prepare_signature(entity_descriptor)

View File

@@ -51,6 +51,7 @@ class ServiceProviderMetadata:
provider = SAMLProvider.objects.create(
name=name, authorization_flow=authorization_flow, invalidation_flow=invalidation_flow
)
provider.issuer = self.entity_id
provider.sp_binding = self.acs_binding
provider.acs_url = self.acs_location
provider.default_name_id_policy = self.name_id_policy

View File

@@ -75,7 +75,6 @@ def handle_saml_iframe_pre_user_logout(
name_id_format=session.name_id_format,
session_index=session.session_index,
relay_state=relay_state,
issuer=session.issuer,
)
if session.provider.sls_binding == SAMLBindings.POST:
@@ -164,7 +163,6 @@ def handle_flow_pre_user_logout(
name_id_format=session.name_id_format,
session_index=session.session_index,
relay_state=relay_state,
issuer=session.issuer,
)
if session.provider.sls_binding == SAMLBindings.POST:
@@ -226,7 +224,6 @@ def user_session_deleted_saml_logout(sender, instance: AuthenticatedSession, **_
name_id=saml_session.name_id,
name_id_format=saml_session.name_id_format,
session_index=saml_session.session_index,
issuer=saml_session.issuer,
)
@@ -260,5 +257,4 @@ def user_deactivated_saml_logout(sender, instance: User, **kwargs):
name_id=saml_session.name_id,
name_id_format=saml_session.name_id_format,
session_index=saml_session.session_index,
issuer=saml_session.issuer,
)

View File

@@ -22,7 +22,6 @@ def send_saml_logout_request(
name_id: str,
name_id_format: str,
session_index: str,
issuer: str,
):
"""Send SAML LogoutRequest to a Service Provider using session data"""
provider = SAMLProvider.objects.filter(pk=provider_pk).first()
@@ -48,7 +47,6 @@ def send_saml_logout_request(
name_id=name_id,
name_id_format=name_id_format,
session_index=session_index,
issuer=issuer,
)
return send_post_logout_request(provider, processor)
@@ -91,7 +89,6 @@ def send_saml_logout_response(
sls_url: str,
logout_request_id: str | None = None,
relay_state: str | None = None,
issuer: str | None = None,
):
"""Send SAML LogoutResponse to a Service Provider using backchannel (server-to-server)"""
provider = SAMLProvider.objects.filter(pk=provider_pk).first()
@@ -122,7 +119,6 @@ def send_saml_logout_response(
logout_request=logout_request,
destination=sls_url,
relay_state=relay_state,
issuer=issuer,
)
encoded_response = processor.encode_post()

View File

@@ -15,7 +15,6 @@ from authentik.common.saml.constants import (
SAML_NAME_ID_FORMAT_EMAIL,
SAML_NAME_ID_FORMAT_UNSPECIFIED,
)
from authentik.core.models import Application
from authentik.core.tests.utils import (
RequestFactory,
create_test_admin_user,
@@ -98,11 +97,6 @@ class TestAuthNRequest(TestCase):
)
self.provider.property_mappings.set(SAMLPropertyMapping.objects.all())
self.provider.save()
Application.objects.create(
name="test-app",
slug="test-app",
provider=self.provider,
)
self.source = SAMLSource.objects.create(
slug="provider",
issuer="authentik",
@@ -532,7 +526,7 @@ class TestAuthNRequest(TestCase):
authorization_flow=create_test_flow(),
acs_url="https://10.120.20.200/saml-sp/SAML2/POST",
audience="https://10.120.20.200/saml-sp/SAML2/POST",
issuer_override="https://10.120.20.200/saml-sp/SAML2/POST",
issuer="https://10.120.20.200/saml-sp/SAML2/POST",
signing_kp=static_keypair,
verification_kp=static_keypair,
)
@@ -553,7 +547,7 @@ class TestAuthNRequest(TestCase):
"saml/acs/2d737f96-55fb-4035-953e-5e24134eb778"
),
audience="https://10.120.20.200/saml-sp/SAML2/POST",
issuer_override="https://10.120.20.200/saml-sp/SAML2/POST",
issuer="https://10.120.20.200/saml-sp/SAML2/POST",
signing_kp=create_test_cert(),
)
parsed_request = AuthNRequestParser(provider).parse(POST_REQUEST)

View File

@@ -47,7 +47,7 @@ class TestNativeLogoutStageView(TestCase):
authorization_flow=self.flow,
acs_url="https://sp1.example.com/acs",
sls_url="https://sp1.example.com/sls",
issuer_override="https://idp.example.com",
issuer="https://idp.example.com",
sp_binding="redirect",
sls_binding="redirect",
logout_method=SAMLLogoutMethods.FRONTCHANNEL_NATIVE,
@@ -58,7 +58,7 @@ class TestNativeLogoutStageView(TestCase):
authorization_flow=self.flow,
acs_url="https://sp2.example.com/acs",
sls_url="https://sp2.example.com/sls",
issuer_override="https://idp.example.com",
issuer="https://idp.example.com",
sp_binding="post",
sls_binding="post",
logout_method=SAMLLogoutMethods.FRONTCHANNEL_NATIVE,
@@ -218,7 +218,7 @@ class TestIframeLogoutStageView(TestCase):
authorization_flow=self.flow,
acs_url="https://sp1.example.com/acs",
sls_url="https://sp1.example.com/sls",
issuer_override="https://idp.example.com",
issuer="https://idp.example.com",
sp_binding="redirect",
sls_binding="redirect",
logout_method="frontchannel_iframe",
@@ -229,7 +229,7 @@ class TestIframeLogoutStageView(TestCase):
authorization_flow=self.flow,
acs_url="https://sp2.example.com/acs",
sls_url="https://sp2.example.com/sls",
issuer_override="https://idp.example.com",
issuer="https://idp.example.com",
sp_binding="post",
sls_binding="post",
logout_method="frontchannel_iframe",
@@ -372,7 +372,7 @@ class TestIdPLogoutIntegration(FlowTestCase):
authorization_flow=self.flow,
acs_url="https://sp.example.com/acs",
sls_url="https://sp.example.com/sls",
issuer_override="https://idp.example.com",
issuer="https://idp.example.com",
sp_binding="redirect",
sls_binding="redirect",
signing_kp=self.keypair,

View File

@@ -28,7 +28,7 @@ class TestLogoutIntegration(TestCase):
authorization_flow=self.flow,
acs_url="https://sp.example.com/acs",
sls_url="https://sp.example.com/sls",
issuer_override="https://idp.example.com",
issuer="https://idp.example.com",
sp_binding="redirect",
sls_binding="redirect",
signature_algorithm=RSA_SHA256,
@@ -57,7 +57,7 @@ class TestLogoutIntegration(TestCase):
parsed = self.parser.parse(encoded)
# Verify all fields match
self.assertEqual(parsed.issuer, self.provider.issuer_override)
self.assertEqual(parsed.issuer, self.provider.issuer)
self.assertEqual(parsed.name_id, "test@example.com")
self.assertEqual(parsed.name_id_format, SAML_NAME_ID_FORMAT_EMAIL)
self.assertEqual(parsed.session_index, "test-session-123")
@@ -72,7 +72,7 @@ class TestLogoutIntegration(TestCase):
parsed = self.parser.parse_detached(encoded)
# Verify all fields match
self.assertEqual(parsed.issuer, self.provider.issuer_override)
self.assertEqual(parsed.issuer, self.provider.issuer)
self.assertEqual(parsed.name_id, "test@example.com")
self.assertEqual(parsed.name_id_format, SAML_NAME_ID_FORMAT_EMAIL)
self.assertEqual(parsed.session_index, "test-session-123")
@@ -106,7 +106,7 @@ class TestLogoutIntegration(TestCase):
parsed = parser.parse(encoded)
# Verify all fields match
self.assertEqual(parsed.issuer, self.provider.issuer_override)
self.assertEqual(parsed.issuer, self.provider.issuer)
self.assertEqual(parsed.name_id, "signed@example.com")
self.assertEqual(parsed.name_id_format, SAML_NAME_ID_FORMAT_EMAIL)
self.assertEqual(parsed.session_index, "signed-session-456")
@@ -125,7 +125,7 @@ class TestLogoutIntegration(TestCase):
parsed = self.parser.parse_detached(saml_request)
# Verify parsing succeeded
self.assertEqual(parsed.issuer, self.provider.issuer_override)
self.assertEqual(parsed.issuer, self.provider.issuer)
self.assertEqual(parsed.name_id, "test@example.com")
self.assertEqual(parsed.name_id_format, SAML_NAME_ID_FORMAT_EMAIL)
@@ -164,7 +164,7 @@ class TestLogoutIntegration(TestCase):
# Parse the SAMLRequest (unsigned XML)
parsed = self.parser.parse_detached(params["SAMLRequest"][0])
self.assertEqual(parsed.issuer, self.provider.issuer_override)
self.assertEqual(parsed.issuer, self.provider.issuer)
def test_form_data_can_be_parsed(self):
"""Test that form data generates parseable POST request"""
@@ -175,7 +175,7 @@ class TestLogoutIntegration(TestCase):
parsed = self.parser.parse(form_data["SAMLRequest"])
# Verify parsing succeeded
self.assertEqual(parsed.issuer, self.provider.issuer_override)
self.assertEqual(parsed.issuer, self.provider.issuer)
self.assertEqual(parsed.name_id, "test@example.com")
self.assertEqual(parsed.name_id_format, SAML_NAME_ID_FORMAT_EMAIL)
self.assertEqual(parsed.session_index, "test-session-123")
@@ -244,4 +244,4 @@ class TestLogoutIntegration(TestCase):
# But same issuer
self.assertEqual(parsed1.issuer, parsed2.issuer)
self.assertEqual(parsed1.issuer, self.provider.issuer_override)
self.assertEqual(parsed1.issuer, self.provider.issuer)

View File

@@ -35,7 +35,7 @@ class TestLogoutRequestProcessor(TestCase):
authorization_flow=self.flow,
acs_url="https://sp.example.com/acs",
sls_url="https://sp.example.com/sls",
issuer_override="https://idp.example.com",
issuer="https://idp.example.com",
sp_binding="redirect",
sls_binding="redirect",
signature_algorithm=RSA_SHA256,

View File

@@ -1,7 +1,7 @@
"""logout response tests"""
from defusedxml import ElementTree
from django.test import RequestFactory, TestCase
from django.test import TestCase
from authentik.blueprints.tests import apply_blueprint
from authentik.common.saml.constants import (
@@ -9,13 +9,10 @@ from authentik.common.saml.constants import (
NS_SAML_PROTOCOL,
NS_SIGNATURE,
)
from authentik.core.models import Application
from authentik.core.tests.utils import create_test_cert, create_test_flow
from authentik.lib.generators import generate_id
from authentik.providers.saml.models import SAMLPropertyMapping, SAMLProvider
from authentik.providers.saml.processors.logout_request_parser import LogoutRequest
from authentik.providers.saml.processors.logout_response_processor import LogoutResponseProcessor
from authentik.providers.saml.processors.metadata import MetadataProcessor
class TestLogoutResponse(TestCase):
@@ -24,7 +21,6 @@ class TestLogoutResponse(TestCase):
@apply_blueprint("system/providers-saml.yaml")
def setUp(self):
cert = create_test_cert()
self.factory = RequestFactory()
self.provider: SAMLProvider = SAMLProvider.objects.create(
authorization_flow=create_test_flow(),
acs_url="http://testserver/source/saml/provider/acs/",
@@ -34,31 +30,17 @@ class TestLogoutResponse(TestCase):
)
self.provider.property_mappings.set(SAMLPropertyMapping.objects.all())
self.provider.save()
self.application = Application.objects.create(
name=generate_id(),
slug=generate_id(),
provider=self.provider,
)
def test_build_response(self):
"""Test building a LogoutResponse uses the generated issuer from the assertion"""
# Generate the issuer the same way the assertion/metadata processors would
request = self.factory.get("/")
metadata_processor = MetadataProcessor(self.provider, request)
generated_issuer = metadata_processor._get_issuer_value()
"""Test building a LogoutResponse"""
logout_request = LogoutRequest(
id="test-request-id",
issuer="test-sp",
relay_state="test-relay-state",
)
# Pass the generated issuer as if it came from SAMLSession.issuer
processor = LogoutResponseProcessor(
self.provider,
logout_request,
destination=self.provider.sls_url,
issuer=generated_issuer,
self.provider, logout_request, destination=self.provider.sls_url
)
response_xml = processor.build_response(status="Success")
@@ -69,9 +51,9 @@ class TestLogoutResponse(TestCase):
self.assertEqual(root.attrib["Destination"], self.provider.sls_url)
self.assertEqual(root.attrib["InResponseTo"], "test-request-id")
# Check Issuer matches the generated issuer from the assertion processor
# Check Issuer
issuer = root.find(f"{{{NS_SAML_ASSERTION}}}Issuer")
self.assertEqual(issuer.text, generated_issuer)
self.assertEqual(issuer.text, self.provider.issuer)
# Check Status
status = root.find(f".//{{{NS_SAML_PROTOCOL}}}StatusCode")

View File

@@ -85,6 +85,7 @@ class TestServiceProviderMetadataParser(TestCase):
metadata = ServiceProviderMetadataParser().parse(load_fixture("fixtures/simple.xml"))
provider = metadata.to_provider("test", self.flow, self.flow)
self.assertEqual(provider.acs_url, "http://localhost:8080/saml/acs")
self.assertEqual(provider.issuer, "http://localhost:8080/saml/metadata")
self.assertEqual(provider.sp_binding, SAMLBindings.POST)
self.assertEqual(provider.default_name_id_policy, SAMLNameIDPolicy.EMAIL)
self.assertEqual(
@@ -98,6 +99,7 @@ class TestServiceProviderMetadataParser(TestCase):
metadata = ServiceProviderMetadataParser().parse(load_fixture("fixtures/cert.xml"))
provider = metadata.to_provider("test", self.flow, self.flow)
self.assertEqual(provider.acs_url, "http://localhost:8080/apps/user_saml/saml/acs")
self.assertEqual(provider.issuer, "http://localhost:8080/apps/user_saml/saml/metadata")
self.assertEqual(provider.sp_binding, SAMLBindings.POST)
self.assertEqual(
provider.verification_kp.certificate_data, load_fixture("fixtures/cert.pem")

View File

@@ -32,7 +32,7 @@ class TestSAMLSessionModel(TestCase):
name="test-provider",
authorization_flow=self.flow,
acs_url="https://sp.example.com/acs",
issuer_override="https://idp.example.com",
issuer="https://idp.example.com",
)
# Create another provider for testing
@@ -40,7 +40,7 @@ class TestSAMLSessionModel(TestCase):
name="test-provider-2",
authorization_flow=self.flow,
acs_url="https://sp2.example.com/acs",
issuer_override="https://idp2.example.com",
issuer="https://idp2.example.com",
)
# Create a session first (using authentik's custom Session model)
@@ -72,7 +72,6 @@ class TestSAMLSessionModel(TestCase):
name_id_format=self.name_id_format,
expires=self.expires,
expiring=True,
issuer="authentik",
)
# Verify the session was created
@@ -101,7 +100,6 @@ class TestSAMLSessionModel(TestCase):
name_id_format=self.name_id_format,
expires=self.expires,
expiring=True,
issuer="authentik",
)
# Try to create another session with same session_index and provider
@@ -115,7 +113,6 @@ class TestSAMLSessionModel(TestCase):
name_id_format=self.name_id_format,
expires=self.expires,
expiring=True,
issuer="authentik",
)
def test_cascade_deletion_user(self):
@@ -130,7 +127,6 @@ class TestSAMLSessionModel(TestCase):
name_id_format=self.name_id_format,
expires=self.expires,
expiring=True,
issuer="authentik",
)
# Verify session exists
@@ -154,7 +150,6 @@ class TestSAMLSessionModel(TestCase):
name_id_format=self.name_id_format,
expires=self.expires,
expiring=True,
issuer="authentik",
)
# Verify session exists
@@ -178,7 +173,6 @@ class TestSAMLSessionModel(TestCase):
name_id_format=self.name_id_format,
expires=self.expires,
expiring=True,
issuer="authentik",
)
# Verify session exists
@@ -202,7 +196,6 @@ class TestSAMLSessionModel(TestCase):
name_id_format=self.name_id_format,
expires=self.expires,
expiring=True,
issuer="authentik",
)
# Create second session with different provider
@@ -215,7 +208,6 @@ class TestSAMLSessionModel(TestCase):
name_id_format=self.name_id_format,
expires=self.expires,
expiring=True,
issuer="authentik",
)
# Verify both sessions exist
@@ -237,7 +229,6 @@ class TestSAMLSessionModel(TestCase):
name_id_format=self.name_id_format,
expires=future_time,
expiring=True,
issuer="authentik",
)
# Verify expiry time
@@ -257,7 +248,6 @@ class TestSAMLSessionModel(TestCase):
name_id_format=self.name_id_format,
expires=past_time,
expiring=True,
issuer="authentik",
)
# Check if marked as expired
@@ -275,7 +265,6 @@ class TestSAMLSessionModel(TestCase):
name_id_format="", # Blank format
expires=self.expires,
expiring=True,
issuer="authentik",
)
# Verify it was created successfully
@@ -294,7 +283,6 @@ class TestSAMLSessionModel(TestCase):
name_id_format=self.name_id_format,
expires=self.expires,
expiring=True,
issuer="authentik",
)
session2 = SAMLSession.objects.create(
@@ -306,7 +294,6 @@ class TestSAMLSessionModel(TestCase):
name_id_format=self.name_id_format,
expires=self.expires,
expiring=True,
issuer="authentik",
)
# Query by provider
@@ -329,7 +316,6 @@ class TestSAMLSessionModel(TestCase):
name_id_format=self.name_id_format,
expires=self.expires,
expiring=True,
issuer="authentik",
)
# Check serializer property
@@ -348,7 +334,6 @@ class TestSAMLSessionModel(TestCase):
name_id_format=self.name_id_format,
expires=self.expires,
expiring=True,
issuer="authentik",
)
# Verify sessions exist

View File

@@ -7,7 +7,6 @@ from guardian.shortcuts import get_anonymous_user
from lxml import etree # nosec
from authentik.blueprints.tests import apply_blueprint
from authentik.core.models import Application
from authentik.core.tests.utils import RequestFactory, create_test_cert, create_test_flow
from authentik.lib.xml import lxml_from_string
from authentik.providers.saml.models import SAMLPropertyMapping, SAMLProvider
@@ -31,11 +30,6 @@ class TestSchema(TestCase):
)
self.provider.property_mappings.set(SAMLPropertyMapping.objects.all())
self.provider.save()
Application.objects.create(
name="test-app",
slug="test-app",
provider=self.provider,
)
self.source = SAMLSource.objects.create(
slug="provider",
issuer="authentik",

View File

@@ -28,7 +28,7 @@ class TestSendSamlLogoutResponse(TestCase):
authorization_flow=self.flow,
acs_url="https://sp.example.com/acs",
sls_url="https://sp.example.com/sls",
issuer_override="https://idp.example.com",
issuer="https://idp.example.com",
signing_kp=self.cert,
)
@@ -137,7 +137,7 @@ class TestSendSamlLogoutRequest(TestCase):
authorization_flow=self.flow,
acs_url="https://sp.example.com/acs",
sls_url="https://sp.example.com/sls",
issuer_override="https://idp.example.com",
issuer="https://idp.example.com",
signing_kp=self.cert,
)
@@ -155,7 +155,6 @@ class TestSendSamlLogoutRequest(TestCase):
name_id="test@example.com",
name_id_format=SAML_NAME_ID_FORMAT_EMAIL,
session_index="test-session-123",
issuer="https://idp.example.com",
)
self.assertTrue(result)
@@ -180,7 +179,6 @@ class TestSendSamlLogoutRequest(TestCase):
name_id="test@example.com",
name_id_format=SAML_NAME_ID_FORMAT_EMAIL,
session_index="test-session-123",
issuer="https://idp.example.com",
)
self.assertFalse(result)
@@ -200,7 +198,6 @@ class TestSendSamlLogoutRequest(TestCase):
name_id="test@example.com",
name_id_format=SAML_NAME_ID_FORMAT_EMAIL,
session_index="test-session-123",
issuer="https://idp.example.com",
)
@@ -217,7 +214,7 @@ class TestSendPostLogoutRequest(TestCase):
authorization_flow=self.flow,
acs_url="https://sp.example.com/acs",
sls_url="https://sp.example.com/sls",
issuer_override="https://idp.example.com",
issuer="https://idp.example.com",
signing_kp=self.cert,
)

View File

@@ -40,7 +40,7 @@ class TestSPInitiatedSLOViews(TestCase):
invalidation_flow=self.invalidation_flow,
acs_url="https://sp.example.com/acs",
sls_url="https://sp.example.com/sls",
issuer_override="https://idp.example.com",
issuer="https://idp.example.com",
sp_binding="redirect",
sls_binding="redirect",
)
@@ -90,7 +90,7 @@ class TestSPInitiatedSLOViews(TestCase):
# Verify logout request was stored in plan context
self.assertIn("authentik/providers/saml/logout_request", view.plan_context)
logout_request = view.plan_context["authentik/providers/saml/logout_request"]
self.assertEqual(logout_request.issuer, self.provider.issuer_override)
self.assertEqual(logout_request.issuer, self.provider.issuer)
self.assertEqual(logout_request.session_index, "test-session-123")
def test_redirect_view_handles_logout_response_with_plan_context(self):
@@ -228,7 +228,7 @@ class TestSPInitiatedSLOViews(TestCase):
# Verify logout request was stored in plan context
self.assertIn("authentik/providers/saml/logout_request", view.plan_context)
logout_request = view.plan_context["authentik/providers/saml/logout_request"]
self.assertEqual(logout_request.issuer, self.provider.issuer_override)
self.assertEqual(logout_request.issuer, self.provider.issuer)
self.assertEqual(logout_request.session_index, "test-session-123")
def test_post_view_handles_logout_response_with_plan_context(self):
@@ -396,7 +396,7 @@ class TestSPInitiatedSLOViews(TestCase):
authorization_flow=self.flow,
acs_url="https://sp2.example.com/acs",
sls_url="https://sp2.example.com/sls",
issuer_override="https://idp2.example.com",
issuer="https://idp2.example.com",
invalidation_flow=None, # No invalidation flow
)
@@ -524,7 +524,7 @@ class TestSPInitiatedSLOLogoutMethods(TestCase):
invalidation_flow=self.invalidation_flow,
acs_url="https://sp.example.com/acs",
sls_url="https://sp.example.com/sls",
issuer_override="https://idp.example.com",
issuer="https://idp.example.com",
sp_binding="redirect",
sls_binding="redirect",
signing_kp=self.cert,
@@ -714,7 +714,7 @@ class TestSPInitiatedSLOLogoutMethods(TestCase):
invalidation_flow=self.invalidation_flow,
acs_url="https://sp.example.com/acs",
sls_url="", # No SLS URL
issuer_override="https://idp.example.com",
issuer="https://idp.example.com",
)
app_no_sls = Application.objects.create(

View File

@@ -11,12 +11,6 @@ from authentik.providers.saml.views.sp_slo import (
)
urlpatterns = [
# Base path for Issuer/Entity ID
path(
"<slug:application_slug>/",
sso.SAMLSSOBindingRedirectView.as_view(),
name="base",
),
# SSO Bindings
path(
"<slug:application_slug>/sso/binding/redirect/",

View File

@@ -81,7 +81,6 @@ class SAMLFlowFinalView(ChallengeStageView):
"session": auth_session,
"name_id": processor.name_id,
"name_id_format": processor.name_id_format,
"issuer": processor.issuer,
"expires": processor.session_not_on_or_after_datetime,
"expiring": True,
},

View File

@@ -107,25 +107,12 @@ class SPInitiatedSLOView(PolicyAccessView):
# Store relay state for the logout response
plan.context[PLAN_CONTEXT_SAML_RELAY_STATE] = relay_state
# Look up the session issuer to use in the logout response
auth_session = AuthenticatedSession.from_request(request, request.user)
session_issuer = None
if auth_session:
saml_session = SAMLSession.objects.filter(
session=auth_session,
user=request.user,
provider=self.provider,
).first()
if saml_session:
session_issuer = saml_session.issuer
if self.provider.logout_method == SAMLLogoutMethods.FRONTCHANNEL_NATIVE:
# Native mode - user will be redirected/posted away from authentik
processor = LogoutResponseProcessor(
self.provider,
logout_request,
destination=self.provider.sls_url,
issuer=session_issuer,
)
if self.provider.sls_binding == SAMLBindings.POST:
@@ -165,7 +152,6 @@ class SPInitiatedSLOView(PolicyAccessView):
sls_url=self.provider.sls_url,
logout_request_id=logout_request.id if logout_request else None,
relay_state=relay_state,
issuer=session_issuer,
)
LOGGER.debug(
@@ -182,7 +168,6 @@ class SPInitiatedSLOView(PolicyAccessView):
self.provider,
logout_request,
destination=self.provider.sls_url,
issuer=session_issuer,
)
logout_response = processor.build_response()

View File

@@ -97,9 +97,6 @@ class SCIMClient[TModel: "Model", TConnection: "Model", TSchema: "BaseModel"](
if cached_config is not None:
return cached_config
if self.provider.compatibility_mode == SCIMCompatibilityMode.VCENTER:
return default_config
# Attempt to fetch from remote
path = "/ServiceProviderConfig"
if self.provider.compatibility_mode == SCIMCompatibilityMode.SALESFORCE:

View File

@@ -94,7 +94,6 @@ class Migration(migrations.Migration):
("slack", "Slack"),
("sfdc", "Salesforce"),
("webex", "Webex"),
("vcenter", "vCenter"),
],
default="default",
help_text="Alter authentik behavior for vendor-specific SCIM implementations.",

View File

@@ -83,7 +83,6 @@ class SCIMCompatibilityMode(models.TextChoices):
SLACK = "slack", _("Slack")
SALESFORCE = "sfdc", _("Salesforce")
WEBEX = "webex", _("Webex")
VCENTER = "vcenter", _("vCenter")
class SCIMProvider(OutgoingSyncProvider, BackchannelProvider):

View File

@@ -94,7 +94,7 @@ class OAuthCallback(OAuthClientMixin, View):
def get_user_id(self, info: dict[str, Any]) -> str | None:
"""Return unique identifier from the profile info."""
if "id" in info:
return str(info["id"])
return info["id"]
return None
def handle_login_failure(self, reason: str) -> HttpResponse:

View File

@@ -353,7 +353,7 @@ class IdentificationStageView(ChallengeStageView):
PLAN_CONTEXT_APPLICATION, Application()
)
challenge.initial_data["application_pre"] = app.name
if not app.meta_hide and (launch_url := app.get_launch_url()):
if launch_url := app.get_launch_url():
challenge.initial_data["application_pre_launch"] = launch_url
if (
PLAN_CONTEXT_DEVICE in self.executor.plan.context

View File

@@ -5215,11 +5215,6 @@
"type": "string",
"title": "Group"
},
"meta_hide": {
"type": "boolean",
"title": "Meta hide",
"description": "Hide this application from the user's My applications page."
},
"icon": {
"type": "string",
"minLength": 1,
@@ -5537,14 +5532,6 @@
"minLength": 1,
"title": "Password"
},
"password_hash": {
"type": [
"string",
"null"
],
"minLength": 1,
"title": "Password hash"
},
"permissions": {
"type": "array",
"items": {
@@ -9975,23 +9962,6 @@
"title": "Client Type",
"description": "Confidential clients are capable of maintaining the confidentiality of their credentials. Public clients are incapable"
},
"grant_types": {
"type": "array",
"items": {
"type": "string",
"enum": [
"authorization_code",
"implicit",
"hybrid",
"refresh_token",
"client_credentials",
"password",
"urn:ietf:params:oauth:grant-type:device_code"
],
"title": "Grant types"
},
"title": "Grant types"
},
"client_id": {
"type": "string",
"maxLength": 255,
@@ -10825,10 +10795,11 @@
"title": "Audience",
"description": "Value of the audience restriction field of the assertion. When left empty, no audience restriction will be added."
},
"issuer_override": {
"issuer": {
"type": "string",
"title": "Issuer override",
"description": "Also known as EntityID. Providing a value overrides the default issuer generated by authentik."
"minLength": 1,
"title": "Issuer",
"description": "Also known as EntityID"
},
"assertion_valid_not_before": {
"type": "string",
@@ -11111,8 +11082,7 @@
"aws",
"slack",
"sfdc",
"webex",
"vcenter"
"webex"
],
"title": "SCIM Compatibility Mode",
"description": "Alter authentik behavior for vendor-specific SCIM implementations."

View File

@@ -11,7 +11,6 @@ context:
group_name: authentik Admins
email: !Env [AUTHENTIK_BOOTSTRAP_EMAIL, "root@example.com"]
password: !Env [AUTHENTIK_BOOTSTRAP_PASSWORD, null]
password_hash: !Env [AUTHENTIK_BOOTSTRAP_PASSWORD_HASH, null]
token: !Env [AUTHENTIK_BOOTSTRAP_TOKEN, null]
entries:
- model: authentik_core.group
@@ -32,7 +31,6 @@ entries:
groups:
- !KeyOf admin-group
password: !Context password
password_hash: !Context password_hash
- model: authentik_core.token
state: created
conditions:

View File

@@ -75,10 +75,6 @@ entries:
url: https://localhost:8443/test/a/authentik/callback
- matching_mode: strict
url: https://host.docker.internal:8443/test/a/authentik/callback
grant_types:
- authorization_code
- implicit
- refresh_token
property_mappings:
- !Find [authentik_providers_oauth2.scopemapping, [managed, goauthentik.io/providers/oauth2/scope-openid]]
- !Find [authentik_providers_oauth2.scopemapping, [managed, goauthentik.io/providers/oauth2/scope-email]]
@@ -110,10 +106,6 @@ entries:
url: https://localhost:8443/test/a/authentik/callback
- matching_mode: strict
url: https://host.docker.internal:8443/test/a/authentik/callback
grant_types:
- authorization_code
- implicit
- refresh_token
property_mappings:
- !Find [authentik_providers_oauth2.scopemapping, [managed, goauthentik.io/providers/oauth2/scope-openid]]
- !Find [authentik_providers_oauth2.scopemapping, [managed, goauthentik.io/providers/oauth2/scope-email]]

View File

@@ -1,6 +0,0 @@
fn main() {
#[cfg(feature = "core")]
{
pyo3_build_config::add_libpython_rpath_link_args();
}
}

2
go.mod
View File

@@ -7,7 +7,7 @@ require (
beryju.io/radius-eap v0.1.0
github.com/avast/retry-go/v4 v4.7.0
github.com/coreos/go-oidc/v3 v3.18.0
github.com/getsentry/sentry-go v0.46.0
github.com/getsentry/sentry-go v0.45.1
github.com/go-http-utils/etag v0.0.0-20161124023236-513ea8f21eb1
github.com/go-ldap/ldap/v3 v3.4.13
github.com/go-openapi/runtime v0.29.4

4
go.sum
View File

@@ -20,8 +20,8 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/felixge/httpsnoop v1.0.3 h1:s/nj+GCswXYzN5v2DpNMuMQYe+0DDwt5WVCU6CWBdXk=
github.com/felixge/httpsnoop v1.0.3/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U=
github.com/getsentry/sentry-go v0.46.0 h1:mbdDaarbUdOt9X+dx6kDdntkShLEX3/+KyOsVDTPDj0=
github.com/getsentry/sentry-go v0.46.0/go.mod h1:evVbw2qotNUdYG8KxXbAdjOQWWvWIwKxpjdZZIvcIPw=
github.com/getsentry/sentry-go v0.45.1 h1:9rfzJtGiJG+MGIaWZXidDGHcH5GU1Z5y0WVJGf9nysw=
github.com/getsentry/sentry-go v0.45.1/go.mod h1:XDotiNZbgf5U8bPDUAfvcFmOnMQQceESxyKaObSssW0=
github.com/go-asn1-ber/asn1-ber v1.5.8-0.20250403174932-29230038a667 h1:BP4M0CvQ4S3TGls2FvczZtj5Re/2ZzkV9VwqPHH/3Bo=
github.com/go-asn1-ber/asn1-ber v1.5.8-0.20250403174932-29230038a667/go.mod h1:hEBeB/ic+5LoWskz+yKT7vGhhPYkProFKoKdwZRWMe0=
github.com/go-errors/errors v1.4.2 h1:J6MZopCL4uSllY1OfXM374weqZFFItUbrImctkmUxIA=

View File

@@ -29,7 +29,7 @@ RUN npm run build && \
npm run build:sfe
# Stage: Build go proxy
FROM --platform=${BUILDPLATFORM} docker.io/library/golang:1.26.2-trixie@sha256:4a7137ea573f79c86ae451ff05817ed762ef5597fcf732259e97abeb3108d873 AS go-builder
FROM --platform=${BUILDPLATFORM} docker.io/library/golang:1.26.2-trixie@sha256:982ae929f9a74083a242c6e25d19d7d9ed78c6e97fab639a119e90707ba819e2 AS go-builder
ARG TARGETOS
ARG TARGETARCH
@@ -147,11 +147,8 @@ RUN --mount=type=bind,target=rust-toolchain.toml,src=rust-toolchain.toml \
--mount=type=bind,target=Cargo.toml,src=Cargo.toml \
--mount=type=bind,target=Cargo.lock,src=Cargo.lock \
--mount=type=bind,target=.cargo/,src=.cargo/ \
--mount=type=bind,target=build.rs,src=build.rs \
--mount=type=bind,target=src/,src=src/ \
--mount=type=bind,target=packages/ak-axum,src=packages/ak-axum \
--mount=type=bind,target=packages/ak-common,src=packages/ak-common \
--mount=type=bind,target=packages/client-rust,src=packages/client-rust \
--mount=type=bind,target=packages/,src=packages/ \
--mount=type=bind,target=authentik/lib/default.yml,src=authentik/lib/default.yml \
# Required otherwise workspace discovery fails
--mount=type=bind,target=website/scripts/docsmg/,src=website/scripts/docsmg/ \
@@ -194,10 +191,7 @@ COPY --from=rust-toolchain /root/.cargo /root/.cargo
ENV UV_NO_BINARY_PACKAGE="cryptography lxml python-kadmin-rs xmlsec"
RUN --mount=type=bind,target=pyproject.toml,src=pyproject.toml \
--mount=type=bind,target=uv.lock,src=uv.lock \
--mount=type=bind,target=packages/ak-guardian,src=packages/ak-guardian \
--mount=type=bind,target=packages/django-channels-postgres,src=packages/django-channels-postgres \
--mount=type=bind,target=packages/django-dramatiq-postgres,src=packages/django-dramatiq-postgres \
--mount=type=bind,target=packages/django-postgres-cache,src=packages/django-postgres-cache \
--mount=type=bind,target=packages,src=packages \
--mount=type=bind,target=rust-toolchain.toml,src=rust-toolchain.toml \
--mount=type=cache,id=uv-python-deps-$TARGETARCH$TARGETVARIANT,target=/root/.cache/uv \
uv sync --frozen --no-install-project --no-dev

View File

@@ -1,7 +1,7 @@
# syntax=docker/dockerfile:1
# Stage 1: Build
FROM --platform=${BUILDPLATFORM} docker.io/library/golang:1.26.2-trixie@sha256:4a7137ea573f79c86ae451ff05817ed762ef5597fcf732259e97abeb3108d873 AS builder
FROM --platform=${BUILDPLATFORM} docker.io/library/golang:1.26.2-trixie@sha256:982ae929f9a74083a242c6e25d19d7d9ed78c6e97fab639a119e90707ba819e2 AS builder
ARG TARGETOS
ARG TARGETARCH

View File

@@ -21,7 +21,7 @@ COPY web .
RUN npm run build-proxy
# Stage 2: Build
FROM --platform=${BUILDPLATFORM} docker.io/library/golang:1.26.2-trixie@sha256:4a7137ea573f79c86ae451ff05817ed762ef5597fcf732259e97abeb3108d873 AS builder
FROM --platform=${BUILDPLATFORM} docker.io/library/golang:1.26.2-trixie@sha256:982ae929f9a74083a242c6e25d19d7d9ed78c6e97fab639a119e90707ba819e2 AS builder
ARG TARGETOS
ARG TARGETARCH

View File

@@ -1,7 +1,7 @@
# syntax=docker/dockerfile:1
# Stage 1: Build
FROM --platform=${BUILDPLATFORM} docker.io/library/golang:1.26.2-trixie@sha256:4a7137ea573f79c86ae451ff05817ed762ef5597fcf732259e97abeb3108d873 AS builder
FROM --platform=${BUILDPLATFORM} docker.io/library/golang:1.26.2-trixie@sha256:982ae929f9a74083a242c6e25d19d7d9ed78c6e97fab639a119e90707ba819e2 AS builder
ARG TARGETOS
ARG TARGETARCH

View File

@@ -1,7 +1,7 @@
# syntax=docker/dockerfile:1
# Stage 1: Build
FROM --platform=${BUILDPLATFORM} docker.io/library/golang:1.26.2-trixie@sha256:4a7137ea573f79c86ae451ff05817ed762ef5597fcf732259e97abeb3108d873 AS builder
FROM --platform=${BUILDPLATFORM} docker.io/library/golang:1.26.2-trixie@sha256:982ae929f9a74083a242c6e25d19d7d9ed78c6e97fab639a119e90707ba819e2 AS builder
ARG TARGETOS
ARG TARGETARCH

View File

@@ -1,6 +1,5 @@
#!/usr/bin/env python3
import faulthandler
import os
import random
import signal
@@ -77,12 +76,6 @@ def main(worker_id: int, socket_path: str):
signal.signal(signal.SIGINT, immediate_shutdown)
signal.signal(signal.SIGQUIT, immediate_shutdown)
signal.signal(signal.SIGTERM, graceful_shutdown)
# SIGUSR1 dumps every thread's traceback to stderr. Without this, the default
# action is "terminate", which kills the worker (and trips the Rust supervisor).
# Side-benefit: signal delivery wakes the eval loop, so `pdb -p` can attach to
# an otherwise-idle worker parked in a C-level syscall.
faulthandler.enable()
faulthandler.register(signal.SIGUSR1)
random.seed()
@@ -104,11 +97,7 @@ def main(worker_id: int, socket_path: str):
# Notify rust process that we are ready
os.kill(os.getppid(), signal.SIGUSR2)
# Poll instead of waiting indefinitely so the main thread's eval loop ticks
# periodically — PEP 768's debugger pending hook is serviced on the main
# thread, and a permanent Event.wait() never returns to bytecode execution.
while not shutdown.wait(timeout=1.0):
pass
shutdown.wait()
logger.info("Shutting down worker...")

Binary file not shown.

View File

@@ -8,7 +8,7 @@ msgid ""
msgstr ""
"Project-Id-Version: PACKAGE VERSION\n"
"Report-Msgid-Bugs-To: \n"
"POT-Creation-Date: 2026-04-29 00:28+0000\n"
"POT-Creation-Date: 2026-04-23 00:25+0000\n"
"PO-Revision-Date: YEAR-MO-DA HO:MI+ZONE\n"
"Last-Translator: FULL NAME <EMAIL@ADDRESS>\n"
"Language-Team: LANGUAGE <LL@li.org>\n"
@@ -392,10 +392,6 @@ msgstr ""
msgid "Open launch URL in a new browser tab or window."
msgstr ""
#: authentik/core/models.py
msgid "Hide this application from the user's My applications page."
msgstr ""
#: authentik/core/models.py
msgid "Application"
msgstr ""
@@ -2495,9 +2491,7 @@ msgid ""
msgstr ""
#: authentik/providers/saml/models.py
msgid ""
"Also known as EntityID. Providing a value overrides the default issuer "
"generated by authentik."
msgid "Also known as EntityID"
msgstr ""
#: authentik/providers/saml/models.py
@@ -2691,10 +2685,6 @@ msgstr ""
msgid "SAML NameID format"
msgstr ""
#: authentik/providers/saml/models.py
msgid "SAML Issuer used for this session"
msgstr ""
#: authentik/providers/saml/models.py
msgid "SAML Session"
msgstr ""
@@ -2727,10 +2717,6 @@ msgstr ""
msgid "Webex"
msgstr ""
#: authentik/providers/scim/models.py
msgid "vCenter"
msgstr ""
#: authentik/providers/scim/models.py
msgid "Group filters used to define sync-scope for groups."
msgstr ""

Binary file not shown.

File diff suppressed because it is too large Load Diff

View File

@@ -18,7 +18,6 @@ durstr.workspace = true
eyre.workspace = true
forwarded-header-value.workspace = true
futures.workspace = true
pin-project-lite.workspace = true
tokio-rustls.workspace = true
tokio.workspace = true
tower-http.workspace = true

View File

@@ -1,737 +0,0 @@
//! axum-server acceptor that catches panics and shuts down the application.
use std::{
any::Any,
io::{self, IoSlice},
panic::{AssertUnwindSafe, catch_unwind, resume_unwind},
task::{Context, Poll},
};
use ak_common::Arbiter;
use axum_server::accept::Accept;
use futures::{FutureExt as _, future::BoxFuture};
use pin_project_lite::pin_project;
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
use tower::Service;
use tracing::{error, instrument};
fn extract_panic_msg<'a>(panic: &'a Box<dyn Any + Send + 'static>) -> &'a str {
panic
.downcast_ref::<&str>()
.copied()
.or_else(|| panic.downcast_ref::<String>().map(String::as_str))
.unwrap_or("unknown panic message")
}
/// Acceptor catching panics from the underlying acceptor.
///
/// Also wraps the stream and service to catch panics.
#[derive(Clone)]
pub(crate) struct CatchPanicAcceptor<A> {
inner: A,
arbiter: Arbiter,
}
impl<A> CatchPanicAcceptor<A> {
pub(crate) fn new(inner: A, arbiter: Arbiter) -> Self {
Self { inner, arbiter }
}
}
impl<A, I, S> Accept<I, S> for CatchPanicAcceptor<A>
where
A: Accept<I, S> + Clone + Send + 'static,
A::Stream: AsyncRead + AsyncWrite + Send,
A::Service: Send,
A::Future: Send,
I: AsyncRead + AsyncWrite + Unpin + Send + 'static,
S: Send + 'static,
{
type Future = BoxFuture<'static, io::Result<(Self::Stream, Self::Service)>>;
type Service = CatchPanicService<A::Service>;
type Stream = CatchPanicStream<A::Stream>;
#[instrument(skip_all)]
fn accept(&self, stream: I, service: S) -> Self::Future {
let acceptor = self.inner.clone();
let arbiter = self.arbiter.clone();
Box::pin(async move {
match AssertUnwindSafe(acceptor.accept(stream, service))
.catch_unwind()
.await
{
Ok(result) => {
let (stream, service) = result?;
Ok((
CatchPanicStream::new(stream, arbiter.clone()),
CatchPanicService::new(service, arbiter),
))
}
Err(panic) => {
error!(
panic = extract_panic_msg(&panic),
"acceptor panicked, shutting down immediately"
);
arbiter.do_fast_shutdown().await;
resume_unwind(panic);
}
}
})
}
}
pin_project! {
/// A stream wrapper that catches panics from the underlying stream.
pub(crate) struct CatchPanicStream<S> {
#[pin]
inner: S,
arbiter: Arbiter,
}
}
impl<S> CatchPanicStream<S> {
pub(crate) fn new(inner: S, arbiter: Arbiter) -> Self {
Self { inner, arbiter }
}
}
impl<S: AsyncRead> AsyncRead for CatchPanicStream<S> {
fn poll_read(
self: std::pin::Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
let this = self.project();
match catch_unwind(AssertUnwindSafe(|| this.inner.poll_read(cx, buf))) {
Ok(result) => result,
Err(panic) => {
error!(
panic = extract_panic_msg(&panic),
"stream poll_read panicked, shutting down immediately"
);
let arbiter = this.arbiter.clone();
tokio::spawn(async move { arbiter.do_fast_shutdown().await });
resume_unwind(panic);
}
}
}
}
impl<S: AsyncWrite> AsyncWrite for CatchPanicStream<S> {
fn poll_write(
self: std::pin::Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
let this = self.project();
match catch_unwind(AssertUnwindSafe(|| this.inner.poll_write(cx, buf))) {
Ok(result) => result,
Err(panic) => {
error!(
panic = extract_panic_msg(&panic),
"stream poll_write panicked, shutting down immediately"
);
let arbiter = this.arbiter.clone();
tokio::spawn(async move { arbiter.do_fast_shutdown().await });
resume_unwind(panic);
}
}
}
fn poll_write_vectored(
self: std::pin::Pin<&mut Self>,
cx: &mut Context<'_>,
bufs: &[IoSlice<'_>],
) -> Poll<io::Result<usize>> {
let this = self.project();
match catch_unwind(AssertUnwindSafe(|| {
this.inner.poll_write_vectored(cx, bufs)
})) {
Ok(result) => result,
Err(panic) => {
error!(
panic = extract_panic_msg(&panic),
"stream poll_write_vectored panicked, shutting down immediately"
);
let arbiter = this.arbiter.clone();
tokio::spawn(async move { arbiter.do_fast_shutdown().await });
resume_unwind(panic)
}
}
}
fn is_write_vectored(&self) -> bool {
match catch_unwind(AssertUnwindSafe(|| self.inner.is_write_vectored())) {
Ok(result) => result,
Err(panic) => {
error!(
panic = extract_panic_msg(&panic),
"stream is_write_vectored panicked, shutting down immediately"
);
let arbiter = self.arbiter.clone();
tokio::spawn(async move { arbiter.do_fast_shutdown().await });
resume_unwind(panic);
}
}
}
fn poll_flush(self: std::pin::Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
let this = self.project();
match catch_unwind(AssertUnwindSafe(|| this.inner.poll_flush(cx))) {
Ok(result) => result,
Err(panic) => {
error!(
panic = extract_panic_msg(&panic),
"stream poll_flush panicked, shutting down immediately"
);
let arbiter = this.arbiter.clone();
tokio::spawn(async move { arbiter.do_fast_shutdown().await });
resume_unwind(panic);
}
}
}
fn poll_shutdown(self: std::pin::Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
let this = self.project();
match catch_unwind(AssertUnwindSafe(|| this.inner.poll_shutdown(cx))) {
Ok(result) => result,
Err(panic) => {
error!(
panic = extract_panic_msg(&panic),
"stream poll_shutdown panicked, shutting down immediately"
);
let arbiter = this.arbiter.clone();
tokio::spawn(async move { arbiter.do_fast_shutdown().await });
resume_unwind(panic);
}
}
}
}
/// A panic wrapper that catches panics from the underlying service.
#[derive(Clone)]
pub(crate) struct CatchPanicService<S> {
inner: S,
arbiter: Arbiter,
}
impl<S> CatchPanicService<S> {
pub(crate) fn new(inner: S, arbiter: Arbiter) -> Self {
Self { inner, arbiter }
}
}
impl<S, R> Service<R> for CatchPanicService<S>
where
S: Service<R>,
{
type Error = S::Error;
type Future = CatchPanicFuture<S::Future>;
type Response = S::Response;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
let inner = &mut self.inner;
match catch_unwind(AssertUnwindSafe(|| inner.poll_ready(cx))) {
Ok(result) => result,
Err(panic) => {
error!(
panic = extract_panic_msg(&panic),
"service poll_ready panicked, shutting down immediately"
);
let arbiter = self.arbiter.clone();
tokio::spawn(async move { arbiter.do_fast_shutdown().await });
resume_unwind(panic);
}
}
}
fn call(&mut self, req: R) -> Self::Future {
let inner = &mut self.inner;
match catch_unwind(AssertUnwindSafe(|| inner.call(req))) {
Ok(future) => CatchPanicFuture::new(future, self.arbiter.clone()),
Err(panic) => {
error!(
panic = extract_panic_msg(&panic),
"service call panicked, shutting down immediately"
);
let arbiter = self.arbiter.clone();
tokio::spawn(async move { arbiter.do_fast_shutdown().await });
resume_unwind(panic);
}
}
}
}
pin_project! {
/// A Future wrapper that catches panics from the inner future.
pub(crate) struct CatchPanicFuture<F> {
#[pin]
inner: F,
arbiter: Arbiter,
}
}
impl<F> CatchPanicFuture<F> {
fn new(inner: F, arbiter: Arbiter) -> Self {
Self { inner, arbiter }
}
}
impl<F: Future> Future for CatchPanicFuture<F> {
type Output = F::Output;
fn poll(self: std::pin::Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.project();
match catch_unwind(AssertUnwindSafe(|| this.inner.poll(cx))) {
Ok(result) => result,
Err(panic) => {
error!(
panic = extract_panic_msg(&panic),
"service future panicked, shutting down immediately"
);
let arbiter = this.arbiter.clone();
tokio::spawn(async move { arbiter.do_fast_shutdown().await });
resume_unwind(panic);
}
}
}
}
#[cfg(test)]
mod tests {
use std::{
convert::Infallible,
io,
panic::{AssertUnwindSafe, panic_any},
task::{Context, Poll},
};
use ak_common::{Arbiter, Tasks};
use axum_server::accept::Accept;
use futures::{
FutureExt as _,
future::{BoxFuture, poll_fn},
};
use tokio::{
io::{AsyncReadExt as _, AsyncWriteExt as _, DuplexStream, ReadBuf, duplex},
time::{Duration, timeout},
};
use tower::Service;
use super::{CatchPanicAcceptor, CatchPanicService, CatchPanicStream};
fn duplex_stream() -> DuplexStream {
let (stream, _peer) = duplex(1024);
stream
}
/// Returns `true` if the arbiter's fast-shutdown has already been triggered.
async fn fast_shutdown_triggered(arbiter: &Arbiter) -> bool {
timeout(Duration::from_millis(50), arbiter.fast_shutdown())
.await
.is_ok()
}
#[derive(Clone)]
struct OkAcceptor;
impl<I: Send + 'static, S: Send + 'static> Accept<I, S> for OkAcceptor {
type Future = BoxFuture<'static, io::Result<(I, S)>>;
type Service = S;
type Stream = I;
fn accept(&self, stream: I, service: S) -> Self::Future {
Box::pin(async move { Ok((stream, service)) })
}
}
#[derive(Clone)]
struct ErrorAcceptor;
impl<I: Send + 'static, S: Send + 'static> Accept<I, S> for ErrorAcceptor {
type Future = BoxFuture<'static, io::Result<(I, S)>>;
type Service = S;
type Stream = I;
fn accept(&self, _stream: I, _service: S) -> Self::Future {
Box::pin(async move { Err(io::Error::other("inner error")) })
}
}
/// Panics with a `&'static str` payload.
#[derive(Clone)]
struct PanicStrAcceptor;
impl<I: Send + 'static, S: Send + 'static> Accept<I, S> for PanicStrAcceptor {
type Future = BoxFuture<'static, io::Result<(I, S)>>;
type Service = S;
type Stream = I;
fn accept(&self, _stream: I, _service: S) -> Self::Future {
Box::pin(async move { panic!("str panic message") })
}
}
/// Panics with a `String` payload.
#[derive(Clone)]
struct PanicStringAcceptor;
impl<I: Send + 'static, S: Send + 'static> Accept<I, S> for PanicStringAcceptor {
type Future = BoxFuture<'static, io::Result<(I, S)>>;
type Service = S;
type Stream = I;
fn accept(&self, _stream: I, _service: S) -> Self::Future {
Box::pin(async move {
let msg = "string panic message".to_owned();
panic_any(msg)
})
}
}
/// Panics with a payload that is neither `&str` nor `String`.
#[derive(Clone)]
struct PanicUnknownAcceptor;
impl<I: Send + 'static, S: Send + 'static> Accept<I, S> for PanicUnknownAcceptor {
type Future = BoxFuture<'static, io::Result<(I, S)>>;
type Service = S;
type Stream = I;
fn accept(&self, _stream: I, _service: S) -> Self::Future {
Box::pin(async move { panic_any(42u32) })
}
}
struct PanicStream;
impl tokio::io::AsyncRead for PanicStream {
fn poll_read(
self: std::pin::Pin<&mut Self>,
_cx: &mut Context<'_>,
_buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
panic!("poll_read panic")
}
}
impl tokio::io::AsyncWrite for PanicStream {
fn poll_write(
self: std::pin::Pin<&mut Self>,
_cx: &mut Context<'_>,
_buf: &[u8],
) -> Poll<io::Result<usize>> {
panic!("poll_write panic")
}
fn poll_flush(
self: std::pin::Pin<&mut Self>,
_cx: &mut Context<'_>,
) -> Poll<io::Result<()>> {
panic!("poll_flush panic")
}
fn poll_shutdown(
self: std::pin::Pin<&mut Self>,
_cx: &mut Context<'_>,
) -> Poll<io::Result<()>> {
panic!("poll_shutdown panic")
}
}
#[derive(Clone)]
struct OkService;
impl Service<()> for OkService {
type Error = Infallible;
type Future = futures::future::Ready<Result<(), Infallible>>;
type Response = ();
fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, _req: ()) -> Self::Future {
futures::future::ready(Ok(()))
}
}
struct PanicPollReadyService;
impl Service<()> for PanicPollReadyService {
type Error = Infallible;
type Future = futures::future::Ready<Result<(), Infallible>>;
type Response = ();
fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
panic!("poll_ready panic")
}
fn call(&mut self, _req: ()) -> Self::Future {
unreachable!()
}
}
struct PanicCallBodyService;
impl Service<()> for PanicCallBodyService {
type Error = Infallible;
type Future = futures::future::Ready<Result<(), Infallible>>;
type Response = ();
fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, _req: ()) -> Self::Future {
panic!("call body panic")
}
}
struct PanicFuture;
impl Future for PanicFuture {
type Output = Result<(), Infallible>;
fn poll(self: std::pin::Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Self::Output> {
panic!("future panic")
}
}
struct PanicCallFutureService;
impl Service<()> for PanicCallFutureService {
type Error = Infallible;
type Future = PanicFuture;
type Response = ();
fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, _req: ()) -> Self::Future {
PanicFuture
}
}
#[tokio::test]
async fn acceptor_passes_through_success() {
let tasks = Tasks::new().expect("failed to create tasks");
let arbiter = tasks.arbiter();
let acceptor = CatchPanicAcceptor::new(OkAcceptor, arbiter.clone());
let result = acceptor.accept(duplex_stream(), OkService).await;
assert!(result.is_ok());
assert!(!fast_shutdown_triggered(&arbiter).await);
}
#[tokio::test]
async fn acceptor_passes_through_error() {
let tasks = Tasks::new().expect("failed to create tasks");
let arbiter = tasks.arbiter();
let acceptor = CatchPanicAcceptor::new(ErrorAcceptor, arbiter.clone());
let result = acceptor.accept(duplex_stream(), OkService).await;
assert!(result.is_err());
assert_eq!(result.err().unwrap().to_string(), "inner error");
assert!(!fast_shutdown_triggered(&arbiter).await);
}
#[tokio::test]
async fn acceptor_catches_str_panic_and_shuts_down() {
let tasks = Tasks::new().expect("failed to create tasks");
let arbiter = tasks.arbiter();
let acceptor = CatchPanicAcceptor::new(PanicStrAcceptor, arbiter.clone());
let result = AssertUnwindSafe(acceptor.accept(duplex_stream(), OkService))
.catch_unwind()
.await;
assert!(result.is_err());
assert!(fast_shutdown_triggered(&arbiter).await);
}
#[tokio::test]
async fn acceptor_catches_string_panic_and_shuts_down() {
let tasks = Tasks::new().expect("failed to create tasks");
let arbiter = tasks.arbiter();
let acceptor = CatchPanicAcceptor::new(PanicStringAcceptor, arbiter.clone());
let result = AssertUnwindSafe(acceptor.accept(duplex_stream(), OkService))
.catch_unwind()
.await;
assert!(result.is_err());
assert!(fast_shutdown_triggered(&arbiter).await);
}
#[tokio::test]
async fn acceptor_catches_unknown_panic_and_shuts_down() {
let tasks = Tasks::new().expect("failed to create tasks");
let arbiter = tasks.arbiter();
let acceptor = CatchPanicAcceptor::new(PanicUnknownAcceptor, arbiter.clone());
let result = AssertUnwindSafe(acceptor.accept(duplex_stream(), OkService))
.catch_unwind()
.await;
assert!(result.is_err());
assert!(fast_shutdown_triggered(&arbiter).await);
}
#[tokio::test]
async fn stream_poll_read_passes_through() {
let tasks = Tasks::new().expect("failed to create tasks");
let arbiter = tasks.arbiter();
let (mut a, mut b) = duplex(1024);
b.write_all(b"hello").await.unwrap();
let mut stream = CatchPanicStream::new(&mut a, arbiter.clone());
let mut buf = [0u8; 5];
let result = stream.read(&mut buf).await;
assert!(result.is_ok());
assert_eq!(&buf, b"hello");
assert!(!fast_shutdown_triggered(&arbiter).await);
}
#[tokio::test]
async fn stream_poll_read_panic_returns_error_and_shuts_down() {
let tasks = Tasks::new().expect("failed to create tasks");
let arbiter = tasks.arbiter();
let mut stream = CatchPanicStream::new(PanicStream, arbiter.clone());
let result = AssertUnwindSafe(stream.read(&mut [0u8; 10]))
.catch_unwind()
.await;
assert!(result.is_err());
assert!(fast_shutdown_triggered(&arbiter).await);
}
#[tokio::test]
async fn stream_poll_write_passes_through() {
let tasks = Tasks::new().expect("failed to create tasks");
let arbiter = tasks.arbiter();
let (mut a, _b) = duplex(1024);
let mut stream = CatchPanicStream::new(&mut a, arbiter.clone());
let result = stream.write_all(b"hello").await;
assert!(result.is_ok());
assert!(!fast_shutdown_triggered(&arbiter).await);
}
#[tokio::test]
async fn stream_poll_write_panic_returns_error_and_shuts_down() {
let tasks = Tasks::new().expect("failed to create tasks");
let arbiter = tasks.arbiter();
let mut stream = CatchPanicStream::new(PanicStream, arbiter.clone());
let result = AssertUnwindSafe(stream.write(b"hello"))
.catch_unwind()
.await;
assert!(result.is_err());
assert!(fast_shutdown_triggered(&arbiter).await);
}
#[tokio::test]
async fn stream_poll_flush_panic_returns_error_and_shuts_down() {
let tasks = Tasks::new().expect("failed to create tasks");
let arbiter = tasks.arbiter();
let mut stream = CatchPanicStream::new(PanicStream, arbiter.clone());
let result = AssertUnwindSafe(stream.flush()).catch_unwind().await;
assert!(result.is_err());
assert!(fast_shutdown_triggered(&arbiter).await);
}
#[tokio::test]
async fn stream_poll_shutdown_panic_returns_error_and_shuts_down() {
let tasks = Tasks::new().expect("failed to create tasks");
let arbiter = tasks.arbiter();
let mut stream = CatchPanicStream::new(PanicStream, arbiter.clone());
let result = AssertUnwindSafe(stream.shutdown()).catch_unwind().await;
assert!(result.is_err());
assert!(fast_shutdown_triggered(&arbiter).await);
}
#[tokio::test]
async fn service_poll_ready_passes_through() {
let tasks = Tasks::new().expect("failed to create tasks");
let arbiter = tasks.arbiter();
let mut service = CatchPanicService::new(OkService, arbiter.clone());
let result = poll_fn(|cx| service.poll_ready(cx)).await;
assert!(result.is_ok());
assert!(!fast_shutdown_triggered(&arbiter).await);
}
#[tokio::test]
async fn service_poll_ready_panic_re_panics_and_shuts_down() {
let tasks = Tasks::new().expect("failed to create tasks");
let arbiter = tasks.arbiter();
let mut service = CatchPanicService::new(PanicPollReadyService, arbiter.clone());
let result = AssertUnwindSafe(poll_fn(|cx| service.poll_ready(cx)))
.catch_unwind()
.await;
assert!(result.is_err());
assert!(fast_shutdown_triggered(&arbiter).await);
}
#[tokio::test]
async fn service_call_passes_through() {
let tasks = Tasks::new().expect("failed to create tasks");
let arbiter = tasks.arbiter();
let mut service = CatchPanicService::new(OkService, arbiter.clone());
let result = service.call(()).await;
assert!(result.is_ok());
assert!(!fast_shutdown_triggered(&arbiter).await);
}
#[tokio::test]
async fn service_call_body_panic_re_panics_and_shuts_down() {
let tasks = Tasks::new().expect("failed to create tasks");
let arbiter = tasks.arbiter();
let mut service = CatchPanicService::new(PanicCallBodyService, arbiter.clone());
let result = AssertUnwindSafe(async { service.call(()).await })
.catch_unwind()
.await;
assert!(result.is_err());
assert!(fast_shutdown_triggered(&arbiter).await);
}
#[tokio::test]
async fn service_call_future_panic_re_panics_and_shuts_down() {
let tasks = Tasks::new().expect("failed to create tasks");
let arbiter = tasks.arbiter();
let mut service = CatchPanicService::new(PanicCallFutureService, arbiter.clone());
let result = AssertUnwindSafe(service.call(())).catch_unwind().await;
assert!(result.is_err());
assert!(fast_shutdown_triggered(&arbiter).await);
}
}

View File

@@ -1,3 +1,2 @@
pub mod catch_panic;
pub mod proxy_protocol;
pub mod tls;

View File

@@ -1,13 +1,7 @@
//! Utilities for working with [`Router`].
use ak_common::config;
use axum::{
Router,
extract::Request,
http::{HeaderName, HeaderValue, StatusCode},
middleware::{Next, from_fn},
response::Response,
};
use axum::{Router, http::StatusCode, middleware::from_fn};
use tower::ServiceBuilder;
use tower_http::timeout::TimeoutLayer;
@@ -19,16 +13,6 @@ use crate::{
tracing::{span_middleware, tracing_middleware},
};
const X_POWERED_BY: HeaderName = HeaderName::from_static("x-powered-by");
async fn powered_by_authentik_middleware(request: Request, next: Next) -> Response {
let mut response = next.run(request).await;
response
.headers_mut()
.insert(X_POWERED_BY, HeaderValue::from_static("authentik"));
response
}
/// Wrap a [`Router`] with common middlewares.
///
/// Set `with_tracing` to [`true`] to log requests.
@@ -46,7 +30,6 @@ pub fn wrap_router(router: Router, with_tracing: bool) -> Router {
timeout,
))
.layer(from_fn(span_middleware))
.layer(from_fn(powered_by_authentik_middleware))
.layer(from_fn(trusted_proxy_middleware))
.layer(from_fn(client_ip_middleware))
.layer(from_fn(scheme_middleware))

View File

@@ -12,9 +12,7 @@ use axum_server::{
use eyre::Result;
use tracing::{info, trace};
use crate::accept::{
catch_panic::CatchPanicAcceptor, proxy_protocol::ProxyProtocolAcceptor, tls::TlsAcceptor,
};
use crate::accept::{proxy_protocol::ProxyProtocolAcceptor, tls::TlsAcceptor};
async fn run_plain(
arbiter: Arbiter,
@@ -29,10 +27,7 @@ async fn run_plain(
arbiter.add_net_handle(handle.clone()).await;
let res = axum_server::Server::bind(addr)
.acceptor(CatchPanicAcceptor::new(
ProxyProtocolAcceptor::new().acceptor(DefaultAcceptor::new()),
arbiter.clone(),
))
.acceptor(ProxyProtocolAcceptor::new().acceptor(DefaultAcceptor::new()))
.handle(handle)
.serve(router.into_make_service_with_connect_info::<net::SocketAddr>())
.await;
@@ -85,10 +80,7 @@ pub(crate) async fn run_unix(
}
}
let res = axum_server::Server::bind(addr.clone())
.acceptor(CatchPanicAcceptor::new(
DefaultAcceptor::new(),
arbiter.clone(),
))
.acceptor(DefaultAcceptor::new())
.handle(handle)
.serve(router.into_make_service())
.await;
@@ -141,12 +133,9 @@ async fn run_tls(
arbiter.add_net_handle(handle.clone()).await;
axum_server::Server::bind(addr)
.acceptor(CatchPanicAcceptor::new(
ProxyProtocolAcceptor::new().acceptor(TlsAcceptor::new(
RustlsAcceptor::new(config).acceptor(DefaultAcceptor::new()),
)),
arbiter.clone(),
))
.acceptor(ProxyProtocolAcceptor::new().acceptor(TlsAcceptor::new(
RustlsAcceptor::new(config).acceptor(DefaultAcceptor::new()),
)))
.handle(handle)
.serve(router.into_make_service_with_connect_info::<net::SocketAddr>())
.await?;

View File

@@ -235,7 +235,7 @@ impl Arbiter {
}
/// Shutdown the application immediately.
pub async fn do_fast_shutdown(&self) {
async fn do_fast_shutdown(&self) {
info!("arbiter has been told to shutdown immediately");
self.unix_handles
.lock()
@@ -253,7 +253,7 @@ impl Arbiter {
}
/// Shutdown the application gracefully.
pub async fn do_graceful_shutdown(&self) {
async fn do_graceful_shutdown(&self) {
info!("arbiter has been told to shutdown gracefully");
// Match the value in lifecycle/gunicorn.conf.py for graceful shutdown
let timeout = Some(Duration::from_secs(30 + 5));

View File

@@ -16,10 +16,7 @@ use url::Url;
pub mod schema;
pub use schema::Config;
use crate::{
arbiter::{Arbiter, Event, Tasks},
config::schema::KEYS_TO_PARSE_AS_LIST,
};
use crate::arbiter::{Arbiter, Event, Tasks};
static DEFAULT_CONFIG: &str = include_str!("../../../../authentik/lib/default.yml");
static CONFIG_MANAGER: OnceLock<ConfigManager> = OnceLock::new();
@@ -78,15 +75,11 @@ impl Config {
config_rs::File::from(path.as_path()).format(config_rs::FileFormat::Yaml),
);
}
let mut env_source = config_rs::Environment::with_prefix("AUTHENTIK")
.prefix_separator("_")
.separator("__")
.try_parsing(true)
.list_separator(",");
for key in KEYS_TO_PARSE_AS_LIST {
env_source = env_source.with_list_parse_key(key);
}
builder = builder.add_source(env_source);
builder = builder.add_source(
config_rs::Environment::with_prefix("AUTHENTIK")
.prefix_separator("_")
.separator("__"),
);
if let Some(overrides) = overrides {
builder = builder.add_source(config_rs::File::from_str(
&overrides.to_string(),
@@ -462,92 +455,4 @@ mod tests {
super::set(json!({"secret_key": "my_new_secret_key"})).expect("failed to set config");
assert_eq!(super::get().secret_key, "my_new_secret_key");
}
#[test]
fn env_bool_true() {
#[expect(unsafe_code, reason = "testing")]
// SAFETY: testing
unsafe {
env::set_var("AUTHENTIK_DEBUG", "true");
}
let (config, _) = super::Config::load(&[], None).expect("failed to load config");
assert!(config.debug);
}
#[test]
fn env_bool_false() {
#[expect(unsafe_code, reason = "testing")]
// SAFETY: testing
unsafe {
env::set_var("AUTHENTIK_DEBUG", "false");
}
let (config, _) = super::Config::load(&[], None).expect("failed to load config");
assert!(!config.debug);
}
// See https://github.com/rust-cli/config-rs/issues/443
// #[test]
// fn env_list_empty() {
// #[expect(unsafe_code, reason = "testing")]
// // SAFETY: testing
// unsafe {
// env::set_var("AUTHENTIK_LISTEN__HTTP", "");
// }
//
// let (config, _) = super::Config::load(&[], None).expect("failed to load config");
//
// assert_eq!(config.listen.http, []);
// }
#[test]
fn env_list_one_element() {
#[expect(unsafe_code, reason = "testing")]
// SAFETY: testing
unsafe {
env::set_var("AUTHENTIK_LISTEN__HTTP", "[::1]:9000");
}
let (config, _) = super::Config::load(&[], None).expect("failed to load config");
assert_eq!(
config.listen.http,
["[::1]:9000".parse().expect("infallible")]
);
}
#[test]
fn env_list_many_elements() {
#[expect(unsafe_code, reason = "testing")]
// SAFETY: testing
unsafe {
env::set_var("AUTHENTIK_LISTEN__HTTP", "[::1]:9000,[::1]:9001");
}
let (config, _) = super::Config::load(&[], None).expect("failed to load config");
assert_eq!(
config.listen.http,
[
"[::1]:9000".parse().expect("infallible"),
"[::1]:9001".parse().expect("infallible")
]
);
}
#[test]
fn env_string() {
#[expect(unsafe_code, reason = "testing")]
// SAFETY: testing
unsafe {
env::set_var("AUTHENTIK_SECRET_KEY", "my_secret_key");
}
let (config, _) = super::Config::load(&[], None).expect("failed to load config");
assert_eq!(config.secret_key, "my_secret_key",);
}
}

Some files were not shown because too many files have changed in this diff Show More