Compare commits

..

3 Commits

Author SHA1 Message Date
Teffen Ellis
64e7fa6bc0 root/middleware: drop nonce from style-src so dynamic <style> is allowed
Mermaid (and a handful of other libs in the bundle) inject `<style>`
elements at runtime via createElement+appendChild, which never carries a
nonce. CSP3 §6.6.2.2 specifies that browsers ignore `'unsafe-inline'`
whenever a nonce is also present in the same source list, so we cannot
have both — keep the nonce and break dynamic styling, or drop it and
rely on `'unsafe-inline'`. Script-side CSP keeps its nonce + strict
allowlist; only style-src is relaxed.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-04-28 19:02:50 +02:00
Teffen Ellis
4b4968c66b Flesh out CSP requirements. 2026-04-27 23:30:04 +02:00
Teffen Ellis
4e5b938ebe Flesh out CSP. 2026-04-27 22:01:20 +02:00
397 changed files with 5257 additions and 33881 deletions

View File

@@ -64,7 +64,7 @@ runs:
rustflags: ""
- name: Setup rust dependencies
if: ${{ contains(inputs.dependencies, 'rust') }}
uses: taiki-e/install-action@b5fddbb5361bce8a06fb168c9d403a6cc552b084 # v2
uses: taiki-e/install-action@787505cde8a44ea468a00478fe52baf23b15bccd # v2
with:
tool: cargo-deny cargo-machete cargo-llvm-cov nextest
- name: Setup node (web)
@@ -104,7 +104,7 @@ runs:
working-directory: ${{ inputs.working-directory }}
run: |
export PSQL_TAG=${{ inputs.postgresql_version }}
docker compose -f .github/actions/setup/compose.yml up -d --wait
docker compose -f .github/actions/setup/compose.yml up -d
cd web && npm ci
- name: Generate config
if: ${{ contains(inputs.dependencies, 'python') }}

View File

@@ -8,14 +8,8 @@ services:
POSTGRES_USER: authentik
POSTGRES_PASSWORD: "EK-5jnKfjrGRm<77"
POSTGRES_DB: authentik
PGDATA: /var/lib/postgresql/data/pgdata
ports:
- 5432:5432
healthcheck:
test: ["CMD-SHELL", "pg_isready -U $${POSTGRES_USER} -d $${POSTGRES_DB} -h 127.0.0.1"]
interval: 1s
timeout: 5s
retries: 60
restart: always
s3:
container_name: s3

View File

@@ -90,7 +90,7 @@ jobs:
registry: ghcr.io
username: ${{ github.repository_owner }}
password: ${{ secrets.GITHUB_TOKEN }}
- uses: int128/docker-manifest-create-action@fa55f72001a6c74b0f4997dca65c70d334905180 # v2
- uses: int128/docker-manifest-create-action@7df7f9e221d927eaadf87db231ddf728047308a4 # v2
id: build
with:
tags: ${{ matrix.tag }}

View File

@@ -282,18 +282,10 @@ jobs:
fail-fast: false
matrix:
job:
- name: oidc_basic
glob: tests/openid_conformance/test_oidc_basic.py
- name: oidc_implicit
glob: tests/openid_conformance/test_oidc_implicit.py
- name: oidc_rp-initiated
glob: tests/openid_conformance/test_oidc_rp_initiated.py
- name: oidc_frontchannel
glob: tests/openid_conformance/test_oidc_frontchannel.py
- name: oidc_backchannel
glob: tests/openid_conformance/test_oidc_backchannel.py
- name: ssf_transmitter
glob: tests/openid_conformance/test_ssf_transmitter.py
- name: basic
glob: tests/openid_conformance/test_basic.py
- name: implicit
glob: tests/openid_conformance/test_implicit.py
steps:
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v5
- name: Setup authentik env

5
.gitignore vendored
View File

@@ -229,11 +229,6 @@ source_docs/
### Golang ###
/vendor/
server
proxy
ldap
rac
radius
### Docker ###
tests/openid_conformance/exports/*.zip

View File

@@ -14,7 +14,6 @@ pyproject.toml @goauthentik/backend
uv.lock @goauthentik/backend
Cargo.toml @goauthentik/backend
Cargo.lock @goauthentik/backend
build.rs @goauthentik/backend
go.mod @goauthentik/backend
go.sum @goauthentik/backend
.cargo/ @goauthentik/backend

173
Cargo.lock generated
View File

@@ -17,6 +17,18 @@ version = "2.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "320119579fcad9c21884f5c4861d16174d0e06250625266f50fe6898340abefa"
[[package]]
name = "ahash"
version = "0.8.12"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5a15f179cd60c4584b8a8c596927aadc462e27f2ca70c04e0071964a73ba7a75"
dependencies = [
"cfg-if",
"once_cell",
"version_check",
"zerocopy",
]
[[package]]
name = "aho-corasick"
version = "1.1.4"
@@ -176,33 +188,20 @@ dependencies = [
"arc-swap",
"argh",
"authentik-axum",
"authentik-client",
"authentik-common",
"axum",
"color-eyre",
"eyre",
"futures",
"hyper-unix-socket",
"hyper-util",
"metrics",
"metrics-exporter-prometheus",
"nix 0.31.2",
"pyo3",
"pyo3-build-config",
"rand 0.10.1",
"serde",
"serde_json",
"serde_repr",
"sqlx",
"time",
"tokio",
"tokio-retry2",
"tokio-tungstenite",
"tower",
"tracing",
"url",
"uuid",
"which",
]
[[package]]
@@ -553,17 +552,6 @@ version = "0.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "613afe47fcd5fac7ccf1db93babcb082c5994d996f20b8b159f2ad1658eb5724"
[[package]]
name = "chacha20"
version = "0.10.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6f8d983286843e49675a4b7a2d174efe136dc93a18d69130dd18198a6c167601"
dependencies = [
"cfg-if",
"cpufeatures 0.3.0",
"rand_core 0.10.1",
]
[[package]]
name = "chrono"
version = "0.4.44"
@@ -801,15 +789,6 @@ dependencies = [
"libc",
]
[[package]]
name = "cpufeatures"
version = "0.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8b2a41393f66f16b0823bb79094d54ac5fbd34ab292ddafb9a0456ac9f87d201"
dependencies = [
"libc",
]
[[package]]
name = "crc"
version = "3.4.0"
@@ -1034,17 +1013,6 @@ dependencies = [
"pin-project-lite",
]
[[package]]
name = "evmap"
version = "11.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1b8874945f036109c72242964c1174cf99434e30cfa45bf45fedc983f50046f8"
dependencies = [
"hashbag",
"left-right",
"smallvec",
]
[[package]]
name = "eyre"
version = "0.6.12"
@@ -1261,21 +1229,6 @@ dependencies = [
"slab",
]
[[package]]
name = "generator"
version = "0.8.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "52f04ae4152da20c76fe800fa48659201d5cf627c5149ca0b707b69d7eef6cf9"
dependencies = [
"cc",
"cfg-if",
"libc",
"log",
"rustversion",
"windows-link",
"windows-result",
]
[[package]]
name = "generic-array"
version = "0.14.7"
@@ -1322,7 +1275,6 @@ dependencies = [
"cfg-if",
"libc",
"r-efi 6.0.0",
"rand_core 0.10.1",
"wasip2",
"wasip3",
]
@@ -1358,12 +1310,6 @@ dependencies = [
"tracing",
]
[[package]]
name = "hashbag"
version = "0.1.13"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7040a10f52cba493ddb09926e15d10a9d8a28043708a405931fe4c6f19fac064"
[[package]]
name = "hashbrown"
version = "0.15.5"
@@ -1921,17 +1867,6 @@ version = "0.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "09edd9e8b54e49e587e4f6295a7d29c3ea94d469cb40ab8ca70b288248a81db2"
[[package]]
name = "left-right"
version = "0.11.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0f0c21e4c8ff95f487fb34e6f9182875f42c84cef966d29216bf115d9bba835a"
dependencies = [
"crossbeam-utils",
"loom",
"slab",
]
[[package]]
name = "libc"
version = "0.2.183"
@@ -2003,19 +1938,6 @@ version = "0.4.29"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5e5032e24019045c762d3c0f28f5b6b8bbf38563a65908389bf7978758920897"
[[package]]
name = "loom"
version = "0.7.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "419e0dc8046cb947daa77eb95ae174acfbddb7673b4151f56d1eed8e93fbfaca"
dependencies = [
"cfg-if",
"generator",
"scoped-tls",
"tracing",
"tracing-subscriber",
]
[[package]]
name = "lru-slab"
version = "0.1.2"
@@ -2055,22 +1977,21 @@ checksum = "f8ca58f447f06ed17d5fc4043ce1b10dd205e060fb3ce5b979b8ed8e59ff3f79"
[[package]]
name = "metrics"
version = "0.24.5"
version = "0.24.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ff56c2e7dce6bd462e3b8919986a617027481b1dcc703175b58cf9dd98a2f071"
checksum = "5d5312e9ba3771cfa961b585728215e3d972c950a3eed9252aa093d6301277e8"
dependencies = [
"ahash",
"portable-atomic",
"rapidhash",
]
[[package]]
name = "metrics-exporter-prometheus"
version = "0.18.3"
version = "0.18.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1db0d8f1fc9e62caebd0319e11eaec5822b0186c171568f0480b46a0137f9108"
checksum = "3589659543c04c7dc5526ec858591015b87cd8746583b51b48ef4353f99dbcda"
dependencies = [
"base64 0.22.1",
"evmap",
"indexmap",
"metrics",
"metrics-util",
@@ -2844,17 +2765,6 @@ dependencies = [
"rand_core 0.9.5",
]
[[package]]
name = "rand"
version = "0.10.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d2e8e8bcc7961af1fdac401278c6a831614941f6164ee3bf4ce61b7edb162207"
dependencies = [
"chacha20",
"getrandom 0.4.2",
"rand_core 0.10.1",
]
[[package]]
name = "rand_chacha"
version = "0.3.1"
@@ -2893,12 +2803,6 @@ dependencies = [
"getrandom 0.3.4",
]
[[package]]
name = "rand_core"
version = "0.10.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "63b8176103e19a2643978565ca18b50549f6101881c443590420e4dc998a3c69"
[[package]]
name = "rand_xoshiro"
version = "0.7.0"
@@ -2908,15 +2812,6 @@ dependencies = [
"rand_core 0.9.5",
]
[[package]]
name = "rapidhash"
version = "4.4.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b5e48930979c155e2f33aa36ab3119b5ee81332beb6482199a8ecd6029b80b59"
dependencies = [
"rustversion",
]
[[package]]
name = "raw-cpuid"
version = "11.6.0"
@@ -2975,9 +2870,9 @@ checksum = "dc897dd8d9e8bd1ed8cdad82b5966c3e0ecae09fb1907d58efaa013543185d0a"
[[package]]
name = "reqwest"
version = "0.13.3"
version = "0.13.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "62e0021ea2c22aed41653bc7e1419abb2c97e038ff2c33d0e1309e49a97deec0"
checksum = "ab3f43e3283ab1488b624b44b0e988d0acea0b3214e694730a055cb6b2efa801"
dependencies = [
"base64 0.22.1",
"bytes",
@@ -3104,9 +2999,9 @@ dependencies = [
[[package]]
name = "rustls"
version = "0.23.40"
version = "0.23.39"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ef86cd5876211988985292b91c96a8f2d298df24e75989a43a3c73f2d4d8168b"
checksum = "7c2c118cb077cca2822033836dfb1b975355dfb784b5e8da48f7b6c5db74e60e"
dependencies = [
"aws-lc-rs",
"log",
@@ -3209,12 +3104,6 @@ dependencies = [
"windows-sys 0.61.2",
]
[[package]]
name = "scoped-tls"
version = "1.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e1cf6437eb19a8f4a6cc0f7dca544973b0b78843adbfeb3683d1a94a0024a294"
[[package]]
name = "scopeguard"
version = "1.2.0"
@@ -3468,7 +3357,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e3bf829a2d51ab4a5ddf1352d8470c140cadc8301b2ae1789db023f01cedd6ba"
dependencies = [
"cfg-if",
"cpufeatures 0.2.17",
"cpufeatures",
"digest",
]
@@ -3479,7 +3368,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a7507d819769d01a365ab707794a4084392c824f54a7a6a7862f8c3d0892b283"
dependencies = [
"cfg-if",
"cpufeatures 0.2.17",
"cpufeatures",
"digest",
]
@@ -4049,12 +3938,8 @@ checksum = "8f72a05e828585856dacd553fba484c242c46e391fb0e58917c942ee9202915c"
dependencies = [
"futures-util",
"log",
"rustls",
"rustls-pki-types",
"tokio",
"tokio-rustls",
"tungstenite",
"webpki-roots 0.26.11",
]
[[package]]
@@ -4268,11 +4153,8 @@ dependencies = [
"httparse",
"log",
"rand 0.9.2",
"rustls",
"rustls-pki-types",
"sha1",
"thiserror 2.0.18",
"url",
]
[[package]]
@@ -4632,15 +4514,6 @@ dependencies = [
"rustls-pki-types",
]
[[package]]
name = "which"
version = "8.0.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "81995fafaaaf6ae47a7d0cc83c67caf92aeb7e5331650ae6ff856f7c0c60c459"
dependencies = [
"libc",
]
[[package]]
name = "whoami"
version = "1.6.1"

View File

@@ -43,16 +43,14 @@ hyper-unix-socket = "= 0.6.1"
hyper-util = "= 0.1.20"
ipnet = { version = "= 2.12.0", features = ["serde"] }
json-subscriber = "= 0.2.8"
metrics = "= 0.24.5"
metrics-exporter-prometheus = { version = "= 0.18.3", default-features = false }
metrics = "= 0.24.3"
metrics-exporter-prometheus = { version = "= 0.18.1", default-features = false }
nix = { version = "= 0.31.2", features = ["hostname", "signal"] }
notify = "= 8.2.0"
pin-project-lite = "= 0.2.17"
pyo3 = "= 0.28.3"
pyo3-build-config = "= 0.28.3"
rand = "= 0.10.1"
regex = "= 1.12.3"
reqwest = { version = "= 0.13.3", features = [
reqwest = { version = "= 0.13.2", features = [
"form",
"json",
"multipart",
@@ -67,7 +65,7 @@ reqwest-middleware = { version = "= 0.5.1", features = [
"query",
"rustls",
] }
rustls = { version = "= 0.23.40", features = ["fips"] }
rustls = { version = "= 0.23.39", features = ["fips"] }
sentry = { version = "= 0.47.0", default-features = false, features = [
"backtrace",
"contexts",
@@ -101,10 +99,6 @@ time = { version = "= 0.3.47", features = ["macros"] }
tokio = { version = "= 1.52.1", features = ["full", "tracing"] }
tokio-retry2 = "= 0.9.1"
tokio-rustls = "= 0.26.4"
tokio-tungstenite = { version = "= 0.29.0", features = [
"rustls-tls-webpki-roots",
"url",
] }
tokio-util = { version = "= 0.7.18", features = ["full"] }
tower = "= 0.5.3"
tower-http = { version = "= 0.6.8", features = ["timeout"] }
@@ -118,7 +112,6 @@ tracing-subscriber = { version = "= 0.3.23", features = [
] }
url = "= 2.5.8"
uuid = { version = "= 1.23.1", features = ["serde", "v4"] }
which = "= 8.0.2"
ak-axum = { package = "authentik-axum", version = "2026.5.0-rc1", path = "./packages/ak-axum" }
ak-client = { package = "authentik-client", version = "2026.5.0-rc1", path = "./packages/client-rust" }
@@ -265,41 +258,26 @@ publish.workspace = true
[features]
default = ["core", "proxy"]
core = ["ak-common/core", "dep:pyo3", "dep:sqlx"]
proxy = ["ak-common/proxy", "dep:ak-client"]
[build-dependencies]
pyo3-build-config.workspace = true
proxy = ["ak-common/proxy"]
[dependencies]
ak-axum.workspace = true
ak-client = { workspace = true, optional = true }
ak-common.workspace = true
arc-swap.workspace = true
argh.workspace = true
axum.workspace = true
color-eyre.workspace = true
eyre.workspace = true
futures.workspace = true
hyper-unix-socket.workspace = true
hyper-util.workspace = true
metrics-exporter-prometheus.workspace = true
metrics.workspace = true
metrics-exporter-prometheus.workspace = true
nix.workspace = true
pyo3 = { workspace = true, optional = true }
rand.workspace = true
serde.workspace = true
serde_json.workspace = true
serde_repr.workspace = true
sqlx = { workspace = true, optional = true }
time.workspace = true
tokio-retry2.workspace = true
tokio-tungstenite.workspace = true
tokio.workspace = true
tower.workspace = true
tracing.workspace = true
url.workspace = true
uuid.workspace = true
which.workspace = true
[lints]
workspace = true

View File

@@ -109,11 +109,14 @@ i18n-extract: core-i18n-extract web-i18n-extract ## Extract strings that requir
aws-cfn:
cd lifecycle/aws && npm i && $(UV) run npm run aws-cfn
run: ## Run the main authentik server and worker processes
$(UV) run ak allinone
run-server: ## Run the main authentik server process
$(UV) run ak server
run-watch: ## Run the authentik server and worker, with auto reloading
watchexec --on-busy-update=restart --stop-signal=SIGINT --exts py,rs,go --no-meta --notify -- $(UV) run ak allinone
run-worker: ## Run the main authentik worker process
$(UV) run ak worker
run-worker-watch: ## Run the authentik worker, with auto reloading
watchexec --on-busy-update=restart --stop-signal=SIGINT --exts py,rs --no-meta --notify -- $(UV) run ak worker
core-i18n-extract:
$(UV) run ak makemessages \

View File

@@ -1,73 +1,31 @@
"""authentik API Modelviewset tests"""
from collections.abc import Callable
from urllib.parse import urlencode
from django.test import TestCase
from rest_framework.viewsets import ModelViewSet, ReadOnlyModelViewSet
from authentik.admin.api.version_history import VersionHistoryViewSet
from authentik.api.v3.urls import router
from authentik.core.tests.utils import RequestFactory, create_test_admin_user
from authentik.lib.generators import generate_id
from authentik.tenants.api.domains import DomainViewSet
from authentik.tenants.api.tenants import TenantViewSet
from authentik.tenants.utils import get_current_tenant
class TestModelViewSets(TestCase):
"""Test Viewset"""
def setUp(self):
self.user = create_test_admin_user()
self.factory = RequestFactory()
def viewset_tester_factory(test_viewset: type[ModelViewSet], full=True) -> dict[str, Callable]:
def viewset_tester_factory(test_viewset: type[ModelViewSet]) -> Callable:
"""Test Viewset"""
def test_attrs(self: TestModelViewSets) -> None:
"""Test attributes we require on all viewsets"""
self.assertIsNotNone(getattr(test_viewset, "ordering", None))
def tester(self: TestModelViewSets):
self.assertIsNotNone(getattr(test_viewset, "search_fields", None))
self.assertIsNotNone(getattr(test_viewset, "ordering", None))
filterset_class = getattr(test_viewset, "filterset_class", None)
if not filterset_class:
self.assertIsNotNone(getattr(test_viewset, "filterset_fields", None))
def test_ordering(self: TestModelViewSets) -> None:
"""Test that all ordering fields are correct"""
view = test_viewset.as_view({"get": "list"})
for ordering_field in test_viewset.ordering:
with self.subTest(ordering_field):
req = self.factory.get(
f"/?{urlencode({'ordering': ordering_field}, doseq=True)}", user=self.user
)
req.tenant = get_current_tenant()
res = view(req)
self.assertEqual(res.status_code, 200)
def test_search(self: TestModelViewSets) -> None:
"""Test that search fields are correct"""
view = test_viewset.as_view({"get": "list"})
req = self.factory.get(
f"/?{urlencode({'search': generate_id()}, doseq=True)}", user=self.user
)
req.tenant = get_current_tenant()
res = view(req)
self.assertEqual(res.status_code, 200)
cases = {
"attrs": test_attrs,
}
if full:
cases["ordering"] = test_ordering
cases["search"] = test_search
return cases
return tester
for _, viewset, _ in router.registry:
if not issubclass(viewset, ModelViewSet | ReadOnlyModelViewSet):
continue
full = viewset not in [VersionHistoryViewSet, DomainViewSet, TenantViewSet]
for test, case in viewset_tester_factory(viewset, full=full).items():
setattr(TestModelViewSets, f"test_viewset_{viewset.__name__}_{test}", case)
setattr(TestModelViewSets, f"test_viewset_{viewset.__name__}", viewset_tester_factory(viewset))

View File

@@ -126,7 +126,7 @@ class BlueprintInstanceSerializer(ModelSerializer):
def check_blueprint_perms(blueprint: Blueprint, user: User, explicit_action: str | None = None):
"""Check for individual permissions for each model in a blueprint"""
for entry in blueprint.iter_entries():
for entry in blueprint.entries:
full_model = entry.get_model(blueprint)
app, __, model = full_model.partition(".")
perms = [

View File

@@ -1,11 +1,8 @@
"""Test blueprints v1"""
from unittest.mock import patch
from django.test import TransactionTestCase
from authentik.blueprints.v1.importer import Importer
from authentik.enterprise.license import LicenseKey
from authentik.flows.models import Flow
from authentik.lib.generators import generate_id
from authentik.lib.tests.utils import load_fixture
@@ -45,45 +42,3 @@ class TestBlueprintsV1Conditions(TransactionTestCase):
# Ensure objects do not exist
self.assertFalse(Flow.objects.filter(slug=flow_slug1))
self.assertFalse(Flow.objects.filter(slug=flow_slug2))
def test_enterprise_license_context_unlicensed(self):
"""Test enterprise license context defaults to a false boolean when unlicensed."""
license_key = LicenseKey("test", 0, "Test license", 0, 0)
with patch("authentik.enterprise.license.LicenseKey.get_total", return_value=license_key):
importer = Importer.from_string("""
version: 1
entries:
- identifiers:
name: enterprise-test
slug: enterprise-test
model: authentik_flows.flow
conditions:
- !Context goauthentik.io/enterprise/licensed
attrs:
designation: stage_configuration
title: foo
""")
self.assertIs(importer.blueprint.context["goauthentik.io/enterprise/licensed"], False)
def test_enterprise_license_context_licensed(self):
"""Test enterprise license context defaults to a true boolean when licensed."""
license_key = LicenseKey("test", 253402300799, "Test license", 1000, 1000)
with patch("authentik.enterprise.license.LicenseKey.get_total", return_value=license_key):
importer = Importer.from_string("""
version: 1
entries:
- identifiers:
name: enterprise-test
slug: enterprise-test
model: authentik_flows.flow
conditions:
- !Context goauthentik.io/enterprise/licensed
attrs:
designation: stage_configuration
title: foo
""")
self.assertIs(importer.blueprint.context["goauthentik.io/enterprise/licensed"], True)

View File

@@ -146,7 +146,9 @@ class Importer:
try:
from authentik.enterprise.license import LicenseKey
context["goauthentik.io/enterprise/licensed"] = LicenseKey.get_total().status().is_valid
context["goauthentik.io/enterprise/licensed"] = (
LicenseKey.get_total().status().is_valid,
)
except ModuleNotFoundError:
pass
return context

View File

@@ -64,7 +64,6 @@ class BrandSerializer(ModelSerializer):
"flow_unenrollment",
"flow_user_settings",
"flow_device_code",
"flow_lockdown",
"default_application",
"web_certificate",
"client_certificates",
@@ -118,7 +117,6 @@ class CurrentBrandSerializer(PassiveSerializer):
flow_unenrollment = CharField(source="flow_unenrollment.slug", required=False)
flow_user_settings = CharField(source="flow_user_settings.slug", required=False)
flow_device_code = CharField(source="flow_device_code.slug", required=False)
flow_lockdown = CharField(source="flow_lockdown.slug", required=False)
default_locale = CharField(read_only=True)
flags = SerializerMethodField()
@@ -156,7 +154,6 @@ class BrandViewSet(UsedByMixin, ModelViewSet):
"flow_unenrollment",
"flow_user_settings",
"flow_device_code",
"flow_lockdown",
"web_certificate",
"client_certificates",
]

View File

@@ -1,25 +0,0 @@
# Generated by Django 5.2.12 on 2026-03-14 02:58
import django.db.models.deletion
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
("authentik_brands", "0011_alter_brand_branding_default_flow_background_and_more"),
("authentik_flows", "0031_alter_flow_layout"),
]
operations = [
migrations.AddField(
model_name="brand",
name="flow_lockdown",
field=models.ForeignKey(
null=True,
on_delete=django.db.models.deletion.SET_NULL,
related_name="brand_lockdown",
to="authentik_flows.flow",
),
),
]

View File

@@ -58,9 +58,6 @@ class Brand(SerializerModel):
flow_device_code = models.ForeignKey(
Flow, null=True, on_delete=models.SET_NULL, related_name="brand_device_code"
)
flow_lockdown = models.ForeignKey(
Flow, null=True, on_delete=models.SET_NULL, related_name="brand_lockdown"
)
default_application = models.ForeignKey(
"authentik_core.Application",

View File

@@ -20,16 +20,11 @@ class TestBrands(APITestCase):
def setUp(self):
super().setUp()
self.default_flags = {}
for flag in Flag.available(visibility="public"):
self.default_flags[flag().key] = flag.get()
Brand.objects.all().delete()
@property
def default_flags(self) -> dict[str, object]:
"""Get current public flags.
Some tests define temporary Flag subclasses, so this can't be cached in setUp.
"""
return {flag().key: flag.get() for flag in Flag.available(visibility="public")}
def test_current_brand(self):
"""Test Current brand API"""
brand = create_test_brand()

View File

@@ -66,4 +66,5 @@ def context_processor(request: HttpRequest) -> dict[str, Any]:
"footer_links": tenant.footer_links,
"html_meta": {**get_http_meta()},
"version": authentik_full_version(),
"csp_nonce": request.request_id,
}

View File

@@ -30,8 +30,6 @@ SAML_BINDING_REDIRECT = "urn:oasis:names:tc:SAML:2.0:bindings:HTTP-Redirect"
SAML_STATUS_SUCCESS = "urn:oasis:names:tc:SAML:2.0:status:Success"
DEFAULT_ISSUER = "authentik"
DSA_SHA1 = "http://www.w3.org/2000/09/xmldsig#dsa-sha1"
RSA_SHA1 = "http://www.w3.org/2000/09/xmldsig#rsa-sha1"
# https://datatracker.ietf.org/doc/html/rfc4051#section-2.3.2

View File

@@ -47,8 +47,7 @@ class ApplicationEntitlementViewSet(UsedByMixin, ModelViewSet):
search_fields = [
"pbm_uuid",
"name",
"app__name",
"app__slug",
"app",
"attributes",
]
filterset_fields = [

View File

@@ -4,7 +4,7 @@ from collections.abc import Iterator
from copy import copy
from django.core.cache import cache
from django.db.models import Case, QuerySet
from django.db.models import Case, Q, QuerySet
from django.db.models.expressions import When
from django.shortcuts import get_object_or_404
from django.utils.translation import gettext as _
@@ -120,7 +120,6 @@ class ApplicationSerializer(ModelSerializer):
"meta_publisher",
"policy_engine_mode",
"group",
"meta_hide",
]
extra_kwargs = {
"backchannel_providers": {"required": False},
@@ -284,12 +283,14 @@ class ApplicationViewSet(UsedByMixin, ModelViewSet):
) == "true"
queryset = self._filter_queryset_for_list(self.get_queryset())
queryset = queryset.exclude(meta_hide=True)
if only_with_launch_url:
# Pre-filter at DB level to skip expensive per-app policy evaluation
# for apps that can never appear in the launcher (no meta_launch_url
# and no provider, so no possible launch URL).
queryset = queryset.exclude(meta_launch_url="", provider__isnull=True)
# for apps that can never appear in the launcher:
# - No meta_launch_url AND no provider: no possible launch URL
# - meta_launch_url="blank://blank": documented convention to hide from launcher
queryset = queryset.exclude(
Q(meta_launch_url="", provider__isnull=True) | Q(meta_launch_url="blank://blank")
)
paginator: Pagination = self.paginator
paginated_apps = paginator.paginate_queryset(queryset, request)

View File

@@ -32,19 +32,19 @@ from authentik.rbac.decorators import permission_required
class UserAgentDeviceDict(TypedDict):
"""User agent device"""
brand: str | None = None
brand: str
family: str
model: str | None = None
model: str
class UserAgentOSDict(TypedDict):
"""User agent os"""
family: str
major: str | None = None
minor: str | None = None
patch: str | None = None
patch_minor: str | None = None
major: str
minor: str
patch: str
patch_minor: str
class UserAgentBrowserDict(TypedDict):

View File

@@ -19,7 +19,7 @@ from rest_framework.authentication import SessionAuthentication
from rest_framework.decorators import action
from rest_framework.fields import CharField, IntegerField, SerializerMethodField
from rest_framework.permissions import IsAuthenticated
from rest_framework.relations import ManyRelatedField, PrimaryKeyRelatedField
from rest_framework.relations import PrimaryKeyRelatedField
from rest_framework.request import Request
from rest_framework.response import Response
from rest_framework.serializers import ListSerializer, ValidationError
@@ -37,77 +37,6 @@ from authentik.endpoints.connectors.agent.auth import AgentAuth
from authentik.rbac.api.roles import RoleSerializer
from authentik.rbac.decorators import permission_required
class BulkManyRelatedField(ManyRelatedField):
"""ManyRelatedField that validates all PKs in a single query instead of one per PK."""
def to_internal_value(self, data):
if isinstance(data, str) or not hasattr(data, "__iter__"):
self.fail("not_a_list", input_type=type(data).__name__)
if not self.allow_empty and len(data) == 0:
self.fail("empty")
child = self.child_relation
pk_field = child.pk_field
# Coerce PKs through pk_field if defined
pk_map = {}
for item in data:
if isinstance(item, bool):
self.fail("incorrect_type", data_type=type(item).__name__)
pk = pk_field.to_internal_value(item) if pk_field else item
pk_map[pk] = item # map coerced PK -> original value for error reporting
queryset = child.get_queryset()
# Use count to validate all PKs exist in a single query
found_count = queryset.filter(pk__in=pk_map.keys()).count()
if found_count < len(pk_map):
# Some PKs not found — fall back to per-PK checks for error reporting.
# This only runs when there's an actual validation error (rare path).
for pk, original in pk_map.items():
if not queryset.filter(pk=pk).exists():
child.fail("does_not_exist", pk_value=original)
# Return raw PKs — Django's M2M set() accepts both objects and PKs,
# using get_prep_value() for type coercion. This avoids loading all
# objects into memory and avoids triggering post_init signals.
return list(pk_map.keys())
def to_representation(self, iterable):
# For non-prefetched querysets, get PKs directly without loading model instances.
# When prefetched, _result_cache is a list (possibly empty); when not, it's None.
if hasattr(iterable, "values_list") and getattr(iterable, "_result_cache", None) is None:
return list(iterable.values_list("pk", flat=True))
return super().to_representation(iterable)
class BulkPrimaryKeyRelatedField(PrimaryKeyRelatedField):
"""PrimaryKeyRelatedField that uses bulk validation when many=True."""
@classmethod
def many_init(cls, *args, **kwargs):
allow_empty = kwargs.pop("allow_empty", None)
max_length = kwargs.pop("max_length", None)
min_length = kwargs.pop("min_length", None)
child_relation = cls(*args, **kwargs)
list_kwargs = {
"child_relation": child_relation,
}
if allow_empty is not None:
list_kwargs["allow_empty"] = allow_empty
if max_length is not None:
list_kwargs["max_length"] = max_length
if min_length is not None:
list_kwargs["min_length"] = min_length
list_kwargs.update(
{
key: value
for key, value in kwargs.items()
if key in ("required", "default", "source")
}
)
return BulkManyRelatedField(**list_kwargs)
PARTIAL_USER_SERIALIZER_MODEL_FIELDS = [
"pk",
"username",
@@ -150,7 +79,6 @@ class GroupSerializer(ModelSerializer):
"""Group Serializer"""
attributes = JSONDictField(required=False)
users = BulkPrimaryKeyRelatedField(queryset=User.objects.all(), many=True, default=list)
parents = PrimaryKeyRelatedField(queryset=Group.objects.all(), many=True, required=False)
parents_obj = SerializerMethodField(allow_null=True)
children_obj = SerializerMethodField(allow_null=True)
@@ -265,6 +193,9 @@ class GroupSerializer(ModelSerializer):
"children_obj",
]
extra_kwargs = {
"users": {
"default": list,
},
"children": {
"required": False,
"default": list,
@@ -294,7 +225,6 @@ class GroupFilter(FilterSet):
members_by_pk = ModelMultipleChoiceFilter(
field_name="users",
queryset=User.objects.all(),
distinct=False,
)
def filter_attributes(self, queryset, name, value):
@@ -346,8 +276,7 @@ class GroupViewSet(UsedByMixin, ModelViewSet):
]
def get_queryset(self):
# Always prefetch parents and children since their PKs are always serialized
base_qs = Group.objects.all().prefetch_related("roles", "parents", "children")
base_qs = Group.objects.all().prefetch_related("roles")
if self.serializer_class(context={"request": self.request})._should_include_users:
# Only fetch fields needed by PartialUserSerializer to reduce DB load and instantiation
@@ -358,9 +287,16 @@ class GroupViewSet(UsedByMixin, ModelViewSet):
queryset=User.objects.all().only(*PARTIAL_USER_SERIALIZER_MODEL_FIELDS),
)
)
# When include_users=false, skip users prefetch entirely.
# BulkManyRelatedField.to_representation will use values_list to get PKs
# directly without loading User instances into memory.
else:
base_qs = base_qs.prefetch_related(
Prefetch("users", queryset=User.objects.all().only("id"))
)
if self.serializer_class(context={"request": self.request})._should_include_children:
base_qs = base_qs.prefetch_related("children")
if self.serializer_class(context={"request": self.request})._should_include_parents:
base_qs = base_qs.prefetch_related("parents")
return base_qs

View File

@@ -6,7 +6,6 @@ from typing import Any
from django.contrib.auth import update_session_auth_hash
from django.contrib.auth.models import AnonymousUser, Permission
from django.db.models import Exists, OuterRef, Prefetch, Q
from django.db.transaction import atomic
from django.db.utils import IntegrityError
from django.urls import reverse_lazy
@@ -14,7 +13,6 @@ from django.utils.http import urlencode
from django.utils.text import slugify
from django.utils.timezone import now
from django.utils.translation import gettext as _
from django.utils.translation import gettext_lazy
from django_filters.filters import (
BooleanFilter,
CharFilter,
@@ -107,10 +105,6 @@ from authentik.stages.email.utils import TemplateEmailMessage
LOGGER = get_logger()
INVALID_PASSWORD_HASH_MESSAGE = gettext_lazy(
"Invalid password hash format. Must be a valid Django password hash."
)
class ParamUserSerializer(PassiveSerializer):
"""Partial serializer for query parameters to select a user"""
@@ -137,7 +131,7 @@ class PartialGroupSerializer(ModelSerializer):
class UserSerializer(ModelSerializer):
"""User Serializer"""
is_superuser = SerializerMethodField()
is_superuser = BooleanField(read_only=True)
avatar = SerializerMethodField()
attributes = JSONDictField(required=False)
groups = PrimaryKeyRelatedField(
@@ -174,14 +168,6 @@ class UserSerializer(ModelSerializer):
return True
return str(request.query_params.get("include_roles", "true")).lower() == "true"
@extend_schema_field(BooleanField)
def get_is_superuser(self, instance: User) -> bool:
"""Use annotation if available to avoid N+1 query"""
ann = getattr(instance, "_annotated_is_superuser", None)
if ann is not None:
return ann
return instance.is_superuser
@extend_schema_field(PartialGroupSerializer(many=True))
def get_groups_obj(self, instance: User) -> list[PartialGroupSerializer] | None:
if not self._should_include_groups:
@@ -195,79 +181,47 @@ class UserSerializer(ModelSerializer):
return RoleSerializer(instance.roles, many=True).data
def __init__(self, *args, **kwargs):
"""Setting password and permissions directly is allowed only in blueprints."""
super().__init__(*args, **kwargs)
if SERIALIZER_CONTEXT_BLUEPRINT in self.context:
self.fields["password"] = CharField(required=False, allow_null=True)
self.fields["password_hash"] = CharField(required=False, allow_null=True)
self.fields["permissions"] = ListField(
required=False,
child=ChoiceField(choices=get_permission_choices()),
)
def create(self, validated_data: dict) -> User:
"""Create a user, with blueprint-only password and permission writes."""
is_blueprint = SERIALIZER_CONTEXT_BLUEPRINT in self.context
if is_blueprint:
password = validated_data.pop("password", None)
password_hash = validated_data.pop("password_hash", None)
permissions = validated_data.pop("permissions", [])
self._validate_password_inputs(password, password_hash)
"""If this serializer is used in the blueprint context, we allow for
directly setting a password. However should be done via the `set_password`
method instead of directly setting it like rest_framework."""
password = validated_data.pop("password", None)
perms_qs = Permission.objects.filter(
codename__in=[x.split(".")[1] for x in validated_data.pop("permissions", [])]
).values_list("content_type__app_label", "codename")
perms_list = [f"{ct}.{name}" for ct, name in list(perms_qs)]
instance: User = super().create(validated_data)
if is_blueprint:
self._set_password(instance, password, password_hash)
perms_qs = Permission.objects.filter(
codename__in=[permission.split(".")[1] for permission in permissions]
).values_list("content_type__app_label", "codename")
perms_list = [f"{ct}.{name}" for ct, name in perms_qs]
instance.assign_perms_to_managed_role(perms_list)
self._ensure_password_not_empty(instance)
self._set_password(instance, password)
instance.assign_perms_to_managed_role(perms_list)
return instance
def update(self, instance: User, validated_data: dict) -> User:
"""Update a user, with blueprint-only password and permission writes."""
is_blueprint = SERIALIZER_CONTEXT_BLUEPRINT in self.context
if is_blueprint:
password = validated_data.pop("password", None)
password_hash = validated_data.pop("password_hash", None)
permissions = validated_data.pop("permissions", [])
self._validate_password_inputs(password, password_hash)
"""Same as `create` above, set the password directly if we're in a blueprint
context"""
password = validated_data.pop("password", None)
perms_qs = Permission.objects.filter(
codename__in=[x.split(".")[1] for x in validated_data.pop("permissions", [])]
).values_list("content_type__app_label", "codename")
perms_list = [f"{ct}.{name}" for ct, name in list(perms_qs)]
instance = super().update(instance, validated_data)
if is_blueprint:
self._set_password(instance, password, password_hash)
perms_qs = Permission.objects.filter(
codename__in=[permission.split(".")[1] for permission in permissions]
).values_list("content_type__app_label", "codename")
perms_list = [f"{ct}.{name}" for ct, name in perms_qs]
instance.assign_perms_to_managed_role(perms_list)
self._ensure_password_not_empty(instance)
self._set_password(instance, password)
instance.assign_perms_to_managed_role(perms_list)
return instance
def _validate_password_inputs(self, password: str | None, password_hash: str | None):
"""Validate mutually-exclusive password inputs before any model mutation."""
if password is not None and password_hash is not None:
raise ValidationError(_("Cannot set both password and password_hash. Use only one."))
if password_hash is None:
return
try:
User.validate_password_hash(password_hash)
except ValueError as exc:
LOGGER.warning("Failed to identify password hash format", exc_info=exc)
raise ValidationError(INVALID_PASSWORD_HASH_MESSAGE) from exc
def _set_password(self, instance: User, password: str | None, password_hash: str | None = None):
"""Set password from plain text or hash."""
if password_hash is not None:
instance.set_password_from_hash(password_hash)
instance.save()
elif password:
def _set_password(self, instance: User, password: str | None):
"""Set password of user if we're in a blueprint context, and if it's an empty
string then use an unusable password"""
if SERIALIZER_CONTEXT_BLUEPRINT in self.context and password:
instance.set_password(password)
instance.save()
def _ensure_password_not_empty(self, instance: User):
"""Store an explicit unusable password instead of an empty password field."""
if len(instance.password) == 0:
instance.set_unusable_password()
instance.save()
@@ -436,12 +390,6 @@ class UserPasswordSetSerializer(PassiveSerializer):
password = CharField(required=True)
class UserPasswordHashSetSerializer(PassiveSerializer):
"""Payload to set a users' password hash directly"""
password = CharField(required=True)
class UserServiceAccountSerializer(PassiveSerializer):
"""Payload to create a service account"""
@@ -563,9 +511,6 @@ class UsersFilter(FilterSet):
class UserViewSet(
ConditionalInheritance(
"authentik.enterprise.stages.account_lockdown.api.UserAccountLockdownMixin"
),
ConditionalInheritance("authentik.enterprise.reports.api.reports.ExportMixin"),
UsedByMixin,
ModelViewSet,
@@ -596,30 +541,10 @@ class UserViewSet(
def get_queryset(self):
base_qs = User.objects.all().exclude_anonymous()
# Always prefetch groups since group PKs are always serialized.
# Use full prefetch when include_groups=true (for groups_obj), ID-only otherwise.
if self.serializer_class(context={"request": self.request})._should_include_groups:
base_qs = base_qs.prefetch_related("groups")
else:
base_qs = base_qs.prefetch_related(
Prefetch("groups", queryset=Group.objects.all().only("group_uuid"))
)
if self.serializer_class(context={"request": self.request})._should_include_roles:
base_qs = base_qs.prefetch_related("roles")
else:
base_qs = base_qs.prefetch_related(
Prefetch("roles", queryset=Role.objects.all().only("uuid"))
)
# Annotate is_superuser to avoid N+1 query per user
base_qs = base_qs.annotate(
_annotated_is_superuser=Exists(
Group.objects.filter(
is_superuser=True,
).filter(
Q(users=OuterRef("pk")) | Q(descendant_nodes__descendant__users=OuterRef("pk"))
)
)
)
return base_qs
@extend_schema(
@@ -788,11 +713,6 @@ class UserViewSet(
self.request.session.modified = True
return Response(serializer.initial_data)
def _update_session_hash_after_password_change(self, request: Request, user: User):
if user.pk == request.user.pk and SESSION_KEY_IMPERSONATE_USER not in self.request.session:
LOGGER.debug("Updating session hash after password change")
update_session_auth_hash(self.request, user)
@permission_required("authentik_core.reset_user_password")
@extend_schema(
request=UserPasswordSetSerializer,
@@ -816,45 +736,9 @@ class UserViewSet(
except (ValidationError, IntegrityError) as exc:
LOGGER.debug("Failed to set password", exc=exc)
return Response(status=400)
self._update_session_hash_after_password_change(request, user)
return Response(status=204)
@permission_required("authentik_core.reset_user_password")
@extend_schema(
request=UserPasswordHashSetSerializer,
responses={
204: OpenApiResponse(description="Successfully changed password"),
400: OpenApiResponse(description="Bad request"),
},
)
@action(
detail=True,
methods=["POST"],
permission_classes=[IsAuthenticated],
)
@validate(UserPasswordHashSetSerializer)
def set_password_hash(
self, request: Request, pk: int, body: UserPasswordHashSetSerializer
) -> Response:
"""Set a user's password from a pre-hashed Django password value.
Submit the Django password hash in the shared ``password`` request field.
This updates authentik's local password verifier only. It does not attempt
to propagate the password change to LDAP or Kerberos because no raw password
is available from the request payload.
"""
user: User = self.get_object()
try:
user.set_password_from_hash(body.validated_data["password"], request=request)
user.save()
except ValueError as exc:
LOGGER.debug("Failed to set password hash", exc=exc)
return Response(data={"password": [INVALID_PASSWORD_HASH_MESSAGE]}, status=400)
except (ValidationError, IntegrityError) as exc:
LOGGER.debug("Failed to set password hash", exc=exc)
return Response(status=400)
self._update_session_hash_after_password_change(request, user)
if user.pk == request.user.pk and SESSION_KEY_IMPERSONATE_USER not in self.request.session:
LOGGER.debug("Updating session hash after password change")
update_session_auth_hash(self.request, user)
return Response(status=204)
@permission_required("authentik_core.reset_user_password")

View File

@@ -1,28 +0,0 @@
"""Hash password using Django's password hashers"""
from django.contrib.auth.hashers import make_password
from django.core.management.base import BaseCommand, CommandError
class Command(BaseCommand):
"""Hash a password using Django's password hashers"""
help = "Hash a password for use with AUTHENTIK_BOOTSTRAP_PASSWORD_HASH"
def add_arguments(self, parser):
parser.add_argument(
"password",
type=str,
help="Password to hash",
)
def handle(self, *args, **options):
password = options["password"]
if not password:
raise CommandError("Password cannot be empty")
try:
hashed = make_password(password)
self.stdout.write(hashed)
except ValueError as exc:
raise CommandError(f"Error hashing password: {exc}") from exc

View File

@@ -1,33 +0,0 @@
# Generated by Django 5.2.12 on 2026-04-09 18:04
from django.apps.registry import Apps
from django.db import migrations, models
from django.db.backends.base.schema import BaseDatabaseSchemaEditor
def migrate_blank_launch_url(apps: Apps, schema_editor: BaseDatabaseSchemaEditor):
db_alias = schema_editor.connection.alias
Application = apps.get_model("authentik_core", "Application")
Application.objects.using(db_alias).filter(meta_launch_url="blank://blank").update(
meta_hide=True, meta_launch_url=""
)
class Migration(migrations.Migration):
dependencies = [
("authentik_core", "0058_setup"),
]
operations = [
migrations.AddField(
model_name="application",
name="meta_hide",
field=models.BooleanField(
default=False,
help_text="Hide this application from the user's My applications page.",
),
),
migrations.RunPython(migrate_blank_launch_url, migrations.RunPython.noop),
]

View File

@@ -10,7 +10,7 @@ from uuid import uuid4
import pgtrigger
from deepmerge import always_merger
from django.contrib.auth.hashers import check_password, identify_hasher
from django.contrib.auth.hashers import check_password
from django.contrib.auth.models import AbstractUser, Permission
from django.contrib.auth.models import UserManager as DjangoUserManager
from django.contrib.sessions.base_session import AbstractBaseSession
@@ -560,33 +560,6 @@ class User(SerializerModel, AttributesMixin, AbstractUser):
self.password_change_date = now()
return super().set_password(raw_password)
@staticmethod
def validate_password_hash(password_hash: str):
"""Validate that the value is a recognized Django password hash."""
identify_hasher(password_hash) # Raises ValueError if invalid
def set_password_from_hash(self, password_hash: str, signal=True, sender=None, request=None):
"""Set password directly from a pre-hashed value.
Unlike set_password(), this does not hash the input again. The provided value
must already be a valid Django password hash, and it is stored directly on the
user after validation.
Because no raw password is available, downstream password sync integrations
such as LDAP and Kerberos cannot be updated from this code path.
Raises ValueError if the hash format is not recognized.
"""
self.validate_password_hash(password_hash)
if self.pk and signal:
from authentik.core.signals import password_hash_changed
if not sender:
sender = self
password_hash_changed.send(sender=sender, user=self, request=request)
self.password = password_hash
self.password_change_date = now()
def check_password(self, raw_password: str) -> bool:
"""
Return a boolean of whether the raw_password was correct. Handles
@@ -762,9 +735,6 @@ class Application(SerializerModel, PolicyBindingModel):
meta_icon = FileField(default="", blank=True)
meta_description = models.TextField(default="", blank=True)
meta_publisher = models.TextField(default="", blank=True)
meta_hide = models.BooleanField(
default=False, help_text=_("Hide this application from the user's My applications page.")
)
objects = ApplicationQuerySet.as_manager()

View File

@@ -16,11 +16,7 @@ LOGGER = get_logger()
@receiver(post_startup)
def post_startup_setup_bootstrap(sender, **_):
if (
not getenv("AUTHENTIK_BOOTSTRAP_PASSWORD")
and not getenv("AUTHENTIK_BOOTSTRAP_PASSWORD_HASH")
and not getenv("AUTHENTIK_BOOTSTRAP_TOKEN")
):
if not getenv("AUTHENTIK_BOOTSTRAP_PASSWORD") and not getenv("AUTHENTIK_BOOTSTRAP_TOKEN"):
return
LOGGER.info("Configuring authentik through bootstrap environment variables")
content = BlueprintInstance(path=BOOTSTRAP_BLUEPRINT).retrieve()

View File

@@ -1,5 +1,6 @@
"""authentik core signals"""
from asgiref.sync import async_to_sync
from channels.layers import get_channel_layer
from django.contrib.auth.signals import user_logged_in
from django.core.cache import cache
@@ -23,8 +24,6 @@ from authentik.root.ws.consumer import build_device_group
# Arguments: user: User, password: str
password_changed = Signal()
# Arguments: user: User, request: HttpRequest | None
password_hash_changed = Signal()
# Arguments: credentials: dict[str, any], request: HttpRequest,
# stage: Stage, context: dict[str, any]
login_failed = Signal()
@@ -58,7 +57,7 @@ def user_logged_in_session(sender, request: HttpRequest, user: User, **_):
layer = get_channel_layer()
device_cookie = request.COOKIES.get("authentik_device")
if device_cookie:
layer.group_send_blocking(
async_to_sync(layer.group_send)(
build_device_group(device_cookie),
{"type": "event.session.authenticated"},
)

View File

@@ -1,7 +1,7 @@
{% load i18n %}
{% get_current_language as LANGUAGE_CODE %}
<script data-id="authentik-config">
<script data-id="authentik-config" nonce="{{ csp_nonce }}">
"use strict";
window.authentik = {

View File

@@ -14,6 +14,7 @@
<meta name="viewport" content="width=device-width, initial-scale=1, maximum-scale=1">
{# Darkreader breaks the site regardless of theme as its not compatible with webcomponents, and we default to a dark theme based on preferred colour-scheme #}
<meta name="darkreader-lock">
<script nonce="{{ csp_nonce }}">window.litNonce = "{{ csp_nonce }}";</script>
<title>{% block title %}{% trans title|default:brand.branding_title %}{% endblock %}</title>
<link rel="icon" href="{{ brand.branding_favicon_url }}">
<link rel="shortcut icon" href="{{ brand.branding_favicon_url }}">
@@ -27,7 +28,7 @@
{% include "base/theme.html" %}
<style data-id="brand-css">{{ brand_css }}</style>
<style data-id="brand-css" nonce="{{ csp_nonce }}">{{ brand_css }}</style>
<script src="{% versioned_script 'dist/poly-%v.js' %}" type="module"></script>
{% block head %}
{% endblock %}

View File

@@ -9,7 +9,7 @@
<meta name="color-scheme" content="light" />
<meta name="theme-color" content="#ffffff">
{% else %}
<script data-id="theme-script">
<script data-id="theme-script" nonce="{{ csp_nonce }}">
"use strict";
(function () {
@@ -22,7 +22,7 @@
if (!(["auto", "light", "dark"].includes(locallyStoredTheme))) {
locallyStoredTheme = null;
}
const initialThemeChoice =
new URLSearchParams(window.location.search).get("theme") || locallyStoredTheme;

View File

@@ -10,9 +10,9 @@
{% endblock %}
{% block head %}
<style data-id="static-styles">
<style data-id="static-styles" nonce="{{ csp_nonce }}">
:root {
--ak-global--background-image: url("{{ request.brand.branding_default_flow_background_url|iriencode|safe }}");
--ak-global--background-image: url("{{ request.brand.branding_default_flow_background_url }}");
}
</style>

View File

@@ -129,7 +129,6 @@ class TestApplicationsAPI(APITestCase):
"meta_icon_url": None,
"meta_icon_themed_urls": None,
"meta_description": "",
"meta_hide": False,
"meta_publisher": "",
"policy_engine_mode": "any",
},
@@ -188,14 +187,12 @@ class TestApplicationsAPI(APITestCase):
"meta_icon_url": None,
"meta_icon_themed_urls": None,
"meta_description": "",
"meta_hide": False,
"meta_publisher": "",
"policy_engine_mode": "any",
},
{
"launch_url": None,
"meta_description": "",
"meta_hide": False,
"meta_icon": "",
"meta_icon_url": None,
"meta_icon_themed_urls": None,

View File

@@ -1,28 +0,0 @@
"""Tests for hash_password management command."""
from io import StringIO
from django.contrib.auth.hashers import check_password
from django.core.management import call_command
from django.core.management.base import CommandError
from django.test import TestCase
class TestHashPasswordCommand(TestCase):
"""Test hash_password management command."""
def test_hash_password(self):
"""Test hashing a password."""
out = StringIO()
call_command("hash_password", "test123", stdout=out)
hashed = out.getvalue().strip()
self.assertTrue(hashed.startswith("pbkdf2_sha256$"))
self.assertTrue(check_password("test123", hashed))
def test_hash_password_empty_fails(self):
"""Test that empty password raises error."""
with self.assertRaises(CommandError) as ctx:
call_command("hash_password", "")
self.assertIn("Password cannot be empty", str(ctx.exception))

View File

@@ -1,7 +1,6 @@
from http import HTTPStatus
from os import environ
from django.contrib.auth.hashers import make_password
from django.urls import reverse
from authentik.blueprints.tests import apply_blueprint
@@ -17,7 +16,6 @@ from authentik.tenants.flags import patch_flag
class TestSetup(FlowTestCase):
def tearDown(self):
environ.pop("AUTHENTIK_BOOTSTRAP_PASSWORD", None)
environ.pop("AUTHENTIK_BOOTSTRAP_PASSWORD_HASH", None)
environ.pop("AUTHENTIK_BOOTSTRAP_TOKEN", None)
@patch_flag(Setup, True)
@@ -156,19 +154,3 @@ class TestSetup(FlowTestCase):
token = Token.objects.filter(identifier="authentik-bootstrap-token").first()
self.assertEqual(token.intent, TokenIntents.INTENT_API)
self.assertEqual(token.key, environ["AUTHENTIK_BOOTSTRAP_TOKEN"])
def test_setup_bootstrap_env_password_hash(self):
"""Test setup with password hash env var"""
User.objects.filter(username="akadmin").delete()
Setup.set(False)
password = generate_id()
password_hash = make_password(password)
environ["AUTHENTIK_BOOTSTRAP_PASSWORD_HASH"] = password_hash
pre_startup.send(sender=self)
post_startup.send(sender=self)
self.assertTrue(Setup.get())
user = User.objects.get(username="akadmin")
self.assertEqual(user.password, password_hash)
self.assertTrue(user.check_password(password))

View File

@@ -1,15 +1,8 @@
"""user tests"""
from unittest.mock import patch
from django.contrib.auth.hashers import make_password
from django.test.testcases import TestCase
from rest_framework.exceptions import ValidationError
from authentik.blueprints.v1.importer import SERIALIZER_CONTEXT_BLUEPRINT
from authentik.core.api.users import UserSerializer
from authentik.core.models import User
from authentik.core.signals import password_changed, password_hash_changed
from authentik.events.models import Event
from authentik.lib.generators import generate_id
@@ -40,99 +33,3 @@ class TestUsers(TestCase):
self.assertEqual(Event.objects.count(), 1)
user.ak_groups.all()
self.assertEqual(Event.objects.count(), 1)
def test_set_password_from_hash_signal_skips_source_sync_receivers(self):
"""Test hash password updates do not expose a raw password to sync receivers."""
user = User.objects.create(
username=generate_id(),
attributes={"distinguishedName": "cn=test,ou=users,dc=example,dc=com"},
)
password_changed_captured = []
password_hash_changed_captured = []
dispatch_uid = generate_id()
hash_dispatch_uid = generate_id()
def password_changed_receiver(sender, **kwargs):
password_changed_captured.append(kwargs)
def password_hash_changed_receiver(sender, **kwargs):
password_hash_changed_captured.append(kwargs)
password_changed.connect(password_changed_receiver, dispatch_uid=dispatch_uid)
password_hash_changed.connect(
password_hash_changed_receiver, dispatch_uid=hash_dispatch_uid
)
try:
with (
patch(
"authentik.sources.ldap.signals.LDAPSource.objects.filter"
) as ldap_sources_filter,
patch(
"authentik.sources.kerberos.signals."
"UserKerberosSourceConnection.objects.select_related"
) as kerberos_connections_select,
):
user.set_password_from_hash(make_password("new-password")) # nosec
user.save()
finally:
password_changed.disconnect(dispatch_uid=dispatch_uid)
password_hash_changed.disconnect(dispatch_uid=hash_dispatch_uid)
self.assertEqual(password_changed_captured, [])
self.assertEqual(len(password_hash_changed_captured), 1)
ldap_sources_filter.assert_not_called()
kerberos_connections_select.assert_not_called()
class TestUserSerializerPasswordHash(TestCase):
"""Test UserSerializer password_hash support in blueprint context."""
def test_password_hash_sets_password_directly(self):
"""Test a valid password hash is stored without re-hashing."""
password = "test-password-123" # nosec
password_hash = make_password(password)
serializer = UserSerializer(
data={
"username": generate_id(),
"name": "Test User",
"password_hash": password_hash,
},
context={SERIALIZER_CONTEXT_BLUEPRINT: True},
)
self.assertTrue(serializer.is_valid(), serializer.errors)
user = serializer.save()
self.assertEqual(user.password, password_hash)
self.assertTrue(user.check_password(password))
self.assertIsNotNone(user.password_change_date)
def test_password_hash_rejects_invalid_format(self):
"""Test invalid password hash values are rejected."""
serializer = UserSerializer(
data={
"username": generate_id(),
"name": "Test User",
"password_hash": "not-a-valid-hash",
},
context={SERIALIZER_CONTEXT_BLUEPRINT: True},
)
self.assertTrue(serializer.is_valid(), serializer.errors)
with self.assertRaises(ValidationError) as ctx:
serializer.save()
self.assertIn("Invalid password hash format", str(ctx.exception))
def test_password_hash_ignored_outside_blueprint_context(self):
"""Test password_hash is not accepted by the regular serializer."""
serializer = UserSerializer(
data={
"username": generate_id(),
"name": "Test User",
"password_hash": make_password("test"), # nosec
}
)
self.assertTrue(serializer.is_valid(), serializer.errors)
self.assertNotIn("password_hash", serializer.validated_data)

View File

@@ -3,7 +3,6 @@
from datetime import datetime, timedelta
from json import loads
from django.contrib.auth.hashers import make_password
from django.urls.base import reverse
from django.utils.timezone import now
from rest_framework.test import APITestCase
@@ -27,9 +26,6 @@ from authentik.flows.models import FlowAuthenticationRequirement, FlowDesignatio
from authentik.lib.generators import generate_id, generate_key
from authentik.stages.email.models import EmailStage
INVALID_PASSWORD_HASH = "not-a-valid-hash"
INVALID_PASSWORD_HASH_ERROR = "Invalid password hash format. Must be a valid Django password hash."
class TestUsersAPI(APITestCase):
"""Test Users API"""
@@ -38,20 +34,6 @@ class TestUsersAPI(APITestCase):
self.admin = create_test_admin_user()
self.user = create_test_user()
def _set_password_hash(self, user: User, password_hash: str, client=None):
return (client or self.client).post(
reverse("authentik_api:user-set-password-hash", kwargs={"pk": user.pk}),
data={"password": password_hash},
)
def _assert_password_hash_set(
self, user: User, password: str, password_hash: str, response
) -> None:
self.assertEqual(response.status_code, 204, response.data)
user.refresh_from_db()
self.assertEqual(user.password, password_hash)
self.assertTrue(user.check_password(password))
def test_filter_type(self):
"""Test API filtering by type"""
self.client.force_login(self.admin)
@@ -131,26 +113,6 @@ class TestUsersAPI(APITestCase):
self.assertEqual(response.status_code, 400)
self.assertJSONEqual(response.content, {"password": ["This field may not be blank."]})
def test_set_password_hash(self):
"""Test setting a user's password from a hash."""
self.client.force_login(self.admin)
password = generate_key()
password_hash = make_password(password)
response = self._set_password_hash(self.user, password_hash)
self._assert_password_hash_set(self.user, password, password_hash, response)
def test_set_password_hash_invalid(self):
"""Test invalid password hashes are rejected."""
self.client.force_login(self.admin)
response = self._set_password_hash(self.user, INVALID_PASSWORD_HASH)
self.assertEqual(response.status_code, 400)
self.assertJSONEqual(
response.content,
{"password": [INVALID_PASSWORD_HASH_ERROR]},
)
def test_recovery(self):
"""Test user recovery link"""
flow = create_test_flow(
@@ -299,29 +261,6 @@ class TestUsersAPI(APITestCase):
self.assertTrue(token_filter.exists())
self.assertTrue(token_filter.first().expiring)
def test_service_account_set_password_hash(self):
"""Service account password hash can be set through the API."""
self.client.force_login(self.admin)
response = self.client.post(
reverse("authentik_api:user-service-account"),
data={
"name": "test-sa",
"create_group": False,
},
)
self.assertEqual(response.status_code, 200, response.data)
body = loads(response.content)
user = User.objects.get(pk=body["user_pk"])
self.assertEqual(user.type, UserTypes.SERVICE_ACCOUNT)
self.assertFalse(user.has_usable_password())
password = generate_key()
password_hash = make_password(password)
response = self._set_password_hash(user, password_hash)
self._assert_password_hash_set(user, password, password_hash, response)
def test_service_account_no_expire(self):
"""Service account creation without token expiration"""
self.client.force_login(self.admin)

View File

@@ -1,6 +1,7 @@
from datetime import datetime
from django.db.models import Exists, OuterRef, Q, Subquery
from django.db.models import BooleanField as ModelBooleanField
from django.db.models import Case, Q, Value, When
from django_filters.rest_framework import BooleanFilter, FilterSet
from drf_spectacular.utils import extend_schema
from rest_framework.decorators import action
@@ -13,7 +14,7 @@ from rest_framework.viewsets import GenericViewSet
from authentik.core.api.utils import ModelSerializer
from authentik.enterprise.api import EnterpriseRequiredMixin
from authentik.enterprise.lifecycle.api.reviews import ReviewSerializer
from authentik.enterprise.lifecycle.models import LifecycleIteration, LifecycleRule, ReviewState
from authentik.enterprise.lifecycle.models import LifecycleIteration, ReviewState
from authentik.enterprise.lifecycle.utils import (
ContentTypeField,
ReviewerGroupSerializer,
@@ -25,25 +26,20 @@ from authentik.enterprise.lifecycle.utils import (
from authentik.lib.utils.time import timedelta_from_string
class RelatedRuleSerializer(EnterpriseRequiredMixin, ModelSerializer):
reviewer_groups = ReviewerGroupSerializer(many=True, read_only=True)
min_reviewers = IntegerField(read_only=True)
reviewers = ReviewerUserSerializer(many=True, read_only=True)
class Meta:
model = LifecycleRule
fields = ["id", "name", "reviewer_groups", "min_reviewers", "reviewers"]
class LifecycleIterationSerializer(EnterpriseRequiredMixin, ModelSerializer):
content_type = ContentTypeField()
object_verbose = SerializerMethodField()
rule = RelatedRuleSerializer(read_only=True)
object_admin_url = SerializerMethodField(read_only=True)
grace_period_end = SerializerMethodField(read_only=True)
reviews = ReviewSerializer(many=True, read_only=True, source="review_set.all")
user_can_review = SerializerMethodField(read_only=True)
reviewer_groups = ReviewerGroupSerializer(
many=True, read_only=True, source="rule.reviewer_groups"
)
min_reviewers = IntegerField(read_only=True, source="rule.min_reviewers")
reviewers = ReviewerUserSerializer(many=True, read_only=True, source="rule.reviewers")
next_review_date = SerializerMethodField(read_only=True)
class Meta:
@@ -59,8 +55,10 @@ class LifecycleIterationSerializer(EnterpriseRequiredMixin, ModelSerializer):
"grace_period_end",
"next_review_date",
"reviews",
"rule",
"user_can_review",
"reviewer_groups",
"min_reviewers",
"reviewers",
]
read_only_fields = fields
@@ -90,55 +88,43 @@ class IterationViewSet(EnterpriseRequiredMixin, CreateModelMixin, GenericViewSet
queryset = LifecycleIteration.objects.all()
serializer_class = LifecycleIterationSerializer
ordering = ["-opened_on"]
ordering_fields = [
"state",
"content_type__model",
"rule__name",
"opened_on",
"grace_period_end",
]
ordering_fields = ["state", "content_type__model", "opened_on", "grace_period_end"]
filterset_class = LifecycleIterationFilterSet
def get_queryset(self):
user = self.request.user
return self.queryset.annotate(
user_is_reviewer=Exists(
LifecycleRule.objects.filter(
pk=OuterRef("rule_id"),
).filter(
Q(reviewers=user) | Q(reviewer_groups__in=user.groups.all().with_ancestors())
)
user_is_reviewer=Case(
When(
Q(rule__reviewers=user)
| Q(rule__reviewer_groups__in=user.groups.all().with_ancestors()),
then=Value(True),
),
default=Value(False),
output_field=ModelBooleanField(),
)
)
).distinct()
@extend_schema(
operation_id="lifecycle_iterations_list_latest",
responses={200: LifecycleIterationSerializer(many=True)},
)
@action(
detail=False,
pagination_class=None,
methods=["get"],
url_path=r"latest/(?P<content_type>[^/]+)/(?P<object_id>[^/]+)",
)
def latest_iterations(self, request: Request, content_type: str, object_id: str) -> Response:
def latest_iteration(self, request: Request, content_type: str, object_id: str) -> Response:
ct = parse_content_type(content_type)
latest_ids_subquery = (
LifecycleIteration.objects.filter(
rule=OuterRef("rule"),
content_type__app_label=ct["app_label"],
content_type__model=ct["model"],
object_id=object_id,
try:
obj = (
self.get_queryset()
.filter(
content_type__app_label=ct["app_label"],
content_type__model=ct["model"],
object_id=object_id,
)
.latest("opened_on")
)
.order_by("-opened_on")
.values("id")[:1]
)
latest_per_rule = LifecycleIteration.objects.filter(
content_type__app_label=ct["app_label"],
content_type__model=ct["model"],
object_id=object_id,
).filter(id=Subquery(latest_ids_subquery))
serializer = self.get_serializer(latest_per_rule, many=True)
except LifecycleIteration.DoesNotExist:
return Response(status=404)
serializer = self.get_serializer(obj)
return Response(serializer.data)
@extend_schema(

View File

@@ -84,6 +84,23 @@ class LifecycleRuleSerializer(EnterpriseRequiredMixin, ModelSerializer):
raise ValidationError(
{"grace_period": _("Grace period must be shorter than the interval.")}
)
if "content_type" in attrs or "object_id" in attrs:
content_type = attrs.get("content_type", getattr(self.instance, "content_type", None))
object_id = attrs.get("object_id", getattr(self.instance, "object_id", None))
if content_type is not None and object_id is None:
existing = LifecycleRule.objects.filter(
content_type=content_type, object_id__isnull=True
)
if self.instance:
existing = existing.exclude(pk=self.instance.pk)
if existing.exists():
raise ValidationError(
{
"content_type": _(
"Only one type-wide rule for each object type is allowed."
)
}
)
return attrs

View File

@@ -1,21 +0,0 @@
# Generated by Django 5.2.11 on 2026-03-05 11:27
from django.db import migrations
class Migration(migrations.Migration):
dependencies = [
("authentik_lifecycle", "0002_alter_lifecycleiteration_opened_on"),
]
operations = [
migrations.RemoveConstraint(
model_name="lifecyclerule",
name="uniq_lifecycle_rule_ct_null_object",
),
migrations.AlterUniqueTogether(
name="lifecyclerule",
unique_together=set(),
),
]

View File

@@ -56,6 +56,14 @@ class LifecycleRule(SerializerModel):
class Meta:
indexes = [models.Index(fields=["content_type"])]
unique_together = [["content_type", "object_id"]]
constraints = [
models.UniqueConstraint(
fields=["content_type"],
condition=Q(object_id__isnull=True),
name="uniq_lifecycle_rule_ct_null_object",
)
]
@property
def serializer(self) -> type[BaseSerializer]:
@@ -74,6 +82,12 @@ class LifecycleRule(SerializerModel):
qs = self.content_type.get_all_objects_for_this_type()
if self.object_id:
qs = qs.filter(pk=self.object_id)
else:
qs = qs.exclude(
pk__in=LifecycleRule.objects.filter(
content_type=self.content_type, object_id__isnull=False
).values_list(Cast("object_id", output_field=self._get_pk_field()), flat=True)
)
return qs
def _get_stale_iterations(self) -> QuerySet[LifecycleIteration]:
@@ -93,7 +107,8 @@ class LifecycleRule(SerializerModel):
def _get_newly_due_objects(self) -> QuerySet:
recent_iteration_ids = LifecycleIteration.objects.filter(
rule=self,
content_type=self.content_type,
object_id__isnull=False,
opened_on__gte=start_of_day(
timezone.now() + timedelta(days=1) - timedelta_from_string(self.interval)
),
@@ -199,15 +214,9 @@ class LifecycleIteration(SerializerModel, ManagedModel):
}
def initialize(self):
if (self.content_type.app_label, self.content_type.model) == ("authentik_core", "group"):
object_label = self.object.name
elif (self.content_type.app_label, self.content_type.model) == ("authentik_rbac", "role"):
object_label = self.object.name
else:
object_label = str(self.object)
event = Event.new(
EventAction.REVIEW_INITIATED,
message=_(f"Access review is due for {self.content_type.name.lower()} {object_label}"),
message=_(f"Access review is due for {self.content_type.name} {str(self.object)}"),
**self._get_event_args(),
)
event.save()

View File

@@ -3,7 +3,6 @@ from django.db.models.signals import post_save, pre_delete
from django.dispatch import receiver
from authentik.enterprise.lifecycle.models import LifecycleRule, ReviewState
from authentik.tasks.schedules.models import Schedule
@receiver(post_save, sender=LifecycleRule)
@@ -12,9 +11,7 @@ def post_rule_save(sender, instance: LifecycleRule, created: bool, **_):
apply_lifecycle_rule.send_with_options(
args=(instance.id,),
rel_obj=Schedule.objects.get(
actor_name="authentik.enterprise.lifecycle.tasks.apply_lifecycle_rules"
),
rel_obj=instance,
)

View File

@@ -4,17 +4,14 @@ from dramatiq import actor
from authentik.core.models import User
from authentik.enterprise.lifecycle.models import LifecycleRule
from authentik.events.models import Event, Notification, NotificationTransport
from authentik.tasks.schedules.models import Schedule
@actor(description=_("Dispatch tasks to apply lifecycle rules."))
@actor(description=_("Dispatch tasks to validate lifecycle rules."))
def apply_lifecycle_rules():
for rule in LifecycleRule.objects.all():
apply_lifecycle_rule.send_with_options(
args=(rule.id,),
rel_obj=Schedule.objects.get(
actor_name="authentik.enterprise.lifecycle.tasks.apply_lifecycle_rules"
),
rel_obj=rule,
)

View File

@@ -1,4 +1,3 @@
from django.apps import apps
from django.contrib.contenttypes.models import ContentType
from django.urls import reverse
from rest_framework.test import APITestCase
@@ -20,11 +19,6 @@ class TestLifecycleRuleAPI(APITestCase):
self.content_type = ContentType.objects.get_for_model(Application)
self.reviewer_group = Group.objects.create(name=generate_id())
@classmethod
def setUpTestData(cls):
config = apps.get_app_config("authentik_tasks_schedules")
config._on_startup_callback(None)
def test_list_rules(self):
rule = LifecycleRule.objects.create(
name=generate_id(),
@@ -196,11 +190,6 @@ class TestIterationAPI(APITestCase):
self.reviewer_group = Group.objects.create(name=generate_id())
self.reviewer_group.users.add(self.user)
@classmethod
def setUpTestData(cls):
config = apps.get_app_config("authentik_tasks_schedules")
config._on_startup_callback(None)
def test_open_iterations(self):
rule = LifecycleRule.objects.create(
name=generate_id(),
@@ -242,7 +231,7 @@ class TestIterationAPI(APITestCase):
response = self.client.get(
reverse(
"authentik_api:lifecycleiteration-latest-iterations",
"authentik_api:lifecycleiteration-latest-iteration",
kwargs={
"content_type": f"{self.content_type.app_label}.{self.content_type.model}",
"object_id": str(self.app.pk),
@@ -250,20 +239,19 @@ class TestIterationAPI(APITestCase):
)
)
self.assertEqual(response.status_code, 200)
self.assertEqual(len(response.data), 1)
self.assertEqual(response.data[0]["object_id"], str(self.app.pk))
self.assertEqual(response.data["object_id"], str(self.app.pk))
def test_latest_iteration_not_found(self):
response = self.client.get(
reverse(
"authentik_api:lifecycleiteration-latest-iterations",
"authentik_api:lifecycleiteration-latest-iteration",
kwargs={
"content_type": f"{self.content_type.app_label}.{self.content_type.model}",
"object_id": "00000000-0000-0000-0000-000000000000",
},
)
)
self.assertEqual(response.data, [])
self.assertEqual(response.status_code, 404)
def test_iteration_includes_user_can_review(self):
rule = LifecycleRule.objects.create(
@@ -291,11 +279,6 @@ class TestReviewAPI(APITestCase):
self.reviewer_group = Group.objects.create(name=generate_id())
self.reviewer_group.users.add(self.user)
@classmethod
def setUpTestData(cls):
config = apps.get_app_config("authentik_tasks_schedules")
config._on_startup_callback(None)
def test_create_review(self):
rule = LifecycleRule.objects.create(
name=generate_id(),

View File

@@ -2,7 +2,6 @@ import datetime as dt
from datetime import timedelta
from unittest.mock import patch
from django.apps import apps
from django.contrib.contenttypes.models import ContentType
from django.test import RequestFactory, TestCase
from django.utils import timezone
@@ -30,11 +29,6 @@ class TestLifecycleModels(TestCase):
def setUp(self):
self.factory = RequestFactory()
@classmethod
def setUpTestData(cls):
config = apps.get_app_config("authentik_tasks_schedules")
config._on_startup_callback(None)
def _get_request(self):
return self.factory.get("/")
@@ -444,6 +438,31 @@ class TestLifecycleModels(TestCase):
self.assertIn(app_one, objects)
self.assertIn(app_two, objects)
def test_rule_type_excludes_objects_with_specific_rules(self):
app_with_rule = Application.objects.create(name=generate_id(), slug=generate_id())
app_without_rule = Application.objects.create(name=generate_id(), slug=generate_id())
content_type = ContentType.objects.get_for_model(Application)
# Create a specific rule for app_with_rule
LifecycleRule.objects.create(
name=generate_id(),
content_type=content_type,
object_id=str(app_with_rule.pk),
interval="days=30",
)
# Create a type-level rule
type_rule = LifecycleRule.objects.create(
name=generate_id(),
content_type=content_type,
object_id=None,
interval="days=60",
)
objects = list(type_rule.get_objects())
self.assertNotIn(app_with_rule, objects)
self.assertIn(app_without_rule, objects)
def test_rule_type_apply_creates_iterations_for_all_objects(self):
app_one = Application.objects.create(name=generate_id(), slug=generate_id())
app_two = Application.objects.create(name=generate_id(), slug=generate_id())
@@ -650,73 +669,6 @@ class TestLifecycleModels(TestCase):
self.assertIn(explicit_reviewer, reviewers)
self.assertIn(group_member, reviewers)
def test_multiple_rules_same_object_create_separate_iterations(self):
"""Two rules targeting the same object each create their own iteration."""
obj = Application.objects.create(name=generate_id(), slug=generate_id())
content_type = ContentType.objects.get_for_model(obj)
rule_one = self._create_rule_for_object(obj, interval="days=30", grace_period="days=10")
rule_two = self._create_rule_for_object(obj, interval="days=60", grace_period="days=20")
iterations = LifecycleIteration.objects.filter(
content_type=content_type, object_id=str(obj.pk)
)
self.assertEqual(iterations.count(), 2)
iter_one = iterations.get(rule=rule_one)
iter_two = iterations.get(rule=rule_two)
self.assertEqual(iter_one.state, ReviewState.PENDING)
self.assertEqual(iter_two.state, ReviewState.PENDING)
self.assertNotEqual(iter_one.pk, iter_two.pk)
def test_multiple_rules_same_object_reviewed_independently(self):
"""Reviewing one rule's iteration does not affect the other rule's iteration."""
obj = Application.objects.create(name=generate_id(), slug=generate_id())
content_type = ContentType.objects.get_for_model(obj)
reviewer = create_test_user()
rule_one = self._create_rule_for_object(obj, min_reviewers=1)
rule_two = self._create_rule_for_object(obj, min_reviewers=1)
group = Group.objects.create(name=generate_id())
group.users.add(reviewer)
rule_one.reviewer_groups.add(group)
rule_two.reviewer_groups.add(group)
iter_one = LifecycleIteration.objects.get(
content_type=content_type, object_id=str(obj.pk), rule=rule_one
)
iter_two = LifecycleIteration.objects.get(
content_type=content_type, object_id=str(obj.pk), rule=rule_two
)
request = self._get_request()
# Review only rule_one's iteration
Review.objects.create(iteration=iter_one, reviewer=reviewer)
iter_one.on_review(request)
iter_one.refresh_from_db()
iter_two.refresh_from_db()
self.assertEqual(iter_one.state, ReviewState.REVIEWED)
self.assertEqual(iter_two.state, ReviewState.PENDING)
def test_type_rule_and_object_rule_both_create_iterations(self):
"""A type-level rule and an object-level rule both create iterations for the same object."""
obj = Application.objects.create(name=generate_id(), slug=generate_id())
content_type = ContentType.objects.get_for_model(obj)
object_rule = self._create_rule_for_object(obj, interval="days=30")
type_rule = self._create_rule_for_type(Application, interval="days=60")
iterations = LifecycleIteration.objects.filter(
content_type=content_type, object_id=str(obj.pk)
)
self.assertEqual(iterations.count(), 2)
self.assertTrue(iterations.filter(rule=object_rule).exists())
self.assertTrue(iterations.filter(rule=type_rule).exists())
class TestLifecycleDateBoundaries(TestCase):
"""Verify that start_of_day normalization ensures correct overdue/due
@@ -727,11 +679,6 @@ class TestLifecycleDateBoundaries(TestCase):
ensures that the boundary is always at midnight, so millisecond variations
in task execution time do not affect results."""
@classmethod
def setUpTestData(cls):
config = apps.get_app_config("authentik_tasks_schedules")
config._on_startup_callback(None)
def _create_rule_and_iteration(self, grace_period="days=1", interval="days=365"):
app = Application.objects.create(name=generate_id(), slug=generate_id())
content_type = ContentType.objects.get_for_model(Application)

View File

@@ -1,7 +1,6 @@
# Generated by Django 5.2.12 on 2026-04-04 16:58
from django.db import migrations, models
import django.contrib.postgres.fields
class Migration(migrations.Migration):
@@ -41,109 +40,4 @@ class Migration(migrations.Migration):
]
),
),
migrations.AlterField(
model_name="stream",
name="events_requested",
field=django.contrib.postgres.fields.ArrayField(
base_field=models.TextField(
choices=[
(
"https://schemas.openid.net/secevent/caep/event-type/session-revoked",
"Caep Session Revoked",
),
(
"https://schemas.openid.net/secevent/caep/event-type/token-claims-change",
"Caep Token Claims Change",
),
(
"https://schemas.openid.net/secevent/caep/event-type/credential-change",
"Caep Credential Change",
),
(
"https://schemas.openid.net/secevent/caep/event-type/assurance-level-change",
"Caep Assurance Level Change",
),
(
"https://schemas.openid.net/secevent/caep/event-type/device-compliance-change",
"Caep Device Compliance Change",
),
(
"https://schemas.openid.net/secevent/caep/event-type/session-established",
"Caep Session Established",
),
(
"https://schemas.openid.net/secevent/caep/event-type/session-presented",
"Caep Session Presented",
),
(
"https://schemas.openid.net/secevent/caep/event-type/risk-level-change",
"Caep Risk Level Change",
),
(
"https://schemas.openid.net/secevent/ssf/event-type/verification",
"Set Verification",
),
]
),
default=list,
size=None,
),
),
migrations.AlterField(
model_name="stream",
name="status",
field=models.TextField(
choices=[
("enabled", "Enabled"),
("paused", "Paused"),
("disabled", "Disabled"),
("disabled_deleted", "Disabled Deleted"),
],
default="enabled",
),
),
migrations.AlterField(
model_name="streamevent",
name="type",
field=models.TextField(
choices=[
(
"https://schemas.openid.net/secevent/caep/event-type/session-revoked",
"Caep Session Revoked",
),
(
"https://schemas.openid.net/secevent/caep/event-type/token-claims-change",
"Caep Token Claims Change",
),
(
"https://schemas.openid.net/secevent/caep/event-type/credential-change",
"Caep Credential Change",
),
(
"https://schemas.openid.net/secevent/caep/event-type/assurance-level-change",
"Caep Assurance Level Change",
),
(
"https://schemas.openid.net/secevent/caep/event-type/device-compliance-change",
"Caep Device Compliance Change",
),
(
"https://schemas.openid.net/secevent/caep/event-type/session-established",
"Caep Session Established",
),
(
"https://schemas.openid.net/secevent/caep/event-type/session-presented",
"Caep Session Presented",
),
(
"https://schemas.openid.net/secevent/caep/event-type/risk-level-change",
"Caep Risk Level Change",
),
(
"https://schemas.openid.net/secevent/ssf/event-type/verification",
"Set Verification",
),
]
),
),
]

View File

@@ -24,31 +24,8 @@ class EventTypes(models.TextChoices):
"""SSF Event types supported by authentik"""
CAEP_SESSION_REVOKED = "https://schemas.openid.net/secevent/caep/event-type/session-revoked"
"""https://openid.net/specs/openid-caep-1_0-final.html#section-3.1"""
CAEP_TOKEN_CLAIMS_CHANGE = (
"https://schemas.openid.net/secevent/caep/event-type/token-claims-change"
)
"""https://openid.net/specs/openid-caep-1_0-final.html#section-3.2"""
CAEP_CREDENTIAL_CHANGE = "https://schemas.openid.net/secevent/caep/event-type/credential-change"
"""https://openid.net/specs/openid-caep-1_0-final.html#section-3.3"""
CAEP_ASSURANCE_LEVEL_CHANGE = (
"https://schemas.openid.net/secevent/caep/event-type/assurance-level-change"
)
"""https://openid.net/specs/openid-caep-1_0-final.html#section-3.4"""
CAEP_DEVICE_COMPLIANCE_CHANGE = (
"https://schemas.openid.net/secevent/caep/event-type/device-compliance-change"
)
"""https://openid.net/specs/openid-caep-1_0-final.html#section-3.5"""
CAEP_SESSION_ESTABLISHED = (
"https://schemas.openid.net/secevent/caep/event-type/session-established"
)
"""https://openid.net/specs/openid-caep-1_0-final.html#section-3.6"""
CAEP_SESSION_PRESENTED = "https://schemas.openid.net/secevent/caep/event-type/session-presented"
"""https://openid.net/specs/openid-caep-1_0-final.html#section-3.7"""
CAEP_RISK_LEVEL_CHANGE = "https://schemas.openid.net/secevent/caep/event-type/risk-level-change"
"""https://openid.net/specs/openid-caep-1_0-final.html#section-3.8"""
SET_VERIFICATION = "https://schemas.openid.net/secevent/ssf/event-type/verification"
"""https://openid.net/specs/openid-sharedsignals-framework-1_0.html#section-8.1.4.1"""
class DeliveryMethods(models.TextChoices):
@@ -69,12 +46,10 @@ class SSFEventStatus(models.TextChoices):
class StreamStatus(models.TextChoices):
"""SSF Stream status"""
ENABLED = "enabled"
PAUSED = "paused"
DISABLED = "disabled"
DISABLED_DELETED = "disabled_deleted"
class SSFProvider(TasksModel, BackchannelProvider):

View File

@@ -12,7 +12,7 @@ from authentik.core.models import (
User,
UserTypes,
)
from authentik.core.signals import password_changed, password_hash_changed
from authentik.core.signals import password_changed
from authentik.enterprise.providers.ssf.models import (
EventTypes,
SSFProvider,
@@ -84,13 +84,14 @@ def ssf_user_session_delete_session_revoked(sender, instance: AuthenticatedSessi
)
def _send_password_credential_change(user: User, change_type: str):
@receiver(password_changed)
def ssf_password_changed_cred_change(sender, user: User, password: str | None, **_):
"""Credential change trigger (password changed)"""
send_ssf_events(
EventTypes.CAEP_CREDENTIAL_CHANGE,
{
"credential_type": "password",
"change_type": change_type,
"change_type": "revoke" if password is None else "update",
},
sub_id={
"format": "complex",
@@ -102,16 +103,6 @@ def _send_password_credential_change(user: User, change_type: str):
)
@receiver(password_hash_changed)
@receiver(password_changed)
def ssf_password_changed_cred_change(signal, sender, user: User, password: str | None = None, **_):
"""Credential change trigger (password changed)"""
if signal is password_hash_changed:
_send_password_credential_change(user, "update")
return
_send_password_credential_change(user, "revoke" if password is None else "update")
device_type_map = {
StaticDevice: "pin",
TOTPDevice: "pin",

View File

@@ -108,13 +108,13 @@ def send_ssf_event(stream_uuid: UUID, event_data: dict[str, Any]):
event.save()
self.info("Event successfully sent", status=response.status_code)
# Cleanup, if we were the last pending message for this stream and it has been deleted
# (status=StreamStatus.DISABLED_DELETED), then we can delete the stream
# (status=StreamStatus.DISABLED), then we can delete the stream
if (
not StreamEvent.objects.filter(
stream=stream,
status__in=[SSFEventStatus.PENDING_FAILED, SSFEventStatus.PENDING_NEW],
).exists()
and stream.status == StreamStatus.DISABLED_DELETED
and stream.status == StreamStatus.DISABLED
):
LOGGER.info(
"Deleting inactive stream as all pending messages were sent.", stream=stream

View File

@@ -62,7 +62,7 @@ class TestSSFAuth(APITestCase):
self.assertEqual(event.status, SSFEventStatus.PENDING_FAILED)
self.assertEqual(
event.payload["events"],
{"https://schemas.openid.net/secevent/ssf/event-type/verification": {}},
{"https://schemas.openid.net/secevent/ssf/event-type/verification": {"state": None}},
)
def test_stream_add_oidc(self):
@@ -115,7 +115,7 @@ class TestSSFAuth(APITestCase):
self.assertEqual(event.status, SSFEventStatus.PENDING_FAILED)
self.assertEqual(
event.payload["events"],
{"https://schemas.openid.net/secevent/ssf/event-type/verification": {}},
{"https://schemas.openid.net/secevent/ssf/event-type/verification": {"state": None}},
)
def test_token_invalid(self):

View File

@@ -1,6 +1,5 @@
from uuid import uuid4
from django.contrib.auth.hashers import make_password
from django.urls import reverse
from rest_framework.test import APITestCase
@@ -53,21 +52,6 @@ class TestSignals(APITestCase):
)
self.assertEqual(res.status_code, 201, res.content)
def _assert_password_credential_change(self, user, change_type: str):
stream = Stream.objects.filter(provider=self.provider).first()
self.assertIsNotNone(stream)
event = StreamEvent.objects.filter(stream=stream).first()
self.assertIsNotNone(event)
self.assertEqual(event.status, SSFEventStatus.PENDING_FAILED)
event_payload = event.payload["events"][
"https://schemas.openid.net/secevent/caep/event-type/credential-change"
]
self.assertEqual(event_payload["change_type"], change_type)
self.assertEqual(event_payload["credential_type"], "password")
self.assertEqual(event.payload["sub_id"]["format"], "complex")
self.assertEqual(event.payload["sub_id"]["user"]["format"], "email")
self.assertEqual(event.payload["sub_id"]["user"]["email"], user.email)
def test_signal_logout(self):
"""Test user logout"""
user = create_test_user()
@@ -95,25 +79,19 @@ class TestSignals(APITestCase):
user.set_password(generate_id())
user.save()
self._assert_password_credential_change(user, "update")
def test_signal_password_change_from_hash(self):
"""Test user password change from a pre-hashed password."""
user = create_test_user()
self.client.force_login(user)
user.set_password_from_hash(make_password(generate_id()))
user.save()
self._assert_password_credential_change(user, "update")
def test_signal_password_revoke(self):
"""Test explicit password revoke."""
user = create_test_user()
self.client.force_login(user)
user.set_password(None)
user.save()
self._assert_password_credential_change(user, "revoke")
stream = Stream.objects.filter(provider=self.provider).first()
self.assertIsNotNone(stream)
event = StreamEvent.objects.filter(stream=stream).first()
self.assertIsNotNone(event)
self.assertEqual(event.status, SSFEventStatus.PENDING_FAILED)
event_payload = event.payload["events"][
"https://schemas.openid.net/secevent/caep/event-type/credential-change"
]
self.assertEqual(event_payload["change_type"], "update")
self.assertEqual(event_payload["credential_type"], "password")
self.assertEqual(event.payload["sub_id"]["format"], "complex")
self.assertEqual(event.payload["sub_id"]["user"]["format"], "email")
self.assertEqual(event.payload["sub_id"]["user"]["email"], user.email)
def test_signal_authenticator_added(self):
"""Test authenticator creation signal"""

View File

@@ -54,7 +54,7 @@ class TestStream(APITestCase):
self.assertEqual(event.status, SSFEventStatus.PENDING_FAILED)
self.assertEqual(
event.payload["events"],
{"https://schemas.openid.net/secevent/ssf/event-type/verification": {}},
{"https://schemas.openid.net/secevent/ssf/event-type/verification": {"state": None}},
)
def test_stream_add_poll(self):
@@ -96,7 +96,7 @@ class TestStream(APITestCase):
)
self.assertEqual(res.status_code, 204)
stream.refresh_from_db()
self.assertEqual(stream.status, StreamStatus.DISABLED_DELETED)
self.assertEqual(stream.status, StreamStatus.DISABLED)
def test_stream_get(self):
"""get stream"""
@@ -225,26 +225,3 @@ class TestStream(APITestCase):
HTTP_AUTHORIZATION=f"Bearer {self.provider.token.key}",
)
self.assertEqual(res.status_code, 404)
def test_stream_status_update(self):
stream = Stream.objects.create(provider=self.provider)
res = self.client.post(
reverse(
"authentik_providers_ssf:stream-status",
kwargs={"application_slug": self.application.slug},
),
data={
"stream_id": str(stream.pk),
"status": StreamStatus.DISABLED,
},
HTTP_AUTHORIZATION=f"Bearer {self.provider.token.key}",
)
self.assertEqual(res.status_code, 200)
stream.refresh_from_db()
self.assertJSONEqual(
res.content,
{
"stream_id": str(stream.pk),
"status": str(stream.status),
},
)

View File

@@ -33,7 +33,7 @@ class TestTasks(APITestCase):
)
event_data = stream.prepare_event_payload(
EventTypes.SET_VERIFICATION,
{},
{"state": None},
sub_id={"format": "opaque", "id": str(stream.uuid)},
)
with Mocker() as mocker:
@@ -46,7 +46,7 @@ class TestTasks(APITestCase):
)
jwt = decode_complete(mocker.request_history[0].body, options={"verify_signature": False})
self.assertEqual(jwt["header"]["typ"], "secevent+jwt")
self.assertEqual(jwt["payload"]["events"][EventTypes.SET_VERIFICATION], {})
self.assertIsNone(jwt["payload"]["events"][EventTypes.SET_VERIFICATION]["state"])
def test_push_auth(self):
auth = generate_id()
@@ -58,7 +58,7 @@ class TestTasks(APITestCase):
)
event_data = stream.prepare_event_payload(
EventTypes.SET_VERIFICATION,
{},
{"state": None},
sub_id={"format": "opaque", "id": str(stream.uuid)},
)
with Mocker() as mocker:
@@ -72,7 +72,7 @@ class TestTasks(APITestCase):
)
jwt = decode_complete(mocker.request_history[0].body, options={"verify_signature": False})
self.assertEqual(jwt["header"]["typ"], "secevent+jwt")
self.assertEqual(jwt["payload"]["events"][EventTypes.SET_VERIFICATION], {})
self.assertIsNone(jwt["payload"]["events"][EventTypes.SET_VERIFICATION]["state"])
def test_push_stream_disable(self):
auth = generate_id()
@@ -81,11 +81,11 @@ class TestTasks(APITestCase):
delivery_method=DeliveryMethods.RFC_PUSH,
endpoint_url="http://localhost/ssf-push",
authorization_header=auth,
status=StreamStatus.DISABLED_DELETED,
status=StreamStatus.DISABLED,
)
event_data = stream.prepare_event_payload(
EventTypes.SET_VERIFICATION,
{},
{"state": None},
sub_id={"format": "opaque", "id": str(stream.uuid)},
)
with Mocker() as mocker:
@@ -95,7 +95,7 @@ class TestTasks(APITestCase):
).get_result(block=True, timeout=1)
jwt = decode_complete(mocker.request_history[0].body, options={"verify_signature": False})
self.assertEqual(jwt["header"]["typ"], "secevent+jwt")
self.assertEqual(jwt["payload"]["events"][EventTypes.SET_VERIFICATION], {})
self.assertIsNone(jwt["payload"]["events"][EventTypes.SET_VERIFICATION]["state"])
self.assertFalse(Stream.objects.filter(pk=stream.pk).exists())
def test_push_error(self):
@@ -106,7 +106,7 @@ class TestTasks(APITestCase):
)
event_data = stream.prepare_event_payload(
EventTypes.SET_VERIFICATION,
{},
{"state": None},
sub_id={"format": "opaque", "id": str(stream.uuid)},
)
with Mocker() as mocker:

View File

@@ -24,10 +24,10 @@ class SSFView(APIView):
class SSFStreamView(SSFView):
def get_object(self) -> Stream:
streams = Stream.objects.filter(provider=self.provider).exclude(
status=StreamStatus.DISABLED_DELETED
)
def get_object(self, any_status=False) -> Stream:
streams = Stream.objects.filter(provider=self.provider)
if not any_status:
streams = streams.filter(status__in=[StreamStatus.ENABLED, StreamStatus.PAUSED])
if "stream_id" in self.request.query_params:
streams = streams.filter(pk=self.request.query_params["stream_id"])
if "stream_id" in self.request.data:

View File

@@ -1,6 +1,6 @@
from uuid import uuid4
from django.http import Http404, HttpRequest
from django.http import HttpRequest
from django.urls import reverse
from rest_framework.exceptions import PermissionDenied, ValidationError
from rest_framework.fields import CharField, ChoiceField, ListField, SerializerMethodField
@@ -106,11 +106,7 @@ class StreamResponseSerializer(PassiveSerializer):
}
def get_events_supported(self, instance: Stream) -> list[str]:
return [
EventTypes.CAEP_SESSION_REVOKED,
EventTypes.CAEP_CREDENTIAL_CHANGE,
EventTypes.SET_VERIFICATION,
]
return [x.value for x in EventTypes]
class StreamView(SSFStreamView):
@@ -132,9 +128,10 @@ class StreamView(SSFStreamView):
LOGGER.info("Sending verification event", stream=instance)
send_ssf_events(
EventTypes.SET_VERIFICATION,
{},
{
"state": None,
},
stream_filter={"pk": instance.uuid},
request=request,
sub_id={"format": "opaque", "id": str(instance.uuid)},
)
response = StreamResponseSerializer(instance=instance, context={"request": request}).data
@@ -162,9 +159,7 @@ class StreamView(SSFStreamView):
def delete(self, request: Request, *args, **kwargs) -> Response:
stream = self.get_object()
if stream.status == StreamStatus.DISABLED_DELETED:
raise Http404
stream.status = StreamStatus.DISABLED_DELETED
stream.status = StreamStatus.DISABLED
stream.save()
return Response(status=204)
@@ -180,7 +175,6 @@ class StreamVerifyView(SSFStreamView):
"state": state,
},
stream_filter={"pk": stream.uuid},
request=request,
sub_id={"format": "opaque", "id": str(stream.uuid)},
)
return Response(status=204)
@@ -188,25 +182,8 @@ class StreamVerifyView(SSFStreamView):
class StreamStatusView(SSFStreamView):
class StreamStatusSerializer(PassiveSerializer):
stream_id = CharField()
status = ChoiceField(choices=StreamStatus.choices)
def get(self, request: Request, *args, **kwargs):
stream = self.get_object()
return Response(
{
"stream_id": str(stream.pk),
"status": str(stream.status),
}
)
def post(self, request: Request, *args, **kwargs):
stream = self.get_object()
serializer = self.StreamStatusSerializer(stream, data=request.data)
serializer.is_valid(raise_exception=True)
stream.status = serializer.validated_data["status"]
stream.save()
stream = self.get_object(any_status=True)
return Response(
{
"stream_id": str(stream.pk),

View File

@@ -14,7 +14,6 @@ TENANT_APPS = [
"authentik.enterprise.providers.ssf",
"authentik.enterprise.providers.ws_federation",
"authentik.enterprise.reports",
"authentik.enterprise.stages.account_lockdown",
"authentik.enterprise.stages.authenticator_endpoint_gdtc",
"authentik.enterprise.stages.mtls",
"authentik.enterprise.stages.source",

View File

@@ -1,141 +0,0 @@
"""Account Lockdown Stage API Views"""
from django.utils.translation import gettext as _
from drf_spectacular.utils import OpenApiExample, OpenApiResponse, extend_schema
from rest_framework.decorators import action
from rest_framework.exceptions import ValidationError
from rest_framework.permissions import IsAuthenticated
from rest_framework.request import Request
from rest_framework.response import Response
from rest_framework.serializers import PrimaryKeyRelatedField
from rest_framework.viewsets import ModelViewSet
from structlog.stdlib import get_logger
from authentik.api.validation import validate
from authentik.core.api.used_by import UsedByMixin
from authentik.core.api.utils import LinkSerializer, PassiveSerializer
from authentik.core.models import (
User,
)
from authentik.enterprise.api import EnterpriseRequiredMixin, enterprise_action
from authentik.enterprise.stages.account_lockdown.models import AccountLockdownStage
from authentik.enterprise.stages.account_lockdown.stage import (
can_lock_user,
get_lockdown_target_users,
)
from authentik.flows.api.stages import StageSerializer
from authentik.flows.exceptions import EmptyFlowException, FlowNonApplicableException
from authentik.flows.planner import PLAN_CONTEXT_PENDING_USER, FlowPlanner
LOGGER = get_logger()
class AccountLockdownStageSerializer(EnterpriseRequiredMixin, StageSerializer):
"""AccountLockdownStage Serializer"""
class Meta:
model = AccountLockdownStage
fields = StageSerializer.Meta.fields + [
"deactivate_user",
"set_unusable_password",
"delete_sessions",
"revoke_tokens",
"self_service_completion_flow",
]
class AccountLockdownStageViewSet(UsedByMixin, ModelViewSet):
"""AccountLockdownStage Viewset"""
queryset = AccountLockdownStage.objects.all()
serializer_class = AccountLockdownStageSerializer
filterset_fields = "__all__"
ordering = ["name"]
search_fields = ["name"]
class UserAccountLockdownSerializer(PassiveSerializer):
"""Choose the target account before starting the lockdown flow."""
user = PrimaryKeyRelatedField(
queryset=get_lockdown_target_users(),
required=False,
allow_null=True,
help_text=_("User to lock. If omitted, locks the current user (self-service)."),
)
class UserAccountLockdownMixin:
"""Enterprise account-lockdown API actions for UserViewSet."""
def _create_lockdown_flow_url(self, request: Request, user: User) -> str:
"""Create a flow URL for account lockdown.
The request body selects the target before the flow starts. The API
pre-plans the lockdown flow with the target as the pending user, so the
account lockdown stage can use the normal flow context.
"""
flow = request._request.brand.flow_lockdown
if flow is None:
raise ValidationError({"non_field_errors": [_("No lockdown flow configured.")]})
planner = FlowPlanner(flow)
planner.use_cache = False
try:
plan = planner.plan(request._request, {PLAN_CONTEXT_PENDING_USER: user})
except EmptyFlowException, FlowNonApplicableException:
raise ValidationError(
{"non_field_errors": [_("Lockdown flow is not applicable.")]}
) from None
return plan.to_redirect(request._request, flow).url
@extend_schema(
description=_("Choose the target account, then return a flow link."),
request=UserAccountLockdownSerializer,
responses={
"200": OpenApiResponse(
response=LinkSerializer,
examples=[
OpenApiExample(
"Lockdown flow URL",
value={
"link": "https://example.invalid/if/flow/default-account-lockdown/",
},
response_only=True,
status_codes=["200"],
)
],
),
"400": OpenApiResponse(
description=_("No lockdown flow configured or the flow is not applicable")
),
"403": OpenApiResponse(
description=_("Permission denied (when targeting another user)")
),
},
)
@action(
detail=False,
methods=["POST"],
permission_classes=[IsAuthenticated],
url_path="account_lockdown",
)
@validate(UserAccountLockdownSerializer)
@enterprise_action
def account_lockdown(self, request: Request, body: UserAccountLockdownSerializer) -> Response:
"""Trigger account lockdown for a user.
If no user is specified, locks the current user (self-service).
When targeting another user, admin permissions are required.
Returns a flow link for the frontend to follow. The flow is pre-planned
with the target user as pending user for the lockdown stage.
"""
user = body.validated_data.get("user") or request.user
if not can_lock_user(request.user, user):
LOGGER.debug("Permission denied for account lockdown", user=request.user)
self.permission_denied(request)
flow_url = self._create_lockdown_flow_url(request, user)
LOGGER.debug("Returning lockdown flow URL", flow_url=flow_url, user=user.username)
return Response({"link": flow_url})

View File

@@ -1,12 +0,0 @@
"""authentik account lockdown stage app config"""
from authentik.enterprise.apps import EnterpriseConfig
class AuthentikEnterpriseStageAccountLockdownConfig(EnterpriseConfig):
"""authentik account lockdown stage config"""
name = "authentik.enterprise.stages.account_lockdown"
label = "authentik_stages_account_lockdown"
verbose_name = "authentik Enterprise.Stages.Account Lockdown"
default = True

View File

@@ -1,74 +0,0 @@
# Generated by Django 5.2.13 on 2026-04-19 21:56
import django.db.models.deletion
from django.db import migrations, models
class Migration(migrations.Migration):
initial = True
dependencies = [
("authentik_flows", "0031_alter_flow_layout"),
]
operations = [
migrations.CreateModel(
name="AccountLockdownStage",
fields=[
(
"stage_ptr",
models.OneToOneField(
auto_created=True,
on_delete=django.db.models.deletion.CASCADE,
parent_link=True,
primary_key=True,
serialize=False,
to="authentik_flows.stage",
),
),
(
"deactivate_user",
models.BooleanField(
default=True,
help_text="Deactivate the user account (set is_active to False)",
),
),
(
"set_unusable_password",
models.BooleanField(
default=True, help_text="Set an unusable password for the user"
),
),
(
"delete_sessions",
models.BooleanField(
default=True, help_text="Delete all active sessions for the user"
),
),
(
"revoke_tokens",
models.BooleanField(
default=True,
help_text="Revoke all tokens for the user (API, app password, recovery, verification, OAuth)",
),
),
(
"self_service_completion_flow",
models.ForeignKey(
blank=True,
help_text="Flow to redirect users to after self-service lockdown. This flow should not require authentication since the user's session is deleted.",
null=True,
on_delete=django.db.models.deletion.SET_NULL,
related_name="account_lockdown_stages",
to="authentik_flows.flow",
),
),
],
options={
"verbose_name": "Account Lockdown Stage",
"verbose_name_plural": "Account Lockdown Stages",
},
bases=("authentik_flows.stage",),
),
]

View File

@@ -1,62 +0,0 @@
"""Account lockdown stage models"""
from django.db import models
from django.utils.translation import gettext_lazy as _
from django.views import View
from rest_framework.serializers import BaseSerializer
from authentik.flows.models import Stage
class AccountLockdownStage(Stage):
"""Lock down a target user account."""
deactivate_user = models.BooleanField(
default=True,
help_text=_("Deactivate the user account (set is_active to False)"),
)
set_unusable_password = models.BooleanField(
default=True,
help_text=_("Set an unusable password for the user"),
)
delete_sessions = models.BooleanField(
default=True,
help_text=_("Delete all active sessions for the user"),
)
revoke_tokens = models.BooleanField(
default=True,
help_text=_(
"Revoke all tokens for the user (API, app password, recovery, verification, OAuth)"
),
)
self_service_completion_flow = models.ForeignKey(
"authentik_flows.Flow",
on_delete=models.SET_NULL,
null=True,
blank=True,
related_name="account_lockdown_stages",
help_text=_(
"Flow to redirect users to after self-service lockdown. "
"This flow should not require authentication since the user's session is deleted."
),
)
@property
def serializer(self) -> type[BaseSerializer]:
from authentik.enterprise.stages.account_lockdown.api import AccountLockdownStageSerializer
return AccountLockdownStageSerializer
@property
def view(self) -> type[View]:
from authentik.enterprise.stages.account_lockdown.stage import AccountLockdownStageView
return AccountLockdownStageView
@property
def component(self) -> str:
return "ak-stage-account-lockdown-form"
class Meta:
verbose_name = _("Account Lockdown Stage")
verbose_name_plural = _("Account Lockdown Stages")

View File

@@ -1,345 +0,0 @@
"""Account lockdown stage logic"""
from django.apps import apps
from django.core.exceptions import FieldDoesNotExist
from django.db.models import Model, QuerySet
from django.db.models.query_utils import Q
from django.db.transaction import atomic
from django.http import HttpRequest, HttpResponse, HttpResponseRedirect
from django.urls import reverse
from django.utils.translation import gettext_lazy as _
from dramatiq.actor import Actor
from dramatiq.composition import group
from dramatiq.results.errors import ResultTimeout
from authentik.core.models import (
AuthenticatedSession,
ExpiringModel,
Session,
Token,
User,
UserTypes,
)
from authentik.enterprise.stages.account_lockdown.models import AccountLockdownStage
from authentik.events.models import Event, EventAction
from authentik.flows.stage import StageView
from authentik.lib.sync.outgoing.models import OutgoingSyncProvider
from authentik.lib.sync.outgoing.signals import sync_outgoing_inhibit_dispatch
from authentik.lib.utils.reflection import class_to_path
from authentik.lib.utils.time import timedelta_from_string
from authentik.stages.prompt.stage import PLAN_CONTEXT_PROMPT
PLAN_CONTEXT_LOCKDOWN_REASON = "lockdown_reason"
LOCKDOWN_EVENT_ACTION_ID = "account_lockdown"
TARGET_REQUIRED_MESSAGE = _("No target user specified for account lockdown")
PERMISSION_DENIED_MESSAGE = _("You do not have permission to lock down this account.")
ACCOUNT_LOCKDOWN_FAILED_MESSAGE = _("Account lockdown failed for this account.")
SELF_SERVICE_COMPLETION_FLOW_REQUIRED_MESSAGE = _(
"Self-service account lockdown requires a completion flow."
)
def get_lockdown_target_users() -> QuerySet[User]:
"""Return users that can be targeted by account lockdown."""
return User.objects.exclude_anonymous().exclude(type=UserTypes.INTERNAL_SERVICE_ACCOUNT)
def _get_model_field(model: type[Model], field_name: str):
"""Get a model field by name, if present."""
try:
return model._meta.get_field(field_name)
except FieldDoesNotExist:
return None
def _has_user_field(model: type[Model]) -> bool:
"""Check if a model has a direct user foreign key."""
field = _get_model_field(model, "user")
return bool(field and getattr(field, "remote_field", None) and field.remote_field.model is User)
def _has_authenticated_session_field(model: type[Model]) -> bool:
"""Check if a model is linked to an authenticated session."""
field = _get_model_field(model, "session")
return bool(
field
and getattr(field, "remote_field", None)
and field.remote_field.model is AuthenticatedSession
)
def _has_provider_field(model: type[Model]) -> bool:
"""Check if a model is linked to a provider."""
return _get_model_field(model, "provider") is not None
def get_lockdown_token_models() -> tuple[type[Model], ...]:
"""Return token, grant, and provider session models removed by account lockdown."""
token_models: list[type[Model]] = []
for model in apps.get_models():
if model._meta.abstract or not issubclass(model, ExpiringModel):
continue
if model is Token:
token_models.append(model)
elif _has_user_field(model) and (
_has_provider_field(model) or _has_authenticated_session_field(model)
):
token_models.append(model)
elif _has_authenticated_session_field(model):
token_models.append(model)
return tuple(token_models)
def get_lockdown_token_queryset(model: type[Model], user: User) -> QuerySet:
"""Return account lockdown artifacts for a model and user."""
manager = model.objects.including_expired()
if _has_user_field(model):
return manager.filter(user=user)
return manager.filter(session__user=user)
def can_lock_user(actor, user: User) -> bool:
"""Check whether the actor may lock the target user."""
if not actor.is_authenticated:
return False
if user.pk == actor.pk:
return True
return actor.has_perm("authentik_core.change_user", user)
def get_outgoing_sync_tasks() -> tuple[tuple[type[OutgoingSyncProvider], Actor], ...]:
"""Return outgoing sync provider types and their direct sync tasks."""
from authentik.enterprise.providers.google_workspace.models import GoogleWorkspaceProvider
from authentik.enterprise.providers.google_workspace.tasks import google_workspace_sync_direct
from authentik.enterprise.providers.microsoft_entra.models import MicrosoftEntraProvider
from authentik.enterprise.providers.microsoft_entra.tasks import microsoft_entra_sync_direct
from authentik.providers.scim.models import SCIMProvider
from authentik.providers.scim.tasks import scim_sync_direct
return (
(SCIMProvider, scim_sync_direct),
(GoogleWorkspaceProvider, google_workspace_sync_direct),
(MicrosoftEntraProvider, microsoft_entra_sync_direct),
)
class AccountLockdownStageView(StageView):
"""Execute account lockdown actions on the target user."""
def is_self_service(self, request: HttpRequest, user: User) -> bool:
"""Check whether the currently authenticated user is locking their own account."""
return request.user.is_authenticated and user.pk == request.user.pk
def get_reason(self) -> str:
"""Get the lockdown reason from the plan context.
Priority:
1. prompt_data[PLAN_CONTEXT_LOCKDOWN_REASON]
2. PLAN_CONTEXT_LOCKDOWN_REASON (explicitly set)
3. Empty string as fallback
"""
prompt_data = self.executor.plan.context.get(PLAN_CONTEXT_PROMPT, {})
if PLAN_CONTEXT_LOCKDOWN_REASON in prompt_data:
return prompt_data[PLAN_CONTEXT_LOCKDOWN_REASON]
return self.executor.plan.context.get(PLAN_CONTEXT_LOCKDOWN_REASON, "")
def _apply_lockdown_actions(self, stage: AccountLockdownStage, user: User) -> None:
"""Apply the configured account changes to the target user."""
if stage.deactivate_user:
user.is_active = False
if stage.set_unusable_password:
user.set_unusable_password()
if stage.deactivate_user:
with sync_outgoing_inhibit_dispatch():
user.save()
return
user.save()
def _sync_deactivated_user_to_outgoing_providers(self, user: User) -> None:
"""Synchronize a deactivated user to outgoing sync providers."""
messages = []
wait_timeout = 0
model = class_to_path(User)
provider_filter = Q(backchannel_application__isnull=False) | Q(application__isnull=False)
for provider_model, task_sync_direct in get_outgoing_sync_tasks():
for provider in provider_model.objects.filter(provider_filter):
time_limit = int(
timedelta_from_string(provider.sync_page_timeout).total_seconds() * 1000
)
messages.append(
task_sync_direct.message_with_options(
args=(model, user.pk, provider.pk),
rel_obj=provider,
time_limit=time_limit,
uid=f"{provider.name}:user:{user.pk}:direct",
)
)
wait_timeout += time_limit
if not messages:
return
try:
group(messages).run().wait(timeout=wait_timeout)
except ResultTimeout:
self.logger.warning(
"Timed out waiting for outgoing sync tasks; tasks remain queued",
user=user.username,
timeout=wait_timeout,
)
def _get_lockdown_artifact_querysets(
self, stage: AccountLockdownStage, user: User
) -> tuple[QuerySet, ...]:
"""Return the configured sessions and tokens targeted by lockdown."""
querysets: list[QuerySet] = []
if stage.delete_sessions:
querysets.append(Session.objects.filter(authenticatedsession__user=user))
if stage.revoke_tokens:
querysets.extend(
get_lockdown_token_queryset(model, user) for model in get_lockdown_token_models()
)
return tuple(querysets)
def _delete_lockdown_artifacts(self, stage: AccountLockdownStage, user: User) -> None:
"""Delete sessions and tokens selected by the lockdown configuration."""
for queryset in self._get_lockdown_artifact_querysets(stage, user):
queryset.delete()
def _has_lockdown_artifacts(self, stage: AccountLockdownStage, user: User) -> bool:
"""Check whether there are still sessions or tokens to remove."""
return any(
queryset.exists() for queryset in self._get_lockdown_artifact_querysets(stage, user)
)
def _emit_lockdown_event(self, request: HttpRequest, user: User, reason: str) -> None:
"""Emit the audit event for a completed lockdown."""
# Emit the audit event after the transaction commits. If event creation
# fails here, dispatch() would otherwise treat the whole lockdown as
# failed even though the account changes have already been committed.
try:
Event.new(
EventAction.USER_WRITE,
action_id=LOCKDOWN_EVENT_ACTION_ID,
reason=reason,
affected_user=user.username,
).from_http(request)
except Exception as exc: # noqa: BLE001
# Event emission should not make the lockdown itself fail.
self.logger.warning(
"Failed to emit account lockdown event",
user=user.username,
exc=exc,
)
def _lockdown_user(
self,
request: HttpRequest,
stage: AccountLockdownStage,
user: User,
reason: str,
) -> None:
"""Execute lockdown actions on a single user."""
with atomic():
user = User.objects.get(pk=user.pk)
self._apply_lockdown_actions(stage, user)
self._delete_lockdown_artifacts(stage, user)
# These additional checks/deletes are done to prevent a timing attack that creates tokens
# with a compromised token that is simultaneously being deleted.
while self._has_lockdown_artifacts(stage, user):
with atomic():
self._delete_lockdown_artifacts(stage, user)
if stage.deactivate_user:
try:
self._sync_deactivated_user_to_outgoing_providers(user)
except Exception as exc: # noqa: BLE001
# Local lockdown has already committed. Provider sync failures
# must not reopen access or mark the lockdown itself as failed.
self.logger.warning(
"Failed to sync account lockdown deactivation to outgoing providers",
user=user.username,
exc=exc,
)
self._emit_lockdown_event(request, user, reason)
def dispatch(self, request: HttpRequest) -> HttpResponse:
"""Execute account lockdown actions."""
self.request = request
stage: AccountLockdownStage = self.executor.current_stage
pending_user = self.get_pending_user()
if not pending_user.is_authenticated:
self.logger.warning("No target user found for account lockdown")
return self.executor.stage_invalid(TARGET_REQUIRED_MESSAGE)
user = get_lockdown_target_users().filter(pk=pending_user.pk).first()
if user is None:
self.logger.warning("Target user is not eligible for account lockdown")
return self.executor.stage_invalid(TARGET_REQUIRED_MESSAGE)
if not can_lock_user(request.user, user):
self.logger.warning(
"Permission denied for account lockdown",
actor=getattr(request.user, "username", None),
target=user.username,
)
return self.executor.stage_invalid(PERMISSION_DENIED_MESSAGE)
reason = self.get_reason()
self_service = self.is_self_service(request, user)
if self_service and stage.delete_sessions and not stage.self_service_completion_flow:
self.logger.warning("No completion flow configured for self-service account lockdown")
return self.executor.stage_invalid(SELF_SERVICE_COMPLETION_FLOW_REQUIRED_MESSAGE)
self.logger.info(
"Executing account lockdown",
user=user.username,
reason=reason,
self_service=self_service,
deactivate_user=stage.deactivate_user,
set_unusable_password=stage.set_unusable_password,
delete_sessions=stage.delete_sessions,
revoke_tokens=stage.revoke_tokens,
)
try:
self._lockdown_user(request, stage, user, reason)
self.logger.info("Account lockdown completed", user=user.username)
except Exception as exc: # noqa: BLE001
# Convert unexpected lockdown errors to a flow-stage failure instead
# of leaking an exception through the flow executor.
self.logger.warning("Account lockdown failed", user=user.username, exc=exc)
return self.executor.stage_invalid(ACCOUNT_LOCKDOWN_FAILED_MESSAGE)
if self_service:
if stage.delete_sessions:
return self._self_service_completion_response(request)
return self.executor.stage_ok()
return self.executor.stage_ok()
def _self_service_completion_response(self, request: HttpRequest) -> HttpResponse:
"""Redirect to completion flow after self-service lockdown.
Since all sessions are deleted, the user cannot continue in the flow.
Redirect them to an unauthenticated completion flow that shows the
lockdown message.
We use a direct HTTP redirect instead of a challenge because the
flow executor's challenge handling may try to access the session
which we just deleted.
"""
stage: AccountLockdownStage = self.executor.current_stage
completion_flow = stage.self_service_completion_flow
if completion_flow:
# Flush the current request's session to prevent Django's session
# middleware from trying to save a deleted session
if hasattr(request, "session"):
request.session.flush()
redirect_to = reverse(
"authentik_core:if-flow",
kwargs={"flow_slug": completion_flow.slug},
)
return HttpResponseRedirect(redirect_to)
return self.executor.stage_invalid(SELF_SERVICE_COMPLETION_FLOW_REQUIRED_MESSAGE)

View File

@@ -1,148 +0,0 @@
"""Test Users Account Lockdown API"""
from json import loads
from unittest.mock import MagicMock, patch
from urllib.parse import urlparse
from django.urls import reverse
from rest_framework.test import APITestCase
from authentik.core.tests.utils import (
create_test_brand,
create_test_flow,
create_test_user,
)
from authentik.enterprise.stages.account_lockdown.models import AccountLockdownStage
from authentik.flows.models import FlowDesignation, FlowStageBinding
from authentik.flows.planner import PLAN_CONTEXT_PENDING_USER
from authentik.flows.views.executor import SESSION_KEY_PLAN
from authentik.lib.generators import generate_id
# Patch for enterprise license check
patch_license = patch(
"authentik.enterprise.models.LicenseUsageStatus.is_valid",
MagicMock(return_value=True),
)
@patch_license
class AccountLockdownAPITestCase(APITestCase):
"""Shared helpers for account lockdown API tests."""
def setUp(self) -> None:
self.lockdown_flow = create_test_flow(FlowDesignation.STAGE_CONFIGURATION)
self.lockdown_stage = AccountLockdownStage.objects.create(name=generate_id())
FlowStageBinding.objects.create(
target=self.lockdown_flow,
stage=self.lockdown_stage,
order=0,
)
self.brand = create_test_brand()
self.brand.flow_lockdown = self.lockdown_flow
self.brand.save()
def create_user_with_email(self):
"""Create a regular user with a unique email address."""
user = create_test_user()
user.email = f"{generate_id()}@test.com"
user.save()
return user
def assert_redirect_targets(self, response, user):
"""Assert that a response contains a pre-planned lockdown flow link for a user."""
self.assertEqual(response.status_code, 200)
body = loads(response.content)
self.assertIn(self.lockdown_flow.slug, body["link"])
self.assertEqual(urlparse(body["link"]).query, "")
plan = self.client.session[SESSION_KEY_PLAN]
self.assertEqual(plan.context[PLAN_CONTEXT_PENDING_USER].pk, user.pk)
def assert_no_flow_configured(self, response):
"""Assert that the API reports a missing lockdown flow."""
self.assertEqual(response.status_code, 400)
body = loads(response.content)
self.assertIn("No lockdown flow configured", body["non_field_errors"][0])
@patch_license
class TestUsersAccountLockdownAPI(AccountLockdownAPITestCase):
"""Test Users Account Lockdown API"""
def setUp(self) -> None:
super().setUp()
self.actor = create_test_user()
self.user = self.create_user_with_email()
def test_account_lockdown_with_change_user_returns_redirect(self):
"""Test that account lockdown allows users with change_user permission."""
self.actor.assign_perms_to_managed_role("authentik_core.change_user", self.user)
self.client.force_login(self.actor)
response = self.client.post(
reverse("authentik_api:user-account-lockdown"),
data={"user": self.user.pk},
format="json",
)
self.assert_redirect_targets(response, self.user)
def test_account_lockdown_no_flow_configured(self):
"""Test account lockdown when no flow is configured"""
self.brand.flow_lockdown = None
self.brand.save()
self.actor.assign_perms_to_managed_role("authentik_core.change_user", self.user)
self.client.force_login(self.actor)
response = self.client.post(
reverse("authentik_api:user-account-lockdown"),
data={"user": self.user.pk},
format="json",
)
self.assert_no_flow_configured(response)
def test_account_lockdown_unauthenticated(self):
"""Test account lockdown requires authentication"""
response = self.client.post(
reverse("authentik_api:user-account-lockdown"),
data={"user": self.user.pk},
format="json",
)
self.assertEqual(response.status_code, 403)
def test_account_lockdown_without_change_user_denied(self):
"""Test account lockdown denies users without change_user permission."""
self.client.force_login(self.actor)
response = self.client.post(
reverse("authentik_api:user-account-lockdown"),
data={"user": self.user.pk},
format="json",
)
self.assertEqual(response.status_code, 403)
def test_account_lockdown_self_returns_redirect(self):
"""Test successful self-service account lockdown returns a direct redirect."""
self.client.force_login(self.user)
response = self.client.post(
reverse("authentik_api:user-account-lockdown"),
data={},
format="json",
)
self.assert_redirect_targets(response, self.user)
def test_account_lockdown_self_target_without_change_user_returns_redirect(self):
"""Test self-service does not require change_user permission."""
self.client.force_login(self.user)
response = self.client.post(
reverse("authentik_api:user-account-lockdown"),
data={"user": self.user.pk},
format="json",
)
self.assert_redirect_targets(response, self.user)

View File

@@ -1,46 +0,0 @@
"""Tests for the packaged account-lockdown blueprint."""
from unittest.mock import patch
from django.test import TransactionTestCase
from authentik.blueprints.models import BlueprintInstance
from authentik.blueprints.v1.importer import Importer
from authentik.blueprints.v1.tasks import blueprints_find, check_blueprint_v1_file
from authentik.enterprise.license import LicenseKey
from authentik.flows.models import Flow
BLUEPRINT_PATH = "example/flow-default-account-lockdown.yaml"
class TestAccountLockdownBlueprint(TransactionTestCase):
"""Test the packaged account-lockdown blueprint behavior."""
def test_blueprint_is_not_auto_instantiated(self):
"""Test the packaged blueprint is opt-in and skipped by discovery."""
BlueprintInstance.objects.filter(path=BLUEPRINT_PATH).delete()
blueprint = next(item for item in blueprints_find() if item.path == BLUEPRINT_PATH)
check_blueprint_v1_file(blueprint)
self.assertFalse(BlueprintInstance.objects.filter(path=BLUEPRINT_PATH).exists())
def test_blueprint_requires_licensed_context(self):
"""Test manual import only creates flows when enterprise is licensed."""
content = BlueprintInstance(path=BLUEPRINT_PATH).retrieve()
license_key = LicenseKey("test", 253402300799, "Test license", 1000, 1000)
with patch("authentik.enterprise.license.LicenseKey.get_total", return_value=license_key):
importer = Importer.from_string(content, {"goauthentik.io/enterprise/licensed": False})
valid, logs = importer.validate()
self.assertTrue(valid, logs)
self.assertTrue(importer.apply())
self.assertFalse(Flow.objects.filter(slug="default-account-lockdown").exists())
self.assertFalse(Flow.objects.filter(slug="default-account-lockdown-complete").exists())
importer = Importer.from_string(content, {"goauthentik.io/enterprise/licensed": True})
valid, logs = importer.validate()
self.assertTrue(valid, logs)
self.assertTrue(importer.apply())
self.assertTrue(Flow.objects.filter(slug="default-account-lockdown").exists())
self.assertTrue(Flow.objects.filter(slug="default-account-lockdown-complete").exists())

View File

@@ -1,627 +0,0 @@
"""Account lockdown stage tests"""
import json
from dataclasses import asdict
from threading import Event as ThreadEvent
from threading import Thread
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
from django.db import connection
from django.http import HttpResponse
from django.test import TransactionTestCase
from django.urls import reverse
from django.utils import timezone
from dramatiq.results.errors import ResultTimeout
from authentik.core.models import AuthenticatedSession, Session, Token, TokenIntents
from authentik.core.tests.utils import (
RequestFactory,
create_test_admin_user,
create_test_cert,
create_test_flow,
create_test_user,
)
from authentik.enterprise.stages.account_lockdown.models import AccountLockdownStage
from authentik.enterprise.stages.account_lockdown.stage import (
LOCKDOWN_EVENT_ACTION_ID,
PLAN_CONTEXT_LOCKDOWN_REASON,
AccountLockdownStageView,
can_lock_user,
)
from authentik.events.models import Event, EventAction
from authentik.flows.markers import StageMarker
from authentik.flows.models import FlowDesignation, FlowStageBinding
from authentik.flows.planner import PLAN_CONTEXT_PENDING_USER, FlowPlan
from authentik.flows.tests import FlowTestCase
from authentik.lib.generators import generate_id
from authentik.lib.utils.reflection import class_to_path
from authentik.providers.oauth2.id_token import IDToken
from authentik.providers.oauth2.models import (
AccessToken,
AuthorizationCode,
DeviceToken,
OAuth2Provider,
RedirectURI,
RedirectURIMatchingMode,
RefreshToken,
)
from authentik.providers.saml.models import SAMLProvider, SAMLSession
from authentik.stages.prompt.stage import PLAN_CONTEXT_PROMPT
patch_enterprise_enabled = patch(
"authentik.enterprise.apps.AuthentikEnterpriseConfig.check_enabled",
return_value=True,
)
class AccountLockdownStageTestMixin:
"""Shared setup helpers for account lockdown stage tests."""
@classmethod
def setUpClass(cls):
cls.patch_enterprise_enabled = patch_enterprise_enabled.start()
cls.patch_event_dispatch = patch("authentik.events.tasks.event_trigger_dispatch.send")
cls.patch_event_dispatch.start()
super().setUpClass()
@classmethod
def tearDownClass(cls):
cls.patch_event_dispatch.stop()
patch_enterprise_enabled.stop()
super().tearDownClass()
def setUp(self):
super().setUp()
self.user = create_test_admin_user()
self.target_user = create_test_admin_user()
self.flow = create_test_flow(FlowDesignation.STAGE_CONFIGURATION)
self.stage = AccountLockdownStage.objects.create(
name="lockdown",
)
self.binding = FlowStageBinding.objects.create(target=self.flow, stage=self.stage, order=0)
self.request_factory = RequestFactory()
def make_stage_view(self, plan: FlowPlan):
def _stage_ok():
return HttpResponse(status=204)
def _stage_invalid(_error_message=None):
return HttpResponse(status=400)
return AccountLockdownStageView(
SimpleNamespace(
plan=plan,
current_stage=self.stage,
current_binding=self.binding,
flow=self.flow,
stage_ok=_stage_ok,
stage_invalid=_stage_invalid,
)
)
def make_request(self, *, user=None, query=None):
return self.request_factory.post(
reverse("authentik_api:flow-executor", kwargs={"flow_slug": self.flow.slug}),
query_params=query or {},
user=user,
)
def get_lockdown_event(self):
"""Return the account-lockdown user-write event."""
return Event.objects.filter(
action=EventAction.USER_WRITE,
context__action_id=LOCKDOWN_EVENT_ACTION_ID,
).first()
class TestAccountLockdownStage(AccountLockdownStageTestMixin, FlowTestCase):
"""Account lockdown stage tests"""
def test_lockdown_no_target(self):
"""Test lockdown stage with no pending user fails"""
plan = FlowPlan(flow_pk=self.flow.pk.hex, bindings=[self.binding], markers=[StageMarker()])
view = self.make_stage_view(plan)
response = view.dispatch(self.make_request())
self.assertEqual(response.status_code, 400)
def test_lockdown_with_pending_user(self):
"""Test lockdown stage with a pending target user."""
self.target_user.is_active = True
self.target_user.save()
plan = FlowPlan(flow_pk=self.flow.pk.hex, bindings=[self.binding], markers=[StageMarker()])
plan.context[PLAN_CONTEXT_LOCKDOWN_REASON] = "Security incident"
plan.context[PLAN_CONTEXT_PENDING_USER] = self.target_user
view = self.make_stage_view(plan)
request = self.make_request(user=self.user)
self.assertTrue(can_lock_user(request.user, self.target_user))
response = view.dispatch(request)
self.target_user.refresh_from_db()
self.assertFalse(self.target_user.is_active)
self.assertFalse(self.target_user.has_usable_password())
self.assertEqual(response.status_code, 204)
# Check event was created
event = self.get_lockdown_event()
self.assertIsNotNone(event)
self.assertEqual(event.context["action_id"], LOCKDOWN_EVENT_ACTION_ID)
self.assertEqual(event.context["reason"], "Security incident")
self.assertEqual(event.context["affected_user"], self.target_user.username)
def test_lockdown_with_pending_user_reason(self):
"""Test lockdown stage with a pending target and explicit reason."""
self.target_user.is_active = True
self.target_user.save()
plan = FlowPlan(flow_pk=self.flow.pk.hex, bindings=[self.binding], markers=[StageMarker()])
plan.context[PLAN_CONTEXT_LOCKDOWN_REASON] = "Compromised account"
plan.context[PLAN_CONTEXT_PENDING_USER] = self.target_user
view = self.make_stage_view(plan)
request = self.make_request(user=self.user)
self.assertTrue(can_lock_user(request.user, self.target_user))
response = view.dispatch(request)
self.target_user.refresh_from_db()
self.assertFalse(self.target_user.is_active)
self.assertEqual(response.status_code, 204)
def test_lockdown_reason_from_prompt(self):
"""Test lockdown stage reads the reason from prompt data."""
self.target_user.is_active = True
self.target_user.save()
plan = FlowPlan(flow_pk=self.flow.pk.hex, bindings=[self.binding], markers=[StageMarker()])
plan.context[PLAN_CONTEXT_PROMPT] = {
PLAN_CONTEXT_LOCKDOWN_REASON: "User requested lockdown",
}
view = self.make_stage_view(plan)
request = self.make_request(user=self.user)
view._lockdown_user(request, self.stage, self.target_user, view.get_reason())
event = self.get_lockdown_event()
self.assertIsNotNone(event)
self.assertEqual(event.context["reason"], "User requested lockdown")
def test_lockdown_event_failure_does_not_fail_self_service(self):
"""Test lockdown still succeeds when event emission fails."""
self.stage.delete_sessions = False
self.stage.save()
self.target_user.is_active = True
self.target_user.save()
plan = FlowPlan(flow_pk=self.flow.pk.hex, bindings=[self.binding], markers=[StageMarker()])
plan.context[PLAN_CONTEXT_PENDING_USER] = self.target_user
view = self.make_stage_view(plan)
request = self.make_request(user=self.target_user)
original_event_new = Event.new
def _event_new_side_effect(action, *args, **kwargs):
if (
action == EventAction.USER_WRITE
and kwargs.get("action_id") == LOCKDOWN_EVENT_ACTION_ID
):
raise RuntimeError("simulated event failure")
return original_event_new(action, *args, **kwargs)
with patch(
"authentik.enterprise.stages.account_lockdown.stage.Event.new",
side_effect=_event_new_side_effect,
):
view._lockdown_user(request, self.stage, self.target_user, view.get_reason())
self.target_user.refresh_from_db()
self.assertFalse(self.target_user.is_active)
def test_dispatch_records_success_when_event_emission_fails(self):
"""Test dispatch still completes if event emission fails."""
self.stage.delete_sessions = False
self.stage.save()
self.target_user.is_active = True
self.target_user.save()
plan = FlowPlan(flow_pk=self.flow.pk.hex, bindings=[self.binding], markers=[StageMarker()])
plan.context[PLAN_CONTEXT_PENDING_USER] = self.target_user
view = self.make_stage_view(plan)
request = self.make_request(
user=self.target_user,
)
original_event_new = Event.new
def _event_new_side_effect(action, *args, **kwargs):
if (
action == EventAction.USER_WRITE
and kwargs.get("action_id") == LOCKDOWN_EVENT_ACTION_ID
):
raise RuntimeError("simulated event failure")
return original_event_new(action, *args, **kwargs)
with patch(
"authentik.enterprise.stages.account_lockdown.stage.Event.new",
side_effect=_event_new_side_effect,
):
response = view.dispatch(request)
self.target_user.refresh_from_db()
self.assertFalse(self.target_user.is_active)
self.assertEqual(response.status_code, 204)
def test_lockdown_self_service_redirects_to_completion_flow(self):
"""Test self-service lockdown redirects to completion flow when sessions are deleted."""
completion_flow = create_test_flow(FlowDesignation.STAGE_CONFIGURATION)
self.stage.self_service_completion_flow = completion_flow
self.stage.save()
self.target_user.is_active = True
self.target_user.save()
plan = FlowPlan(flow_pk=self.flow.pk.hex, bindings=[self.binding], markers=[StageMarker()])
view = self.make_stage_view(plan)
request = self.make_request(user=self.target_user)
view._lockdown_user(request, self.stage, self.target_user, view.get_reason())
response = view._self_service_completion_response(request)
self.assertEqual(response.status_code, 302)
self.assertEqual(
response.url,
reverse("authentik_core:if-flow", kwargs={"flow_slug": completion_flow.slug}),
)
def test_lockdown_self_service_requires_completion_flow(self):
"""Test self-service lockdown fails before deleting sessions without a completion flow."""
self.stage.self_service_completion_flow = None
self.stage.save()
self.target_user.is_active = True
self.target_user.save()
plan = FlowPlan(flow_pk=self.flow.pk.hex, bindings=[self.binding], markers=[StageMarker()])
plan.context[PLAN_CONTEXT_PENDING_USER] = self.target_user
view = self.make_stage_view(plan)
request = self.make_request(user=self.target_user)
response = view.dispatch(request)
self.assertEqual(response.status_code, 400)
self.target_user.refresh_from_db()
self.assertTrue(self.target_user.is_active)
def test_lockdown_denies_other_user_without_permission(self):
"""Test lockdown stage rejects non-self requests without change_user permission."""
actor = create_test_user()
plan = FlowPlan(flow_pk=self.flow.pk.hex, bindings=[self.binding], markers=[StageMarker()])
plan.context[PLAN_CONTEXT_PENDING_USER] = self.target_user
view = self.make_stage_view(plan)
request = self.make_request(user=actor)
self.assertFalse(can_lock_user(request.user, self.target_user))
response = view.dispatch(request)
self.assertEqual(response.status_code, 400)
def test_lockdown_revokes_tokens(self):
"""Test lockdown stage revokes tokens"""
Token.objects.create(
user=self.target_user,
identifier="test-token",
intent=TokenIntents.INTENT_API,
key=generate_id(),
expiring=False,
)
self.assertEqual(Token.objects.filter(user=self.target_user).count(), 1)
plan = FlowPlan(flow_pk=self.flow.pk.hex, bindings=[self.binding], markers=[StageMarker()])
view = self.make_stage_view(plan)
view._lockdown_user(self.make_request(user=self.user), self.stage, self.target_user, "")
self.assertEqual(Token.objects.filter(user=self.target_user).count(), 0)
def test_lockdown_revokes_provider_tokens(self):
"""Test lockdown stage revokes provider tokens and sessions."""
oauth_provider = OAuth2Provider.objects.create(
name=generate_id(),
authorization_flow=create_test_flow(),
redirect_uris=[
RedirectURI(RedirectURIMatchingMode.STRICT, "http://testserver/callback")
],
signing_key=create_test_cert(),
)
saml_provider = SAMLProvider.objects.create(
name=generate_id(),
authorization_flow=create_test_flow(),
acs_url="https://sp.example.com/acs",
issuer_override="https://idp.example.com",
)
session = Session.objects.create(
session_key=generate_id(),
expires=timezone.now() + timezone.timedelta(hours=1),
last_ip="127.0.0.1",
)
auth_session = AuthenticatedSession.objects.create(
session=session,
user=self.target_user,
)
grant_kwargs = {
"provider": oauth_provider,
"user": self.target_user,
"auth_time": timezone.now(),
"_scope": "openid profile",
"expiring": False,
}
token_kwargs = grant_kwargs | {"_id_token": json.dumps(asdict(IDToken("foo", "bar")))}
AuthorizationCode.objects.create(
code=generate_id(),
session=auth_session,
**grant_kwargs,
)
AccessToken.objects.create(
token=generate_id(),
session=auth_session,
**token_kwargs,
)
RefreshToken.objects.create(
token=generate_id(),
session=auth_session,
**token_kwargs,
)
DeviceToken.objects.create(
provider=oauth_provider,
user=self.target_user,
session=auth_session,
_scope="openid profile",
expiring=False,
)
SAMLSession.objects.create(
provider=saml_provider,
user=self.target_user,
session=auth_session,
session_index=generate_id(),
name_id=self.target_user.email,
expires=timezone.now() + timezone.timedelta(hours=1),
expiring=True,
)
plan = FlowPlan(flow_pk=self.flow.pk.hex, bindings=[self.binding], markers=[StageMarker()])
view = self.make_stage_view(plan)
view._lockdown_user(self.make_request(user=self.user), self.stage, self.target_user, "")
self.assertEqual(AuthorizationCode.objects.filter(user=self.target_user).count(), 0)
self.assertEqual(AccessToken.objects.filter(user=self.target_user).count(), 0)
self.assertEqual(RefreshToken.objects.filter(user=self.target_user).count(), 0)
self.assertEqual(DeviceToken.objects.filter(user=self.target_user).count(), 0)
self.assertEqual(SAMLSession.objects.filter(user=self.target_user).count(), 0)
def test_lockdown_selective_actions(self):
"""Test lockdown stage with selective actions"""
self.stage.deactivate_user = True
self.stage.set_unusable_password = False
self.stage.delete_sessions = False
self.stage.revoke_tokens = False
self.stage.save()
self.target_user.is_active = True
self.target_user.set_password("testpassword")
self.target_user.save()
Token.objects.create(
user=self.target_user,
identifier="test-token",
intent=TokenIntents.INTENT_API,
key=generate_id(),
expiring=False,
)
plan = FlowPlan(flow_pk=self.flow.pk.hex, bindings=[self.binding], markers=[StageMarker()])
view = self.make_stage_view(plan)
view._lockdown_user(self.make_request(user=self.user), self.stage, self.target_user, "")
self.target_user.refresh_from_db()
# User should be deactivated
self.assertFalse(self.target_user.is_active)
# Password should still be usable
self.assertTrue(self.target_user.has_usable_password())
# Token should still exist
self.assertEqual(Token.objects.filter(user=self.target_user).count(), 1)
def test_lockdown_no_actions(self):
"""Test lockdown stage with all actions disabled"""
self.stage.deactivate_user = False
self.stage.set_unusable_password = False
self.stage.delete_sessions = False
self.stage.revoke_tokens = False
self.stage.save()
self.target_user.is_active = True
self.target_user.set_password("testpassword")
self.target_user.save()
plan = FlowPlan(flow_pk=self.flow.pk.hex, bindings=[self.binding], markers=[StageMarker()])
view = self.make_stage_view(plan)
view._lockdown_user(self.make_request(user=self.user), self.stage, self.target_user, "")
self.target_user.refresh_from_db()
# User should still be active
self.assertTrue(self.target_user.is_active)
# Password should still be usable
self.assertTrue(self.target_user.has_usable_password())
# Event should still be created
event = self.get_lockdown_event()
self.assertIsNotNone(event)
def test_lockdown_deactivation_inhibits_signal_dispatch_until_after_commit(self):
"""Test lockdown queues explicit outgoing syncs after the deactivation transaction."""
plan = FlowPlan(flow_pk=self.flow.pk.hex, bindings=[self.binding], markers=[StageMarker()])
view = self.make_stage_view(plan)
with (
patch(
"authentik.enterprise.stages.account_lockdown.stage.sync_outgoing_inhibit_dispatch"
) as inhibit,
patch.object(view, "_sync_deactivated_user_to_outgoing_providers") as sync_outgoing,
):
view._lockdown_user(self.make_request(user=self.user), self.stage, self.target_user, "")
inhibit.assert_called_once()
sync_outgoing.assert_called_once()
synced_user = sync_outgoing.call_args.args[0]
self.assertEqual(synced_user.pk, self.target_user.pk)
self.assertFalse(synced_user.is_active)
def test_lockdown_waits_for_direct_outgoing_provider_syncs(self):
"""Test direct outgoing sync tasks are enqueued and waited on."""
plan = FlowPlan(flow_pk=self.flow.pk.hex, bindings=[self.binding], markers=[StageMarker()])
view = self.make_stage_view(plan)
provider = SimpleNamespace(name="outgoing", pk=1, sync_page_timeout="seconds=5")
task_sync_direct = MagicMock()
task_sync_direct.message_with_options.return_value = "direct-message"
provider_model = SimpleNamespace(
objects=SimpleNamespace(filter=MagicMock(return_value=[provider]))
)
task_group = MagicMock()
with (
patch(
"authentik.enterprise.stages.account_lockdown.stage.get_outgoing_sync_tasks",
return_value=((provider_model, task_sync_direct),),
),
patch(
"authentik.enterprise.stages.account_lockdown.stage.group",
return_value=task_group,
) as task_group_cls,
):
view._sync_deactivated_user_to_outgoing_providers(self.target_user)
task_sync_direct.message_with_options.assert_called_once_with(
args=(class_to_path(type(self.target_user)), self.target_user.pk, provider.pk),
rel_obj=provider,
time_limit=5000,
uid=f"{provider.name}:user:{self.target_user.pk}:direct",
)
task_group_cls.assert_called_once_with(["direct-message"])
task_group.run.return_value.wait.assert_called_once_with(timeout=5000)
def test_lockdown_outgoing_provider_sync_timeout_leaves_tasks_running(self):
"""Test timeout while waiting for direct outgoing syncs does not fail lockdown."""
plan = FlowPlan(flow_pk=self.flow.pk.hex, bindings=[self.binding], markers=[StageMarker()])
view = self.make_stage_view(plan)
provider = SimpleNamespace(name="outgoing", pk=1, sync_page_timeout="seconds=5")
task_sync_direct = MagicMock()
task_sync_direct.message_with_options.return_value = "direct-message"
provider_model = SimpleNamespace(
objects=SimpleNamespace(filter=MagicMock(return_value=[provider]))
)
task_group = MagicMock()
task_group.run.return_value.wait.side_effect = ResultTimeout("timed out")
with (
patch(
"authentik.enterprise.stages.account_lockdown.stage.get_outgoing_sync_tasks",
return_value=((provider_model, task_sync_direct),),
),
patch(
"authentik.enterprise.stages.account_lockdown.stage.group",
return_value=task_group,
),
):
view._sync_deactivated_user_to_outgoing_providers(self.target_user)
task_group.run.assert_called_once_with()
task_group.run.return_value.wait.assert_called_once_with(timeout=5000)
def test_lockdown_outgoing_provider_sync_failure_does_not_fail_lockdown(self):
"""Test completed local lockdown still emits an event if outgoing sync fails."""
plan = FlowPlan(flow_pk=self.flow.pk.hex, bindings=[self.binding], markers=[StageMarker()])
view = self.make_stage_view(plan)
with patch.object(
view,
"_sync_deactivated_user_to_outgoing_providers",
side_effect=ValueError("sync failed"),
):
view._lockdown_user(self.make_request(user=self.user), self.stage, self.target_user, "")
self.target_user.refresh_from_db()
self.assertFalse(self.target_user.is_active)
event = self.get_lockdown_event()
self.assertIsNotNone(event)
class TestAccountLockdownStageConcurrency(AccountLockdownStageTestMixin, TransactionTestCase):
"""Account lockdown concurrency tests."""
def test_lockdown_retries_when_another_transaction_recreates_a_token(self):
"""Lockdown should remove a token recreated before the retry check runs."""
Token.objects.create(
user=self.target_user,
identifier=f"initial-token-{generate_id()}",
intent=TokenIntents.INTENT_API,
key=generate_id(),
expiring=False,
)
plan = FlowPlan(flow_pk=self.flow.pk.hex, bindings=[self.binding], markers=[StageMarker()])
view = self.make_stage_view(plan)
original_has_artifacts = view._has_lockdown_artifacts
target_user = self.target_user
thread_ready = ThreadEvent()
start_create = ThreadEvent()
thread_done = ThreadEvent()
thread_errors = []
class TokenCreatorThread(Thread):
__test__ = False
def run(self):
try:
thread_ready.set()
if not start_create.wait(timeout=5):
thread_errors.append("timed out waiting to recreate token")
return
Token.objects.create(
user=target_user,
identifier=f"concurrent-token-{generate_id()}",
intent=TokenIntents.INTENT_API,
key=generate_id(),
expiring=False,
)
except Exception as exc: # noqa: BLE001
thread_errors.append(exc)
finally:
thread_done.set()
connection.close()
def has_artifacts_after_concurrent_create(stage, user):
if not start_create.is_set():
start_create.set()
self.assertTrue(
thread_done.wait(timeout=30),
(
"Concurrent token creation did not complete "
f"before retry check: {thread_errors}"
),
)
return original_has_artifacts(stage, user)
creator = TokenCreatorThread()
with patch.object(
view, "_has_lockdown_artifacts", side_effect=has_artifacts_after_concurrent_create
):
creator.start()
self.assertTrue(
thread_ready.wait(timeout=5),
"Concurrent token creation thread did not start",
)
view._lockdown_user(self.make_request(user=self.user), self.stage, self.target_user, "")
creator.join()
self.assertEqual(thread_errors, [])
self.assertEqual(Token.objects.filter(user=self.target_user).count(), 0)

View File

@@ -1,5 +0,0 @@
"""API URLs"""
from authentik.enterprise.stages.account_lockdown.api import AccountLockdownStageViewSet
api_urlpatterns = [("stages/account_lockdown", AccountLockdownStageViewSet)]

View File

@@ -8,6 +8,7 @@ from inspect import currentframe
from typing import Any
from uuid import uuid4
from asgiref.sync import async_to_sync
from channels.layers import get_channel_layer
from django.apps import apps
from django.db import models
@@ -409,7 +410,7 @@ class NotificationTransport(TasksModel, SerializerModel):
)
notification.save()
layer = get_channel_layer()
layer.group_send_blocking(
async_to_sync(layer.group_send)(
build_user_group(notification.user),
{
"type": "event.notification",

View File

@@ -11,7 +11,7 @@ from django.http import HttpRequest
from rest_framework.request import Request
from authentik.core.models import AuthenticatedSession, User
from authentik.core.signals import login_failed, password_changed, password_hash_changed
from authentik.core.signals import login_failed, password_changed
from authentik.events.models import Event, EventAction
from authentik.flows.models import Stage
from authentik.flows.planner import (
@@ -112,15 +112,8 @@ def on_invitation_used(sender, request: HttpRequest, invitation: Invitation, **_
)
@receiver(password_hash_changed)
@receiver(password_changed)
def on_password_changed(
sender,
user: User,
password: str | None = None,
request: HttpRequest | None = None,
**_,
):
def on_password_changed(sender, user: User, password: str, request: HttpRequest | None, **_):
"""Log password change"""
Event.new(EventAction.PASSWORD_SET).from_http(request, user=user)

View File

@@ -2,7 +2,6 @@
from urllib.parse import urlencode
from django.contrib.auth.hashers import make_password
from django.contrib.contenttypes.models import ContentType
from django.test import RequestFactory, TestCase
from django.views.debug import SafeExceptionReporterFilter
@@ -11,7 +10,7 @@ from guardian.shortcuts import get_anonymous_user
from authentik.brands.models import Brand
from authentik.core.models import Group, User
from authentik.core.tests.utils import create_test_user
from authentik.events.models import Event, EventAction
from authentik.events.models import Event
from authentik.flows.planner import PLAN_CONTEXT_PENDING_USER, FlowPlan
from authentik.flows.views.executor import QS_QUERY, SESSION_KEY_PLAN
from authentik.lib.generators import generate_id
@@ -214,14 +213,3 @@ class TestEvents(TestCase):
event = Event.new("unittest", foo="foo bar \u0000 baz")
event.save()
self.assertEqual(event.context["foo"], "foo bar baz")
def test_password_set_signal_on_set_password_from_hash(self):
"""Changing password from hash should still emit an audit event."""
user = create_test_user()
old_count = Event.objects.filter(action=EventAction.PASSWORD_SET, user__pk=user.pk).count()
user.set_password_from_hash(make_password(generate_id()))
user.save()
new_count = Event.objects.filter(action=EventAction.PASSWORD_SET, user__pk=user.pk).count()
self.assertEqual(new_count, old_count + 1)

View File

@@ -1,5 +1,5 @@
<html>
<script>
<script nonce="{{ csp_nonce }}">
window.parent.postMessage({
message: "submit",
source: "goauthentik.io",

View File

@@ -17,13 +17,13 @@
<meta name="sentry-trace" content="{{ sentry_trace }}" />
<link rel="prefetch" href="{{ flow_background_url }}" />
{% include "base/header_js.html" %}
<style data-id="flow-sfe">
<style data-id="flow-sfe" nonce="{{ csp_nonce }}">
html,
body {
height: 100%;
}
body {
background-image: url("{{ flow_background_url|iriencode|safe }}");
background-image: url("{{ flow_background_url }}");
background-repeat: no-repeat;
background-size: cover;
}
@@ -43,13 +43,13 @@
max-width: 100%;
}
</style>
<script src="{% static 'dist/sfe/index.js' %}"></script>
</head>
<body class="d-flex align-items-center py-4 bg-body-tertiary">
<div class="card m-auto">
<main class="form-signin w-100 m-auto" id="flow-sfe-container">
</main>
<span class="mt-3 mb-0 text-muted text-center">{% trans 'Powered by authentik' %}</span>
</div>
<script src="{% static 'dist/sfe/index.js' %}"></script>
<div class="card m-auto">
<main class="form-signin w-100 m-auto" id="flow-sfe-container">
</main>
<span class="mt-3 mb-0 text-muted text-center">{% trans 'Powered by authentik' %}</span>
</div>
</body>
</html>

View File

@@ -12,7 +12,7 @@
{% comment %}
@see {@link web/types/webcomponents.d.ts} for type definitions.
{% endcomment %}
<script data-id="shady-dom">
<script data-id="shady-dom" nonce="{{ csp_nonce }}">
"use strict";
window.ShadyDOM = window.ShadyDOM || {}
@@ -20,7 +20,7 @@
</script>
{% endif %}
{% include "base/header_js.html" %}
<script data-id="flow-config">
<script data-id="flow-config" nonce="{{ csp_nonce }}">
"use strict";
window.authentik.flow = {
@@ -37,9 +37,9 @@
{% block head %}
<script src="{% versioned_script 'dist/flow/FlowInterface-%v.js' %}" type="module"></script>
<style data-id="flow-css">
<style data-id="flow-css" nonce="{{ csp_nonce }}">
:root {
--ak-global--background-image: url("{{ flow_background_url|iriencode|safe }}");
--ak-global--background-image: url("{{ flow_background_url }}");
}
</style>
{% endblock %}
@@ -55,10 +55,10 @@
loading
>
{% include "base/placeholder.html" %}
<ak-brand-links name="flow-links" slot="footer"></ak-brand-links>
</ak-flow-executor>
<ak-flow-inspector
slot="panel"
id="flow-inspector"

View File

@@ -1,14 +1,12 @@
"""stage view tests"""
from collections.abc import Callable
from unittest.mock import patch
from django.test import RequestFactory, TestCase
from django.urls import reverse
from authentik.core.tests.utils import RequestFactory as AuthentikRequestFactory
from authentik.core.tests.utils import create_test_flow
from authentik.flows.models import Flow, FlowStageBinding
from authentik.flows.models import FlowStageBinding
from authentik.flows.stage import StageView
from authentik.flows.views.executor import FlowExecutorView
from authentik.lib.utils.reflection import all_subclasses
@@ -44,46 +42,6 @@ class TestViews(TestCase):
"/static/dist/assets/images/flow_background.jpg",
)
def test_flow_interface_css_background_preserves_presigned_url_query(self):
"""Test flow CSS keeps signed URL query separators intact."""
flow = create_test_flow()
background_url = (
"https://s3.ca-central-1.amazonaws.com/example/media/public/background.png"
"?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=credential"
"&X-Amz-Signature=signature"
)
with patch.object(Flow, "background_url", return_value=background_url):
response = self.client.get(
reverse("authentik_core:if-flow", kwargs={"flow_slug": flow.slug})
)
self.assertContains(
response,
f'--ak-global--background-image: url("{background_url}");',
html=False,
)
def test_flow_sfe_css_background_preserves_presigned_url_query(self):
"""Test SFE flow CSS keeps signed URL query separators intact."""
flow = create_test_flow()
background_url = (
"https://s3.ca-central-1.amazonaws.com/example/media/public/background.png"
"?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=credential"
"&X-Amz-Signature=signature"
)
with patch.object(Flow, "background_url", return_value=background_url):
response = self.client.get(
reverse("authentik_core:if-flow", kwargs={"flow_slug": flow.slug}) + "?sfe"
)
self.assertContains(
response,
f'background-image: url("{background_url}");',
html=False,
)
def view_tester_factory(view_class: type[StageView]) -> Callable:
"""Test a form"""

View File

@@ -165,6 +165,8 @@ web:
timeout_http_read: 30s
timeout_http_write: 60s
timeout_http_idle: 120s
csp:
report_only: false
worker:
processes: 1

View File

@@ -53,16 +53,6 @@ class TestEndSessionView(OAuthTestCase):
self.brand.flow_invalidation = self.invalidation_flow
self.brand.save()
def _id_token_hint(self, host: str) -> str:
"""Issue a valid id_token_hint for the test provider under the given host."""
return self.provider.encode(
{
"iss": f"http://{host}/application/o/{self.app.slug}/",
"aud": self.provider.client_id,
"sub": str(self.user.pk),
}
)
def test_post_logout_redirect_uri_strict_match(self):
"""Test strict URI matching redirects to flow"""
self.client.force_login(self.user)
@@ -71,10 +61,7 @@ class TestEndSessionView(OAuthTestCase):
"authentik_providers_oauth2:end-session",
kwargs={"application_slug": self.app.slug},
),
{
"post_logout_redirect_uri": "http://testserver/logout",
"id_token_hint": self._id_token_hint(self.brand.domain),
},
{"post_logout_redirect_uri": "http://testserver/logout"},
HTTP_HOST=self.brand.domain,
)
# Should redirect to the invalidation flow
@@ -82,12 +69,7 @@ class TestEndSessionView(OAuthTestCase):
self.assertIn(self.invalidation_flow.slug, response.url)
def test_post_logout_redirect_uri_strict_no_match(self):
"""Test strict URI not matching returns an error and does not start logout flow.
Required by OIDC RP-Initiated Logout 1.0: on an unregistered
post_logout_redirect_uri, the OP MUST NOT redirect and MUST NOT proceed with
logout that targets the RP.
"""
"""Test strict URI not matching still proceeds with flow (no redirect URI in context)"""
self.client.force_login(self.user)
invalid_uri = "http://testserver/other"
response = self.client.get(
@@ -95,14 +77,12 @@ class TestEndSessionView(OAuthTestCase):
"authentik_providers_oauth2:end-session",
kwargs={"application_slug": self.app.slug},
),
{
"post_logout_redirect_uri": invalid_uri,
"id_token_hint": self._id_token_hint(self.brand.domain),
},
{"post_logout_redirect_uri": invalid_uri},
HTTP_HOST=self.brand.domain,
)
self.assertEqual(response.status_code, 400)
self.assertNotIn(invalid_uri, response.content.decode())
# Should still redirect to flow, but invalid URI should not be in response
self.assertEqual(response.status_code, 302)
self.assertNotIn(invalid_uri, response.url)
def test_post_logout_redirect_uri_regex_match(self):
"""Test regex URI matching redirects to flow"""
@@ -112,10 +92,7 @@ class TestEndSessionView(OAuthTestCase):
"authentik_providers_oauth2:end-session",
kwargs={"application_slug": self.app.slug},
),
{
"post_logout_redirect_uri": "https://app.example.com/logout",
"id_token_hint": self._id_token_hint(self.brand.domain),
},
{"post_logout_redirect_uri": "https://app.example.com/logout"},
HTTP_HOST=self.brand.domain,
)
# Should redirect to the invalidation flow
@@ -123,7 +100,7 @@ class TestEndSessionView(OAuthTestCase):
self.assertIn(self.invalidation_flow.slug, response.url)
def test_post_logout_redirect_uri_regex_no_match(self):
"""Test regex URI not matching returns an error and does not start logout flow."""
"""Test regex URI not matching"""
self.client.force_login(self.user)
invalid_uri = "https://malicious.com/logout"
response = self.client.get(
@@ -131,14 +108,12 @@ class TestEndSessionView(OAuthTestCase):
"authentik_providers_oauth2:end-session",
kwargs={"application_slug": self.app.slug},
),
{
"post_logout_redirect_uri": invalid_uri,
"id_token_hint": self._id_token_hint(self.brand.domain),
},
{"post_logout_redirect_uri": invalid_uri},
HTTP_HOST=self.brand.domain,
)
self.assertEqual(response.status_code, 400)
self.assertNotIn(invalid_uri, response.content.decode())
# Should still proceed to flow, but invalid URI should not be in response
self.assertEqual(response.status_code, 302)
self.assertNotIn(invalid_uri, response.url)
def test_state_parameter_appended_to_uri(self):
"""Test state parameter is appended to validated redirect URI"""
@@ -148,7 +123,6 @@ class TestEndSessionView(OAuthTestCase):
{
"post_logout_redirect_uri": "http://testserver/logout",
"state": "test-state-123",
"id_token_hint": self._id_token_hint("testserver"),
},
)
request.user = self.user
@@ -158,7 +132,6 @@ class TestEndSessionView(OAuthTestCase):
view.request = request
view.kwargs = {"application_slug": self.app.slug}
view.resolve_provider_application()
view.validate()
self.assertIn("state=test-state-123", view.post_logout_redirect_uri)
@@ -173,7 +146,6 @@ class TestEndSessionView(OAuthTestCase):
{
"post_logout_redirect_uri": "http://testserver/logout",
"state": "xyz789",
"id_token_hint": self._id_token_hint(self.brand.domain),
},
HTTP_HOST=self.brand.domain,
)

View File

@@ -5,8 +5,6 @@ from urllib.parse import quote, urlparse
from django.http import Http404, HttpRequest, HttpResponse
from django.shortcuts import get_object_or_404
from jwt import PyJWTError
from jwt import decode as jwt_decode
from authentik.common.oauth.constants import (
FORBIDDEN_URI_SCHEMES,
@@ -23,14 +21,11 @@ from authentik.flows.planner import (
from authentik.flows.stage import SessionEndStage
from authentik.flows.views.executor import SESSION_KEY_PLAN
from authentik.lib.views import bad_request_message
from authentik.policies.views import PolicyAccessView
from authentik.policies.views import PolicyAccessView, RequestValidationError
from authentik.providers.iframe_logout import IframeLogoutStageView
from authentik.providers.oauth2.errors import TokenError
from authentik.providers.oauth2.models import (
AccessToken,
JWTAlgorithms,
OAuth2LogoutMethod,
OAuth2Provider,
RedirectURIMatchingMode,
)
from authentik.providers.oauth2.tasks import send_backchannel_logout_request
@@ -52,45 +47,21 @@ class EndSessionView(PolicyAccessView):
if not self.flow:
raise Http404
def validate(self):
# Parse end session parameters
query_dict = self.request.POST if self.request.method == "POST" else self.request.GET
state = query_dict.get("state")
request_redirect_uri = query_dict.get("post_logout_redirect_uri")
id_token_hint = query_dict.get("id_token_hint")
self.post_logout_redirect_uri = None
# OIDC Certification: Verify id_token_hint. If invalid or missing, throw an error
if id_token_hint:
# Load a fresh provider instance that's not part of the flow
# since it'll have the cryptography Certificate that can't be pickled
provider = OAuth2Provider.objects.get(pk=self.provider.pk)
key, alg = provider.jwt_key
if alg != JWTAlgorithms.HS256:
key = provider.signing_key.public_key
try:
jwt_decode(
id_token_hint,
key,
algorithms=[alg],
audience=provider.client_id,
issuer=provider.get_issuer(self.request),
# ID Tokens are short-lived; a logout request arriving
# after expiry is still legitimate and must succeed.
options={"verify_exp": False},
)
except PyJWTError:
raise TokenError("invalid_request").with_cause(
"id_token_hint_decode_failed"
) from None
# Validate post_logout_redirect_uri against registered URIs
if request_redirect_uri:
# OIDC Certification: id_token_hint required with post_logout_redirect_uri
if not id_token_hint:
raise TokenError("invalid_request").with_cause("id_token_hint_missing")
if urlparse(request_redirect_uri).scheme in FORBIDDEN_URI_SCHEMES:
raise TokenError("invalid_request").with_cause("post_logout_redirect_uri")
raise RequestValidationError(
bad_request_message(
self.request,
"Forbidden URI scheme in post_logout_redirect_uri",
)
)
for allowed in self.provider.post_logout_redirect_uris:
if allowed.matching_mode == RedirectURIMatchingMode.STRICT:
if request_redirect_uri == allowed.url:
@@ -100,10 +71,6 @@ class EndSessionView(PolicyAccessView):
if fullmatch(allowed.url, request_redirect_uri):
self.post_logout_redirect_uri = request_redirect_uri
break
# OIDC Certification: OP MUST NOT perform post-logout redirection
# if the supplied URI does not exactly match a registered one
if self.post_logout_redirect_uri is None:
raise TokenError("invalid_request").with_cause("invalid_post_logout_redirect_uri")
# Append state to the redirect URI if both are present
if self.post_logout_redirect_uri and state:
@@ -124,43 +91,50 @@ class EndSessionView(PolicyAccessView):
"<html><body>Logout successful</body></html>", content_type="text/html", status=200
)
# Otherwise, continue with normal policy checks
return super().dispatch(request, *args, **kwargs)
def get(self, request: HttpRequest, *args, **kwargs) -> HttpResponse:
"""Dispatch the flow planner for the invalidation flow"""
try:
self.validate()
except TokenError as exc:
return bad_request_message(
self.request,
exc.description,
)
planner = FlowPlanner(self.flow)
planner.allow_empty_flows = True
# Build flow context with logout parameters
context = {
PLAN_CONTEXT_APPLICATION: self.application,
}
# Get session info for logout notifications and token invalidation
auth_session = AuthenticatedSession.from_request(request, request.user)
# Add validated redirect URI (with state appended) to context if available
if self.post_logout_redirect_uri:
context[PLAN_CONTEXT_POST_LOGOUT_REDIRECT_URI] = self.post_logout_redirect_uri
# Invalidate tokens for this provider/session (RP-initiated logout:
# user stays logged into authentik, only this provider's tokens are revoked)
if request.user.is_authenticated and auth_session:
AccessToken.objects.filter(
user=request.user,
provider=self.provider,
session=auth_session,
).delete()
session_key = (
auth_session.session.session_key if auth_session and auth_session.session else None
)
# Handle frontchannel logout
frontchannel_logout_url = None
if self.provider.logout_method == OAuth2LogoutMethod.FRONTCHANNEL:
frontchannel_logout_url = build_frontchannel_logout_url(
self.provider, request, session_key
)
# Handle backchannel logout
if (
self.provider.logout_method == OAuth2LogoutMethod.BACKCHANNEL
and self.provider.logout_uri
):
# Find access token to get iss and sub for the logout token
access_token = AccessToken.objects.filter(
user=request.user,
provider=self.provider,
@@ -189,16 +163,9 @@ class EndSessionView(PolicyAccessView):
}
]
access_tokens = AccessToken.objects.filter(
user=request.user,
provider=self.provider,
)
if auth_session:
access_tokens = access_tokens.filter(session=auth_session)
access_tokens.delete()
plan = planner.plan(request, context)
# Inject iframe logout stage if frontchannel logout is configured
if frontchannel_logout_url:
plan.insert_stage(in_memory_stage(IframeLogoutStageView))

View File

@@ -1,5 +1,6 @@
"""RAC Signals"""
from asgiref.sync import async_to_sync
from channels.layers import get_channel_layer
from django.core.cache import cache
from django.db.models.signals import post_delete, post_save, pre_delete
@@ -17,7 +18,7 @@ from authentik.providers.rac.models import ConnectionToken, Endpoint
@receiver(pre_delete, sender=AuthenticatedSession)
def user_session_deleted(sender, instance: AuthenticatedSession, **_):
layer = get_channel_layer()
layer.group_send_blocking(
async_to_sync(layer.group_send)(
build_rac_client_group_session(instance.session.session_key),
{"type": "event.disconnect", "reason": "session_logout"},
)
@@ -27,7 +28,7 @@ def user_session_deleted(sender, instance: AuthenticatedSession, **_):
def pre_delete_connection_token_disconnect(sender, instance: ConnectionToken, **_):
"""Disconnect session when connection token is deleted"""
layer = get_channel_layer()
layer.group_send_blocking(
async_to_sync(layer.group_send)(
build_rac_client_group_token(instance.token),
{"type": "event.disconnect", "reason": "token_delete"},
)

View File

@@ -24,11 +24,7 @@ from rest_framework.viewsets import ModelViewSet
from structlog.stdlib import get_logger
from authentik.api.validation import validate
from authentik.common.saml.constants import (
DEFAULT_ISSUER,
SAML_BINDING_POST,
SAML_BINDING_REDIRECT,
)
from authentik.common.saml.constants import SAML_BINDING_POST, SAML_BINDING_REDIRECT
from authentik.core.api.providers import ProviderSerializer
from authentik.core.api.used_by import UsedByMixin
from authentik.core.api.utils import PassiveSerializer, PropertyMappingPreviewSerializer
@@ -59,7 +55,6 @@ class SAMLProviderSerializer(ProviderSerializer):
"""SAMLProvider Serializer"""
url_download_metadata = SerializerMethodField()
url_issuer = SerializerMethodField()
url_sso_post = SerializerMethodField()
url_sso_redirect = SerializerMethodField()
@@ -90,23 +85,6 @@ class SAMLProviderSerializer(ProviderSerializer):
+ "?download"
)
def get_url_issuer(self, instance: SAMLProvider) -> str:
"""Get Issuer/EntityID URL"""
if instance.issuer_override:
return instance.issuer_override
if "request" not in self._context:
return DEFAULT_ISSUER
request: HttpRequest = self._context["request"]._request
try:
return request.build_absolute_uri(
reverse(
"authentik_providers_saml:base",
kwargs={"application_slug": instance.application.slug},
)
)
except Provider.application.RelatedObjectDoesNotExist:
return DEFAULT_ISSUER
def get_url_sso_post(self, instance: SAMLProvider) -> str:
"""Get SSO Post URL"""
if "request" not in self._context:
@@ -220,7 +198,7 @@ class SAMLProviderSerializer(ProviderSerializer):
"acs_url",
"sls_url",
"audience",
"issuer_override",
"issuer",
"assertion_valid_not_before",
"assertion_valid_not_on_or_after",
"session_valid_not_on_or_after",
@@ -242,7 +220,6 @@ class SAMLProviderSerializer(ProviderSerializer):
"default_relay_state",
"default_name_id_policy",
"url_download_metadata",
"url_issuer",
"url_sso_post",
"url_sso_redirect",
"url_sso_init",

View File

@@ -1,34 +0,0 @@
# Generated by Django 5.2.11 on 2026-02-24 06:03
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
("authentik_providers_saml", "0021_samlprovider_sign_logout_response"),
]
operations = [
migrations.RenameField(
model_name="samlprovider",
old_name="issuer",
new_name="issuer_override",
),
migrations.AlterField(
model_name="samlprovider",
name="issuer_override",
field=models.TextField(
blank=True,
default="",
help_text="Also known as EntityID. Providing a value overrides the default issuer generated by authentik.",
),
),
migrations.AddField(
model_name="samlsession",
name="issuer",
field=models.TextField(
default=None, help_text="SAML Issuer used for this session", null=True
),
),
]

View File

@@ -77,14 +77,7 @@ class SAMLProvider(Provider):
"no audience restriction will be added."
),
)
issuer_override = models.TextField(
blank=True,
default="",
help_text=_(
"Also known as EntityID. Providing a value overrides the default issuer "
"generated by authentik."
),
)
issuer = models.TextField(help_text=_("Also known as EntityID"), default="authentik")
sls_url = models.TextField(
blank=True,
validators=[DomainlessURLValidator(schemes=("http", "https"))],
@@ -325,9 +318,6 @@ class SAMLSession(InternallyManagedMixin, SerializerModel, ExpiringModel):
session_index = models.TextField(help_text=_("SAML SessionIndex for this session"))
name_id = models.TextField(help_text=_("SAML NameID value for this session"))
name_id_format = models.TextField(default="", blank=True, help_text=_("SAML NameID format"))
issuer = models.TextField(
default=None, null=True, help_text=_("SAML Issuer used for this session")
)
created = models.DateTimeField(auto_now_add=True)
@property

View File

@@ -6,7 +6,6 @@ from types import GeneratorType
import xmlsec
from django.http import HttpRequest
from django.urls import reverse
from django.utils.timezone import now
from lxml import etree # nosec
from lxml.etree import Element, SubElement, _Element # nosec
@@ -64,7 +63,6 @@ class AssertionProcessor:
session_index: str
name_id: str
name_id_format: str
issuer: str
session_not_on_or_after_datetime: datetime
def __init__(self, provider: SAMLProvider, request: HttpRequest, auth_n_request: AuthNRequest):
@@ -139,24 +137,10 @@ class AssertionProcessor:
continue
return attribute_statement
def _get_issuer_value(self) -> str:
"""Get issuer value, with fallback to generated URL if empty"""
# If user has set an override issuer, use it
if self.provider.issuer_override:
return self.provider.issuer_override
return self.http_request.build_absolute_uri(
reverse(
"authentik_providers_saml:base",
kwargs={"application_slug": self.provider.application.slug},
)
)
def get_issuer(self) -> Element:
"""Get Issuer Element"""
issuer = Element(f"{{{NS_SAML_ASSERTION}}}Issuer", nsmap=NS_MAP)
self.issuer = self._get_issuer_value()
issuer.text = self.issuer
issuer.text = self.provider.issuer
return issuer
def get_assertion_auth_n_statement(self) -> Element:

View File

@@ -8,7 +8,6 @@ from lxml import etree # nosec
from lxml.etree import Element, _Element
from authentik.common.saml.constants import (
DEFAULT_ISSUER,
DIGEST_ALGORITHM_TRANSLATION_MAP,
NS_MAP,
NS_SAML_ASSERTION,
@@ -34,12 +33,11 @@ class LogoutRequestProcessor:
name_id_format: str
session_index: str | None
relay_state: str | None
issuer: str | None
_issue_instant: str
_request_id: str
def __init__( # noqa: PLR0913
def __init__(
self,
provider: SAMLProvider,
user: User | None,
@@ -48,7 +46,6 @@ class LogoutRequestProcessor:
name_id_format: str = SAML_NAME_ID_FORMAT_EMAIL,
session_index: str | None = None,
relay_state: str | None = None,
issuer: str | None = None,
):
self.provider = provider
self.user = user
@@ -57,23 +54,14 @@ class LogoutRequestProcessor:
self.name_id_format = name_id_format
self.session_index = session_index
self.relay_state = relay_state
self.issuer = issuer
self._issue_instant = get_time_string()
self._request_id = get_random_id()
def _get_issuer_value(self) -> str:
"""Get issuer value from session, with fallback to provider"""
if self.issuer:
return self.issuer
if self.provider.issuer_override:
return self.provider.issuer_override
return DEFAULT_ISSUER
def get_issuer(self) -> Element:
"""Get Issuer element"""
issuer = Element(f"{{{NS_SAML_ASSERTION}}}Issuer")
issuer.text = self._get_issuer_value()
issuer.text = self.provider.issuer
return issuer
def get_name_id(self) -> Element:

View File

@@ -8,7 +8,6 @@ from lxml import etree
from lxml.etree import Element, SubElement
from authentik.common.saml.constants import (
DEFAULT_ISSUER,
DIGEST_ALGORITHM_TRANSLATION_MAP,
NS_MAP,
NS_SAML_ASSERTION,
@@ -29,38 +28,27 @@ class LogoutResponseProcessor:
logout_request: LogoutRequest
destination: str | None
relay_state: str | None
issuer: str | None
_issue_instant: str
_response_id: str
def __init__( # noqa: PLR0913
def __init__(
self,
provider: SAMLProvider,
logout_request: LogoutRequest,
destination: str | None = None,
relay_state: str | None = None,
issuer: str | None = None,
):
self.provider = provider
self.logout_request = logout_request
self.destination = destination
self.relay_state = relay_state or (logout_request.relay_state if logout_request else None)
self.issuer = issuer
self._issue_instant = get_time_string()
self._response_id = get_random_id()
def _get_issuer_value(self) -> str:
"""Get issuer value from session, with fallback to provider"""
if self.issuer:
return self.issuer
if self.provider.issuer_override:
return self.provider.issuer_override
return DEFAULT_ISSUER
def get_issuer(self) -> Element:
"""Get Issuer element"""
issuer = Element(f"{{{NS_SAML_ASSERTION}}}Issuer")
issuer.text = self._get_issuer_value()
issuer.text = self.provider.issuer
return issuer
def build(self, status: str = "Success") -> Element:

View File

@@ -40,19 +40,6 @@ class MetadataProcessor:
self.force_binding = None
self.xml_id = "_" + sha256(f"{provider.name}-{provider.pk}".encode("ascii")).hexdigest()
def _get_issuer_value(self) -> str:
"""Get issuer value, with fallback to generated URL if empty"""
# If user has set an override issuer, use it
if self.provider.issuer_override:
return self.provider.issuer_override
return self.http_request.build_absolute_uri(
reverse(
"authentik_providers_saml:base",
kwargs={"application_slug": self.provider.application.slug},
)
)
# Using type unions doesn't work with cython types (which is what lxml is)
def get_signing_key_descriptor(self) -> Element | None:
"""Get Signing KeyDescriptor, if enabled for the provider"""
@@ -202,7 +189,7 @@ class MetadataProcessor:
"""Build full EntityDescriptor"""
entity_descriptor = Element(f"{{{NS_SAML_METADATA}}}EntityDescriptor", nsmap=NS_MAP)
entity_descriptor.attrib["ID"] = self.xml_id
entity_descriptor.attrib["entityID"] = self._get_issuer_value()
entity_descriptor.attrib["entityID"] = self.provider.issuer
if self.provider.signing_kp:
self._prepare_signature(entity_descriptor)

View File

@@ -51,6 +51,7 @@ class ServiceProviderMetadata:
provider = SAMLProvider.objects.create(
name=name, authorization_flow=authorization_flow, invalidation_flow=invalidation_flow
)
provider.issuer = self.entity_id
provider.sp_binding = self.acs_binding
provider.acs_url = self.acs_location
provider.default_name_id_policy = self.name_id_policy

View File

@@ -75,7 +75,6 @@ def handle_saml_iframe_pre_user_logout(
name_id_format=session.name_id_format,
session_index=session.session_index,
relay_state=relay_state,
issuer=session.issuer,
)
if session.provider.sls_binding == SAMLBindings.POST:
@@ -164,7 +163,6 @@ def handle_flow_pre_user_logout(
name_id_format=session.name_id_format,
session_index=session.session_index,
relay_state=relay_state,
issuer=session.issuer,
)
if session.provider.sls_binding == SAMLBindings.POST:
@@ -226,7 +224,6 @@ def user_session_deleted_saml_logout(sender, instance: AuthenticatedSession, **_
name_id=saml_session.name_id,
name_id_format=saml_session.name_id_format,
session_index=saml_session.session_index,
issuer=saml_session.issuer,
)
@@ -260,5 +257,4 @@ def user_deactivated_saml_logout(sender, instance: User, **kwargs):
name_id=saml_session.name_id,
name_id_format=saml_session.name_id_format,
session_index=saml_session.session_index,
issuer=saml_session.issuer,
)

View File

@@ -22,7 +22,6 @@ def send_saml_logout_request(
name_id: str,
name_id_format: str,
session_index: str,
issuer: str,
):
"""Send SAML LogoutRequest to a Service Provider using session data"""
provider = SAMLProvider.objects.filter(pk=provider_pk).first()
@@ -48,7 +47,6 @@ def send_saml_logout_request(
name_id=name_id,
name_id_format=name_id_format,
session_index=session_index,
issuer=issuer,
)
return send_post_logout_request(provider, processor)
@@ -91,7 +89,6 @@ def send_saml_logout_response(
sls_url: str,
logout_request_id: str | None = None,
relay_state: str | None = None,
issuer: str | None = None,
):
"""Send SAML LogoutResponse to a Service Provider using backchannel (server-to-server)"""
provider = SAMLProvider.objects.filter(pk=provider_pk).first()
@@ -122,7 +119,6 @@ def send_saml_logout_response(
logout_request=logout_request,
destination=sls_url,
relay_state=relay_state,
issuer=issuer,
)
encoded_response = processor.encode_post()

View File

@@ -15,7 +15,6 @@ from authentik.common.saml.constants import (
SAML_NAME_ID_FORMAT_EMAIL,
SAML_NAME_ID_FORMAT_UNSPECIFIED,
)
from authentik.core.models import Application
from authentik.core.tests.utils import (
RequestFactory,
create_test_admin_user,
@@ -98,11 +97,6 @@ class TestAuthNRequest(TestCase):
)
self.provider.property_mappings.set(SAMLPropertyMapping.objects.all())
self.provider.save()
Application.objects.create(
name="test-app",
slug="test-app",
provider=self.provider,
)
self.source = SAMLSource.objects.create(
slug="provider",
issuer="authentik",
@@ -532,7 +526,7 @@ class TestAuthNRequest(TestCase):
authorization_flow=create_test_flow(),
acs_url="https://10.120.20.200/saml-sp/SAML2/POST",
audience="https://10.120.20.200/saml-sp/SAML2/POST",
issuer_override="https://10.120.20.200/saml-sp/SAML2/POST",
issuer="https://10.120.20.200/saml-sp/SAML2/POST",
signing_kp=static_keypair,
verification_kp=static_keypair,
)
@@ -553,7 +547,7 @@ class TestAuthNRequest(TestCase):
"saml/acs/2d737f96-55fb-4035-953e-5e24134eb778"
),
audience="https://10.120.20.200/saml-sp/SAML2/POST",
issuer_override="https://10.120.20.200/saml-sp/SAML2/POST",
issuer="https://10.120.20.200/saml-sp/SAML2/POST",
signing_kp=create_test_cert(),
)
parsed_request = AuthNRequestParser(provider).parse(POST_REQUEST)

View File

@@ -47,7 +47,7 @@ class TestNativeLogoutStageView(TestCase):
authorization_flow=self.flow,
acs_url="https://sp1.example.com/acs",
sls_url="https://sp1.example.com/sls",
issuer_override="https://idp.example.com",
issuer="https://idp.example.com",
sp_binding="redirect",
sls_binding="redirect",
logout_method=SAMLLogoutMethods.FRONTCHANNEL_NATIVE,
@@ -58,7 +58,7 @@ class TestNativeLogoutStageView(TestCase):
authorization_flow=self.flow,
acs_url="https://sp2.example.com/acs",
sls_url="https://sp2.example.com/sls",
issuer_override="https://idp.example.com",
issuer="https://idp.example.com",
sp_binding="post",
sls_binding="post",
logout_method=SAMLLogoutMethods.FRONTCHANNEL_NATIVE,
@@ -218,7 +218,7 @@ class TestIframeLogoutStageView(TestCase):
authorization_flow=self.flow,
acs_url="https://sp1.example.com/acs",
sls_url="https://sp1.example.com/sls",
issuer_override="https://idp.example.com",
issuer="https://idp.example.com",
sp_binding="redirect",
sls_binding="redirect",
logout_method="frontchannel_iframe",
@@ -229,7 +229,7 @@ class TestIframeLogoutStageView(TestCase):
authorization_flow=self.flow,
acs_url="https://sp2.example.com/acs",
sls_url="https://sp2.example.com/sls",
issuer_override="https://idp.example.com",
issuer="https://idp.example.com",
sp_binding="post",
sls_binding="post",
logout_method="frontchannel_iframe",
@@ -372,7 +372,7 @@ class TestIdPLogoutIntegration(FlowTestCase):
authorization_flow=self.flow,
acs_url="https://sp.example.com/acs",
sls_url="https://sp.example.com/sls",
issuer_override="https://idp.example.com",
issuer="https://idp.example.com",
sp_binding="redirect",
sls_binding="redirect",
signing_kp=self.keypair,

View File

@@ -28,7 +28,7 @@ class TestLogoutIntegration(TestCase):
authorization_flow=self.flow,
acs_url="https://sp.example.com/acs",
sls_url="https://sp.example.com/sls",
issuer_override="https://idp.example.com",
issuer="https://idp.example.com",
sp_binding="redirect",
sls_binding="redirect",
signature_algorithm=RSA_SHA256,
@@ -57,7 +57,7 @@ class TestLogoutIntegration(TestCase):
parsed = self.parser.parse(encoded)
# Verify all fields match
self.assertEqual(parsed.issuer, self.provider.issuer_override)
self.assertEqual(parsed.issuer, self.provider.issuer)
self.assertEqual(parsed.name_id, "test@example.com")
self.assertEqual(parsed.name_id_format, SAML_NAME_ID_FORMAT_EMAIL)
self.assertEqual(parsed.session_index, "test-session-123")
@@ -72,7 +72,7 @@ class TestLogoutIntegration(TestCase):
parsed = self.parser.parse_detached(encoded)
# Verify all fields match
self.assertEqual(parsed.issuer, self.provider.issuer_override)
self.assertEqual(parsed.issuer, self.provider.issuer)
self.assertEqual(parsed.name_id, "test@example.com")
self.assertEqual(parsed.name_id_format, SAML_NAME_ID_FORMAT_EMAIL)
self.assertEqual(parsed.session_index, "test-session-123")
@@ -106,7 +106,7 @@ class TestLogoutIntegration(TestCase):
parsed = parser.parse(encoded)
# Verify all fields match
self.assertEqual(parsed.issuer, self.provider.issuer_override)
self.assertEqual(parsed.issuer, self.provider.issuer)
self.assertEqual(parsed.name_id, "signed@example.com")
self.assertEqual(parsed.name_id_format, SAML_NAME_ID_FORMAT_EMAIL)
self.assertEqual(parsed.session_index, "signed-session-456")
@@ -125,7 +125,7 @@ class TestLogoutIntegration(TestCase):
parsed = self.parser.parse_detached(saml_request)
# Verify parsing succeeded
self.assertEqual(parsed.issuer, self.provider.issuer_override)
self.assertEqual(parsed.issuer, self.provider.issuer)
self.assertEqual(parsed.name_id, "test@example.com")
self.assertEqual(parsed.name_id_format, SAML_NAME_ID_FORMAT_EMAIL)
@@ -164,7 +164,7 @@ class TestLogoutIntegration(TestCase):
# Parse the SAMLRequest (unsigned XML)
parsed = self.parser.parse_detached(params["SAMLRequest"][0])
self.assertEqual(parsed.issuer, self.provider.issuer_override)
self.assertEqual(parsed.issuer, self.provider.issuer)
def test_form_data_can_be_parsed(self):
"""Test that form data generates parseable POST request"""
@@ -175,7 +175,7 @@ class TestLogoutIntegration(TestCase):
parsed = self.parser.parse(form_data["SAMLRequest"])
# Verify parsing succeeded
self.assertEqual(parsed.issuer, self.provider.issuer_override)
self.assertEqual(parsed.issuer, self.provider.issuer)
self.assertEqual(parsed.name_id, "test@example.com")
self.assertEqual(parsed.name_id_format, SAML_NAME_ID_FORMAT_EMAIL)
self.assertEqual(parsed.session_index, "test-session-123")
@@ -244,4 +244,4 @@ class TestLogoutIntegration(TestCase):
# But same issuer
self.assertEqual(parsed1.issuer, parsed2.issuer)
self.assertEqual(parsed1.issuer, self.provider.issuer_override)
self.assertEqual(parsed1.issuer, self.provider.issuer)

View File

@@ -35,7 +35,7 @@ class TestLogoutRequestProcessor(TestCase):
authorization_flow=self.flow,
acs_url="https://sp.example.com/acs",
sls_url="https://sp.example.com/sls",
issuer_override="https://idp.example.com",
issuer="https://idp.example.com",
sp_binding="redirect",
sls_binding="redirect",
signature_algorithm=RSA_SHA256,

View File

@@ -1,7 +1,7 @@
"""logout response tests"""
from defusedxml import ElementTree
from django.test import RequestFactory, TestCase
from django.test import TestCase
from authentik.blueprints.tests import apply_blueprint
from authentik.common.saml.constants import (
@@ -9,13 +9,10 @@ from authentik.common.saml.constants import (
NS_SAML_PROTOCOL,
NS_SIGNATURE,
)
from authentik.core.models import Application
from authentik.core.tests.utils import create_test_cert, create_test_flow
from authentik.lib.generators import generate_id
from authentik.providers.saml.models import SAMLPropertyMapping, SAMLProvider
from authentik.providers.saml.processors.logout_request_parser import LogoutRequest
from authentik.providers.saml.processors.logout_response_processor import LogoutResponseProcessor
from authentik.providers.saml.processors.metadata import MetadataProcessor
class TestLogoutResponse(TestCase):
@@ -24,7 +21,6 @@ class TestLogoutResponse(TestCase):
@apply_blueprint("system/providers-saml.yaml")
def setUp(self):
cert = create_test_cert()
self.factory = RequestFactory()
self.provider: SAMLProvider = SAMLProvider.objects.create(
authorization_flow=create_test_flow(),
acs_url="http://testserver/source/saml/provider/acs/",
@@ -34,31 +30,17 @@ class TestLogoutResponse(TestCase):
)
self.provider.property_mappings.set(SAMLPropertyMapping.objects.all())
self.provider.save()
self.application = Application.objects.create(
name=generate_id(),
slug=generate_id(),
provider=self.provider,
)
def test_build_response(self):
"""Test building a LogoutResponse uses the generated issuer from the assertion"""
# Generate the issuer the same way the assertion/metadata processors would
request = self.factory.get("/")
metadata_processor = MetadataProcessor(self.provider, request)
generated_issuer = metadata_processor._get_issuer_value()
"""Test building a LogoutResponse"""
logout_request = LogoutRequest(
id="test-request-id",
issuer="test-sp",
relay_state="test-relay-state",
)
# Pass the generated issuer as if it came from SAMLSession.issuer
processor = LogoutResponseProcessor(
self.provider,
logout_request,
destination=self.provider.sls_url,
issuer=generated_issuer,
self.provider, logout_request, destination=self.provider.sls_url
)
response_xml = processor.build_response(status="Success")
@@ -69,9 +51,9 @@ class TestLogoutResponse(TestCase):
self.assertEqual(root.attrib["Destination"], self.provider.sls_url)
self.assertEqual(root.attrib["InResponseTo"], "test-request-id")
# Check Issuer matches the generated issuer from the assertion processor
# Check Issuer
issuer = root.find(f"{{{NS_SAML_ASSERTION}}}Issuer")
self.assertEqual(issuer.text, generated_issuer)
self.assertEqual(issuer.text, self.provider.issuer)
# Check Status
status = root.find(f".//{{{NS_SAML_PROTOCOL}}}StatusCode")

View File

@@ -85,6 +85,7 @@ class TestServiceProviderMetadataParser(TestCase):
metadata = ServiceProviderMetadataParser().parse(load_fixture("fixtures/simple.xml"))
provider = metadata.to_provider("test", self.flow, self.flow)
self.assertEqual(provider.acs_url, "http://localhost:8080/saml/acs")
self.assertEqual(provider.issuer, "http://localhost:8080/saml/metadata")
self.assertEqual(provider.sp_binding, SAMLBindings.POST)
self.assertEqual(provider.default_name_id_policy, SAMLNameIDPolicy.EMAIL)
self.assertEqual(
@@ -98,6 +99,7 @@ class TestServiceProviderMetadataParser(TestCase):
metadata = ServiceProviderMetadataParser().parse(load_fixture("fixtures/cert.xml"))
provider = metadata.to_provider("test", self.flow, self.flow)
self.assertEqual(provider.acs_url, "http://localhost:8080/apps/user_saml/saml/acs")
self.assertEqual(provider.issuer, "http://localhost:8080/apps/user_saml/saml/metadata")
self.assertEqual(provider.sp_binding, SAMLBindings.POST)
self.assertEqual(
provider.verification_kp.certificate_data, load_fixture("fixtures/cert.pem")

View File

@@ -32,7 +32,7 @@ class TestSAMLSessionModel(TestCase):
name="test-provider",
authorization_flow=self.flow,
acs_url="https://sp.example.com/acs",
issuer_override="https://idp.example.com",
issuer="https://idp.example.com",
)
# Create another provider for testing
@@ -40,7 +40,7 @@ class TestSAMLSessionModel(TestCase):
name="test-provider-2",
authorization_flow=self.flow,
acs_url="https://sp2.example.com/acs",
issuer_override="https://idp2.example.com",
issuer="https://idp2.example.com",
)
# Create a session first (using authentik's custom Session model)
@@ -72,7 +72,6 @@ class TestSAMLSessionModel(TestCase):
name_id_format=self.name_id_format,
expires=self.expires,
expiring=True,
issuer="authentik",
)
# Verify the session was created
@@ -101,7 +100,6 @@ class TestSAMLSessionModel(TestCase):
name_id_format=self.name_id_format,
expires=self.expires,
expiring=True,
issuer="authentik",
)
# Try to create another session with same session_index and provider
@@ -115,7 +113,6 @@ class TestSAMLSessionModel(TestCase):
name_id_format=self.name_id_format,
expires=self.expires,
expiring=True,
issuer="authentik",
)
def test_cascade_deletion_user(self):
@@ -130,7 +127,6 @@ class TestSAMLSessionModel(TestCase):
name_id_format=self.name_id_format,
expires=self.expires,
expiring=True,
issuer="authentik",
)
# Verify session exists
@@ -154,7 +150,6 @@ class TestSAMLSessionModel(TestCase):
name_id_format=self.name_id_format,
expires=self.expires,
expiring=True,
issuer="authentik",
)
# Verify session exists
@@ -178,7 +173,6 @@ class TestSAMLSessionModel(TestCase):
name_id_format=self.name_id_format,
expires=self.expires,
expiring=True,
issuer="authentik",
)
# Verify session exists
@@ -202,7 +196,6 @@ class TestSAMLSessionModel(TestCase):
name_id_format=self.name_id_format,
expires=self.expires,
expiring=True,
issuer="authentik",
)
# Create second session with different provider
@@ -215,7 +208,6 @@ class TestSAMLSessionModel(TestCase):
name_id_format=self.name_id_format,
expires=self.expires,
expiring=True,
issuer="authentik",
)
# Verify both sessions exist
@@ -237,7 +229,6 @@ class TestSAMLSessionModel(TestCase):
name_id_format=self.name_id_format,
expires=future_time,
expiring=True,
issuer="authentik",
)
# Verify expiry time
@@ -257,7 +248,6 @@ class TestSAMLSessionModel(TestCase):
name_id_format=self.name_id_format,
expires=past_time,
expiring=True,
issuer="authentik",
)
# Check if marked as expired
@@ -275,7 +265,6 @@ class TestSAMLSessionModel(TestCase):
name_id_format="", # Blank format
expires=self.expires,
expiring=True,
issuer="authentik",
)
# Verify it was created successfully
@@ -294,7 +283,6 @@ class TestSAMLSessionModel(TestCase):
name_id_format=self.name_id_format,
expires=self.expires,
expiring=True,
issuer="authentik",
)
session2 = SAMLSession.objects.create(
@@ -306,7 +294,6 @@ class TestSAMLSessionModel(TestCase):
name_id_format=self.name_id_format,
expires=self.expires,
expiring=True,
issuer="authentik",
)
# Query by provider
@@ -329,7 +316,6 @@ class TestSAMLSessionModel(TestCase):
name_id_format=self.name_id_format,
expires=self.expires,
expiring=True,
issuer="authentik",
)
# Check serializer property
@@ -348,7 +334,6 @@ class TestSAMLSessionModel(TestCase):
name_id_format=self.name_id_format,
expires=self.expires,
expiring=True,
issuer="authentik",
)
# Verify sessions exist

View File

@@ -7,7 +7,6 @@ from guardian.shortcuts import get_anonymous_user
from lxml import etree # nosec
from authentik.blueprints.tests import apply_blueprint
from authentik.core.models import Application
from authentik.core.tests.utils import RequestFactory, create_test_cert, create_test_flow
from authentik.lib.xml import lxml_from_string
from authentik.providers.saml.models import SAMLPropertyMapping, SAMLProvider
@@ -31,11 +30,6 @@ class TestSchema(TestCase):
)
self.provider.property_mappings.set(SAMLPropertyMapping.objects.all())
self.provider.save()
Application.objects.create(
name="test-app",
slug="test-app",
provider=self.provider,
)
self.source = SAMLSource.objects.create(
slug="provider",
issuer="authentik",

View File

@@ -28,7 +28,7 @@ class TestSendSamlLogoutResponse(TestCase):
authorization_flow=self.flow,
acs_url="https://sp.example.com/acs",
sls_url="https://sp.example.com/sls",
issuer_override="https://idp.example.com",
issuer="https://idp.example.com",
signing_kp=self.cert,
)
@@ -137,7 +137,7 @@ class TestSendSamlLogoutRequest(TestCase):
authorization_flow=self.flow,
acs_url="https://sp.example.com/acs",
sls_url="https://sp.example.com/sls",
issuer_override="https://idp.example.com",
issuer="https://idp.example.com",
signing_kp=self.cert,
)
@@ -155,7 +155,6 @@ class TestSendSamlLogoutRequest(TestCase):
name_id="test@example.com",
name_id_format=SAML_NAME_ID_FORMAT_EMAIL,
session_index="test-session-123",
issuer="https://idp.example.com",
)
self.assertTrue(result)
@@ -180,7 +179,6 @@ class TestSendSamlLogoutRequest(TestCase):
name_id="test@example.com",
name_id_format=SAML_NAME_ID_FORMAT_EMAIL,
session_index="test-session-123",
issuer="https://idp.example.com",
)
self.assertFalse(result)
@@ -200,7 +198,6 @@ class TestSendSamlLogoutRequest(TestCase):
name_id="test@example.com",
name_id_format=SAML_NAME_ID_FORMAT_EMAIL,
session_index="test-session-123",
issuer="https://idp.example.com",
)
@@ -217,7 +214,7 @@ class TestSendPostLogoutRequest(TestCase):
authorization_flow=self.flow,
acs_url="https://sp.example.com/acs",
sls_url="https://sp.example.com/sls",
issuer_override="https://idp.example.com",
issuer="https://idp.example.com",
signing_kp=self.cert,
)

View File

@@ -40,7 +40,7 @@ class TestSPInitiatedSLOViews(TestCase):
invalidation_flow=self.invalidation_flow,
acs_url="https://sp.example.com/acs",
sls_url="https://sp.example.com/sls",
issuer_override="https://idp.example.com",
issuer="https://idp.example.com",
sp_binding="redirect",
sls_binding="redirect",
)
@@ -90,7 +90,7 @@ class TestSPInitiatedSLOViews(TestCase):
# Verify logout request was stored in plan context
self.assertIn("authentik/providers/saml/logout_request", view.plan_context)
logout_request = view.plan_context["authentik/providers/saml/logout_request"]
self.assertEqual(logout_request.issuer, self.provider.issuer_override)
self.assertEqual(logout_request.issuer, self.provider.issuer)
self.assertEqual(logout_request.session_index, "test-session-123")
def test_redirect_view_handles_logout_response_with_plan_context(self):
@@ -228,7 +228,7 @@ class TestSPInitiatedSLOViews(TestCase):
# Verify logout request was stored in plan context
self.assertIn("authentik/providers/saml/logout_request", view.plan_context)
logout_request = view.plan_context["authentik/providers/saml/logout_request"]
self.assertEqual(logout_request.issuer, self.provider.issuer_override)
self.assertEqual(logout_request.issuer, self.provider.issuer)
self.assertEqual(logout_request.session_index, "test-session-123")
def test_post_view_handles_logout_response_with_plan_context(self):
@@ -396,7 +396,7 @@ class TestSPInitiatedSLOViews(TestCase):
authorization_flow=self.flow,
acs_url="https://sp2.example.com/acs",
sls_url="https://sp2.example.com/sls",
issuer_override="https://idp2.example.com",
issuer="https://idp2.example.com",
invalidation_flow=None, # No invalidation flow
)
@@ -524,7 +524,7 @@ class TestSPInitiatedSLOLogoutMethods(TestCase):
invalidation_flow=self.invalidation_flow,
acs_url="https://sp.example.com/acs",
sls_url="https://sp.example.com/sls",
issuer_override="https://idp.example.com",
issuer="https://idp.example.com",
sp_binding="redirect",
sls_binding="redirect",
signing_kp=self.cert,
@@ -714,7 +714,7 @@ class TestSPInitiatedSLOLogoutMethods(TestCase):
invalidation_flow=self.invalidation_flow,
acs_url="https://sp.example.com/acs",
sls_url="", # No SLS URL
issuer_override="https://idp.example.com",
issuer="https://idp.example.com",
)
app_no_sls = Application.objects.create(

Some files were not shown because too many files have changed in this diff Show More