mirror of
https://github.com/goauthentik/authentik
synced 2026-05-05 22:52:42 +02:00
Compare commits
96 Commits
core/allow
...
docs/invit
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
731163200a | ||
|
|
a8db2882ec | ||
|
|
befc15ad92 | ||
|
|
2b48c27760 | ||
|
|
6be7b2f7b7 | ||
|
|
7cffbb4d07 | ||
|
|
5d629bec9b | ||
|
|
5357f42029 | ||
|
|
716bc6e136 | ||
|
|
60355fdf80 | ||
|
|
828a380569 | ||
|
|
b04f8a6177 | ||
|
|
ff190847f2 | ||
|
|
a7339c7f87 | ||
|
|
38ae472f6c | ||
|
|
7d0656c6fa | ||
|
|
0bbe415b5b | ||
|
|
e52c1b2bdc | ||
|
|
5064167f28 | ||
|
|
bca0f51b53 | ||
|
|
67c197e5a5 | ||
|
|
32b17da699 | ||
|
|
c75eed630a | ||
|
|
9f17d6df96 | ||
|
|
13c8ad5c56 | ||
|
|
28209c03e2 | ||
|
|
f47cf08d8a | ||
|
|
d69433b314 | ||
|
|
849a6053ad | ||
|
|
abdbe0269f | ||
|
|
55384c384a | ||
|
|
06fd68f076 | ||
|
|
d35ab99b2d | ||
|
|
a3b0180049 | ||
|
|
88a545f4fb | ||
|
|
ba62507fc2 | ||
|
|
82fc2e2c80 | ||
|
|
80b3739640 | ||
|
|
1258e1eada | ||
|
|
96ed17e760 | ||
|
|
4b17468b6e | ||
|
|
c834681251 | ||
|
|
9edd7cfbda | ||
|
|
4851179522 | ||
|
|
685f920de2 | ||
|
|
3b4d51b0c5 | ||
|
|
a1098d00b7 | ||
|
|
0d4984b964 | ||
|
|
38330df1f9 | ||
|
|
8b03c36d5a | ||
|
|
07a53a101c | ||
|
|
a3db2ce6a3 | ||
|
|
5487cdb874 | ||
|
|
2d5160d09b | ||
|
|
973fe0bd65 | ||
|
|
58b5e605de | ||
|
|
626e23b87a | ||
|
|
3559beba9c | ||
|
|
0b6d3a2850 | ||
|
|
56ca192391 | ||
|
|
6df62aaa2a | ||
|
|
ca344a64c4 | ||
|
|
a0cdd81f71 | ||
|
|
8eff4c7e0b | ||
|
|
d241a0e8f1 | ||
|
|
ebfc01fcda | ||
|
|
4b0e8a411b | ||
|
|
9bf6595fc6 | ||
|
|
5c07e845d2 | ||
|
|
4f76232e7c | ||
|
|
846f8a7e30 | ||
|
|
fa1c3490c3 | ||
|
|
a35edf7d0f | ||
|
|
9d4d5b7133 | ||
|
|
8d91a76bc9 | ||
|
|
6910428a93 | ||
|
|
cb181d388a | ||
|
|
aad4b6f925 | ||
|
|
821b74d7c1 | ||
|
|
8963d29ab4 | ||
|
|
699360064e | ||
|
|
3f94f830fc | ||
|
|
aaba353a9e | ||
|
|
abdff1c877 | ||
|
|
16fd8183b0 | ||
|
|
d3eaa3a4d9 | ||
|
|
02aba83017 | ||
|
|
e78c43e9d9 | ||
|
|
d6c0ae21de | ||
|
|
2c35df35b6 | ||
|
|
90d4f4296b | ||
|
|
bf7747268b | ||
|
|
552cb78458 | ||
|
|
899994027d | ||
|
|
99250b0498 | ||
|
|
a2ca19d718 |
4
.github/actions/setup/action.yml
vendored
4
.github/actions/setup/action.yml
vendored
@@ -64,7 +64,7 @@ runs:
|
||||
rustflags: ""
|
||||
- name: Setup rust dependencies
|
||||
if: ${{ contains(inputs.dependencies, 'rust') }}
|
||||
uses: taiki-e/install-action@cf525cb33f51aca27cd6fa02034117ab963ff9f1 # v2
|
||||
uses: taiki-e/install-action@b5fddbb5361bce8a06fb168c9d403a6cc552b084 # v2
|
||||
with:
|
||||
tool: cargo-deny cargo-machete cargo-llvm-cov nextest
|
||||
- name: Setup node (web)
|
||||
@@ -104,7 +104,7 @@ runs:
|
||||
working-directory: ${{ inputs.working-directory }}
|
||||
run: |
|
||||
export PSQL_TAG=${{ inputs.postgresql_version }}
|
||||
docker compose -f .github/actions/setup/compose.yml up -d
|
||||
docker compose -f .github/actions/setup/compose.yml up -d --wait
|
||||
cd web && npm ci
|
||||
- name: Generate config
|
||||
if: ${{ contains(inputs.dependencies, 'python') }}
|
||||
|
||||
6
.github/actions/setup/compose.yml
vendored
6
.github/actions/setup/compose.yml
vendored
@@ -8,8 +8,14 @@ services:
|
||||
POSTGRES_USER: authentik
|
||||
POSTGRES_PASSWORD: "EK-5jnKfjrGRm<77"
|
||||
POSTGRES_DB: authentik
|
||||
PGDATA: /var/lib/postgresql/data/pgdata
|
||||
ports:
|
||||
- 5432:5432
|
||||
healthcheck:
|
||||
test: ["CMD-SHELL", "pg_isready -U $${POSTGRES_USER} -d $${POSTGRES_DB} -h 127.0.0.1"]
|
||||
interval: 1s
|
||||
timeout: 5s
|
||||
retries: 60
|
||||
restart: always
|
||||
s3:
|
||||
container_name: s3
|
||||
|
||||
2
.github/workflows/_reusable-docker-build.yml
vendored
2
.github/workflows/_reusable-docker-build.yml
vendored
@@ -90,7 +90,7 @@ jobs:
|
||||
registry: ghcr.io
|
||||
username: ${{ github.repository_owner }}
|
||||
password: ${{ secrets.GITHUB_TOKEN }}
|
||||
- uses: int128/docker-manifest-create-action@7df7f9e221d927eaadf87db231ddf728047308a4 # v2
|
||||
- uses: int128/docker-manifest-create-action@fa55f72001a6c74b0f4997dca65c70d334905180 # v2
|
||||
id: build
|
||||
with:
|
||||
tags: ${{ matrix.tag }}
|
||||
|
||||
16
.github/workflows/ci-main.yml
vendored
16
.github/workflows/ci-main.yml
vendored
@@ -282,10 +282,18 @@ jobs:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
job:
|
||||
- name: basic
|
||||
glob: tests/openid_conformance/test_basic.py
|
||||
- name: implicit
|
||||
glob: tests/openid_conformance/test_implicit.py
|
||||
- name: oidc_basic
|
||||
glob: tests/openid_conformance/test_oidc_basic.py
|
||||
- name: oidc_implicit
|
||||
glob: tests/openid_conformance/test_oidc_implicit.py
|
||||
- name: oidc_rp-initiated
|
||||
glob: tests/openid_conformance/test_oidc_rp_initiated.py
|
||||
- name: oidc_frontchannel
|
||||
glob: tests/openid_conformance/test_oidc_frontchannel.py
|
||||
- name: oidc_backchannel
|
||||
glob: tests/openid_conformance/test_oidc_backchannel.py
|
||||
- name: ssf_transmitter
|
||||
glob: tests/openid_conformance/test_ssf_transmitter.py
|
||||
steps:
|
||||
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v5
|
||||
- name: Setup authentik env
|
||||
|
||||
5
.gitignore
vendored
5
.gitignore
vendored
@@ -229,6 +229,11 @@ source_docs/
|
||||
|
||||
### Golang ###
|
||||
/vendor/
|
||||
server
|
||||
proxy
|
||||
ldap
|
||||
rac
|
||||
radius
|
||||
|
||||
### Docker ###
|
||||
tests/openid_conformance/exports/*.zip
|
||||
|
||||
112
Cargo.lock
generated
112
Cargo.lock
generated
@@ -17,18 +17,6 @@ version = "2.0.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "320119579fcad9c21884f5c4861d16174d0e06250625266f50fe6898340abefa"
|
||||
|
||||
[[package]]
|
||||
name = "ahash"
|
||||
version = "0.8.12"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "5a15f179cd60c4584b8a8c596927aadc462e27f2ca70c04e0071964a73ba7a75"
|
||||
dependencies = [
|
||||
"cfg-if",
|
||||
"once_cell",
|
||||
"version_check",
|
||||
"zerocopy",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "aho-corasick"
|
||||
version = "1.1.4"
|
||||
@@ -203,6 +191,7 @@ dependencies = [
|
||||
"tokio",
|
||||
"tracing",
|
||||
"uuid",
|
||||
"which",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -1014,6 +1003,17 @@ dependencies = [
|
||||
"pin-project-lite",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "evmap"
|
||||
version = "11.0.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1b8874945f036109c72242964c1174cf99434e30cfa45bf45fedc983f50046f8"
|
||||
dependencies = [
|
||||
"hashbag",
|
||||
"left-right",
|
||||
"smallvec",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "eyre"
|
||||
version = "0.6.12"
|
||||
@@ -1230,6 +1230,21 @@ dependencies = [
|
||||
"slab",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "generator"
|
||||
version = "0.8.8"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "52f04ae4152da20c76fe800fa48659201d5cf627c5149ca0b707b69d7eef6cf9"
|
||||
dependencies = [
|
||||
"cc",
|
||||
"cfg-if",
|
||||
"libc",
|
||||
"log",
|
||||
"rustversion",
|
||||
"windows-link",
|
||||
"windows-result",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "generic-array"
|
||||
version = "0.14.7"
|
||||
@@ -1311,6 +1326,12 @@ dependencies = [
|
||||
"tracing",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "hashbag"
|
||||
version = "0.1.13"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7040a10f52cba493ddb09926e15d10a9d8a28043708a405931fe4c6f19fac064"
|
||||
|
||||
[[package]]
|
||||
name = "hashbrown"
|
||||
version = "0.15.5"
|
||||
@@ -1868,6 +1889,17 @@ version = "0.1.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "09edd9e8b54e49e587e4f6295a7d29c3ea94d469cb40ab8ca70b288248a81db2"
|
||||
|
||||
[[package]]
|
||||
name = "left-right"
|
||||
version = "0.11.7"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "0f0c21e4c8ff95f487fb34e6f9182875f42c84cef966d29216bf115d9bba835a"
|
||||
dependencies = [
|
||||
"crossbeam-utils",
|
||||
"loom",
|
||||
"slab",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "libc"
|
||||
version = "0.2.183"
|
||||
@@ -1939,6 +1971,19 @@ version = "0.4.29"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "5e5032e24019045c762d3c0f28f5b6b8bbf38563a65908389bf7978758920897"
|
||||
|
||||
[[package]]
|
||||
name = "loom"
|
||||
version = "0.7.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "419e0dc8046cb947daa77eb95ae174acfbddb7673b4151f56d1eed8e93fbfaca"
|
||||
dependencies = [
|
||||
"cfg-if",
|
||||
"generator",
|
||||
"scoped-tls",
|
||||
"tracing",
|
||||
"tracing-subscriber",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "lru-slab"
|
||||
version = "0.1.2"
|
||||
@@ -1978,21 +2023,22 @@ checksum = "f8ca58f447f06ed17d5fc4043ce1b10dd205e060fb3ce5b979b8ed8e59ff3f79"
|
||||
|
||||
[[package]]
|
||||
name = "metrics"
|
||||
version = "0.24.3"
|
||||
version = "0.24.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "5d5312e9ba3771cfa961b585728215e3d972c950a3eed9252aa093d6301277e8"
|
||||
checksum = "ff56c2e7dce6bd462e3b8919986a617027481b1dcc703175b58cf9dd98a2f071"
|
||||
dependencies = [
|
||||
"ahash",
|
||||
"portable-atomic",
|
||||
"rapidhash",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "metrics-exporter-prometheus"
|
||||
version = "0.18.1"
|
||||
version = "0.18.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "3589659543c04c7dc5526ec858591015b87cd8746583b51b48ef4353f99dbcda"
|
||||
checksum = "1db0d8f1fc9e62caebd0319e11eaec5822b0186c171568f0480b46a0137f9108"
|
||||
dependencies = [
|
||||
"base64 0.22.1",
|
||||
"evmap",
|
||||
"indexmap",
|
||||
"metrics",
|
||||
"metrics-util",
|
||||
@@ -2813,6 +2859,15 @@ dependencies = [
|
||||
"rand_core 0.9.5",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rapidhash"
|
||||
version = "4.4.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b5e48930979c155e2f33aa36ab3119b5ee81332beb6482199a8ecd6029b80b59"
|
||||
dependencies = [
|
||||
"rustversion",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "raw-cpuid"
|
||||
version = "11.6.0"
|
||||
@@ -2871,9 +2926,9 @@ checksum = "dc897dd8d9e8bd1ed8cdad82b5966c3e0ecae09fb1907d58efaa013543185d0a"
|
||||
|
||||
[[package]]
|
||||
name = "reqwest"
|
||||
version = "0.13.2"
|
||||
version = "0.13.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ab3f43e3283ab1488b624b44b0e988d0acea0b3214e694730a055cb6b2efa801"
|
||||
checksum = "62e0021ea2c22aed41653bc7e1419abb2c97e038ff2c33d0e1309e49a97deec0"
|
||||
dependencies = [
|
||||
"base64 0.22.1",
|
||||
"bytes",
|
||||
@@ -3000,9 +3055,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "rustls"
|
||||
version = "0.23.39"
|
||||
version = "0.23.40"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7c2c118cb077cca2822033836dfb1b975355dfb784b5e8da48f7b6c5db74e60e"
|
||||
checksum = "ef86cd5876211988985292b91c96a8f2d298df24e75989a43a3c73f2d4d8168b"
|
||||
dependencies = [
|
||||
"aws-lc-rs",
|
||||
"log",
|
||||
@@ -3105,6 +3160,12 @@ dependencies = [
|
||||
"windows-sys 0.61.2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "scoped-tls"
|
||||
version = "1.0.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e1cf6437eb19a8f4a6cc0f7dca544973b0b78843adbfeb3683d1a94a0024a294"
|
||||
|
||||
[[package]]
|
||||
name = "scopeguard"
|
||||
version = "1.2.0"
|
||||
@@ -4515,6 +4576,15 @@ dependencies = [
|
||||
"rustls-pki-types",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "which"
|
||||
version = "8.0.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "81995fafaaaf6ae47a7d0cc83c67caf92aeb7e5331650ae6ff856f7c0c60c459"
|
||||
dependencies = [
|
||||
"libc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "whoami"
|
||||
version = "1.6.1"
|
||||
|
||||
10
Cargo.toml
10
Cargo.toml
@@ -43,15 +43,15 @@ hyper-unix-socket = "= 0.6.1"
|
||||
hyper-util = "= 0.1.20"
|
||||
ipnet = { version = "= 2.12.0", features = ["serde"] }
|
||||
json-subscriber = "= 0.2.8"
|
||||
metrics = "= 0.24.3"
|
||||
metrics-exporter-prometheus = { version = "= 0.18.1", default-features = false }
|
||||
metrics = "= 0.24.5"
|
||||
metrics-exporter-prometheus = { version = "= 0.18.3", default-features = false }
|
||||
nix = { version = "= 0.31.2", features = ["hostname", "signal"] }
|
||||
notify = "= 8.2.0"
|
||||
pin-project-lite = "= 0.2.17"
|
||||
pyo3 = "= 0.28.3"
|
||||
pyo3-build-config = "= 0.28.3"
|
||||
regex = "= 1.12.3"
|
||||
reqwest = { version = "= 0.13.2", features = [
|
||||
reqwest = { version = "= 0.13.3", features = [
|
||||
"form",
|
||||
"json",
|
||||
"multipart",
|
||||
@@ -66,7 +66,7 @@ reqwest-middleware = { version = "= 0.5.1", features = [
|
||||
"query",
|
||||
"rustls",
|
||||
] }
|
||||
rustls = { version = "= 0.23.39", features = ["fips"] }
|
||||
rustls = { version = "= 0.23.40", features = ["fips"] }
|
||||
sentry = { version = "= 0.47.0", default-features = false, features = [
|
||||
"backtrace",
|
||||
"contexts",
|
||||
@@ -113,6 +113,7 @@ tracing-subscriber = { version = "= 0.3.23", features = [
|
||||
] }
|
||||
url = "= 2.5.8"
|
||||
uuid = { version = "= 1.23.1", features = ["serde", "v4"] }
|
||||
which = "= 8.0.2"
|
||||
|
||||
ak-axum = { package = "authentik-axum", version = "2026.5.0-rc1", path = "./packages/ak-axum" }
|
||||
ak-client = { package = "authentik-client", version = "2026.5.0-rc1", path = "./packages/client-rust" }
|
||||
@@ -282,6 +283,7 @@ sqlx = { workspace = true, optional = true }
|
||||
tokio.workspace = true
|
||||
tracing.workspace = true
|
||||
uuid.workspace = true
|
||||
which.workspace = true
|
||||
|
||||
[lints]
|
||||
workspace = true
|
||||
|
||||
11
Makefile
11
Makefile
@@ -109,14 +109,11 @@ i18n-extract: core-i18n-extract web-i18n-extract ## Extract strings that requir
|
||||
aws-cfn:
|
||||
cd lifecycle/aws && npm i && $(UV) run npm run aws-cfn
|
||||
|
||||
run-server: ## Run the main authentik server process
|
||||
$(UV) run ak server
|
||||
run: ## Run the main authentik server and worker processes
|
||||
$(UV) run ak allinone
|
||||
|
||||
run-worker: ## Run the main authentik worker process
|
||||
$(UV) run ak worker
|
||||
|
||||
run-worker-watch: ## Run the authentik worker, with auto reloading
|
||||
watchexec --on-busy-update=restart --stop-signal=SIGINT --exts py,rs --no-meta --notify -- $(UV) run ak worker
|
||||
run-watch: ## Run the authentik server and worker, with auto reloading
|
||||
watchexec --on-busy-update=restart --stop-signal=SIGINT --exts py,rs,go --no-meta --notify -- $(UV) run ak allinone
|
||||
|
||||
core-i18n-extract:
|
||||
$(UV) run ak makemessages \
|
||||
|
||||
@@ -1,31 +1,73 @@
|
||||
"""authentik API Modelviewset tests"""
|
||||
|
||||
from collections.abc import Callable
|
||||
from urllib.parse import urlencode
|
||||
|
||||
from django.test import TestCase
|
||||
from rest_framework.viewsets import ModelViewSet, ReadOnlyModelViewSet
|
||||
|
||||
from authentik.admin.api.version_history import VersionHistoryViewSet
|
||||
from authentik.api.v3.urls import router
|
||||
from authentik.core.tests.utils import RequestFactory, create_test_admin_user
|
||||
from authentik.lib.generators import generate_id
|
||||
from authentik.tenants.api.domains import DomainViewSet
|
||||
from authentik.tenants.api.tenants import TenantViewSet
|
||||
from authentik.tenants.utils import get_current_tenant
|
||||
|
||||
|
||||
class TestModelViewSets(TestCase):
|
||||
"""Test Viewset"""
|
||||
|
||||
def setUp(self):
|
||||
self.user = create_test_admin_user()
|
||||
self.factory = RequestFactory()
|
||||
|
||||
def viewset_tester_factory(test_viewset: type[ModelViewSet]) -> Callable:
|
||||
|
||||
def viewset_tester_factory(test_viewset: type[ModelViewSet], full=True) -> dict[str, Callable]:
|
||||
"""Test Viewset"""
|
||||
|
||||
def tester(self: TestModelViewSets):
|
||||
self.assertIsNotNone(getattr(test_viewset, "search_fields", None))
|
||||
def test_attrs(self: TestModelViewSets) -> None:
|
||||
"""Test attributes we require on all viewsets"""
|
||||
self.assertIsNotNone(getattr(test_viewset, "ordering", None))
|
||||
self.assertIsNotNone(getattr(test_viewset, "search_fields", None))
|
||||
filterset_class = getattr(test_viewset, "filterset_class", None)
|
||||
if not filterset_class:
|
||||
self.assertIsNotNone(getattr(test_viewset, "filterset_fields", None))
|
||||
|
||||
return tester
|
||||
def test_ordering(self: TestModelViewSets) -> None:
|
||||
"""Test that all ordering fields are correct"""
|
||||
view = test_viewset.as_view({"get": "list"})
|
||||
for ordering_field in test_viewset.ordering:
|
||||
with self.subTest(ordering_field):
|
||||
req = self.factory.get(
|
||||
f"/?{urlencode({'ordering': ordering_field}, doseq=True)}", user=self.user
|
||||
)
|
||||
req.tenant = get_current_tenant()
|
||||
res = view(req)
|
||||
self.assertEqual(res.status_code, 200)
|
||||
|
||||
def test_search(self: TestModelViewSets) -> None:
|
||||
"""Test that search fields are correct"""
|
||||
view = test_viewset.as_view({"get": "list"})
|
||||
req = self.factory.get(
|
||||
f"/?{urlencode({'search': generate_id()}, doseq=True)}", user=self.user
|
||||
)
|
||||
req.tenant = get_current_tenant()
|
||||
res = view(req)
|
||||
self.assertEqual(res.status_code, 200)
|
||||
|
||||
cases = {
|
||||
"attrs": test_attrs,
|
||||
}
|
||||
if full:
|
||||
cases["ordering"] = test_ordering
|
||||
cases["search"] = test_search
|
||||
return cases
|
||||
|
||||
|
||||
for _, viewset, _ in router.registry:
|
||||
if not issubclass(viewset, ModelViewSet | ReadOnlyModelViewSet):
|
||||
continue
|
||||
setattr(TestModelViewSets, f"test_viewset_{viewset.__name__}", viewset_tester_factory(viewset))
|
||||
full = viewset not in [VersionHistoryViewSet, DomainViewSet, TenantViewSet]
|
||||
for test, case in viewset_tester_factory(viewset, full=full).items():
|
||||
setattr(TestModelViewSets, f"test_viewset_{viewset.__name__}_{test}", case)
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
"""Serializer mixin for managed models"""
|
||||
|
||||
from json import JSONDecodeError, loads
|
||||
from typing import cast
|
||||
|
||||
from django.conf import settings
|
||||
@@ -44,6 +45,7 @@ class BlueprintUploadSerializer(PassiveSerializer):
|
||||
|
||||
file = FileField(required=False)
|
||||
path = CharField(required=False)
|
||||
context = CharField(required=False, allow_blank=True)
|
||||
|
||||
def validate_path(self, path: str) -> str:
|
||||
"""Ensure the path (if set) specified is retrievable"""
|
||||
@@ -54,6 +56,18 @@ class BlueprintUploadSerializer(PassiveSerializer):
|
||||
raise ValidationError(_("Blueprint file does not exist"))
|
||||
return path
|
||||
|
||||
def validate_context(self, context: str) -> dict:
|
||||
"""Parse context as a JSON object"""
|
||||
if not context:
|
||||
return {}
|
||||
try:
|
||||
parsed = loads(context)
|
||||
except JSONDecodeError as exc:
|
||||
raise ValidationError(_("Context must be valid JSON")) from exc
|
||||
if not isinstance(parsed, dict):
|
||||
raise ValidationError(_("Context must be a JSON object"))
|
||||
return parsed
|
||||
|
||||
|
||||
class ManagedSerializer:
|
||||
"""Managed Serializer"""
|
||||
@@ -126,7 +140,7 @@ class BlueprintInstanceSerializer(ModelSerializer):
|
||||
|
||||
def check_blueprint_perms(blueprint: Blueprint, user: User, explicit_action: str | None = None):
|
||||
"""Check for individual permissions for each model in a blueprint"""
|
||||
for entry in blueprint.entries:
|
||||
for entry in blueprint.iter_entries():
|
||||
full_model = entry.get_model(blueprint)
|
||||
app, __, model = full_model.partition(".")
|
||||
perms = [
|
||||
@@ -224,7 +238,8 @@ class BlueprintInstanceViewSet(UsedByMixin, ModelViewSet):
|
||||
).retrieve_file()
|
||||
else:
|
||||
raise ValidationError("Either path or file must be set")
|
||||
importer = Importer.from_string(string_contents)
|
||||
context = body.validated_data.get("context") or {}
|
||||
importer = Importer.from_string(string_contents, context)
|
||||
|
||||
check_blueprint_perms(importer.blueprint, request.user)
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
"""Test blueprints v1 api"""
|
||||
|
||||
from json import loads
|
||||
from json import dumps, loads
|
||||
from tempfile import NamedTemporaryFile, mkdtemp
|
||||
|
||||
from django.urls import reverse
|
||||
@@ -8,7 +8,11 @@ from rest_framework.test import APITestCase
|
||||
from yaml import dump
|
||||
|
||||
from authentik.core.tests.utils import create_test_admin_user
|
||||
from authentik.flows.models import Flow
|
||||
from authentik.lib.config import CONFIG
|
||||
from authentik.lib.generators import generate_id
|
||||
from authentik.stages.invitation.models import InvitationStage
|
||||
from authentik.stages.user_write.models import UserWriteStage
|
||||
|
||||
TMP = mkdtemp("authentik-blueprints")
|
||||
|
||||
@@ -80,3 +84,107 @@ class TestBlueprintsV1API(APITestCase):
|
||||
res.content.decode(),
|
||||
{"content": ["Failed to validate blueprint", "- Invalid blueprint version"]},
|
||||
)
|
||||
|
||||
def test_api_import_with_context(self):
|
||||
"""Test that the import endpoint applies the supplied context to the real blueprint"""
|
||||
slug = f"invitation-enrollment-{generate_id()}"
|
||||
flow_name = f"Invitation Enrollment {generate_id()}"
|
||||
stage_name = f"invitation-stage-{generate_id()}"
|
||||
user_type = "internal"
|
||||
continue_without_invitation = True
|
||||
|
||||
res = self.client.post(
|
||||
reverse("authentik_api:blueprintinstance-import-"),
|
||||
data={
|
||||
"path": "example/flows-invitation-enrollment-minimal.yaml",
|
||||
"context": dumps(
|
||||
{
|
||||
"flow_slug": slug,
|
||||
"flow_name": flow_name,
|
||||
"stage_name": stage_name,
|
||||
"continue_flow_without_invitation": continue_without_invitation,
|
||||
"user_type": user_type,
|
||||
}
|
||||
),
|
||||
},
|
||||
format="multipart",
|
||||
)
|
||||
self.assertEqual(res.status_code, 200)
|
||||
self.assertTrue(res.json()["success"])
|
||||
|
||||
flow = Flow.objects.get(slug=slug)
|
||||
self.assertEqual(flow.name, flow_name)
|
||||
self.assertEqual(flow.title, flow_name)
|
||||
|
||||
invitation_stage = InvitationStage.objects.get(name=stage_name)
|
||||
self.assertEqual(
|
||||
invitation_stage.continue_flow_without_invitation,
|
||||
continue_without_invitation,
|
||||
)
|
||||
|
||||
user_write_stage = UserWriteStage.objects.get(
|
||||
name=f"invitation-enrollment-user-write-{slug}"
|
||||
)
|
||||
self.assertEqual(user_write_stage.user_type, user_type)
|
||||
self.assertEqual(user_write_stage.user_path_template, f"users/{user_type}")
|
||||
|
||||
def test_api_import_blank_path(self):
|
||||
"""Validator returns empty path unchanged (covers api.py:53)."""
|
||||
with NamedTemporaryFile(mode="w+", suffix=".yaml") as file:
|
||||
file.write(dump({"version": 1, "entries": []}))
|
||||
file.flush()
|
||||
file.seek(0)
|
||||
res = self.client.post(
|
||||
reverse("authentik_api:blueprintinstance-import-"),
|
||||
data={"path": "", "file": file},
|
||||
format="multipart",
|
||||
)
|
||||
self.assertEqual(res.status_code, 200)
|
||||
|
||||
def test_api_import_unknown_path(self):
|
||||
"""Path not in available blueprints is rejected (covers api.py:56)."""
|
||||
res = self.client.post(
|
||||
reverse("authentik_api:blueprintinstance-import-"),
|
||||
data={"path": "does/not/exist.yaml"},
|
||||
format="multipart",
|
||||
)
|
||||
self.assertEqual(res.status_code, 400)
|
||||
self.assertIn("Blueprint file does not exist", res.content.decode())
|
||||
|
||||
def test_api_import_blank_context(self):
|
||||
"""Blank context is normalized to empty dict (covers api.py:62)."""
|
||||
res = self.client.post(
|
||||
reverse("authentik_api:blueprintinstance-import-"),
|
||||
data={
|
||||
"path": "example/flows-invitation-enrollment-minimal.yaml",
|
||||
"context": "",
|
||||
},
|
||||
format="multipart",
|
||||
)
|
||||
self.assertEqual(res.status_code, 200)
|
||||
|
||||
def test_api_import_invalid_json_context(self):
|
||||
"""Malformed JSON context raises ValidationError (covers api.py:65-66)."""
|
||||
res = self.client.post(
|
||||
reverse("authentik_api:blueprintinstance-import-"),
|
||||
data={
|
||||
"path": "example/flows-invitation-enrollment-minimal.yaml",
|
||||
"context": "{not json",
|
||||
},
|
||||
format="multipart",
|
||||
)
|
||||
self.assertEqual(res.status_code, 400)
|
||||
self.assertIn("Context must be valid JSON", res.content.decode())
|
||||
|
||||
def test_api_import_non_object_context(self):
|
||||
"""JSON context that isn't an object is rejected (covers api.py:68)."""
|
||||
res = self.client.post(
|
||||
reverse("authentik_api:blueprintinstance-import-"),
|
||||
data={
|
||||
"path": "example/flows-invitation-enrollment-minimal.yaml",
|
||||
"context": "[1, 2, 3]",
|
||||
},
|
||||
format="multipart",
|
||||
)
|
||||
self.assertEqual(res.status_code, 400)
|
||||
self.assertIn("Context must be a JSON object", res.content.decode())
|
||||
|
||||
@@ -1,8 +1,11 @@
|
||||
"""Test blueprints v1"""
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
from django.test import TransactionTestCase
|
||||
|
||||
from authentik.blueprints.v1.importer import Importer
|
||||
from authentik.enterprise.license import LicenseKey
|
||||
from authentik.flows.models import Flow
|
||||
from authentik.lib.generators import generate_id
|
||||
from authentik.lib.tests.utils import load_fixture
|
||||
@@ -42,3 +45,45 @@ class TestBlueprintsV1Conditions(TransactionTestCase):
|
||||
# Ensure objects do not exist
|
||||
self.assertFalse(Flow.objects.filter(slug=flow_slug1))
|
||||
self.assertFalse(Flow.objects.filter(slug=flow_slug2))
|
||||
|
||||
def test_enterprise_license_context_unlicensed(self):
|
||||
"""Test enterprise license context defaults to a false boolean when unlicensed."""
|
||||
license_key = LicenseKey("test", 0, "Test license", 0, 0)
|
||||
|
||||
with patch("authentik.enterprise.license.LicenseKey.get_total", return_value=license_key):
|
||||
importer = Importer.from_string("""
|
||||
version: 1
|
||||
entries:
|
||||
- identifiers:
|
||||
name: enterprise-test
|
||||
slug: enterprise-test
|
||||
model: authentik_flows.flow
|
||||
conditions:
|
||||
- !Context goauthentik.io/enterprise/licensed
|
||||
attrs:
|
||||
designation: stage_configuration
|
||||
title: foo
|
||||
""")
|
||||
|
||||
self.assertIs(importer.blueprint.context["goauthentik.io/enterprise/licensed"], False)
|
||||
|
||||
def test_enterprise_license_context_licensed(self):
|
||||
"""Test enterprise license context defaults to a true boolean when licensed."""
|
||||
license_key = LicenseKey("test", 253402300799, "Test license", 1000, 1000)
|
||||
|
||||
with patch("authentik.enterprise.license.LicenseKey.get_total", return_value=license_key):
|
||||
importer = Importer.from_string("""
|
||||
version: 1
|
||||
entries:
|
||||
- identifiers:
|
||||
name: enterprise-test
|
||||
slug: enterprise-test
|
||||
model: authentik_flows.flow
|
||||
conditions:
|
||||
- !Context goauthentik.io/enterprise/licensed
|
||||
attrs:
|
||||
designation: stage_configuration
|
||||
title: foo
|
||||
""")
|
||||
|
||||
self.assertIs(importer.blueprint.context["goauthentik.io/enterprise/licensed"], True)
|
||||
|
||||
@@ -146,9 +146,7 @@ class Importer:
|
||||
try:
|
||||
from authentik.enterprise.license import LicenseKey
|
||||
|
||||
context["goauthentik.io/enterprise/licensed"] = (
|
||||
LicenseKey.get_total().status().is_valid,
|
||||
)
|
||||
context["goauthentik.io/enterprise/licensed"] = LicenseKey.get_total().status().is_valid
|
||||
except ModuleNotFoundError:
|
||||
pass
|
||||
return context
|
||||
|
||||
@@ -64,6 +64,7 @@ class BrandSerializer(ModelSerializer):
|
||||
"flow_unenrollment",
|
||||
"flow_user_settings",
|
||||
"flow_device_code",
|
||||
"flow_lockdown",
|
||||
"default_application",
|
||||
"web_certificate",
|
||||
"client_certificates",
|
||||
@@ -117,6 +118,7 @@ class CurrentBrandSerializer(PassiveSerializer):
|
||||
flow_unenrollment = CharField(source="flow_unenrollment.slug", required=False)
|
||||
flow_user_settings = CharField(source="flow_user_settings.slug", required=False)
|
||||
flow_device_code = CharField(source="flow_device_code.slug", required=False)
|
||||
flow_lockdown = CharField(source="flow_lockdown.slug", required=False)
|
||||
|
||||
default_locale = CharField(read_only=True)
|
||||
flags = SerializerMethodField()
|
||||
@@ -154,6 +156,7 @@ class BrandViewSet(UsedByMixin, ModelViewSet):
|
||||
"flow_unenrollment",
|
||||
"flow_user_settings",
|
||||
"flow_device_code",
|
||||
"flow_lockdown",
|
||||
"web_certificate",
|
||||
"client_certificates",
|
||||
]
|
||||
|
||||
25
authentik/brands/migrations/0012_brand_flow_lockdown.py
Normal file
25
authentik/brands/migrations/0012_brand_flow_lockdown.py
Normal file
@@ -0,0 +1,25 @@
|
||||
# Generated by Django 5.2.12 on 2026-03-14 02:58
|
||||
|
||||
import django.db.models.deletion
|
||||
from django.db import migrations, models
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
|
||||
dependencies = [
|
||||
("authentik_brands", "0011_alter_brand_branding_default_flow_background_and_more"),
|
||||
("authentik_flows", "0031_alter_flow_layout"),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.AddField(
|
||||
model_name="brand",
|
||||
name="flow_lockdown",
|
||||
field=models.ForeignKey(
|
||||
null=True,
|
||||
on_delete=django.db.models.deletion.SET_NULL,
|
||||
related_name="brand_lockdown",
|
||||
to="authentik_flows.flow",
|
||||
),
|
||||
),
|
||||
]
|
||||
@@ -58,6 +58,9 @@ class Brand(SerializerModel):
|
||||
flow_device_code = models.ForeignKey(
|
||||
Flow, null=True, on_delete=models.SET_NULL, related_name="brand_device_code"
|
||||
)
|
||||
flow_lockdown = models.ForeignKey(
|
||||
Flow, null=True, on_delete=models.SET_NULL, related_name="brand_lockdown"
|
||||
)
|
||||
|
||||
default_application = models.ForeignKey(
|
||||
"authentik_core.Application",
|
||||
|
||||
@@ -20,11 +20,16 @@ class TestBrands(APITestCase):
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self.default_flags = {}
|
||||
for flag in Flag.available(visibility="public"):
|
||||
self.default_flags[flag().key] = flag.get()
|
||||
Brand.objects.all().delete()
|
||||
|
||||
@property
|
||||
def default_flags(self) -> dict[str, object]:
|
||||
"""Get current public flags.
|
||||
|
||||
Some tests define temporary Flag subclasses, so this can't be cached in setUp.
|
||||
"""
|
||||
return {flag().key: flag.get() for flag in Flag.available(visibility="public")}
|
||||
|
||||
def test_current_brand(self):
|
||||
"""Test Current brand API"""
|
||||
brand = create_test_brand()
|
||||
|
||||
@@ -30,6 +30,8 @@ SAML_BINDING_REDIRECT = "urn:oasis:names:tc:SAML:2.0:bindings:HTTP-Redirect"
|
||||
|
||||
SAML_STATUS_SUCCESS = "urn:oasis:names:tc:SAML:2.0:status:Success"
|
||||
|
||||
DEFAULT_ISSUER = "authentik"
|
||||
|
||||
DSA_SHA1 = "http://www.w3.org/2000/09/xmldsig#dsa-sha1"
|
||||
RSA_SHA1 = "http://www.w3.org/2000/09/xmldsig#rsa-sha1"
|
||||
# https://datatracker.ietf.org/doc/html/rfc4051#section-2.3.2
|
||||
|
||||
@@ -47,7 +47,8 @@ class ApplicationEntitlementViewSet(UsedByMixin, ModelViewSet):
|
||||
search_fields = [
|
||||
"pbm_uuid",
|
||||
"name",
|
||||
"app",
|
||||
"app__name",
|
||||
"app__slug",
|
||||
"attributes",
|
||||
]
|
||||
filterset_fields = [
|
||||
|
||||
@@ -32,19 +32,19 @@ from authentik.rbac.decorators import permission_required
|
||||
class UserAgentDeviceDict(TypedDict):
|
||||
"""User agent device"""
|
||||
|
||||
brand: str
|
||||
brand: str | None = None
|
||||
family: str
|
||||
model: str
|
||||
model: str | None = None
|
||||
|
||||
|
||||
class UserAgentOSDict(TypedDict):
|
||||
"""User agent os"""
|
||||
|
||||
family: str
|
||||
major: str
|
||||
minor: str
|
||||
patch: str
|
||||
patch_minor: str
|
||||
major: str | None = None
|
||||
minor: str | None = None
|
||||
patch: str | None = None
|
||||
patch_minor: str | None = None
|
||||
|
||||
|
||||
class UserAgentBrowserDict(TypedDict):
|
||||
|
||||
@@ -14,6 +14,7 @@ from django.utils.http import urlencode
|
||||
from django.utils.text import slugify
|
||||
from django.utils.timezone import now
|
||||
from django.utils.translation import gettext as _
|
||||
from django.utils.translation import gettext_lazy
|
||||
from django_filters.filters import (
|
||||
BooleanFilter,
|
||||
CharFilter,
|
||||
@@ -106,6 +107,10 @@ from authentik.stages.email.utils import TemplateEmailMessage
|
||||
|
||||
LOGGER = get_logger()
|
||||
|
||||
INVALID_PASSWORD_HASH_MESSAGE = gettext_lazy(
|
||||
"Invalid password hash format. Must be a valid Django password hash."
|
||||
)
|
||||
|
||||
|
||||
class ParamUserSerializer(PassiveSerializer):
|
||||
"""Partial serializer for query parameters to select a user"""
|
||||
@@ -190,47 +195,79 @@ class UserSerializer(ModelSerializer):
|
||||
return RoleSerializer(instance.roles, many=True).data
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
"""Setting password and permissions directly is allowed only in blueprints."""
|
||||
super().__init__(*args, **kwargs)
|
||||
if SERIALIZER_CONTEXT_BLUEPRINT in self.context:
|
||||
self.fields["password"] = CharField(required=False, allow_null=True)
|
||||
self.fields["password_hash"] = CharField(required=False, allow_null=True)
|
||||
self.fields["permissions"] = ListField(
|
||||
required=False,
|
||||
child=ChoiceField(choices=get_permission_choices()),
|
||||
)
|
||||
|
||||
def create(self, validated_data: dict) -> User:
|
||||
"""If this serializer is used in the blueprint context, we allow for
|
||||
directly setting a password. However should be done via the `set_password`
|
||||
method instead of directly setting it like rest_framework."""
|
||||
password = validated_data.pop("password", None)
|
||||
perms_qs = Permission.objects.filter(
|
||||
codename__in=[x.split(".")[1] for x in validated_data.pop("permissions", [])]
|
||||
).values_list("content_type__app_label", "codename")
|
||||
perms_list = [f"{ct}.{name}" for ct, name in list(perms_qs)]
|
||||
"""Create a user, with blueprint-only password and permission writes."""
|
||||
is_blueprint = SERIALIZER_CONTEXT_BLUEPRINT in self.context
|
||||
if is_blueprint:
|
||||
password = validated_data.pop("password", None)
|
||||
password_hash = validated_data.pop("password_hash", None)
|
||||
permissions = validated_data.pop("permissions", [])
|
||||
self._validate_password_inputs(password, password_hash)
|
||||
|
||||
instance: User = super().create(validated_data)
|
||||
self._set_password(instance, password)
|
||||
instance.assign_perms_to_managed_role(perms_list)
|
||||
if is_blueprint:
|
||||
self._set_password(instance, password, password_hash)
|
||||
perms_qs = Permission.objects.filter(
|
||||
codename__in=[permission.split(".")[1] for permission in permissions]
|
||||
).values_list("content_type__app_label", "codename")
|
||||
perms_list = [f"{ct}.{name}" for ct, name in perms_qs]
|
||||
instance.assign_perms_to_managed_role(perms_list)
|
||||
self._ensure_password_not_empty(instance)
|
||||
return instance
|
||||
|
||||
def update(self, instance: User, validated_data: dict) -> User:
|
||||
"""Same as `create` above, set the password directly if we're in a blueprint
|
||||
context"""
|
||||
password = validated_data.pop("password", None)
|
||||
perms_qs = Permission.objects.filter(
|
||||
codename__in=[x.split(".")[1] for x in validated_data.pop("permissions", [])]
|
||||
).values_list("content_type__app_label", "codename")
|
||||
perms_list = [f"{ct}.{name}" for ct, name in list(perms_qs)]
|
||||
"""Update a user, with blueprint-only password and permission writes."""
|
||||
is_blueprint = SERIALIZER_CONTEXT_BLUEPRINT in self.context
|
||||
if is_blueprint:
|
||||
password = validated_data.pop("password", None)
|
||||
password_hash = validated_data.pop("password_hash", None)
|
||||
permissions = validated_data.pop("permissions", [])
|
||||
self._validate_password_inputs(password, password_hash)
|
||||
|
||||
instance = super().update(instance, validated_data)
|
||||
self._set_password(instance, password)
|
||||
instance.assign_perms_to_managed_role(perms_list)
|
||||
if is_blueprint:
|
||||
self._set_password(instance, password, password_hash)
|
||||
perms_qs = Permission.objects.filter(
|
||||
codename__in=[permission.split(".")[1] for permission in permissions]
|
||||
).values_list("content_type__app_label", "codename")
|
||||
perms_list = [f"{ct}.{name}" for ct, name in perms_qs]
|
||||
instance.assign_perms_to_managed_role(perms_list)
|
||||
self._ensure_password_not_empty(instance)
|
||||
return instance
|
||||
|
||||
def _set_password(self, instance: User, password: str | None):
|
||||
"""Set password of user if we're in a blueprint context, and if it's an empty
|
||||
string then use an unusable password"""
|
||||
if SERIALIZER_CONTEXT_BLUEPRINT in self.context and password:
|
||||
def _validate_password_inputs(self, password: str | None, password_hash: str | None):
|
||||
"""Validate mutually-exclusive password inputs before any model mutation."""
|
||||
if password is not None and password_hash is not None:
|
||||
raise ValidationError(_("Cannot set both password and password_hash. Use only one."))
|
||||
if password_hash is None:
|
||||
return
|
||||
try:
|
||||
User.validate_password_hash(password_hash)
|
||||
except ValueError as exc:
|
||||
LOGGER.warning("Failed to identify password hash format", exc_info=exc)
|
||||
raise ValidationError(INVALID_PASSWORD_HASH_MESSAGE) from exc
|
||||
|
||||
def _set_password(self, instance: User, password: str | None, password_hash: str | None = None):
|
||||
"""Set password from plain text or hash."""
|
||||
if password_hash is not None:
|
||||
instance.set_password_from_hash(password_hash)
|
||||
instance.save()
|
||||
elif password:
|
||||
instance.set_password(password)
|
||||
instance.save()
|
||||
|
||||
def _ensure_password_not_empty(self, instance: User):
|
||||
"""Store an explicit unusable password instead of an empty password field."""
|
||||
if len(instance.password) == 0:
|
||||
instance.set_unusable_password()
|
||||
instance.save()
|
||||
@@ -399,6 +436,12 @@ class UserPasswordSetSerializer(PassiveSerializer):
|
||||
password = CharField(required=True)
|
||||
|
||||
|
||||
class UserPasswordHashSetSerializer(PassiveSerializer):
|
||||
"""Payload to set a users' password hash directly"""
|
||||
|
||||
password = CharField(required=True)
|
||||
|
||||
|
||||
class UserServiceAccountSerializer(PassiveSerializer):
|
||||
"""Payload to create a service account"""
|
||||
|
||||
@@ -520,6 +563,9 @@ class UsersFilter(FilterSet):
|
||||
|
||||
|
||||
class UserViewSet(
|
||||
ConditionalInheritance(
|
||||
"authentik.enterprise.stages.account_lockdown.api.UserAccountLockdownMixin"
|
||||
),
|
||||
ConditionalInheritance("authentik.enterprise.reports.api.reports.ExportMixin"),
|
||||
UsedByMixin,
|
||||
ModelViewSet,
|
||||
@@ -742,6 +788,11 @@ class UserViewSet(
|
||||
self.request.session.modified = True
|
||||
return Response(serializer.initial_data)
|
||||
|
||||
def _update_session_hash_after_password_change(self, request: Request, user: User):
|
||||
if user.pk == request.user.pk and SESSION_KEY_IMPERSONATE_USER not in self.request.session:
|
||||
LOGGER.debug("Updating session hash after password change")
|
||||
update_session_auth_hash(self.request, user)
|
||||
|
||||
@permission_required("authentik_core.reset_user_password")
|
||||
@extend_schema(
|
||||
request=UserPasswordSetSerializer,
|
||||
@@ -765,9 +816,45 @@ class UserViewSet(
|
||||
except (ValidationError, IntegrityError) as exc:
|
||||
LOGGER.debug("Failed to set password", exc=exc)
|
||||
return Response(status=400)
|
||||
if user.pk == request.user.pk and SESSION_KEY_IMPERSONATE_USER not in self.request.session:
|
||||
LOGGER.debug("Updating session hash after password change")
|
||||
update_session_auth_hash(self.request, user)
|
||||
self._update_session_hash_after_password_change(request, user)
|
||||
return Response(status=204)
|
||||
|
||||
@permission_required("authentik_core.reset_user_password")
|
||||
@extend_schema(
|
||||
request=UserPasswordHashSetSerializer,
|
||||
responses={
|
||||
204: OpenApiResponse(description="Successfully changed password"),
|
||||
400: OpenApiResponse(description="Bad request"),
|
||||
},
|
||||
)
|
||||
@action(
|
||||
detail=True,
|
||||
methods=["POST"],
|
||||
permission_classes=[IsAuthenticated],
|
||||
)
|
||||
@validate(UserPasswordHashSetSerializer)
|
||||
def set_password_hash(
|
||||
self, request: Request, pk: int, body: UserPasswordHashSetSerializer
|
||||
) -> Response:
|
||||
"""Set a user's password from a pre-hashed Django password value.
|
||||
|
||||
Submit the Django password hash in the shared ``password`` request field.
|
||||
|
||||
This updates authentik's local password verifier only. It does not attempt
|
||||
to propagate the password change to LDAP or Kerberos because no raw password
|
||||
is available from the request payload.
|
||||
"""
|
||||
user: User = self.get_object()
|
||||
try:
|
||||
user.set_password_from_hash(body.validated_data["password"], request=request)
|
||||
user.save()
|
||||
except ValueError as exc:
|
||||
LOGGER.debug("Failed to set password hash", exc=exc)
|
||||
return Response(data={"password": [INVALID_PASSWORD_HASH_MESSAGE]}, status=400)
|
||||
except (ValidationError, IntegrityError) as exc:
|
||||
LOGGER.debug("Failed to set password hash", exc=exc)
|
||||
return Response(status=400)
|
||||
self._update_session_hash_after_password_change(request, user)
|
||||
return Response(status=204)
|
||||
|
||||
@permission_required("authentik_core.reset_user_password")
|
||||
|
||||
28
authentik/core/management/commands/hash_password.py
Normal file
28
authentik/core/management/commands/hash_password.py
Normal file
@@ -0,0 +1,28 @@
|
||||
"""Hash password using Django's password hashers"""
|
||||
|
||||
from django.contrib.auth.hashers import make_password
|
||||
from django.core.management.base import BaseCommand, CommandError
|
||||
|
||||
|
||||
class Command(BaseCommand):
|
||||
"""Hash a password using Django's password hashers"""
|
||||
|
||||
help = "Hash a password for use with AUTHENTIK_BOOTSTRAP_PASSWORD_HASH"
|
||||
|
||||
def add_arguments(self, parser):
|
||||
parser.add_argument(
|
||||
"password",
|
||||
type=str,
|
||||
help="Password to hash",
|
||||
)
|
||||
|
||||
def handle(self, *args, **options):
|
||||
password = options["password"]
|
||||
|
||||
if not password:
|
||||
raise CommandError("Password cannot be empty")
|
||||
try:
|
||||
hashed = make_password(password)
|
||||
self.stdout.write(hashed)
|
||||
except ValueError as exc:
|
||||
raise CommandError(f"Error hashing password: {exc}") from exc
|
||||
@@ -10,7 +10,7 @@ from uuid import uuid4
|
||||
|
||||
import pgtrigger
|
||||
from deepmerge import always_merger
|
||||
from django.contrib.auth.hashers import check_password
|
||||
from django.contrib.auth.hashers import check_password, identify_hasher
|
||||
from django.contrib.auth.models import AbstractUser, Permission
|
||||
from django.contrib.auth.models import UserManager as DjangoUserManager
|
||||
from django.contrib.sessions.base_session import AbstractBaseSession
|
||||
@@ -560,6 +560,33 @@ class User(SerializerModel, AttributesMixin, AbstractUser):
|
||||
self.password_change_date = now()
|
||||
return super().set_password(raw_password)
|
||||
|
||||
@staticmethod
|
||||
def validate_password_hash(password_hash: str):
|
||||
"""Validate that the value is a recognized Django password hash."""
|
||||
identify_hasher(password_hash) # Raises ValueError if invalid
|
||||
|
||||
def set_password_from_hash(self, password_hash: str, signal=True, sender=None, request=None):
|
||||
"""Set password directly from a pre-hashed value.
|
||||
|
||||
Unlike set_password(), this does not hash the input again. The provided value
|
||||
must already be a valid Django password hash, and it is stored directly on the
|
||||
user after validation.
|
||||
|
||||
Because no raw password is available, downstream password sync integrations
|
||||
such as LDAP and Kerberos cannot be updated from this code path.
|
||||
|
||||
Raises ValueError if the hash format is not recognized.
|
||||
"""
|
||||
self.validate_password_hash(password_hash)
|
||||
if self.pk and signal:
|
||||
from authentik.core.signals import password_hash_changed
|
||||
|
||||
if not sender:
|
||||
sender = self
|
||||
password_hash_changed.send(sender=sender, user=self, request=request)
|
||||
self.password = password_hash
|
||||
self.password_change_date = now()
|
||||
|
||||
def check_password(self, raw_password: str) -> bool:
|
||||
"""
|
||||
Return a boolean of whether the raw_password was correct. Handles
|
||||
|
||||
@@ -16,7 +16,11 @@ LOGGER = get_logger()
|
||||
|
||||
@receiver(post_startup)
|
||||
def post_startup_setup_bootstrap(sender, **_):
|
||||
if not getenv("AUTHENTIK_BOOTSTRAP_PASSWORD") and not getenv("AUTHENTIK_BOOTSTRAP_TOKEN"):
|
||||
if (
|
||||
not getenv("AUTHENTIK_BOOTSTRAP_PASSWORD")
|
||||
and not getenv("AUTHENTIK_BOOTSTRAP_PASSWORD_HASH")
|
||||
and not getenv("AUTHENTIK_BOOTSTRAP_TOKEN")
|
||||
):
|
||||
return
|
||||
LOGGER.info("Configuring authentik through bootstrap environment variables")
|
||||
content = BlueprintInstance(path=BOOTSTRAP_BLUEPRINT).retrieve()
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
"""authentik core signals"""
|
||||
|
||||
from asgiref.sync import async_to_sync
|
||||
from channels.layers import get_channel_layer
|
||||
from django.contrib.auth.signals import user_logged_in
|
||||
from django.core.cache import cache
|
||||
@@ -24,6 +23,8 @@ from authentik.root.ws.consumer import build_device_group
|
||||
|
||||
# Arguments: user: User, password: str
|
||||
password_changed = Signal()
|
||||
# Arguments: user: User, request: HttpRequest | None
|
||||
password_hash_changed = Signal()
|
||||
# Arguments: credentials: dict[str, any], request: HttpRequest,
|
||||
# stage: Stage, context: dict[str, any]
|
||||
login_failed = Signal()
|
||||
@@ -57,7 +58,7 @@ def user_logged_in_session(sender, request: HttpRequest, user: User, **_):
|
||||
layer = get_channel_layer()
|
||||
device_cookie = request.COOKIES.get("authentik_device")
|
||||
if device_cookie:
|
||||
async_to_sync(layer.group_send)(
|
||||
layer.group_send_blocking(
|
||||
build_device_group(device_cookie),
|
||||
{"type": "event.session.authenticated"},
|
||||
)
|
||||
|
||||
@@ -12,7 +12,7 @@
|
||||
{% block head %}
|
||||
<style data-id="static-styles">
|
||||
:root {
|
||||
--ak-global--background-image: url("{{ request.brand.branding_default_flow_background_url }}");
|
||||
--ak-global--background-image: url("{{ request.brand.branding_default_flow_background_url|iriencode|safe }}");
|
||||
}
|
||||
</style>
|
||||
|
||||
|
||||
28
authentik/core/tests/test_hash_password_command.py
Normal file
28
authentik/core/tests/test_hash_password_command.py
Normal file
@@ -0,0 +1,28 @@
|
||||
"""Tests for hash_password management command."""
|
||||
|
||||
from io import StringIO
|
||||
|
||||
from django.contrib.auth.hashers import check_password
|
||||
from django.core.management import call_command
|
||||
from django.core.management.base import CommandError
|
||||
from django.test import TestCase
|
||||
|
||||
|
||||
class TestHashPasswordCommand(TestCase):
|
||||
"""Test hash_password management command."""
|
||||
|
||||
def test_hash_password(self):
|
||||
"""Test hashing a password."""
|
||||
out = StringIO()
|
||||
call_command("hash_password", "test123", stdout=out)
|
||||
hashed = out.getvalue().strip()
|
||||
|
||||
self.assertTrue(hashed.startswith("pbkdf2_sha256$"))
|
||||
self.assertTrue(check_password("test123", hashed))
|
||||
|
||||
def test_hash_password_empty_fails(self):
|
||||
"""Test that empty password raises error."""
|
||||
with self.assertRaises(CommandError) as ctx:
|
||||
call_command("hash_password", "")
|
||||
|
||||
self.assertIn("Password cannot be empty", str(ctx.exception))
|
||||
@@ -1,6 +1,7 @@
|
||||
from http import HTTPStatus
|
||||
from os import environ
|
||||
|
||||
from django.contrib.auth.hashers import make_password
|
||||
from django.urls import reverse
|
||||
|
||||
from authentik.blueprints.tests import apply_blueprint
|
||||
@@ -16,6 +17,7 @@ from authentik.tenants.flags import patch_flag
|
||||
class TestSetup(FlowTestCase):
|
||||
def tearDown(self):
|
||||
environ.pop("AUTHENTIK_BOOTSTRAP_PASSWORD", None)
|
||||
environ.pop("AUTHENTIK_BOOTSTRAP_PASSWORD_HASH", None)
|
||||
environ.pop("AUTHENTIK_BOOTSTRAP_TOKEN", None)
|
||||
|
||||
@patch_flag(Setup, True)
|
||||
@@ -154,3 +156,19 @@ class TestSetup(FlowTestCase):
|
||||
token = Token.objects.filter(identifier="authentik-bootstrap-token").first()
|
||||
self.assertEqual(token.intent, TokenIntents.INTENT_API)
|
||||
self.assertEqual(token.key, environ["AUTHENTIK_BOOTSTRAP_TOKEN"])
|
||||
|
||||
def test_setup_bootstrap_env_password_hash(self):
|
||||
"""Test setup with password hash env var"""
|
||||
User.objects.filter(username="akadmin").delete()
|
||||
Setup.set(False)
|
||||
|
||||
password = generate_id()
|
||||
password_hash = make_password(password)
|
||||
environ["AUTHENTIK_BOOTSTRAP_PASSWORD_HASH"] = password_hash
|
||||
pre_startup.send(sender=self)
|
||||
post_startup.send(sender=self)
|
||||
|
||||
self.assertTrue(Setup.get())
|
||||
user = User.objects.get(username="akadmin")
|
||||
self.assertEqual(user.password, password_hash)
|
||||
self.assertTrue(user.check_password(password))
|
||||
|
||||
@@ -1,8 +1,15 @@
|
||||
"""user tests"""
|
||||
|
||||
from django.test.testcases import TestCase
|
||||
from unittest.mock import patch
|
||||
|
||||
from django.contrib.auth.hashers import make_password
|
||||
from django.test.testcases import TestCase
|
||||
from rest_framework.exceptions import ValidationError
|
||||
|
||||
from authentik.blueprints.v1.importer import SERIALIZER_CONTEXT_BLUEPRINT
|
||||
from authentik.core.api.users import UserSerializer
|
||||
from authentik.core.models import User
|
||||
from authentik.core.signals import password_changed, password_hash_changed
|
||||
from authentik.events.models import Event
|
||||
from authentik.lib.generators import generate_id
|
||||
|
||||
@@ -33,3 +40,99 @@ class TestUsers(TestCase):
|
||||
self.assertEqual(Event.objects.count(), 1)
|
||||
user.ak_groups.all()
|
||||
self.assertEqual(Event.objects.count(), 1)
|
||||
|
||||
def test_set_password_from_hash_signal_skips_source_sync_receivers(self):
|
||||
"""Test hash password updates do not expose a raw password to sync receivers."""
|
||||
user = User.objects.create(
|
||||
username=generate_id(),
|
||||
attributes={"distinguishedName": "cn=test,ou=users,dc=example,dc=com"},
|
||||
)
|
||||
password_changed_captured = []
|
||||
password_hash_changed_captured = []
|
||||
dispatch_uid = generate_id()
|
||||
hash_dispatch_uid = generate_id()
|
||||
|
||||
def password_changed_receiver(sender, **kwargs):
|
||||
password_changed_captured.append(kwargs)
|
||||
|
||||
def password_hash_changed_receiver(sender, **kwargs):
|
||||
password_hash_changed_captured.append(kwargs)
|
||||
|
||||
password_changed.connect(password_changed_receiver, dispatch_uid=dispatch_uid)
|
||||
password_hash_changed.connect(
|
||||
password_hash_changed_receiver, dispatch_uid=hash_dispatch_uid
|
||||
)
|
||||
try:
|
||||
with (
|
||||
patch(
|
||||
"authentik.sources.ldap.signals.LDAPSource.objects.filter"
|
||||
) as ldap_sources_filter,
|
||||
patch(
|
||||
"authentik.sources.kerberos.signals."
|
||||
"UserKerberosSourceConnection.objects.select_related"
|
||||
) as kerberos_connections_select,
|
||||
):
|
||||
user.set_password_from_hash(make_password("new-password")) # nosec
|
||||
user.save()
|
||||
finally:
|
||||
password_changed.disconnect(dispatch_uid=dispatch_uid)
|
||||
password_hash_changed.disconnect(dispatch_uid=hash_dispatch_uid)
|
||||
|
||||
self.assertEqual(password_changed_captured, [])
|
||||
self.assertEqual(len(password_hash_changed_captured), 1)
|
||||
ldap_sources_filter.assert_not_called()
|
||||
kerberos_connections_select.assert_not_called()
|
||||
|
||||
|
||||
class TestUserSerializerPasswordHash(TestCase):
|
||||
"""Test UserSerializer password_hash support in blueprint context."""
|
||||
|
||||
def test_password_hash_sets_password_directly(self):
|
||||
"""Test a valid password hash is stored without re-hashing."""
|
||||
password = "test-password-123" # nosec
|
||||
password_hash = make_password(password)
|
||||
serializer = UserSerializer(
|
||||
data={
|
||||
"username": generate_id(),
|
||||
"name": "Test User",
|
||||
"password_hash": password_hash,
|
||||
},
|
||||
context={SERIALIZER_CONTEXT_BLUEPRINT: True},
|
||||
)
|
||||
|
||||
self.assertTrue(serializer.is_valid(), serializer.errors)
|
||||
user = serializer.save()
|
||||
|
||||
self.assertEqual(user.password, password_hash)
|
||||
self.assertTrue(user.check_password(password))
|
||||
self.assertIsNotNone(user.password_change_date)
|
||||
|
||||
def test_password_hash_rejects_invalid_format(self):
|
||||
"""Test invalid password hash values are rejected."""
|
||||
serializer = UserSerializer(
|
||||
data={
|
||||
"username": generate_id(),
|
||||
"name": "Test User",
|
||||
"password_hash": "not-a-valid-hash",
|
||||
},
|
||||
context={SERIALIZER_CONTEXT_BLUEPRINT: True},
|
||||
)
|
||||
|
||||
self.assertTrue(serializer.is_valid(), serializer.errors)
|
||||
with self.assertRaises(ValidationError) as ctx:
|
||||
serializer.save()
|
||||
|
||||
self.assertIn("Invalid password hash format", str(ctx.exception))
|
||||
|
||||
def test_password_hash_ignored_outside_blueprint_context(self):
|
||||
"""Test password_hash is not accepted by the regular serializer."""
|
||||
serializer = UserSerializer(
|
||||
data={
|
||||
"username": generate_id(),
|
||||
"name": "Test User",
|
||||
"password_hash": make_password("test"), # nosec
|
||||
}
|
||||
)
|
||||
|
||||
self.assertTrue(serializer.is_valid(), serializer.errors)
|
||||
self.assertNotIn("password_hash", serializer.validated_data)
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
from datetime import datetime, timedelta
|
||||
from json import loads
|
||||
|
||||
from django.contrib.auth.hashers import make_password
|
||||
from django.urls.base import reverse
|
||||
from django.utils.timezone import now
|
||||
from rest_framework.test import APITestCase
|
||||
@@ -26,6 +27,9 @@ from authentik.flows.models import FlowAuthenticationRequirement, FlowDesignatio
|
||||
from authentik.lib.generators import generate_id, generate_key
|
||||
from authentik.stages.email.models import EmailStage
|
||||
|
||||
INVALID_PASSWORD_HASH = "not-a-valid-hash"
|
||||
INVALID_PASSWORD_HASH_ERROR = "Invalid password hash format. Must be a valid Django password hash."
|
||||
|
||||
|
||||
class TestUsersAPI(APITestCase):
|
||||
"""Test Users API"""
|
||||
@@ -34,6 +38,20 @@ class TestUsersAPI(APITestCase):
|
||||
self.admin = create_test_admin_user()
|
||||
self.user = create_test_user()
|
||||
|
||||
def _set_password_hash(self, user: User, password_hash: str, client=None):
|
||||
return (client or self.client).post(
|
||||
reverse("authentik_api:user-set-password-hash", kwargs={"pk": user.pk}),
|
||||
data={"password": password_hash},
|
||||
)
|
||||
|
||||
def _assert_password_hash_set(
|
||||
self, user: User, password: str, password_hash: str, response
|
||||
) -> None:
|
||||
self.assertEqual(response.status_code, 204, response.data)
|
||||
user.refresh_from_db()
|
||||
self.assertEqual(user.password, password_hash)
|
||||
self.assertTrue(user.check_password(password))
|
||||
|
||||
def test_filter_type(self):
|
||||
"""Test API filtering by type"""
|
||||
self.client.force_login(self.admin)
|
||||
@@ -113,6 +131,26 @@ class TestUsersAPI(APITestCase):
|
||||
self.assertEqual(response.status_code, 400)
|
||||
self.assertJSONEqual(response.content, {"password": ["This field may not be blank."]})
|
||||
|
||||
def test_set_password_hash(self):
|
||||
"""Test setting a user's password from a hash."""
|
||||
self.client.force_login(self.admin)
|
||||
password = generate_key()
|
||||
password_hash = make_password(password)
|
||||
response = self._set_password_hash(self.user, password_hash)
|
||||
|
||||
self._assert_password_hash_set(self.user, password, password_hash, response)
|
||||
|
||||
def test_set_password_hash_invalid(self):
|
||||
"""Test invalid password hashes are rejected."""
|
||||
self.client.force_login(self.admin)
|
||||
response = self._set_password_hash(self.user, INVALID_PASSWORD_HASH)
|
||||
|
||||
self.assertEqual(response.status_code, 400)
|
||||
self.assertJSONEqual(
|
||||
response.content,
|
||||
{"password": [INVALID_PASSWORD_HASH_ERROR]},
|
||||
)
|
||||
|
||||
def test_recovery(self):
|
||||
"""Test user recovery link"""
|
||||
flow = create_test_flow(
|
||||
@@ -261,6 +299,29 @@ class TestUsersAPI(APITestCase):
|
||||
self.assertTrue(token_filter.exists())
|
||||
self.assertTrue(token_filter.first().expiring)
|
||||
|
||||
def test_service_account_set_password_hash(self):
|
||||
"""Service account password hash can be set through the API."""
|
||||
self.client.force_login(self.admin)
|
||||
response = self.client.post(
|
||||
reverse("authentik_api:user-service-account"),
|
||||
data={
|
||||
"name": "test-sa",
|
||||
"create_group": False,
|
||||
},
|
||||
)
|
||||
self.assertEqual(response.status_code, 200, response.data)
|
||||
body = loads(response.content)
|
||||
|
||||
user = User.objects.get(pk=body["user_pk"])
|
||||
self.assertEqual(user.type, UserTypes.SERVICE_ACCOUNT)
|
||||
self.assertFalse(user.has_usable_password())
|
||||
|
||||
password = generate_key()
|
||||
password_hash = make_password(password)
|
||||
response = self._set_password_hash(user, password_hash)
|
||||
|
||||
self._assert_password_hash_set(user, password, password_hash, response)
|
||||
|
||||
def test_service_account_no_expire(self):
|
||||
"""Service account creation without token expiration"""
|
||||
self.client.force_login(self.admin)
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
from datetime import datetime
|
||||
|
||||
from django.db.models import BooleanField as ModelBooleanField
|
||||
from django.db.models import Case, Q, Value, When
|
||||
from django.db.models import Exists, OuterRef, Q, Subquery
|
||||
from django_filters.rest_framework import BooleanFilter, FilterSet
|
||||
from drf_spectacular.utils import extend_schema
|
||||
from rest_framework.decorators import action
|
||||
@@ -14,7 +13,7 @@ from rest_framework.viewsets import GenericViewSet
|
||||
from authentik.core.api.utils import ModelSerializer
|
||||
from authentik.enterprise.api import EnterpriseRequiredMixin
|
||||
from authentik.enterprise.lifecycle.api.reviews import ReviewSerializer
|
||||
from authentik.enterprise.lifecycle.models import LifecycleIteration, ReviewState
|
||||
from authentik.enterprise.lifecycle.models import LifecycleIteration, LifecycleRule, ReviewState
|
||||
from authentik.enterprise.lifecycle.utils import (
|
||||
ContentTypeField,
|
||||
ReviewerGroupSerializer,
|
||||
@@ -26,20 +25,25 @@ from authentik.enterprise.lifecycle.utils import (
|
||||
from authentik.lib.utils.time import timedelta_from_string
|
||||
|
||||
|
||||
class RelatedRuleSerializer(EnterpriseRequiredMixin, ModelSerializer):
|
||||
reviewer_groups = ReviewerGroupSerializer(many=True, read_only=True)
|
||||
min_reviewers = IntegerField(read_only=True)
|
||||
reviewers = ReviewerUserSerializer(many=True, read_only=True)
|
||||
|
||||
class Meta:
|
||||
model = LifecycleRule
|
||||
fields = ["id", "name", "reviewer_groups", "min_reviewers", "reviewers"]
|
||||
|
||||
|
||||
class LifecycleIterationSerializer(EnterpriseRequiredMixin, ModelSerializer):
|
||||
content_type = ContentTypeField()
|
||||
object_verbose = SerializerMethodField()
|
||||
rule = RelatedRuleSerializer(read_only=True)
|
||||
object_admin_url = SerializerMethodField(read_only=True)
|
||||
grace_period_end = SerializerMethodField(read_only=True)
|
||||
reviews = ReviewSerializer(many=True, read_only=True, source="review_set.all")
|
||||
user_can_review = SerializerMethodField(read_only=True)
|
||||
|
||||
reviewer_groups = ReviewerGroupSerializer(
|
||||
many=True, read_only=True, source="rule.reviewer_groups"
|
||||
)
|
||||
min_reviewers = IntegerField(read_only=True, source="rule.min_reviewers")
|
||||
reviewers = ReviewerUserSerializer(many=True, read_only=True, source="rule.reviewers")
|
||||
|
||||
next_review_date = SerializerMethodField(read_only=True)
|
||||
|
||||
class Meta:
|
||||
@@ -55,10 +59,8 @@ class LifecycleIterationSerializer(EnterpriseRequiredMixin, ModelSerializer):
|
||||
"grace_period_end",
|
||||
"next_review_date",
|
||||
"reviews",
|
||||
"rule",
|
||||
"user_can_review",
|
||||
"reviewer_groups",
|
||||
"min_reviewers",
|
||||
"reviewers",
|
||||
]
|
||||
read_only_fields = fields
|
||||
|
||||
@@ -88,43 +90,55 @@ class IterationViewSet(EnterpriseRequiredMixin, CreateModelMixin, GenericViewSet
|
||||
queryset = LifecycleIteration.objects.all()
|
||||
serializer_class = LifecycleIterationSerializer
|
||||
ordering = ["-opened_on"]
|
||||
ordering_fields = ["state", "content_type__model", "opened_on", "grace_period_end"]
|
||||
ordering_fields = [
|
||||
"state",
|
||||
"content_type__model",
|
||||
"rule__name",
|
||||
"opened_on",
|
||||
"grace_period_end",
|
||||
]
|
||||
filterset_class = LifecycleIterationFilterSet
|
||||
|
||||
def get_queryset(self):
|
||||
user = self.request.user
|
||||
return self.queryset.annotate(
|
||||
user_is_reviewer=Case(
|
||||
When(
|
||||
Q(rule__reviewers=user)
|
||||
| Q(rule__reviewer_groups__in=user.groups.all().with_ancestors()),
|
||||
then=Value(True),
|
||||
),
|
||||
default=Value(False),
|
||||
output_field=ModelBooleanField(),
|
||||
user_is_reviewer=Exists(
|
||||
LifecycleRule.objects.filter(
|
||||
pk=OuterRef("rule_id"),
|
||||
).filter(
|
||||
Q(reviewers=user) | Q(reviewer_groups__in=user.groups.all().with_ancestors())
|
||||
)
|
||||
)
|
||||
).distinct()
|
||||
)
|
||||
|
||||
@extend_schema(
|
||||
operation_id="lifecycle_iterations_list_latest",
|
||||
responses={200: LifecycleIterationSerializer(many=True)},
|
||||
)
|
||||
@action(
|
||||
detail=False,
|
||||
pagination_class=None,
|
||||
methods=["get"],
|
||||
url_path=r"latest/(?P<content_type>[^/]+)/(?P<object_id>[^/]+)",
|
||||
)
|
||||
def latest_iteration(self, request: Request, content_type: str, object_id: str) -> Response:
|
||||
def latest_iterations(self, request: Request, content_type: str, object_id: str) -> Response:
|
||||
ct = parse_content_type(content_type)
|
||||
try:
|
||||
obj = (
|
||||
self.get_queryset()
|
||||
.filter(
|
||||
content_type__app_label=ct["app_label"],
|
||||
content_type__model=ct["model"],
|
||||
object_id=object_id,
|
||||
)
|
||||
.latest("opened_on")
|
||||
latest_ids_subquery = (
|
||||
LifecycleIteration.objects.filter(
|
||||
rule=OuterRef("rule"),
|
||||
content_type__app_label=ct["app_label"],
|
||||
content_type__model=ct["model"],
|
||||
object_id=object_id,
|
||||
)
|
||||
except LifecycleIteration.DoesNotExist:
|
||||
return Response(status=404)
|
||||
serializer = self.get_serializer(obj)
|
||||
.order_by("-opened_on")
|
||||
.values("id")[:1]
|
||||
)
|
||||
latest_per_rule = LifecycleIteration.objects.filter(
|
||||
content_type__app_label=ct["app_label"],
|
||||
content_type__model=ct["model"],
|
||||
object_id=object_id,
|
||||
).filter(id=Subquery(latest_ids_subquery))
|
||||
serializer = self.get_serializer(latest_per_rule, many=True)
|
||||
return Response(serializer.data)
|
||||
|
||||
@extend_schema(
|
||||
|
||||
@@ -84,23 +84,6 @@ class LifecycleRuleSerializer(EnterpriseRequiredMixin, ModelSerializer):
|
||||
raise ValidationError(
|
||||
{"grace_period": _("Grace period must be shorter than the interval.")}
|
||||
)
|
||||
if "content_type" in attrs or "object_id" in attrs:
|
||||
content_type = attrs.get("content_type", getattr(self.instance, "content_type", None))
|
||||
object_id = attrs.get("object_id", getattr(self.instance, "object_id", None))
|
||||
if content_type is not None and object_id is None:
|
||||
existing = LifecycleRule.objects.filter(
|
||||
content_type=content_type, object_id__isnull=True
|
||||
)
|
||||
if self.instance:
|
||||
existing = existing.exclude(pk=self.instance.pk)
|
||||
if existing.exists():
|
||||
raise ValidationError(
|
||||
{
|
||||
"content_type": _(
|
||||
"Only one type-wide rule for each object type is allowed."
|
||||
)
|
||||
}
|
||||
)
|
||||
return attrs
|
||||
|
||||
|
||||
|
||||
@@ -0,0 +1,21 @@
|
||||
# Generated by Django 5.2.11 on 2026-03-05 11:27
|
||||
|
||||
from django.db import migrations
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
|
||||
dependencies = [
|
||||
("authentik_lifecycle", "0002_alter_lifecycleiteration_opened_on"),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.RemoveConstraint(
|
||||
model_name="lifecyclerule",
|
||||
name="uniq_lifecycle_rule_ct_null_object",
|
||||
),
|
||||
migrations.AlterUniqueTogether(
|
||||
name="lifecyclerule",
|
||||
unique_together=set(),
|
||||
),
|
||||
]
|
||||
@@ -56,14 +56,6 @@ class LifecycleRule(SerializerModel):
|
||||
|
||||
class Meta:
|
||||
indexes = [models.Index(fields=["content_type"])]
|
||||
unique_together = [["content_type", "object_id"]]
|
||||
constraints = [
|
||||
models.UniqueConstraint(
|
||||
fields=["content_type"],
|
||||
condition=Q(object_id__isnull=True),
|
||||
name="uniq_lifecycle_rule_ct_null_object",
|
||||
)
|
||||
]
|
||||
|
||||
@property
|
||||
def serializer(self) -> type[BaseSerializer]:
|
||||
@@ -82,12 +74,6 @@ class LifecycleRule(SerializerModel):
|
||||
qs = self.content_type.get_all_objects_for_this_type()
|
||||
if self.object_id:
|
||||
qs = qs.filter(pk=self.object_id)
|
||||
else:
|
||||
qs = qs.exclude(
|
||||
pk__in=LifecycleRule.objects.filter(
|
||||
content_type=self.content_type, object_id__isnull=False
|
||||
).values_list(Cast("object_id", output_field=self._get_pk_field()), flat=True)
|
||||
)
|
||||
return qs
|
||||
|
||||
def _get_stale_iterations(self) -> QuerySet[LifecycleIteration]:
|
||||
@@ -107,8 +93,7 @@ class LifecycleRule(SerializerModel):
|
||||
|
||||
def _get_newly_due_objects(self) -> QuerySet:
|
||||
recent_iteration_ids = LifecycleIteration.objects.filter(
|
||||
content_type=self.content_type,
|
||||
object_id__isnull=False,
|
||||
rule=self,
|
||||
opened_on__gte=start_of_day(
|
||||
timezone.now() + timedelta(days=1) - timedelta_from_string(self.interval)
|
||||
),
|
||||
@@ -214,9 +199,15 @@ class LifecycleIteration(SerializerModel, ManagedModel):
|
||||
}
|
||||
|
||||
def initialize(self):
|
||||
if (self.content_type.app_label, self.content_type.model) == ("authentik_core", "group"):
|
||||
object_label = self.object.name
|
||||
elif (self.content_type.app_label, self.content_type.model) == ("authentik_rbac", "role"):
|
||||
object_label = self.object.name
|
||||
else:
|
||||
object_label = str(self.object)
|
||||
event = Event.new(
|
||||
EventAction.REVIEW_INITIATED,
|
||||
message=_(f"Access review is due for {self.content_type.name} {str(self.object)}"),
|
||||
message=_(f"Access review is due for {self.content_type.name.lower()} {object_label}"),
|
||||
**self._get_event_args(),
|
||||
)
|
||||
event.save()
|
||||
|
||||
@@ -3,6 +3,7 @@ from django.db.models.signals import post_save, pre_delete
|
||||
from django.dispatch import receiver
|
||||
|
||||
from authentik.enterprise.lifecycle.models import LifecycleRule, ReviewState
|
||||
from authentik.tasks.schedules.models import Schedule
|
||||
|
||||
|
||||
@receiver(post_save, sender=LifecycleRule)
|
||||
@@ -11,7 +12,9 @@ def post_rule_save(sender, instance: LifecycleRule, created: bool, **_):
|
||||
|
||||
apply_lifecycle_rule.send_with_options(
|
||||
args=(instance.id,),
|
||||
rel_obj=instance,
|
||||
rel_obj=Schedule.objects.get(
|
||||
actor_name="authentik.enterprise.lifecycle.tasks.apply_lifecycle_rules"
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -4,14 +4,17 @@ from dramatiq import actor
|
||||
from authentik.core.models import User
|
||||
from authentik.enterprise.lifecycle.models import LifecycleRule
|
||||
from authentik.events.models import Event, Notification, NotificationTransport
|
||||
from authentik.tasks.schedules.models import Schedule
|
||||
|
||||
|
||||
@actor(description=_("Dispatch tasks to validate lifecycle rules."))
|
||||
@actor(description=_("Dispatch tasks to apply lifecycle rules."))
|
||||
def apply_lifecycle_rules():
|
||||
for rule in LifecycleRule.objects.all():
|
||||
apply_lifecycle_rule.send_with_options(
|
||||
args=(rule.id,),
|
||||
rel_obj=rule,
|
||||
rel_obj=Schedule.objects.get(
|
||||
actor_name="authentik.enterprise.lifecycle.tasks.apply_lifecycle_rules"
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
from django.apps import apps
|
||||
from django.contrib.contenttypes.models import ContentType
|
||||
from django.urls import reverse
|
||||
from rest_framework.test import APITestCase
|
||||
@@ -19,6 +20,11 @@ class TestLifecycleRuleAPI(APITestCase):
|
||||
self.content_type = ContentType.objects.get_for_model(Application)
|
||||
self.reviewer_group = Group.objects.create(name=generate_id())
|
||||
|
||||
@classmethod
|
||||
def setUpTestData(cls):
|
||||
config = apps.get_app_config("authentik_tasks_schedules")
|
||||
config._on_startup_callback(None)
|
||||
|
||||
def test_list_rules(self):
|
||||
rule = LifecycleRule.objects.create(
|
||||
name=generate_id(),
|
||||
@@ -190,6 +196,11 @@ class TestIterationAPI(APITestCase):
|
||||
self.reviewer_group = Group.objects.create(name=generate_id())
|
||||
self.reviewer_group.users.add(self.user)
|
||||
|
||||
@classmethod
|
||||
def setUpTestData(cls):
|
||||
config = apps.get_app_config("authentik_tasks_schedules")
|
||||
config._on_startup_callback(None)
|
||||
|
||||
def test_open_iterations(self):
|
||||
rule = LifecycleRule.objects.create(
|
||||
name=generate_id(),
|
||||
@@ -231,7 +242,7 @@ class TestIterationAPI(APITestCase):
|
||||
|
||||
response = self.client.get(
|
||||
reverse(
|
||||
"authentik_api:lifecycleiteration-latest-iteration",
|
||||
"authentik_api:lifecycleiteration-latest-iterations",
|
||||
kwargs={
|
||||
"content_type": f"{self.content_type.app_label}.{self.content_type.model}",
|
||||
"object_id": str(self.app.pk),
|
||||
@@ -239,19 +250,20 @@ class TestIterationAPI(APITestCase):
|
||||
)
|
||||
)
|
||||
self.assertEqual(response.status_code, 200)
|
||||
self.assertEqual(response.data["object_id"], str(self.app.pk))
|
||||
self.assertEqual(len(response.data), 1)
|
||||
self.assertEqual(response.data[0]["object_id"], str(self.app.pk))
|
||||
|
||||
def test_latest_iteration_not_found(self):
|
||||
response = self.client.get(
|
||||
reverse(
|
||||
"authentik_api:lifecycleiteration-latest-iteration",
|
||||
"authentik_api:lifecycleiteration-latest-iterations",
|
||||
kwargs={
|
||||
"content_type": f"{self.content_type.app_label}.{self.content_type.model}",
|
||||
"object_id": "00000000-0000-0000-0000-000000000000",
|
||||
},
|
||||
)
|
||||
)
|
||||
self.assertEqual(response.status_code, 404)
|
||||
self.assertEqual(response.data, [])
|
||||
|
||||
def test_iteration_includes_user_can_review(self):
|
||||
rule = LifecycleRule.objects.create(
|
||||
@@ -279,6 +291,11 @@ class TestReviewAPI(APITestCase):
|
||||
self.reviewer_group = Group.objects.create(name=generate_id())
|
||||
self.reviewer_group.users.add(self.user)
|
||||
|
||||
@classmethod
|
||||
def setUpTestData(cls):
|
||||
config = apps.get_app_config("authentik_tasks_schedules")
|
||||
config._on_startup_callback(None)
|
||||
|
||||
def test_create_review(self):
|
||||
rule = LifecycleRule.objects.create(
|
||||
name=generate_id(),
|
||||
|
||||
@@ -2,6 +2,7 @@ import datetime as dt
|
||||
from datetime import timedelta
|
||||
from unittest.mock import patch
|
||||
|
||||
from django.apps import apps
|
||||
from django.contrib.contenttypes.models import ContentType
|
||||
from django.test import RequestFactory, TestCase
|
||||
from django.utils import timezone
|
||||
@@ -29,6 +30,11 @@ class TestLifecycleModels(TestCase):
|
||||
def setUp(self):
|
||||
self.factory = RequestFactory()
|
||||
|
||||
@classmethod
|
||||
def setUpTestData(cls):
|
||||
config = apps.get_app_config("authentik_tasks_schedules")
|
||||
config._on_startup_callback(None)
|
||||
|
||||
def _get_request(self):
|
||||
return self.factory.get("/")
|
||||
|
||||
@@ -438,31 +444,6 @@ class TestLifecycleModels(TestCase):
|
||||
self.assertIn(app_one, objects)
|
||||
self.assertIn(app_two, objects)
|
||||
|
||||
def test_rule_type_excludes_objects_with_specific_rules(self):
|
||||
app_with_rule = Application.objects.create(name=generate_id(), slug=generate_id())
|
||||
app_without_rule = Application.objects.create(name=generate_id(), slug=generate_id())
|
||||
content_type = ContentType.objects.get_for_model(Application)
|
||||
|
||||
# Create a specific rule for app_with_rule
|
||||
LifecycleRule.objects.create(
|
||||
name=generate_id(),
|
||||
content_type=content_type,
|
||||
object_id=str(app_with_rule.pk),
|
||||
interval="days=30",
|
||||
)
|
||||
|
||||
# Create a type-level rule
|
||||
type_rule = LifecycleRule.objects.create(
|
||||
name=generate_id(),
|
||||
content_type=content_type,
|
||||
object_id=None,
|
||||
interval="days=60",
|
||||
)
|
||||
|
||||
objects = list(type_rule.get_objects())
|
||||
self.assertNotIn(app_with_rule, objects)
|
||||
self.assertIn(app_without_rule, objects)
|
||||
|
||||
def test_rule_type_apply_creates_iterations_for_all_objects(self):
|
||||
app_one = Application.objects.create(name=generate_id(), slug=generate_id())
|
||||
app_two = Application.objects.create(name=generate_id(), slug=generate_id())
|
||||
@@ -669,6 +650,73 @@ class TestLifecycleModels(TestCase):
|
||||
self.assertIn(explicit_reviewer, reviewers)
|
||||
self.assertIn(group_member, reviewers)
|
||||
|
||||
def test_multiple_rules_same_object_create_separate_iterations(self):
|
||||
"""Two rules targeting the same object each create their own iteration."""
|
||||
obj = Application.objects.create(name=generate_id(), slug=generate_id())
|
||||
content_type = ContentType.objects.get_for_model(obj)
|
||||
|
||||
rule_one = self._create_rule_for_object(obj, interval="days=30", grace_period="days=10")
|
||||
rule_two = self._create_rule_for_object(obj, interval="days=60", grace_period="days=20")
|
||||
|
||||
iterations = LifecycleIteration.objects.filter(
|
||||
content_type=content_type, object_id=str(obj.pk)
|
||||
)
|
||||
self.assertEqual(iterations.count(), 2)
|
||||
|
||||
iter_one = iterations.get(rule=rule_one)
|
||||
iter_two = iterations.get(rule=rule_two)
|
||||
self.assertEqual(iter_one.state, ReviewState.PENDING)
|
||||
self.assertEqual(iter_two.state, ReviewState.PENDING)
|
||||
self.assertNotEqual(iter_one.pk, iter_two.pk)
|
||||
|
||||
def test_multiple_rules_same_object_reviewed_independently(self):
|
||||
"""Reviewing one rule's iteration does not affect the other rule's iteration."""
|
||||
obj = Application.objects.create(name=generate_id(), slug=generate_id())
|
||||
content_type = ContentType.objects.get_for_model(obj)
|
||||
|
||||
reviewer = create_test_user()
|
||||
|
||||
rule_one = self._create_rule_for_object(obj, min_reviewers=1)
|
||||
rule_two = self._create_rule_for_object(obj, min_reviewers=1)
|
||||
|
||||
group = Group.objects.create(name=generate_id())
|
||||
group.users.add(reviewer)
|
||||
rule_one.reviewer_groups.add(group)
|
||||
rule_two.reviewer_groups.add(group)
|
||||
|
||||
iter_one = LifecycleIteration.objects.get(
|
||||
content_type=content_type, object_id=str(obj.pk), rule=rule_one
|
||||
)
|
||||
iter_two = LifecycleIteration.objects.get(
|
||||
content_type=content_type, object_id=str(obj.pk), rule=rule_two
|
||||
)
|
||||
|
||||
request = self._get_request()
|
||||
|
||||
# Review only rule_one's iteration
|
||||
Review.objects.create(iteration=iter_one, reviewer=reviewer)
|
||||
iter_one.on_review(request)
|
||||
|
||||
iter_one.refresh_from_db()
|
||||
iter_two.refresh_from_db()
|
||||
self.assertEqual(iter_one.state, ReviewState.REVIEWED)
|
||||
self.assertEqual(iter_two.state, ReviewState.PENDING)
|
||||
|
||||
def test_type_rule_and_object_rule_both_create_iterations(self):
|
||||
"""A type-level rule and an object-level rule both create iterations for the same object."""
|
||||
obj = Application.objects.create(name=generate_id(), slug=generate_id())
|
||||
content_type = ContentType.objects.get_for_model(obj)
|
||||
|
||||
object_rule = self._create_rule_for_object(obj, interval="days=30")
|
||||
type_rule = self._create_rule_for_type(Application, interval="days=60")
|
||||
|
||||
iterations = LifecycleIteration.objects.filter(
|
||||
content_type=content_type, object_id=str(obj.pk)
|
||||
)
|
||||
self.assertEqual(iterations.count(), 2)
|
||||
self.assertTrue(iterations.filter(rule=object_rule).exists())
|
||||
self.assertTrue(iterations.filter(rule=type_rule).exists())
|
||||
|
||||
|
||||
class TestLifecycleDateBoundaries(TestCase):
|
||||
"""Verify that start_of_day normalization ensures correct overdue/due
|
||||
@@ -679,6 +727,11 @@ class TestLifecycleDateBoundaries(TestCase):
|
||||
ensures that the boundary is always at midnight, so millisecond variations
|
||||
in task execution time do not affect results."""
|
||||
|
||||
@classmethod
|
||||
def setUpTestData(cls):
|
||||
config = apps.get_app_config("authentik_tasks_schedules")
|
||||
config._on_startup_callback(None)
|
||||
|
||||
def _create_rule_and_iteration(self, grace_period="days=1", interval="days=365"):
|
||||
app = Application.objects.create(name=generate_id(), slug=generate_id())
|
||||
content_type = ContentType.objects.get_for_model(Application)
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
# Generated by Django 5.2.12 on 2026-04-04 16:58
|
||||
|
||||
from django.db import migrations, models
|
||||
import django.contrib.postgres.fields
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
@@ -40,4 +41,109 @@ class Migration(migrations.Migration):
|
||||
]
|
||||
),
|
||||
),
|
||||
migrations.AlterField(
|
||||
model_name="stream",
|
||||
name="events_requested",
|
||||
field=django.contrib.postgres.fields.ArrayField(
|
||||
base_field=models.TextField(
|
||||
choices=[
|
||||
(
|
||||
"https://schemas.openid.net/secevent/caep/event-type/session-revoked",
|
||||
"Caep Session Revoked",
|
||||
),
|
||||
(
|
||||
"https://schemas.openid.net/secevent/caep/event-type/token-claims-change",
|
||||
"Caep Token Claims Change",
|
||||
),
|
||||
(
|
||||
"https://schemas.openid.net/secevent/caep/event-type/credential-change",
|
||||
"Caep Credential Change",
|
||||
),
|
||||
(
|
||||
"https://schemas.openid.net/secevent/caep/event-type/assurance-level-change",
|
||||
"Caep Assurance Level Change",
|
||||
),
|
||||
(
|
||||
"https://schemas.openid.net/secevent/caep/event-type/device-compliance-change",
|
||||
"Caep Device Compliance Change",
|
||||
),
|
||||
(
|
||||
"https://schemas.openid.net/secevent/caep/event-type/session-established",
|
||||
"Caep Session Established",
|
||||
),
|
||||
(
|
||||
"https://schemas.openid.net/secevent/caep/event-type/session-presented",
|
||||
"Caep Session Presented",
|
||||
),
|
||||
(
|
||||
"https://schemas.openid.net/secevent/caep/event-type/risk-level-change",
|
||||
"Caep Risk Level Change",
|
||||
),
|
||||
(
|
||||
"https://schemas.openid.net/secevent/ssf/event-type/verification",
|
||||
"Set Verification",
|
||||
),
|
||||
]
|
||||
),
|
||||
default=list,
|
||||
size=None,
|
||||
),
|
||||
),
|
||||
migrations.AlterField(
|
||||
model_name="stream",
|
||||
name="status",
|
||||
field=models.TextField(
|
||||
choices=[
|
||||
("enabled", "Enabled"),
|
||||
("paused", "Paused"),
|
||||
("disabled", "Disabled"),
|
||||
("disabled_deleted", "Disabled Deleted"),
|
||||
],
|
||||
default="enabled",
|
||||
),
|
||||
),
|
||||
migrations.AlterField(
|
||||
model_name="streamevent",
|
||||
name="type",
|
||||
field=models.TextField(
|
||||
choices=[
|
||||
(
|
||||
"https://schemas.openid.net/secevent/caep/event-type/session-revoked",
|
||||
"Caep Session Revoked",
|
||||
),
|
||||
(
|
||||
"https://schemas.openid.net/secevent/caep/event-type/token-claims-change",
|
||||
"Caep Token Claims Change",
|
||||
),
|
||||
(
|
||||
"https://schemas.openid.net/secevent/caep/event-type/credential-change",
|
||||
"Caep Credential Change",
|
||||
),
|
||||
(
|
||||
"https://schemas.openid.net/secevent/caep/event-type/assurance-level-change",
|
||||
"Caep Assurance Level Change",
|
||||
),
|
||||
(
|
||||
"https://schemas.openid.net/secevent/caep/event-type/device-compliance-change",
|
||||
"Caep Device Compliance Change",
|
||||
),
|
||||
(
|
||||
"https://schemas.openid.net/secevent/caep/event-type/session-established",
|
||||
"Caep Session Established",
|
||||
),
|
||||
(
|
||||
"https://schemas.openid.net/secevent/caep/event-type/session-presented",
|
||||
"Caep Session Presented",
|
||||
),
|
||||
(
|
||||
"https://schemas.openid.net/secevent/caep/event-type/risk-level-change",
|
||||
"Caep Risk Level Change",
|
||||
),
|
||||
(
|
||||
"https://schemas.openid.net/secevent/ssf/event-type/verification",
|
||||
"Set Verification",
|
||||
),
|
||||
]
|
||||
),
|
||||
),
|
||||
]
|
||||
|
||||
@@ -24,8 +24,31 @@ class EventTypes(models.TextChoices):
|
||||
"""SSF Event types supported by authentik"""
|
||||
|
||||
CAEP_SESSION_REVOKED = "https://schemas.openid.net/secevent/caep/event-type/session-revoked"
|
||||
"""https://openid.net/specs/openid-caep-1_0-final.html#section-3.1"""
|
||||
CAEP_TOKEN_CLAIMS_CHANGE = (
|
||||
"https://schemas.openid.net/secevent/caep/event-type/token-claims-change"
|
||||
)
|
||||
"""https://openid.net/specs/openid-caep-1_0-final.html#section-3.2"""
|
||||
CAEP_CREDENTIAL_CHANGE = "https://schemas.openid.net/secevent/caep/event-type/credential-change"
|
||||
"""https://openid.net/specs/openid-caep-1_0-final.html#section-3.3"""
|
||||
CAEP_ASSURANCE_LEVEL_CHANGE = (
|
||||
"https://schemas.openid.net/secevent/caep/event-type/assurance-level-change"
|
||||
)
|
||||
"""https://openid.net/specs/openid-caep-1_0-final.html#section-3.4"""
|
||||
CAEP_DEVICE_COMPLIANCE_CHANGE = (
|
||||
"https://schemas.openid.net/secevent/caep/event-type/device-compliance-change"
|
||||
)
|
||||
"""https://openid.net/specs/openid-caep-1_0-final.html#section-3.5"""
|
||||
CAEP_SESSION_ESTABLISHED = (
|
||||
"https://schemas.openid.net/secevent/caep/event-type/session-established"
|
||||
)
|
||||
"""https://openid.net/specs/openid-caep-1_0-final.html#section-3.6"""
|
||||
CAEP_SESSION_PRESENTED = "https://schemas.openid.net/secevent/caep/event-type/session-presented"
|
||||
"""https://openid.net/specs/openid-caep-1_0-final.html#section-3.7"""
|
||||
CAEP_RISK_LEVEL_CHANGE = "https://schemas.openid.net/secevent/caep/event-type/risk-level-change"
|
||||
"""https://openid.net/specs/openid-caep-1_0-final.html#section-3.8"""
|
||||
SET_VERIFICATION = "https://schemas.openid.net/secevent/ssf/event-type/verification"
|
||||
"""https://openid.net/specs/openid-sharedsignals-framework-1_0.html#section-8.1.4.1"""
|
||||
|
||||
|
||||
class DeliveryMethods(models.TextChoices):
|
||||
@@ -46,10 +69,12 @@ class SSFEventStatus(models.TextChoices):
|
||||
|
||||
|
||||
class StreamStatus(models.TextChoices):
|
||||
"""SSF Stream status"""
|
||||
|
||||
ENABLED = "enabled"
|
||||
PAUSED = "paused"
|
||||
DISABLED = "disabled"
|
||||
DISABLED_DELETED = "disabled_deleted"
|
||||
|
||||
|
||||
class SSFProvider(TasksModel, BackchannelProvider):
|
||||
|
||||
@@ -12,7 +12,7 @@ from authentik.core.models import (
|
||||
User,
|
||||
UserTypes,
|
||||
)
|
||||
from authentik.core.signals import password_changed
|
||||
from authentik.core.signals import password_changed, password_hash_changed
|
||||
from authentik.enterprise.providers.ssf.models import (
|
||||
EventTypes,
|
||||
SSFProvider,
|
||||
@@ -84,14 +84,13 @@ def ssf_user_session_delete_session_revoked(sender, instance: AuthenticatedSessi
|
||||
)
|
||||
|
||||
|
||||
@receiver(password_changed)
|
||||
def ssf_password_changed_cred_change(sender, user: User, password: str | None, **_):
|
||||
def _send_password_credential_change(user: User, change_type: str):
|
||||
"""Credential change trigger (password changed)"""
|
||||
send_ssf_events(
|
||||
EventTypes.CAEP_CREDENTIAL_CHANGE,
|
||||
{
|
||||
"credential_type": "password",
|
||||
"change_type": "revoke" if password is None else "update",
|
||||
"change_type": change_type,
|
||||
},
|
||||
sub_id={
|
||||
"format": "complex",
|
||||
@@ -103,6 +102,16 @@ def ssf_password_changed_cred_change(sender, user: User, password: str | None, *
|
||||
)
|
||||
|
||||
|
||||
@receiver(password_hash_changed)
|
||||
@receiver(password_changed)
|
||||
def ssf_password_changed_cred_change(signal, sender, user: User, password: str | None = None, **_):
|
||||
"""Credential change trigger (password changed)"""
|
||||
if signal is password_hash_changed:
|
||||
_send_password_credential_change(user, "update")
|
||||
return
|
||||
_send_password_credential_change(user, "revoke" if password is None else "update")
|
||||
|
||||
|
||||
device_type_map = {
|
||||
StaticDevice: "pin",
|
||||
TOTPDevice: "pin",
|
||||
|
||||
@@ -108,13 +108,13 @@ def send_ssf_event(stream_uuid: UUID, event_data: dict[str, Any]):
|
||||
event.save()
|
||||
self.info("Event successfully sent", status=response.status_code)
|
||||
# Cleanup, if we were the last pending message for this stream and it has been deleted
|
||||
# (status=StreamStatus.DISABLED), then we can delete the stream
|
||||
# (status=StreamStatus.DISABLED_DELETED), then we can delete the stream
|
||||
if (
|
||||
not StreamEvent.objects.filter(
|
||||
stream=stream,
|
||||
status__in=[SSFEventStatus.PENDING_FAILED, SSFEventStatus.PENDING_NEW],
|
||||
).exists()
|
||||
and stream.status == StreamStatus.DISABLED
|
||||
and stream.status == StreamStatus.DISABLED_DELETED
|
||||
):
|
||||
LOGGER.info(
|
||||
"Deleting inactive stream as all pending messages were sent.", stream=stream
|
||||
|
||||
@@ -62,7 +62,7 @@ class TestSSFAuth(APITestCase):
|
||||
self.assertEqual(event.status, SSFEventStatus.PENDING_FAILED)
|
||||
self.assertEqual(
|
||||
event.payload["events"],
|
||||
{"https://schemas.openid.net/secevent/ssf/event-type/verification": {"state": None}},
|
||||
{"https://schemas.openid.net/secevent/ssf/event-type/verification": {}},
|
||||
)
|
||||
|
||||
def test_stream_add_oidc(self):
|
||||
@@ -115,7 +115,7 @@ class TestSSFAuth(APITestCase):
|
||||
self.assertEqual(event.status, SSFEventStatus.PENDING_FAILED)
|
||||
self.assertEqual(
|
||||
event.payload["events"],
|
||||
{"https://schemas.openid.net/secevent/ssf/event-type/verification": {"state": None}},
|
||||
{"https://schemas.openid.net/secevent/ssf/event-type/verification": {}},
|
||||
)
|
||||
|
||||
def test_token_invalid(self):
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from uuid import uuid4
|
||||
|
||||
from django.contrib.auth.hashers import make_password
|
||||
from django.urls import reverse
|
||||
from rest_framework.test import APITestCase
|
||||
|
||||
@@ -52,6 +53,21 @@ class TestSignals(APITestCase):
|
||||
)
|
||||
self.assertEqual(res.status_code, 201, res.content)
|
||||
|
||||
def _assert_password_credential_change(self, user, change_type: str):
|
||||
stream = Stream.objects.filter(provider=self.provider).first()
|
||||
self.assertIsNotNone(stream)
|
||||
event = StreamEvent.objects.filter(stream=stream).first()
|
||||
self.assertIsNotNone(event)
|
||||
self.assertEqual(event.status, SSFEventStatus.PENDING_FAILED)
|
||||
event_payload = event.payload["events"][
|
||||
"https://schemas.openid.net/secevent/caep/event-type/credential-change"
|
||||
]
|
||||
self.assertEqual(event_payload["change_type"], change_type)
|
||||
self.assertEqual(event_payload["credential_type"], "password")
|
||||
self.assertEqual(event.payload["sub_id"]["format"], "complex")
|
||||
self.assertEqual(event.payload["sub_id"]["user"]["format"], "email")
|
||||
self.assertEqual(event.payload["sub_id"]["user"]["email"], user.email)
|
||||
|
||||
def test_signal_logout(self):
|
||||
"""Test user logout"""
|
||||
user = create_test_user()
|
||||
@@ -79,19 +95,25 @@ class TestSignals(APITestCase):
|
||||
user.set_password(generate_id())
|
||||
user.save()
|
||||
|
||||
stream = Stream.objects.filter(provider=self.provider).first()
|
||||
self.assertIsNotNone(stream)
|
||||
event = StreamEvent.objects.filter(stream=stream).first()
|
||||
self.assertIsNotNone(event)
|
||||
self.assertEqual(event.status, SSFEventStatus.PENDING_FAILED)
|
||||
event_payload = event.payload["events"][
|
||||
"https://schemas.openid.net/secevent/caep/event-type/credential-change"
|
||||
]
|
||||
self.assertEqual(event_payload["change_type"], "update")
|
||||
self.assertEqual(event_payload["credential_type"], "password")
|
||||
self.assertEqual(event.payload["sub_id"]["format"], "complex")
|
||||
self.assertEqual(event.payload["sub_id"]["user"]["format"], "email")
|
||||
self.assertEqual(event.payload["sub_id"]["user"]["email"], user.email)
|
||||
self._assert_password_credential_change(user, "update")
|
||||
|
||||
def test_signal_password_change_from_hash(self):
|
||||
"""Test user password change from a pre-hashed password."""
|
||||
user = create_test_user()
|
||||
self.client.force_login(user)
|
||||
user.set_password_from_hash(make_password(generate_id()))
|
||||
user.save()
|
||||
|
||||
self._assert_password_credential_change(user, "update")
|
||||
|
||||
def test_signal_password_revoke(self):
|
||||
"""Test explicit password revoke."""
|
||||
user = create_test_user()
|
||||
self.client.force_login(user)
|
||||
user.set_password(None)
|
||||
user.save()
|
||||
|
||||
self._assert_password_credential_change(user, "revoke")
|
||||
|
||||
def test_signal_authenticator_added(self):
|
||||
"""Test authenticator creation signal"""
|
||||
|
||||
@@ -54,7 +54,7 @@ class TestStream(APITestCase):
|
||||
self.assertEqual(event.status, SSFEventStatus.PENDING_FAILED)
|
||||
self.assertEqual(
|
||||
event.payload["events"],
|
||||
{"https://schemas.openid.net/secevent/ssf/event-type/verification": {"state": None}},
|
||||
{"https://schemas.openid.net/secevent/ssf/event-type/verification": {}},
|
||||
)
|
||||
|
||||
def test_stream_add_poll(self):
|
||||
@@ -96,7 +96,7 @@ class TestStream(APITestCase):
|
||||
)
|
||||
self.assertEqual(res.status_code, 204)
|
||||
stream.refresh_from_db()
|
||||
self.assertEqual(stream.status, StreamStatus.DISABLED)
|
||||
self.assertEqual(stream.status, StreamStatus.DISABLED_DELETED)
|
||||
|
||||
def test_stream_get(self):
|
||||
"""get stream"""
|
||||
@@ -225,3 +225,26 @@ class TestStream(APITestCase):
|
||||
HTTP_AUTHORIZATION=f"Bearer {self.provider.token.key}",
|
||||
)
|
||||
self.assertEqual(res.status_code, 404)
|
||||
|
||||
def test_stream_status_update(self):
|
||||
stream = Stream.objects.create(provider=self.provider)
|
||||
res = self.client.post(
|
||||
reverse(
|
||||
"authentik_providers_ssf:stream-status",
|
||||
kwargs={"application_slug": self.application.slug},
|
||||
),
|
||||
data={
|
||||
"stream_id": str(stream.pk),
|
||||
"status": StreamStatus.DISABLED,
|
||||
},
|
||||
HTTP_AUTHORIZATION=f"Bearer {self.provider.token.key}",
|
||||
)
|
||||
self.assertEqual(res.status_code, 200)
|
||||
stream.refresh_from_db()
|
||||
self.assertJSONEqual(
|
||||
res.content,
|
||||
{
|
||||
"stream_id": str(stream.pk),
|
||||
"status": str(stream.status),
|
||||
},
|
||||
)
|
||||
|
||||
@@ -33,7 +33,7 @@ class TestTasks(APITestCase):
|
||||
)
|
||||
event_data = stream.prepare_event_payload(
|
||||
EventTypes.SET_VERIFICATION,
|
||||
{"state": None},
|
||||
{},
|
||||
sub_id={"format": "opaque", "id": str(stream.uuid)},
|
||||
)
|
||||
with Mocker() as mocker:
|
||||
@@ -46,7 +46,7 @@ class TestTasks(APITestCase):
|
||||
)
|
||||
jwt = decode_complete(mocker.request_history[0].body, options={"verify_signature": False})
|
||||
self.assertEqual(jwt["header"]["typ"], "secevent+jwt")
|
||||
self.assertIsNone(jwt["payload"]["events"][EventTypes.SET_VERIFICATION]["state"])
|
||||
self.assertEqual(jwt["payload"]["events"][EventTypes.SET_VERIFICATION], {})
|
||||
|
||||
def test_push_auth(self):
|
||||
auth = generate_id()
|
||||
@@ -58,7 +58,7 @@ class TestTasks(APITestCase):
|
||||
)
|
||||
event_data = stream.prepare_event_payload(
|
||||
EventTypes.SET_VERIFICATION,
|
||||
{"state": None},
|
||||
{},
|
||||
sub_id={"format": "opaque", "id": str(stream.uuid)},
|
||||
)
|
||||
with Mocker() as mocker:
|
||||
@@ -72,7 +72,7 @@ class TestTasks(APITestCase):
|
||||
)
|
||||
jwt = decode_complete(mocker.request_history[0].body, options={"verify_signature": False})
|
||||
self.assertEqual(jwt["header"]["typ"], "secevent+jwt")
|
||||
self.assertIsNone(jwt["payload"]["events"][EventTypes.SET_VERIFICATION]["state"])
|
||||
self.assertEqual(jwt["payload"]["events"][EventTypes.SET_VERIFICATION], {})
|
||||
|
||||
def test_push_stream_disable(self):
|
||||
auth = generate_id()
|
||||
@@ -81,11 +81,11 @@ class TestTasks(APITestCase):
|
||||
delivery_method=DeliveryMethods.RFC_PUSH,
|
||||
endpoint_url="http://localhost/ssf-push",
|
||||
authorization_header=auth,
|
||||
status=StreamStatus.DISABLED,
|
||||
status=StreamStatus.DISABLED_DELETED,
|
||||
)
|
||||
event_data = stream.prepare_event_payload(
|
||||
EventTypes.SET_VERIFICATION,
|
||||
{"state": None},
|
||||
{},
|
||||
sub_id={"format": "opaque", "id": str(stream.uuid)},
|
||||
)
|
||||
with Mocker() as mocker:
|
||||
@@ -95,7 +95,7 @@ class TestTasks(APITestCase):
|
||||
).get_result(block=True, timeout=1)
|
||||
jwt = decode_complete(mocker.request_history[0].body, options={"verify_signature": False})
|
||||
self.assertEqual(jwt["header"]["typ"], "secevent+jwt")
|
||||
self.assertIsNone(jwt["payload"]["events"][EventTypes.SET_VERIFICATION]["state"])
|
||||
self.assertEqual(jwt["payload"]["events"][EventTypes.SET_VERIFICATION], {})
|
||||
self.assertFalse(Stream.objects.filter(pk=stream.pk).exists())
|
||||
|
||||
def test_push_error(self):
|
||||
@@ -106,7 +106,7 @@ class TestTasks(APITestCase):
|
||||
)
|
||||
event_data = stream.prepare_event_payload(
|
||||
EventTypes.SET_VERIFICATION,
|
||||
{"state": None},
|
||||
{},
|
||||
sub_id={"format": "opaque", "id": str(stream.uuid)},
|
||||
)
|
||||
with Mocker() as mocker:
|
||||
|
||||
@@ -24,10 +24,10 @@ class SSFView(APIView):
|
||||
|
||||
|
||||
class SSFStreamView(SSFView):
|
||||
def get_object(self, any_status=False) -> Stream:
|
||||
streams = Stream.objects.filter(provider=self.provider)
|
||||
if not any_status:
|
||||
streams = streams.filter(status__in=[StreamStatus.ENABLED, StreamStatus.PAUSED])
|
||||
def get_object(self) -> Stream:
|
||||
streams = Stream.objects.filter(provider=self.provider).exclude(
|
||||
status=StreamStatus.DISABLED_DELETED
|
||||
)
|
||||
if "stream_id" in self.request.query_params:
|
||||
streams = streams.filter(pk=self.request.query_params["stream_id"])
|
||||
if "stream_id" in self.request.data:
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from uuid import uuid4
|
||||
|
||||
from django.http import HttpRequest
|
||||
from django.http import Http404, HttpRequest
|
||||
from django.urls import reverse
|
||||
from rest_framework.exceptions import PermissionDenied, ValidationError
|
||||
from rest_framework.fields import CharField, ChoiceField, ListField, SerializerMethodField
|
||||
@@ -106,7 +106,11 @@ class StreamResponseSerializer(PassiveSerializer):
|
||||
}
|
||||
|
||||
def get_events_supported(self, instance: Stream) -> list[str]:
|
||||
return [x.value for x in EventTypes]
|
||||
return [
|
||||
EventTypes.CAEP_SESSION_REVOKED,
|
||||
EventTypes.CAEP_CREDENTIAL_CHANGE,
|
||||
EventTypes.SET_VERIFICATION,
|
||||
]
|
||||
|
||||
|
||||
class StreamView(SSFStreamView):
|
||||
@@ -128,10 +132,9 @@ class StreamView(SSFStreamView):
|
||||
LOGGER.info("Sending verification event", stream=instance)
|
||||
send_ssf_events(
|
||||
EventTypes.SET_VERIFICATION,
|
||||
{
|
||||
"state": None,
|
||||
},
|
||||
{},
|
||||
stream_filter={"pk": instance.uuid},
|
||||
request=request,
|
||||
sub_id={"format": "opaque", "id": str(instance.uuid)},
|
||||
)
|
||||
response = StreamResponseSerializer(instance=instance, context={"request": request}).data
|
||||
@@ -159,7 +162,9 @@ class StreamView(SSFStreamView):
|
||||
|
||||
def delete(self, request: Request, *args, **kwargs) -> Response:
|
||||
stream = self.get_object()
|
||||
stream.status = StreamStatus.DISABLED
|
||||
if stream.status == StreamStatus.DISABLED_DELETED:
|
||||
raise Http404
|
||||
stream.status = StreamStatus.DISABLED_DELETED
|
||||
stream.save()
|
||||
return Response(status=204)
|
||||
|
||||
@@ -175,6 +180,7 @@ class StreamVerifyView(SSFStreamView):
|
||||
"state": state,
|
||||
},
|
||||
stream_filter={"pk": stream.uuid},
|
||||
request=request,
|
||||
sub_id={"format": "opaque", "id": str(stream.uuid)},
|
||||
)
|
||||
return Response(status=204)
|
||||
@@ -182,8 +188,25 @@ class StreamVerifyView(SSFStreamView):
|
||||
|
||||
class StreamStatusView(SSFStreamView):
|
||||
|
||||
class StreamStatusSerializer(PassiveSerializer):
|
||||
stream_id = CharField()
|
||||
status = ChoiceField(choices=StreamStatus.choices)
|
||||
|
||||
def get(self, request: Request, *args, **kwargs):
|
||||
stream = self.get_object(any_status=True)
|
||||
stream = self.get_object()
|
||||
return Response(
|
||||
{
|
||||
"stream_id": str(stream.pk),
|
||||
"status": str(stream.status),
|
||||
}
|
||||
)
|
||||
|
||||
def post(self, request: Request, *args, **kwargs):
|
||||
stream = self.get_object()
|
||||
serializer = self.StreamStatusSerializer(stream, data=request.data)
|
||||
serializer.is_valid(raise_exception=True)
|
||||
stream.status = serializer.validated_data["status"]
|
||||
stream.save()
|
||||
return Response(
|
||||
{
|
||||
"stream_id": str(stream.pk),
|
||||
|
||||
@@ -14,6 +14,7 @@ TENANT_APPS = [
|
||||
"authentik.enterprise.providers.ssf",
|
||||
"authentik.enterprise.providers.ws_federation",
|
||||
"authentik.enterprise.reports",
|
||||
"authentik.enterprise.stages.account_lockdown",
|
||||
"authentik.enterprise.stages.authenticator_endpoint_gdtc",
|
||||
"authentik.enterprise.stages.mtls",
|
||||
"authentik.enterprise.stages.source",
|
||||
|
||||
141
authentik/enterprise/stages/account_lockdown/api.py
Normal file
141
authentik/enterprise/stages/account_lockdown/api.py
Normal file
@@ -0,0 +1,141 @@
|
||||
"""Account Lockdown Stage API Views"""
|
||||
|
||||
from django.utils.translation import gettext as _
|
||||
from drf_spectacular.utils import OpenApiExample, OpenApiResponse, extend_schema
|
||||
from rest_framework.decorators import action
|
||||
from rest_framework.exceptions import ValidationError
|
||||
from rest_framework.permissions import IsAuthenticated
|
||||
from rest_framework.request import Request
|
||||
from rest_framework.response import Response
|
||||
from rest_framework.serializers import PrimaryKeyRelatedField
|
||||
from rest_framework.viewsets import ModelViewSet
|
||||
from structlog.stdlib import get_logger
|
||||
|
||||
from authentik.api.validation import validate
|
||||
from authentik.core.api.used_by import UsedByMixin
|
||||
from authentik.core.api.utils import LinkSerializer, PassiveSerializer
|
||||
from authentik.core.models import (
|
||||
User,
|
||||
)
|
||||
from authentik.enterprise.api import EnterpriseRequiredMixin, enterprise_action
|
||||
from authentik.enterprise.stages.account_lockdown.models import AccountLockdownStage
|
||||
from authentik.enterprise.stages.account_lockdown.stage import (
|
||||
can_lock_user,
|
||||
get_lockdown_target_users,
|
||||
)
|
||||
from authentik.flows.api.stages import StageSerializer
|
||||
from authentik.flows.exceptions import EmptyFlowException, FlowNonApplicableException
|
||||
from authentik.flows.planner import PLAN_CONTEXT_PENDING_USER, FlowPlanner
|
||||
|
||||
LOGGER = get_logger()
|
||||
|
||||
|
||||
class AccountLockdownStageSerializer(EnterpriseRequiredMixin, StageSerializer):
|
||||
"""AccountLockdownStage Serializer"""
|
||||
|
||||
class Meta:
|
||||
model = AccountLockdownStage
|
||||
fields = StageSerializer.Meta.fields + [
|
||||
"deactivate_user",
|
||||
"set_unusable_password",
|
||||
"delete_sessions",
|
||||
"revoke_tokens",
|
||||
"self_service_completion_flow",
|
||||
]
|
||||
|
||||
|
||||
class AccountLockdownStageViewSet(UsedByMixin, ModelViewSet):
|
||||
"""AccountLockdownStage Viewset"""
|
||||
|
||||
queryset = AccountLockdownStage.objects.all()
|
||||
serializer_class = AccountLockdownStageSerializer
|
||||
filterset_fields = "__all__"
|
||||
ordering = ["name"]
|
||||
search_fields = ["name"]
|
||||
|
||||
|
||||
class UserAccountLockdownSerializer(PassiveSerializer):
|
||||
"""Choose the target account before starting the lockdown flow."""
|
||||
|
||||
user = PrimaryKeyRelatedField(
|
||||
queryset=get_lockdown_target_users(),
|
||||
required=False,
|
||||
allow_null=True,
|
||||
help_text=_("User to lock. If omitted, locks the current user (self-service)."),
|
||||
)
|
||||
|
||||
|
||||
class UserAccountLockdownMixin:
|
||||
"""Enterprise account-lockdown API actions for UserViewSet."""
|
||||
|
||||
def _create_lockdown_flow_url(self, request: Request, user: User) -> str:
|
||||
"""Create a flow URL for account lockdown.
|
||||
|
||||
The request body selects the target before the flow starts. The API
|
||||
pre-plans the lockdown flow with the target as the pending user, so the
|
||||
account lockdown stage can use the normal flow context.
|
||||
"""
|
||||
flow = request._request.brand.flow_lockdown
|
||||
if flow is None:
|
||||
raise ValidationError({"non_field_errors": [_("No lockdown flow configured.")]})
|
||||
planner = FlowPlanner(flow)
|
||||
planner.use_cache = False
|
||||
try:
|
||||
plan = planner.plan(request._request, {PLAN_CONTEXT_PENDING_USER: user})
|
||||
except EmptyFlowException, FlowNonApplicableException:
|
||||
raise ValidationError(
|
||||
{"non_field_errors": [_("Lockdown flow is not applicable.")]}
|
||||
) from None
|
||||
return plan.to_redirect(request._request, flow).url
|
||||
|
||||
@extend_schema(
|
||||
description=_("Choose the target account, then return a flow link."),
|
||||
request=UserAccountLockdownSerializer,
|
||||
responses={
|
||||
"200": OpenApiResponse(
|
||||
response=LinkSerializer,
|
||||
examples=[
|
||||
OpenApiExample(
|
||||
"Lockdown flow URL",
|
||||
value={
|
||||
"link": "https://example.invalid/if/flow/default-account-lockdown/",
|
||||
},
|
||||
response_only=True,
|
||||
status_codes=["200"],
|
||||
)
|
||||
],
|
||||
),
|
||||
"400": OpenApiResponse(
|
||||
description=_("No lockdown flow configured or the flow is not applicable")
|
||||
),
|
||||
"403": OpenApiResponse(
|
||||
description=_("Permission denied (when targeting another user)")
|
||||
),
|
||||
},
|
||||
)
|
||||
@action(
|
||||
detail=False,
|
||||
methods=["POST"],
|
||||
permission_classes=[IsAuthenticated],
|
||||
url_path="account_lockdown",
|
||||
)
|
||||
@validate(UserAccountLockdownSerializer)
|
||||
@enterprise_action
|
||||
def account_lockdown(self, request: Request, body: UserAccountLockdownSerializer) -> Response:
|
||||
"""Trigger account lockdown for a user.
|
||||
|
||||
If no user is specified, locks the current user (self-service).
|
||||
When targeting another user, admin permissions are required.
|
||||
|
||||
Returns a flow link for the frontend to follow. The flow is pre-planned
|
||||
with the target user as pending user for the lockdown stage.
|
||||
"""
|
||||
user = body.validated_data.get("user") or request.user
|
||||
|
||||
if not can_lock_user(request.user, user):
|
||||
LOGGER.debug("Permission denied for account lockdown", user=request.user)
|
||||
self.permission_denied(request)
|
||||
|
||||
flow_url = self._create_lockdown_flow_url(request, user)
|
||||
LOGGER.debug("Returning lockdown flow URL", flow_url=flow_url, user=user.username)
|
||||
return Response({"link": flow_url})
|
||||
12
authentik/enterprise/stages/account_lockdown/apps.py
Normal file
12
authentik/enterprise/stages/account_lockdown/apps.py
Normal file
@@ -0,0 +1,12 @@
|
||||
"""authentik account lockdown stage app config"""
|
||||
|
||||
from authentik.enterprise.apps import EnterpriseConfig
|
||||
|
||||
|
||||
class AuthentikEnterpriseStageAccountLockdownConfig(EnterpriseConfig):
|
||||
"""authentik account lockdown stage config"""
|
||||
|
||||
name = "authentik.enterprise.stages.account_lockdown"
|
||||
label = "authentik_stages_account_lockdown"
|
||||
verbose_name = "authentik Enterprise.Stages.Account Lockdown"
|
||||
default = True
|
||||
@@ -0,0 +1,74 @@
|
||||
# Generated by Django 5.2.13 on 2026-04-19 21:56
|
||||
|
||||
import django.db.models.deletion
|
||||
from django.db import migrations, models
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
|
||||
initial = True
|
||||
|
||||
dependencies = [
|
||||
("authentik_flows", "0031_alter_flow_layout"),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.CreateModel(
|
||||
name="AccountLockdownStage",
|
||||
fields=[
|
||||
(
|
||||
"stage_ptr",
|
||||
models.OneToOneField(
|
||||
auto_created=True,
|
||||
on_delete=django.db.models.deletion.CASCADE,
|
||||
parent_link=True,
|
||||
primary_key=True,
|
||||
serialize=False,
|
||||
to="authentik_flows.stage",
|
||||
),
|
||||
),
|
||||
(
|
||||
"deactivate_user",
|
||||
models.BooleanField(
|
||||
default=True,
|
||||
help_text="Deactivate the user account (set is_active to False)",
|
||||
),
|
||||
),
|
||||
(
|
||||
"set_unusable_password",
|
||||
models.BooleanField(
|
||||
default=True, help_text="Set an unusable password for the user"
|
||||
),
|
||||
),
|
||||
(
|
||||
"delete_sessions",
|
||||
models.BooleanField(
|
||||
default=True, help_text="Delete all active sessions for the user"
|
||||
),
|
||||
),
|
||||
(
|
||||
"revoke_tokens",
|
||||
models.BooleanField(
|
||||
default=True,
|
||||
help_text="Revoke all tokens for the user (API, app password, recovery, verification, OAuth)",
|
||||
),
|
||||
),
|
||||
(
|
||||
"self_service_completion_flow",
|
||||
models.ForeignKey(
|
||||
blank=True,
|
||||
help_text="Flow to redirect users to after self-service lockdown. This flow should not require authentication since the user's session is deleted.",
|
||||
null=True,
|
||||
on_delete=django.db.models.deletion.SET_NULL,
|
||||
related_name="account_lockdown_stages",
|
||||
to="authentik_flows.flow",
|
||||
),
|
||||
),
|
||||
],
|
||||
options={
|
||||
"verbose_name": "Account Lockdown Stage",
|
||||
"verbose_name_plural": "Account Lockdown Stages",
|
||||
},
|
||||
bases=("authentik_flows.stage",),
|
||||
),
|
||||
]
|
||||
62
authentik/enterprise/stages/account_lockdown/models.py
Normal file
62
authentik/enterprise/stages/account_lockdown/models.py
Normal file
@@ -0,0 +1,62 @@
|
||||
"""Account lockdown stage models"""
|
||||
|
||||
from django.db import models
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
from django.views import View
|
||||
from rest_framework.serializers import BaseSerializer
|
||||
|
||||
from authentik.flows.models import Stage
|
||||
|
||||
|
||||
class AccountLockdownStage(Stage):
|
||||
"""Lock down a target user account."""
|
||||
|
||||
deactivate_user = models.BooleanField(
|
||||
default=True,
|
||||
help_text=_("Deactivate the user account (set is_active to False)"),
|
||||
)
|
||||
set_unusable_password = models.BooleanField(
|
||||
default=True,
|
||||
help_text=_("Set an unusable password for the user"),
|
||||
)
|
||||
delete_sessions = models.BooleanField(
|
||||
default=True,
|
||||
help_text=_("Delete all active sessions for the user"),
|
||||
)
|
||||
revoke_tokens = models.BooleanField(
|
||||
default=True,
|
||||
help_text=_(
|
||||
"Revoke all tokens for the user (API, app password, recovery, verification, OAuth)"
|
||||
),
|
||||
)
|
||||
self_service_completion_flow = models.ForeignKey(
|
||||
"authentik_flows.Flow",
|
||||
on_delete=models.SET_NULL,
|
||||
null=True,
|
||||
blank=True,
|
||||
related_name="account_lockdown_stages",
|
||||
help_text=_(
|
||||
"Flow to redirect users to after self-service lockdown. "
|
||||
"This flow should not require authentication since the user's session is deleted."
|
||||
),
|
||||
)
|
||||
|
||||
@property
|
||||
def serializer(self) -> type[BaseSerializer]:
|
||||
from authentik.enterprise.stages.account_lockdown.api import AccountLockdownStageSerializer
|
||||
|
||||
return AccountLockdownStageSerializer
|
||||
|
||||
@property
|
||||
def view(self) -> type[View]:
|
||||
from authentik.enterprise.stages.account_lockdown.stage import AccountLockdownStageView
|
||||
|
||||
return AccountLockdownStageView
|
||||
|
||||
@property
|
||||
def component(self) -> str:
|
||||
return "ak-stage-account-lockdown-form"
|
||||
|
||||
class Meta:
|
||||
verbose_name = _("Account Lockdown Stage")
|
||||
verbose_name_plural = _("Account Lockdown Stages")
|
||||
345
authentik/enterprise/stages/account_lockdown/stage.py
Normal file
345
authentik/enterprise/stages/account_lockdown/stage.py
Normal file
@@ -0,0 +1,345 @@
|
||||
"""Account lockdown stage logic"""
|
||||
|
||||
from django.apps import apps
|
||||
from django.core.exceptions import FieldDoesNotExist
|
||||
from django.db.models import Model, QuerySet
|
||||
from django.db.models.query_utils import Q
|
||||
from django.db.transaction import atomic
|
||||
from django.http import HttpRequest, HttpResponse, HttpResponseRedirect
|
||||
from django.urls import reverse
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
from dramatiq.actor import Actor
|
||||
from dramatiq.composition import group
|
||||
from dramatiq.results.errors import ResultTimeout
|
||||
|
||||
from authentik.core.models import (
|
||||
AuthenticatedSession,
|
||||
ExpiringModel,
|
||||
Session,
|
||||
Token,
|
||||
User,
|
||||
UserTypes,
|
||||
)
|
||||
from authentik.enterprise.stages.account_lockdown.models import AccountLockdownStage
|
||||
from authentik.events.models import Event, EventAction
|
||||
from authentik.flows.stage import StageView
|
||||
from authentik.lib.sync.outgoing.models import OutgoingSyncProvider
|
||||
from authentik.lib.sync.outgoing.signals import sync_outgoing_inhibit_dispatch
|
||||
from authentik.lib.utils.reflection import class_to_path
|
||||
from authentik.lib.utils.time import timedelta_from_string
|
||||
from authentik.stages.prompt.stage import PLAN_CONTEXT_PROMPT
|
||||
|
||||
PLAN_CONTEXT_LOCKDOWN_REASON = "lockdown_reason"
|
||||
LOCKDOWN_EVENT_ACTION_ID = "account_lockdown"
|
||||
|
||||
TARGET_REQUIRED_MESSAGE = _("No target user specified for account lockdown")
|
||||
PERMISSION_DENIED_MESSAGE = _("You do not have permission to lock down this account.")
|
||||
ACCOUNT_LOCKDOWN_FAILED_MESSAGE = _("Account lockdown failed for this account.")
|
||||
SELF_SERVICE_COMPLETION_FLOW_REQUIRED_MESSAGE = _(
|
||||
"Self-service account lockdown requires a completion flow."
|
||||
)
|
||||
|
||||
|
||||
def get_lockdown_target_users() -> QuerySet[User]:
|
||||
"""Return users that can be targeted by account lockdown."""
|
||||
return User.objects.exclude_anonymous().exclude(type=UserTypes.INTERNAL_SERVICE_ACCOUNT)
|
||||
|
||||
|
||||
def _get_model_field(model: type[Model], field_name: str):
|
||||
"""Get a model field by name, if present."""
|
||||
try:
|
||||
return model._meta.get_field(field_name)
|
||||
except FieldDoesNotExist:
|
||||
return None
|
||||
|
||||
|
||||
def _has_user_field(model: type[Model]) -> bool:
|
||||
"""Check if a model has a direct user foreign key."""
|
||||
field = _get_model_field(model, "user")
|
||||
return bool(field and getattr(field, "remote_field", None) and field.remote_field.model is User)
|
||||
|
||||
|
||||
def _has_authenticated_session_field(model: type[Model]) -> bool:
|
||||
"""Check if a model is linked to an authenticated session."""
|
||||
field = _get_model_field(model, "session")
|
||||
return bool(
|
||||
field
|
||||
and getattr(field, "remote_field", None)
|
||||
and field.remote_field.model is AuthenticatedSession
|
||||
)
|
||||
|
||||
|
||||
def _has_provider_field(model: type[Model]) -> bool:
|
||||
"""Check if a model is linked to a provider."""
|
||||
return _get_model_field(model, "provider") is not None
|
||||
|
||||
|
||||
def get_lockdown_token_models() -> tuple[type[Model], ...]:
|
||||
"""Return token, grant, and provider session models removed by account lockdown."""
|
||||
token_models: list[type[Model]] = []
|
||||
for model in apps.get_models():
|
||||
if model._meta.abstract or not issubclass(model, ExpiringModel):
|
||||
continue
|
||||
if model is Token:
|
||||
token_models.append(model)
|
||||
elif _has_user_field(model) and (
|
||||
_has_provider_field(model) or _has_authenticated_session_field(model)
|
||||
):
|
||||
token_models.append(model)
|
||||
elif _has_authenticated_session_field(model):
|
||||
token_models.append(model)
|
||||
return tuple(token_models)
|
||||
|
||||
|
||||
def get_lockdown_token_queryset(model: type[Model], user: User) -> QuerySet:
|
||||
"""Return account lockdown artifacts for a model and user."""
|
||||
manager = model.objects.including_expired()
|
||||
if _has_user_field(model):
|
||||
return manager.filter(user=user)
|
||||
return manager.filter(session__user=user)
|
||||
|
||||
|
||||
def can_lock_user(actor, user: User) -> bool:
|
||||
"""Check whether the actor may lock the target user."""
|
||||
if not actor.is_authenticated:
|
||||
return False
|
||||
if user.pk == actor.pk:
|
||||
return True
|
||||
return actor.has_perm("authentik_core.change_user", user)
|
||||
|
||||
|
||||
def get_outgoing_sync_tasks() -> tuple[tuple[type[OutgoingSyncProvider], Actor], ...]:
|
||||
"""Return outgoing sync provider types and their direct sync tasks."""
|
||||
from authentik.enterprise.providers.google_workspace.models import GoogleWorkspaceProvider
|
||||
from authentik.enterprise.providers.google_workspace.tasks import google_workspace_sync_direct
|
||||
from authentik.enterprise.providers.microsoft_entra.models import MicrosoftEntraProvider
|
||||
from authentik.enterprise.providers.microsoft_entra.tasks import microsoft_entra_sync_direct
|
||||
from authentik.providers.scim.models import SCIMProvider
|
||||
from authentik.providers.scim.tasks import scim_sync_direct
|
||||
|
||||
return (
|
||||
(SCIMProvider, scim_sync_direct),
|
||||
(GoogleWorkspaceProvider, google_workspace_sync_direct),
|
||||
(MicrosoftEntraProvider, microsoft_entra_sync_direct),
|
||||
)
|
||||
|
||||
|
||||
class AccountLockdownStageView(StageView):
|
||||
"""Execute account lockdown actions on the target user."""
|
||||
|
||||
def is_self_service(self, request: HttpRequest, user: User) -> bool:
|
||||
"""Check whether the currently authenticated user is locking their own account."""
|
||||
return request.user.is_authenticated and user.pk == request.user.pk
|
||||
|
||||
def get_reason(self) -> str:
|
||||
"""Get the lockdown reason from the plan context.
|
||||
|
||||
Priority:
|
||||
1. prompt_data[PLAN_CONTEXT_LOCKDOWN_REASON]
|
||||
2. PLAN_CONTEXT_LOCKDOWN_REASON (explicitly set)
|
||||
3. Empty string as fallback
|
||||
"""
|
||||
prompt_data = self.executor.plan.context.get(PLAN_CONTEXT_PROMPT, {})
|
||||
if PLAN_CONTEXT_LOCKDOWN_REASON in prompt_data:
|
||||
return prompt_data[PLAN_CONTEXT_LOCKDOWN_REASON]
|
||||
return self.executor.plan.context.get(PLAN_CONTEXT_LOCKDOWN_REASON, "")
|
||||
|
||||
def _apply_lockdown_actions(self, stage: AccountLockdownStage, user: User) -> None:
|
||||
"""Apply the configured account changes to the target user."""
|
||||
if stage.deactivate_user:
|
||||
user.is_active = False
|
||||
if stage.set_unusable_password:
|
||||
user.set_unusable_password()
|
||||
if stage.deactivate_user:
|
||||
with sync_outgoing_inhibit_dispatch():
|
||||
user.save()
|
||||
return
|
||||
user.save()
|
||||
|
||||
def _sync_deactivated_user_to_outgoing_providers(self, user: User) -> None:
|
||||
"""Synchronize a deactivated user to outgoing sync providers."""
|
||||
messages = []
|
||||
wait_timeout = 0
|
||||
model = class_to_path(User)
|
||||
provider_filter = Q(backchannel_application__isnull=False) | Q(application__isnull=False)
|
||||
|
||||
for provider_model, task_sync_direct in get_outgoing_sync_tasks():
|
||||
for provider in provider_model.objects.filter(provider_filter):
|
||||
time_limit = int(
|
||||
timedelta_from_string(provider.sync_page_timeout).total_seconds() * 1000
|
||||
)
|
||||
messages.append(
|
||||
task_sync_direct.message_with_options(
|
||||
args=(model, user.pk, provider.pk),
|
||||
rel_obj=provider,
|
||||
time_limit=time_limit,
|
||||
uid=f"{provider.name}:user:{user.pk}:direct",
|
||||
)
|
||||
)
|
||||
wait_timeout += time_limit
|
||||
|
||||
if not messages:
|
||||
return
|
||||
try:
|
||||
group(messages).run().wait(timeout=wait_timeout)
|
||||
except ResultTimeout:
|
||||
self.logger.warning(
|
||||
"Timed out waiting for outgoing sync tasks; tasks remain queued",
|
||||
user=user.username,
|
||||
timeout=wait_timeout,
|
||||
)
|
||||
|
||||
def _get_lockdown_artifact_querysets(
|
||||
self, stage: AccountLockdownStage, user: User
|
||||
) -> tuple[QuerySet, ...]:
|
||||
"""Return the configured sessions and tokens targeted by lockdown."""
|
||||
querysets: list[QuerySet] = []
|
||||
if stage.delete_sessions:
|
||||
querysets.append(Session.objects.filter(authenticatedsession__user=user))
|
||||
if stage.revoke_tokens:
|
||||
querysets.extend(
|
||||
get_lockdown_token_queryset(model, user) for model in get_lockdown_token_models()
|
||||
)
|
||||
return tuple(querysets)
|
||||
|
||||
def _delete_lockdown_artifacts(self, stage: AccountLockdownStage, user: User) -> None:
|
||||
"""Delete sessions and tokens selected by the lockdown configuration."""
|
||||
for queryset in self._get_lockdown_artifact_querysets(stage, user):
|
||||
queryset.delete()
|
||||
|
||||
def _has_lockdown_artifacts(self, stage: AccountLockdownStage, user: User) -> bool:
|
||||
"""Check whether there are still sessions or tokens to remove."""
|
||||
return any(
|
||||
queryset.exists() for queryset in self._get_lockdown_artifact_querysets(stage, user)
|
||||
)
|
||||
|
||||
def _emit_lockdown_event(self, request: HttpRequest, user: User, reason: str) -> None:
|
||||
"""Emit the audit event for a completed lockdown."""
|
||||
# Emit the audit event after the transaction commits. If event creation
|
||||
# fails here, dispatch() would otherwise treat the whole lockdown as
|
||||
# failed even though the account changes have already been committed.
|
||||
try:
|
||||
Event.new(
|
||||
EventAction.USER_WRITE,
|
||||
action_id=LOCKDOWN_EVENT_ACTION_ID,
|
||||
reason=reason,
|
||||
affected_user=user.username,
|
||||
).from_http(request)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
# Event emission should not make the lockdown itself fail.
|
||||
self.logger.warning(
|
||||
"Failed to emit account lockdown event",
|
||||
user=user.username,
|
||||
exc=exc,
|
||||
)
|
||||
|
||||
def _lockdown_user(
|
||||
self,
|
||||
request: HttpRequest,
|
||||
stage: AccountLockdownStage,
|
||||
user: User,
|
||||
reason: str,
|
||||
) -> None:
|
||||
"""Execute lockdown actions on a single user."""
|
||||
with atomic():
|
||||
user = User.objects.get(pk=user.pk)
|
||||
self._apply_lockdown_actions(stage, user)
|
||||
self._delete_lockdown_artifacts(stage, user)
|
||||
|
||||
# These additional checks/deletes are done to prevent a timing attack that creates tokens
|
||||
# with a compromised token that is simultaneously being deleted.
|
||||
while self._has_lockdown_artifacts(stage, user):
|
||||
with atomic():
|
||||
self._delete_lockdown_artifacts(stage, user)
|
||||
|
||||
if stage.deactivate_user:
|
||||
try:
|
||||
self._sync_deactivated_user_to_outgoing_providers(user)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
# Local lockdown has already committed. Provider sync failures
|
||||
# must not reopen access or mark the lockdown itself as failed.
|
||||
self.logger.warning(
|
||||
"Failed to sync account lockdown deactivation to outgoing providers",
|
||||
user=user.username,
|
||||
exc=exc,
|
||||
)
|
||||
self._emit_lockdown_event(request, user, reason)
|
||||
|
||||
def dispatch(self, request: HttpRequest) -> HttpResponse:
|
||||
"""Execute account lockdown actions."""
|
||||
self.request = request
|
||||
stage: AccountLockdownStage = self.executor.current_stage
|
||||
|
||||
pending_user = self.get_pending_user()
|
||||
if not pending_user.is_authenticated:
|
||||
self.logger.warning("No target user found for account lockdown")
|
||||
return self.executor.stage_invalid(TARGET_REQUIRED_MESSAGE)
|
||||
user = get_lockdown_target_users().filter(pk=pending_user.pk).first()
|
||||
if user is None:
|
||||
self.logger.warning("Target user is not eligible for account lockdown")
|
||||
return self.executor.stage_invalid(TARGET_REQUIRED_MESSAGE)
|
||||
if not can_lock_user(request.user, user):
|
||||
self.logger.warning(
|
||||
"Permission denied for account lockdown",
|
||||
actor=getattr(request.user, "username", None),
|
||||
target=user.username,
|
||||
)
|
||||
return self.executor.stage_invalid(PERMISSION_DENIED_MESSAGE)
|
||||
|
||||
reason = self.get_reason()
|
||||
self_service = self.is_self_service(request, user)
|
||||
if self_service and stage.delete_sessions and not stage.self_service_completion_flow:
|
||||
self.logger.warning("No completion flow configured for self-service account lockdown")
|
||||
return self.executor.stage_invalid(SELF_SERVICE_COMPLETION_FLOW_REQUIRED_MESSAGE)
|
||||
|
||||
self.logger.info(
|
||||
"Executing account lockdown",
|
||||
user=user.username,
|
||||
reason=reason,
|
||||
self_service=self_service,
|
||||
deactivate_user=stage.deactivate_user,
|
||||
set_unusable_password=stage.set_unusable_password,
|
||||
delete_sessions=stage.delete_sessions,
|
||||
revoke_tokens=stage.revoke_tokens,
|
||||
)
|
||||
|
||||
try:
|
||||
self._lockdown_user(request, stage, user, reason)
|
||||
self.logger.info("Account lockdown completed", user=user.username)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
# Convert unexpected lockdown errors to a flow-stage failure instead
|
||||
# of leaking an exception through the flow executor.
|
||||
self.logger.warning("Account lockdown failed", user=user.username, exc=exc)
|
||||
return self.executor.stage_invalid(ACCOUNT_LOCKDOWN_FAILED_MESSAGE)
|
||||
|
||||
if self_service:
|
||||
if stage.delete_sessions:
|
||||
return self._self_service_completion_response(request)
|
||||
return self.executor.stage_ok()
|
||||
|
||||
return self.executor.stage_ok()
|
||||
|
||||
def _self_service_completion_response(self, request: HttpRequest) -> HttpResponse:
|
||||
"""Redirect to completion flow after self-service lockdown.
|
||||
|
||||
Since all sessions are deleted, the user cannot continue in the flow.
|
||||
Redirect them to an unauthenticated completion flow that shows the
|
||||
lockdown message.
|
||||
|
||||
We use a direct HTTP redirect instead of a challenge because the
|
||||
flow executor's challenge handling may try to access the session
|
||||
which we just deleted.
|
||||
"""
|
||||
stage: AccountLockdownStage = self.executor.current_stage
|
||||
completion_flow = stage.self_service_completion_flow
|
||||
if completion_flow:
|
||||
# Flush the current request's session to prevent Django's session
|
||||
# middleware from trying to save a deleted session
|
||||
if hasattr(request, "session"):
|
||||
request.session.flush()
|
||||
redirect_to = reverse(
|
||||
"authentik_core:if-flow",
|
||||
kwargs={"flow_slug": completion_flow.slug},
|
||||
)
|
||||
return HttpResponseRedirect(redirect_to)
|
||||
return self.executor.stage_invalid(SELF_SERVICE_COMPLETION_FLOW_REQUIRED_MESSAGE)
|
||||
148
authentik/enterprise/stages/account_lockdown/tests/test_api.py
Normal file
148
authentik/enterprise/stages/account_lockdown/tests/test_api.py
Normal file
@@ -0,0 +1,148 @@
|
||||
"""Test Users Account Lockdown API"""
|
||||
|
||||
from json import loads
|
||||
from unittest.mock import MagicMock, patch
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from django.urls import reverse
|
||||
from rest_framework.test import APITestCase
|
||||
|
||||
from authentik.core.tests.utils import (
|
||||
create_test_brand,
|
||||
create_test_flow,
|
||||
create_test_user,
|
||||
)
|
||||
from authentik.enterprise.stages.account_lockdown.models import AccountLockdownStage
|
||||
from authentik.flows.models import FlowDesignation, FlowStageBinding
|
||||
from authentik.flows.planner import PLAN_CONTEXT_PENDING_USER
|
||||
from authentik.flows.views.executor import SESSION_KEY_PLAN
|
||||
from authentik.lib.generators import generate_id
|
||||
|
||||
# Patch for enterprise license check
|
||||
patch_license = patch(
|
||||
"authentik.enterprise.models.LicenseUsageStatus.is_valid",
|
||||
MagicMock(return_value=True),
|
||||
)
|
||||
|
||||
|
||||
@patch_license
|
||||
class AccountLockdownAPITestCase(APITestCase):
|
||||
"""Shared helpers for account lockdown API tests."""
|
||||
|
||||
def setUp(self) -> None:
|
||||
self.lockdown_flow = create_test_flow(FlowDesignation.STAGE_CONFIGURATION)
|
||||
self.lockdown_stage = AccountLockdownStage.objects.create(name=generate_id())
|
||||
FlowStageBinding.objects.create(
|
||||
target=self.lockdown_flow,
|
||||
stage=self.lockdown_stage,
|
||||
order=0,
|
||||
)
|
||||
self.brand = create_test_brand()
|
||||
self.brand.flow_lockdown = self.lockdown_flow
|
||||
self.brand.save()
|
||||
|
||||
def create_user_with_email(self):
|
||||
"""Create a regular user with a unique email address."""
|
||||
user = create_test_user()
|
||||
user.email = f"{generate_id()}@test.com"
|
||||
user.save()
|
||||
return user
|
||||
|
||||
def assert_redirect_targets(self, response, user):
|
||||
"""Assert that a response contains a pre-planned lockdown flow link for a user."""
|
||||
self.assertEqual(response.status_code, 200)
|
||||
body = loads(response.content)
|
||||
self.assertIn(self.lockdown_flow.slug, body["link"])
|
||||
self.assertEqual(urlparse(body["link"]).query, "")
|
||||
plan = self.client.session[SESSION_KEY_PLAN]
|
||||
self.assertEqual(plan.context[PLAN_CONTEXT_PENDING_USER].pk, user.pk)
|
||||
|
||||
def assert_no_flow_configured(self, response):
|
||||
"""Assert that the API reports a missing lockdown flow."""
|
||||
self.assertEqual(response.status_code, 400)
|
||||
body = loads(response.content)
|
||||
self.assertIn("No lockdown flow configured", body["non_field_errors"][0])
|
||||
|
||||
|
||||
@patch_license
|
||||
class TestUsersAccountLockdownAPI(AccountLockdownAPITestCase):
|
||||
"""Test Users Account Lockdown API"""
|
||||
|
||||
def setUp(self) -> None:
|
||||
super().setUp()
|
||||
self.actor = create_test_user()
|
||||
self.user = self.create_user_with_email()
|
||||
|
||||
def test_account_lockdown_with_change_user_returns_redirect(self):
|
||||
"""Test that account lockdown allows users with change_user permission."""
|
||||
self.actor.assign_perms_to_managed_role("authentik_core.change_user", self.user)
|
||||
self.client.force_login(self.actor)
|
||||
|
||||
response = self.client.post(
|
||||
reverse("authentik_api:user-account-lockdown"),
|
||||
data={"user": self.user.pk},
|
||||
format="json",
|
||||
)
|
||||
|
||||
self.assert_redirect_targets(response, self.user)
|
||||
|
||||
def test_account_lockdown_no_flow_configured(self):
|
||||
"""Test account lockdown when no flow is configured"""
|
||||
self.brand.flow_lockdown = None
|
||||
self.brand.save()
|
||||
self.actor.assign_perms_to_managed_role("authentik_core.change_user", self.user)
|
||||
self.client.force_login(self.actor)
|
||||
|
||||
response = self.client.post(
|
||||
reverse("authentik_api:user-account-lockdown"),
|
||||
data={"user": self.user.pk},
|
||||
format="json",
|
||||
)
|
||||
|
||||
self.assert_no_flow_configured(response)
|
||||
|
||||
def test_account_lockdown_unauthenticated(self):
|
||||
"""Test account lockdown requires authentication"""
|
||||
response = self.client.post(
|
||||
reverse("authentik_api:user-account-lockdown"),
|
||||
data={"user": self.user.pk},
|
||||
format="json",
|
||||
)
|
||||
|
||||
self.assertEqual(response.status_code, 403)
|
||||
|
||||
def test_account_lockdown_without_change_user_denied(self):
|
||||
"""Test account lockdown denies users without change_user permission."""
|
||||
self.client.force_login(self.actor)
|
||||
|
||||
response = self.client.post(
|
||||
reverse("authentik_api:user-account-lockdown"),
|
||||
data={"user": self.user.pk},
|
||||
format="json",
|
||||
)
|
||||
|
||||
self.assertEqual(response.status_code, 403)
|
||||
|
||||
def test_account_lockdown_self_returns_redirect(self):
|
||||
"""Test successful self-service account lockdown returns a direct redirect."""
|
||||
self.client.force_login(self.user)
|
||||
|
||||
response = self.client.post(
|
||||
reverse("authentik_api:user-account-lockdown"),
|
||||
data={},
|
||||
format="json",
|
||||
)
|
||||
|
||||
self.assert_redirect_targets(response, self.user)
|
||||
|
||||
def test_account_lockdown_self_target_without_change_user_returns_redirect(self):
|
||||
"""Test self-service does not require change_user permission."""
|
||||
self.client.force_login(self.user)
|
||||
|
||||
response = self.client.post(
|
||||
reverse("authentik_api:user-account-lockdown"),
|
||||
data={"user": self.user.pk},
|
||||
format="json",
|
||||
)
|
||||
|
||||
self.assert_redirect_targets(response, self.user)
|
||||
@@ -0,0 +1,46 @@
|
||||
"""Tests for the packaged account-lockdown blueprint."""
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
from django.test import TransactionTestCase
|
||||
|
||||
from authentik.blueprints.models import BlueprintInstance
|
||||
from authentik.blueprints.v1.importer import Importer
|
||||
from authentik.blueprints.v1.tasks import blueprints_find, check_blueprint_v1_file
|
||||
from authentik.enterprise.license import LicenseKey
|
||||
from authentik.flows.models import Flow
|
||||
|
||||
BLUEPRINT_PATH = "example/flow-default-account-lockdown.yaml"
|
||||
|
||||
|
||||
class TestAccountLockdownBlueprint(TransactionTestCase):
|
||||
"""Test the packaged account-lockdown blueprint behavior."""
|
||||
|
||||
def test_blueprint_is_not_auto_instantiated(self):
|
||||
"""Test the packaged blueprint is opt-in and skipped by discovery."""
|
||||
BlueprintInstance.objects.filter(path=BLUEPRINT_PATH).delete()
|
||||
blueprint = next(item for item in blueprints_find() if item.path == BLUEPRINT_PATH)
|
||||
|
||||
check_blueprint_v1_file(blueprint)
|
||||
|
||||
self.assertFalse(BlueprintInstance.objects.filter(path=BLUEPRINT_PATH).exists())
|
||||
|
||||
def test_blueprint_requires_licensed_context(self):
|
||||
"""Test manual import only creates flows when enterprise is licensed."""
|
||||
content = BlueprintInstance(path=BLUEPRINT_PATH).retrieve()
|
||||
license_key = LicenseKey("test", 253402300799, "Test license", 1000, 1000)
|
||||
|
||||
with patch("authentik.enterprise.license.LicenseKey.get_total", return_value=license_key):
|
||||
importer = Importer.from_string(content, {"goauthentik.io/enterprise/licensed": False})
|
||||
valid, logs = importer.validate()
|
||||
self.assertTrue(valid, logs)
|
||||
self.assertTrue(importer.apply())
|
||||
self.assertFalse(Flow.objects.filter(slug="default-account-lockdown").exists())
|
||||
self.assertFalse(Flow.objects.filter(slug="default-account-lockdown-complete").exists())
|
||||
|
||||
importer = Importer.from_string(content, {"goauthentik.io/enterprise/licensed": True})
|
||||
valid, logs = importer.validate()
|
||||
self.assertTrue(valid, logs)
|
||||
self.assertTrue(importer.apply())
|
||||
self.assertTrue(Flow.objects.filter(slug="default-account-lockdown").exists())
|
||||
self.assertTrue(Flow.objects.filter(slug="default-account-lockdown-complete").exists())
|
||||
627
authentik/enterprise/stages/account_lockdown/tests/test_stage.py
Normal file
627
authentik/enterprise/stages/account_lockdown/tests/test_stage.py
Normal file
@@ -0,0 +1,627 @@
|
||||
"""Account lockdown stage tests"""
|
||||
|
||||
import json
|
||||
from dataclasses import asdict
|
||||
from threading import Event as ThreadEvent
|
||||
from threading import Thread
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from django.db import connection
|
||||
from django.http import HttpResponse
|
||||
from django.test import TransactionTestCase
|
||||
from django.urls import reverse
|
||||
from django.utils import timezone
|
||||
from dramatiq.results.errors import ResultTimeout
|
||||
|
||||
from authentik.core.models import AuthenticatedSession, Session, Token, TokenIntents
|
||||
from authentik.core.tests.utils import (
|
||||
RequestFactory,
|
||||
create_test_admin_user,
|
||||
create_test_cert,
|
||||
create_test_flow,
|
||||
create_test_user,
|
||||
)
|
||||
from authentik.enterprise.stages.account_lockdown.models import AccountLockdownStage
|
||||
from authentik.enterprise.stages.account_lockdown.stage import (
|
||||
LOCKDOWN_EVENT_ACTION_ID,
|
||||
PLAN_CONTEXT_LOCKDOWN_REASON,
|
||||
AccountLockdownStageView,
|
||||
can_lock_user,
|
||||
)
|
||||
from authentik.events.models import Event, EventAction
|
||||
from authentik.flows.markers import StageMarker
|
||||
from authentik.flows.models import FlowDesignation, FlowStageBinding
|
||||
from authentik.flows.planner import PLAN_CONTEXT_PENDING_USER, FlowPlan
|
||||
from authentik.flows.tests import FlowTestCase
|
||||
from authentik.lib.generators import generate_id
|
||||
from authentik.lib.utils.reflection import class_to_path
|
||||
from authentik.providers.oauth2.id_token import IDToken
|
||||
from authentik.providers.oauth2.models import (
|
||||
AccessToken,
|
||||
AuthorizationCode,
|
||||
DeviceToken,
|
||||
OAuth2Provider,
|
||||
RedirectURI,
|
||||
RedirectURIMatchingMode,
|
||||
RefreshToken,
|
||||
)
|
||||
from authentik.providers.saml.models import SAMLProvider, SAMLSession
|
||||
from authentik.stages.prompt.stage import PLAN_CONTEXT_PROMPT
|
||||
|
||||
patch_enterprise_enabled = patch(
|
||||
"authentik.enterprise.apps.AuthentikEnterpriseConfig.check_enabled",
|
||||
return_value=True,
|
||||
)
|
||||
|
||||
|
||||
class AccountLockdownStageTestMixin:
|
||||
"""Shared setup helpers for account lockdown stage tests."""
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.patch_enterprise_enabled = patch_enterprise_enabled.start()
|
||||
cls.patch_event_dispatch = patch("authentik.events.tasks.event_trigger_dispatch.send")
|
||||
cls.patch_event_dispatch.start()
|
||||
super().setUpClass()
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
cls.patch_event_dispatch.stop()
|
||||
patch_enterprise_enabled.stop()
|
||||
super().tearDownClass()
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self.user = create_test_admin_user()
|
||||
self.target_user = create_test_admin_user()
|
||||
self.flow = create_test_flow(FlowDesignation.STAGE_CONFIGURATION)
|
||||
self.stage = AccountLockdownStage.objects.create(
|
||||
name="lockdown",
|
||||
)
|
||||
self.binding = FlowStageBinding.objects.create(target=self.flow, stage=self.stage, order=0)
|
||||
self.request_factory = RequestFactory()
|
||||
|
||||
def make_stage_view(self, plan: FlowPlan):
|
||||
def _stage_ok():
|
||||
return HttpResponse(status=204)
|
||||
|
||||
def _stage_invalid(_error_message=None):
|
||||
return HttpResponse(status=400)
|
||||
|
||||
return AccountLockdownStageView(
|
||||
SimpleNamespace(
|
||||
plan=plan,
|
||||
current_stage=self.stage,
|
||||
current_binding=self.binding,
|
||||
flow=self.flow,
|
||||
stage_ok=_stage_ok,
|
||||
stage_invalid=_stage_invalid,
|
||||
)
|
||||
)
|
||||
|
||||
def make_request(self, *, user=None, query=None):
|
||||
return self.request_factory.post(
|
||||
reverse("authentik_api:flow-executor", kwargs={"flow_slug": self.flow.slug}),
|
||||
query_params=query or {},
|
||||
user=user,
|
||||
)
|
||||
|
||||
def get_lockdown_event(self):
|
||||
"""Return the account-lockdown user-write event."""
|
||||
return Event.objects.filter(
|
||||
action=EventAction.USER_WRITE,
|
||||
context__action_id=LOCKDOWN_EVENT_ACTION_ID,
|
||||
).first()
|
||||
|
||||
|
||||
class TestAccountLockdownStage(AccountLockdownStageTestMixin, FlowTestCase):
|
||||
"""Account lockdown stage tests"""
|
||||
|
||||
def test_lockdown_no_target(self):
|
||||
"""Test lockdown stage with no pending user fails"""
|
||||
plan = FlowPlan(flow_pk=self.flow.pk.hex, bindings=[self.binding], markers=[StageMarker()])
|
||||
view = self.make_stage_view(plan)
|
||||
|
||||
response = view.dispatch(self.make_request())
|
||||
|
||||
self.assertEqual(response.status_code, 400)
|
||||
|
||||
def test_lockdown_with_pending_user(self):
|
||||
"""Test lockdown stage with a pending target user."""
|
||||
self.target_user.is_active = True
|
||||
self.target_user.save()
|
||||
|
||||
plan = FlowPlan(flow_pk=self.flow.pk.hex, bindings=[self.binding], markers=[StageMarker()])
|
||||
plan.context[PLAN_CONTEXT_LOCKDOWN_REASON] = "Security incident"
|
||||
plan.context[PLAN_CONTEXT_PENDING_USER] = self.target_user
|
||||
view = self.make_stage_view(plan)
|
||||
request = self.make_request(user=self.user)
|
||||
|
||||
self.assertTrue(can_lock_user(request.user, self.target_user))
|
||||
response = view.dispatch(request)
|
||||
|
||||
self.target_user.refresh_from_db()
|
||||
self.assertFalse(self.target_user.is_active)
|
||||
self.assertFalse(self.target_user.has_usable_password())
|
||||
self.assertEqual(response.status_code, 204)
|
||||
|
||||
# Check event was created
|
||||
event = self.get_lockdown_event()
|
||||
self.assertIsNotNone(event)
|
||||
self.assertEqual(event.context["action_id"], LOCKDOWN_EVENT_ACTION_ID)
|
||||
self.assertEqual(event.context["reason"], "Security incident")
|
||||
self.assertEqual(event.context["affected_user"], self.target_user.username)
|
||||
|
||||
def test_lockdown_with_pending_user_reason(self):
|
||||
"""Test lockdown stage with a pending target and explicit reason."""
|
||||
self.target_user.is_active = True
|
||||
self.target_user.save()
|
||||
|
||||
plan = FlowPlan(flow_pk=self.flow.pk.hex, bindings=[self.binding], markers=[StageMarker()])
|
||||
plan.context[PLAN_CONTEXT_LOCKDOWN_REASON] = "Compromised account"
|
||||
plan.context[PLAN_CONTEXT_PENDING_USER] = self.target_user
|
||||
view = self.make_stage_view(plan)
|
||||
request = self.make_request(user=self.user)
|
||||
|
||||
self.assertTrue(can_lock_user(request.user, self.target_user))
|
||||
response = view.dispatch(request)
|
||||
|
||||
self.target_user.refresh_from_db()
|
||||
self.assertFalse(self.target_user.is_active)
|
||||
self.assertEqual(response.status_code, 204)
|
||||
|
||||
def test_lockdown_reason_from_prompt(self):
|
||||
"""Test lockdown stage reads the reason from prompt data."""
|
||||
self.target_user.is_active = True
|
||||
self.target_user.save()
|
||||
|
||||
plan = FlowPlan(flow_pk=self.flow.pk.hex, bindings=[self.binding], markers=[StageMarker()])
|
||||
plan.context[PLAN_CONTEXT_PROMPT] = {
|
||||
PLAN_CONTEXT_LOCKDOWN_REASON: "User requested lockdown",
|
||||
}
|
||||
view = self.make_stage_view(plan)
|
||||
request = self.make_request(user=self.user)
|
||||
view._lockdown_user(request, self.stage, self.target_user, view.get_reason())
|
||||
|
||||
event = self.get_lockdown_event()
|
||||
self.assertIsNotNone(event)
|
||||
self.assertEqual(event.context["reason"], "User requested lockdown")
|
||||
|
||||
def test_lockdown_event_failure_does_not_fail_self_service(self):
|
||||
"""Test lockdown still succeeds when event emission fails."""
|
||||
self.stage.delete_sessions = False
|
||||
self.stage.save()
|
||||
|
||||
self.target_user.is_active = True
|
||||
self.target_user.save()
|
||||
|
||||
plan = FlowPlan(flow_pk=self.flow.pk.hex, bindings=[self.binding], markers=[StageMarker()])
|
||||
plan.context[PLAN_CONTEXT_PENDING_USER] = self.target_user
|
||||
view = self.make_stage_view(plan)
|
||||
request = self.make_request(user=self.target_user)
|
||||
|
||||
original_event_new = Event.new
|
||||
|
||||
def _event_new_side_effect(action, *args, **kwargs):
|
||||
if (
|
||||
action == EventAction.USER_WRITE
|
||||
and kwargs.get("action_id") == LOCKDOWN_EVENT_ACTION_ID
|
||||
):
|
||||
raise RuntimeError("simulated event failure")
|
||||
return original_event_new(action, *args, **kwargs)
|
||||
|
||||
with patch(
|
||||
"authentik.enterprise.stages.account_lockdown.stage.Event.new",
|
||||
side_effect=_event_new_side_effect,
|
||||
):
|
||||
view._lockdown_user(request, self.stage, self.target_user, view.get_reason())
|
||||
|
||||
self.target_user.refresh_from_db()
|
||||
self.assertFalse(self.target_user.is_active)
|
||||
|
||||
def test_dispatch_records_success_when_event_emission_fails(self):
|
||||
"""Test dispatch still completes if event emission fails."""
|
||||
self.stage.delete_sessions = False
|
||||
self.stage.save()
|
||||
|
||||
self.target_user.is_active = True
|
||||
self.target_user.save()
|
||||
|
||||
plan = FlowPlan(flow_pk=self.flow.pk.hex, bindings=[self.binding], markers=[StageMarker()])
|
||||
plan.context[PLAN_CONTEXT_PENDING_USER] = self.target_user
|
||||
view = self.make_stage_view(plan)
|
||||
request = self.make_request(
|
||||
user=self.target_user,
|
||||
)
|
||||
|
||||
original_event_new = Event.new
|
||||
|
||||
def _event_new_side_effect(action, *args, **kwargs):
|
||||
if (
|
||||
action == EventAction.USER_WRITE
|
||||
and kwargs.get("action_id") == LOCKDOWN_EVENT_ACTION_ID
|
||||
):
|
||||
raise RuntimeError("simulated event failure")
|
||||
return original_event_new(action, *args, **kwargs)
|
||||
|
||||
with patch(
|
||||
"authentik.enterprise.stages.account_lockdown.stage.Event.new",
|
||||
side_effect=_event_new_side_effect,
|
||||
):
|
||||
response = view.dispatch(request)
|
||||
|
||||
self.target_user.refresh_from_db()
|
||||
self.assertFalse(self.target_user.is_active)
|
||||
self.assertEqual(response.status_code, 204)
|
||||
|
||||
def test_lockdown_self_service_redirects_to_completion_flow(self):
|
||||
"""Test self-service lockdown redirects to completion flow when sessions are deleted."""
|
||||
completion_flow = create_test_flow(FlowDesignation.STAGE_CONFIGURATION)
|
||||
self.stage.self_service_completion_flow = completion_flow
|
||||
self.stage.save()
|
||||
|
||||
self.target_user.is_active = True
|
||||
self.target_user.save()
|
||||
|
||||
plan = FlowPlan(flow_pk=self.flow.pk.hex, bindings=[self.binding], markers=[StageMarker()])
|
||||
view = self.make_stage_view(plan)
|
||||
request = self.make_request(user=self.target_user)
|
||||
view._lockdown_user(request, self.stage, self.target_user, view.get_reason())
|
||||
response = view._self_service_completion_response(request)
|
||||
|
||||
self.assertEqual(response.status_code, 302)
|
||||
self.assertEqual(
|
||||
response.url,
|
||||
reverse("authentik_core:if-flow", kwargs={"flow_slug": completion_flow.slug}),
|
||||
)
|
||||
|
||||
def test_lockdown_self_service_requires_completion_flow(self):
|
||||
"""Test self-service lockdown fails before deleting sessions without a completion flow."""
|
||||
self.stage.self_service_completion_flow = None
|
||||
self.stage.save()
|
||||
|
||||
self.target_user.is_active = True
|
||||
self.target_user.save()
|
||||
|
||||
plan = FlowPlan(flow_pk=self.flow.pk.hex, bindings=[self.binding], markers=[StageMarker()])
|
||||
plan.context[PLAN_CONTEXT_PENDING_USER] = self.target_user
|
||||
view = self.make_stage_view(plan)
|
||||
request = self.make_request(user=self.target_user)
|
||||
|
||||
response = view.dispatch(request)
|
||||
|
||||
self.assertEqual(response.status_code, 400)
|
||||
self.target_user.refresh_from_db()
|
||||
self.assertTrue(self.target_user.is_active)
|
||||
|
||||
def test_lockdown_denies_other_user_without_permission(self):
|
||||
"""Test lockdown stage rejects non-self requests without change_user permission."""
|
||||
actor = create_test_user()
|
||||
|
||||
plan = FlowPlan(flow_pk=self.flow.pk.hex, bindings=[self.binding], markers=[StageMarker()])
|
||||
plan.context[PLAN_CONTEXT_PENDING_USER] = self.target_user
|
||||
view = self.make_stage_view(plan)
|
||||
request = self.make_request(user=actor)
|
||||
|
||||
self.assertFalse(can_lock_user(request.user, self.target_user))
|
||||
response = view.dispatch(request)
|
||||
self.assertEqual(response.status_code, 400)
|
||||
|
||||
def test_lockdown_revokes_tokens(self):
|
||||
"""Test lockdown stage revokes tokens"""
|
||||
Token.objects.create(
|
||||
user=self.target_user,
|
||||
identifier="test-token",
|
||||
intent=TokenIntents.INTENT_API,
|
||||
key=generate_id(),
|
||||
expiring=False,
|
||||
)
|
||||
self.assertEqual(Token.objects.filter(user=self.target_user).count(), 1)
|
||||
|
||||
plan = FlowPlan(flow_pk=self.flow.pk.hex, bindings=[self.binding], markers=[StageMarker()])
|
||||
view = self.make_stage_view(plan)
|
||||
view._lockdown_user(self.make_request(user=self.user), self.stage, self.target_user, "")
|
||||
|
||||
self.assertEqual(Token.objects.filter(user=self.target_user).count(), 0)
|
||||
|
||||
def test_lockdown_revokes_provider_tokens(self):
|
||||
"""Test lockdown stage revokes provider tokens and sessions."""
|
||||
oauth_provider = OAuth2Provider.objects.create(
|
||||
name=generate_id(),
|
||||
authorization_flow=create_test_flow(),
|
||||
redirect_uris=[
|
||||
RedirectURI(RedirectURIMatchingMode.STRICT, "http://testserver/callback")
|
||||
],
|
||||
signing_key=create_test_cert(),
|
||||
)
|
||||
saml_provider = SAMLProvider.objects.create(
|
||||
name=generate_id(),
|
||||
authorization_flow=create_test_flow(),
|
||||
acs_url="https://sp.example.com/acs",
|
||||
issuer_override="https://idp.example.com",
|
||||
)
|
||||
session = Session.objects.create(
|
||||
session_key=generate_id(),
|
||||
expires=timezone.now() + timezone.timedelta(hours=1),
|
||||
last_ip="127.0.0.1",
|
||||
)
|
||||
auth_session = AuthenticatedSession.objects.create(
|
||||
session=session,
|
||||
user=self.target_user,
|
||||
)
|
||||
grant_kwargs = {
|
||||
"provider": oauth_provider,
|
||||
"user": self.target_user,
|
||||
"auth_time": timezone.now(),
|
||||
"_scope": "openid profile",
|
||||
"expiring": False,
|
||||
}
|
||||
token_kwargs = grant_kwargs | {"_id_token": json.dumps(asdict(IDToken("foo", "bar")))}
|
||||
AuthorizationCode.objects.create(
|
||||
code=generate_id(),
|
||||
session=auth_session,
|
||||
**grant_kwargs,
|
||||
)
|
||||
AccessToken.objects.create(
|
||||
token=generate_id(),
|
||||
session=auth_session,
|
||||
**token_kwargs,
|
||||
)
|
||||
RefreshToken.objects.create(
|
||||
token=generate_id(),
|
||||
session=auth_session,
|
||||
**token_kwargs,
|
||||
)
|
||||
DeviceToken.objects.create(
|
||||
provider=oauth_provider,
|
||||
user=self.target_user,
|
||||
session=auth_session,
|
||||
_scope="openid profile",
|
||||
expiring=False,
|
||||
)
|
||||
SAMLSession.objects.create(
|
||||
provider=saml_provider,
|
||||
user=self.target_user,
|
||||
session=auth_session,
|
||||
session_index=generate_id(),
|
||||
name_id=self.target_user.email,
|
||||
expires=timezone.now() + timezone.timedelta(hours=1),
|
||||
expiring=True,
|
||||
)
|
||||
|
||||
plan = FlowPlan(flow_pk=self.flow.pk.hex, bindings=[self.binding], markers=[StageMarker()])
|
||||
view = self.make_stage_view(plan)
|
||||
view._lockdown_user(self.make_request(user=self.user), self.stage, self.target_user, "")
|
||||
|
||||
self.assertEqual(AuthorizationCode.objects.filter(user=self.target_user).count(), 0)
|
||||
self.assertEqual(AccessToken.objects.filter(user=self.target_user).count(), 0)
|
||||
self.assertEqual(RefreshToken.objects.filter(user=self.target_user).count(), 0)
|
||||
self.assertEqual(DeviceToken.objects.filter(user=self.target_user).count(), 0)
|
||||
self.assertEqual(SAMLSession.objects.filter(user=self.target_user).count(), 0)
|
||||
|
||||
def test_lockdown_selective_actions(self):
|
||||
"""Test lockdown stage with selective actions"""
|
||||
self.stage.deactivate_user = True
|
||||
self.stage.set_unusable_password = False
|
||||
self.stage.delete_sessions = False
|
||||
self.stage.revoke_tokens = False
|
||||
self.stage.save()
|
||||
|
||||
self.target_user.is_active = True
|
||||
self.target_user.set_password("testpassword")
|
||||
self.target_user.save()
|
||||
|
||||
Token.objects.create(
|
||||
user=self.target_user,
|
||||
identifier="test-token",
|
||||
intent=TokenIntents.INTENT_API,
|
||||
key=generate_id(),
|
||||
expiring=False,
|
||||
)
|
||||
|
||||
plan = FlowPlan(flow_pk=self.flow.pk.hex, bindings=[self.binding], markers=[StageMarker()])
|
||||
view = self.make_stage_view(plan)
|
||||
view._lockdown_user(self.make_request(user=self.user), self.stage, self.target_user, "")
|
||||
|
||||
self.target_user.refresh_from_db()
|
||||
# User should be deactivated
|
||||
self.assertFalse(self.target_user.is_active)
|
||||
# Password should still be usable
|
||||
self.assertTrue(self.target_user.has_usable_password())
|
||||
# Token should still exist
|
||||
self.assertEqual(Token.objects.filter(user=self.target_user).count(), 1)
|
||||
|
||||
def test_lockdown_no_actions(self):
|
||||
"""Test lockdown stage with all actions disabled"""
|
||||
self.stage.deactivate_user = False
|
||||
self.stage.set_unusable_password = False
|
||||
self.stage.delete_sessions = False
|
||||
self.stage.revoke_tokens = False
|
||||
self.stage.save()
|
||||
|
||||
self.target_user.is_active = True
|
||||
self.target_user.set_password("testpassword")
|
||||
self.target_user.save()
|
||||
|
||||
plan = FlowPlan(flow_pk=self.flow.pk.hex, bindings=[self.binding], markers=[StageMarker()])
|
||||
view = self.make_stage_view(plan)
|
||||
view._lockdown_user(self.make_request(user=self.user), self.stage, self.target_user, "")
|
||||
|
||||
self.target_user.refresh_from_db()
|
||||
# User should still be active
|
||||
self.assertTrue(self.target_user.is_active)
|
||||
# Password should still be usable
|
||||
self.assertTrue(self.target_user.has_usable_password())
|
||||
# Event should still be created
|
||||
event = self.get_lockdown_event()
|
||||
self.assertIsNotNone(event)
|
||||
|
||||
def test_lockdown_deactivation_inhibits_signal_dispatch_until_after_commit(self):
|
||||
"""Test lockdown queues explicit outgoing syncs after the deactivation transaction."""
|
||||
plan = FlowPlan(flow_pk=self.flow.pk.hex, bindings=[self.binding], markers=[StageMarker()])
|
||||
view = self.make_stage_view(plan)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"authentik.enterprise.stages.account_lockdown.stage.sync_outgoing_inhibit_dispatch"
|
||||
) as inhibit,
|
||||
patch.object(view, "_sync_deactivated_user_to_outgoing_providers") as sync_outgoing,
|
||||
):
|
||||
view._lockdown_user(self.make_request(user=self.user), self.stage, self.target_user, "")
|
||||
|
||||
inhibit.assert_called_once()
|
||||
sync_outgoing.assert_called_once()
|
||||
synced_user = sync_outgoing.call_args.args[0]
|
||||
self.assertEqual(synced_user.pk, self.target_user.pk)
|
||||
self.assertFalse(synced_user.is_active)
|
||||
|
||||
def test_lockdown_waits_for_direct_outgoing_provider_syncs(self):
|
||||
"""Test direct outgoing sync tasks are enqueued and waited on."""
|
||||
plan = FlowPlan(flow_pk=self.flow.pk.hex, bindings=[self.binding], markers=[StageMarker()])
|
||||
view = self.make_stage_view(plan)
|
||||
provider = SimpleNamespace(name="outgoing", pk=1, sync_page_timeout="seconds=5")
|
||||
task_sync_direct = MagicMock()
|
||||
task_sync_direct.message_with_options.return_value = "direct-message"
|
||||
provider_model = SimpleNamespace(
|
||||
objects=SimpleNamespace(filter=MagicMock(return_value=[provider]))
|
||||
)
|
||||
task_group = MagicMock()
|
||||
|
||||
with (
|
||||
patch(
|
||||
"authentik.enterprise.stages.account_lockdown.stage.get_outgoing_sync_tasks",
|
||||
return_value=((provider_model, task_sync_direct),),
|
||||
),
|
||||
patch(
|
||||
"authentik.enterprise.stages.account_lockdown.stage.group",
|
||||
return_value=task_group,
|
||||
) as task_group_cls,
|
||||
):
|
||||
view._sync_deactivated_user_to_outgoing_providers(self.target_user)
|
||||
|
||||
task_sync_direct.message_with_options.assert_called_once_with(
|
||||
args=(class_to_path(type(self.target_user)), self.target_user.pk, provider.pk),
|
||||
rel_obj=provider,
|
||||
time_limit=5000,
|
||||
uid=f"{provider.name}:user:{self.target_user.pk}:direct",
|
||||
)
|
||||
task_group_cls.assert_called_once_with(["direct-message"])
|
||||
task_group.run.return_value.wait.assert_called_once_with(timeout=5000)
|
||||
|
||||
def test_lockdown_outgoing_provider_sync_timeout_leaves_tasks_running(self):
|
||||
"""Test timeout while waiting for direct outgoing syncs does not fail lockdown."""
|
||||
plan = FlowPlan(flow_pk=self.flow.pk.hex, bindings=[self.binding], markers=[StageMarker()])
|
||||
view = self.make_stage_view(plan)
|
||||
provider = SimpleNamespace(name="outgoing", pk=1, sync_page_timeout="seconds=5")
|
||||
task_sync_direct = MagicMock()
|
||||
task_sync_direct.message_with_options.return_value = "direct-message"
|
||||
provider_model = SimpleNamespace(
|
||||
objects=SimpleNamespace(filter=MagicMock(return_value=[provider]))
|
||||
)
|
||||
task_group = MagicMock()
|
||||
task_group.run.return_value.wait.side_effect = ResultTimeout("timed out")
|
||||
|
||||
with (
|
||||
patch(
|
||||
"authentik.enterprise.stages.account_lockdown.stage.get_outgoing_sync_tasks",
|
||||
return_value=((provider_model, task_sync_direct),),
|
||||
),
|
||||
patch(
|
||||
"authentik.enterprise.stages.account_lockdown.stage.group",
|
||||
return_value=task_group,
|
||||
),
|
||||
):
|
||||
view._sync_deactivated_user_to_outgoing_providers(self.target_user)
|
||||
|
||||
task_group.run.assert_called_once_with()
|
||||
task_group.run.return_value.wait.assert_called_once_with(timeout=5000)
|
||||
|
||||
def test_lockdown_outgoing_provider_sync_failure_does_not_fail_lockdown(self):
|
||||
"""Test completed local lockdown still emits an event if outgoing sync fails."""
|
||||
plan = FlowPlan(flow_pk=self.flow.pk.hex, bindings=[self.binding], markers=[StageMarker()])
|
||||
view = self.make_stage_view(plan)
|
||||
|
||||
with patch.object(
|
||||
view,
|
||||
"_sync_deactivated_user_to_outgoing_providers",
|
||||
side_effect=ValueError("sync failed"),
|
||||
):
|
||||
view._lockdown_user(self.make_request(user=self.user), self.stage, self.target_user, "")
|
||||
|
||||
self.target_user.refresh_from_db()
|
||||
self.assertFalse(self.target_user.is_active)
|
||||
event = self.get_lockdown_event()
|
||||
self.assertIsNotNone(event)
|
||||
|
||||
|
||||
class TestAccountLockdownStageConcurrency(AccountLockdownStageTestMixin, TransactionTestCase):
|
||||
"""Account lockdown concurrency tests."""
|
||||
|
||||
def test_lockdown_retries_when_another_transaction_recreates_a_token(self):
|
||||
"""Lockdown should remove a token recreated before the retry check runs."""
|
||||
Token.objects.create(
|
||||
user=self.target_user,
|
||||
identifier=f"initial-token-{generate_id()}",
|
||||
intent=TokenIntents.INTENT_API,
|
||||
key=generate_id(),
|
||||
expiring=False,
|
||||
)
|
||||
|
||||
plan = FlowPlan(flow_pk=self.flow.pk.hex, bindings=[self.binding], markers=[StageMarker()])
|
||||
view = self.make_stage_view(plan)
|
||||
original_has_artifacts = view._has_lockdown_artifacts
|
||||
target_user = self.target_user
|
||||
thread_ready = ThreadEvent()
|
||||
start_create = ThreadEvent()
|
||||
thread_done = ThreadEvent()
|
||||
thread_errors = []
|
||||
|
||||
class TokenCreatorThread(Thread):
|
||||
__test__ = False
|
||||
|
||||
def run(self):
|
||||
try:
|
||||
thread_ready.set()
|
||||
if not start_create.wait(timeout=5):
|
||||
thread_errors.append("timed out waiting to recreate token")
|
||||
return
|
||||
Token.objects.create(
|
||||
user=target_user,
|
||||
identifier=f"concurrent-token-{generate_id()}",
|
||||
intent=TokenIntents.INTENT_API,
|
||||
key=generate_id(),
|
||||
expiring=False,
|
||||
)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
thread_errors.append(exc)
|
||||
finally:
|
||||
thread_done.set()
|
||||
connection.close()
|
||||
|
||||
def has_artifacts_after_concurrent_create(stage, user):
|
||||
if not start_create.is_set():
|
||||
start_create.set()
|
||||
self.assertTrue(
|
||||
thread_done.wait(timeout=30),
|
||||
(
|
||||
"Concurrent token creation did not complete "
|
||||
f"before retry check: {thread_errors}"
|
||||
),
|
||||
)
|
||||
return original_has_artifacts(stage, user)
|
||||
|
||||
creator = TokenCreatorThread()
|
||||
with patch.object(
|
||||
view, "_has_lockdown_artifacts", side_effect=has_artifacts_after_concurrent_create
|
||||
):
|
||||
creator.start()
|
||||
self.assertTrue(
|
||||
thread_ready.wait(timeout=5),
|
||||
"Concurrent token creation thread did not start",
|
||||
)
|
||||
view._lockdown_user(self.make_request(user=self.user), self.stage, self.target_user, "")
|
||||
creator.join()
|
||||
|
||||
self.assertEqual(thread_errors, [])
|
||||
self.assertEqual(Token.objects.filter(user=self.target_user).count(), 0)
|
||||
5
authentik/enterprise/stages/account_lockdown/urls.py
Normal file
5
authentik/enterprise/stages/account_lockdown/urls.py
Normal file
@@ -0,0 +1,5 @@
|
||||
"""API URLs"""
|
||||
|
||||
from authentik.enterprise.stages.account_lockdown.api import AccountLockdownStageViewSet
|
||||
|
||||
api_urlpatterns = [("stages/account_lockdown", AccountLockdownStageViewSet)]
|
||||
@@ -8,7 +8,6 @@ from inspect import currentframe
|
||||
from typing import Any
|
||||
from uuid import uuid4
|
||||
|
||||
from asgiref.sync import async_to_sync
|
||||
from channels.layers import get_channel_layer
|
||||
from django.apps import apps
|
||||
from django.db import models
|
||||
@@ -410,7 +409,7 @@ class NotificationTransport(TasksModel, SerializerModel):
|
||||
)
|
||||
notification.save()
|
||||
layer = get_channel_layer()
|
||||
async_to_sync(layer.group_send)(
|
||||
layer.group_send_blocking(
|
||||
build_user_group(notification.user),
|
||||
{
|
||||
"type": "event.notification",
|
||||
|
||||
@@ -11,7 +11,7 @@ from django.http import HttpRequest
|
||||
from rest_framework.request import Request
|
||||
|
||||
from authentik.core.models import AuthenticatedSession, User
|
||||
from authentik.core.signals import login_failed, password_changed
|
||||
from authentik.core.signals import login_failed, password_changed, password_hash_changed
|
||||
from authentik.events.models import Event, EventAction
|
||||
from authentik.flows.models import Stage
|
||||
from authentik.flows.planner import (
|
||||
@@ -112,8 +112,15 @@ def on_invitation_used(sender, request: HttpRequest, invitation: Invitation, **_
|
||||
)
|
||||
|
||||
|
||||
@receiver(password_hash_changed)
|
||||
@receiver(password_changed)
|
||||
def on_password_changed(sender, user: User, password: str, request: HttpRequest | None, **_):
|
||||
def on_password_changed(
|
||||
sender,
|
||||
user: User,
|
||||
password: str | None = None,
|
||||
request: HttpRequest | None = None,
|
||||
**_,
|
||||
):
|
||||
"""Log password change"""
|
||||
Event.new(EventAction.PASSWORD_SET).from_http(request, user=user)
|
||||
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
|
||||
from urllib.parse import urlencode
|
||||
|
||||
from django.contrib.auth.hashers import make_password
|
||||
from django.contrib.contenttypes.models import ContentType
|
||||
from django.test import RequestFactory, TestCase
|
||||
from django.views.debug import SafeExceptionReporterFilter
|
||||
@@ -10,7 +11,7 @@ from guardian.shortcuts import get_anonymous_user
|
||||
from authentik.brands.models import Brand
|
||||
from authentik.core.models import Group, User
|
||||
from authentik.core.tests.utils import create_test_user
|
||||
from authentik.events.models import Event
|
||||
from authentik.events.models import Event, EventAction
|
||||
from authentik.flows.planner import PLAN_CONTEXT_PENDING_USER, FlowPlan
|
||||
from authentik.flows.views.executor import QS_QUERY, SESSION_KEY_PLAN
|
||||
from authentik.lib.generators import generate_id
|
||||
@@ -213,3 +214,14 @@ class TestEvents(TestCase):
|
||||
event = Event.new("unittest", foo="foo bar \u0000 baz")
|
||||
event.save()
|
||||
self.assertEqual(event.context["foo"], "foo bar baz")
|
||||
|
||||
def test_password_set_signal_on_set_password_from_hash(self):
|
||||
"""Changing password from hash should still emit an audit event."""
|
||||
user = create_test_user()
|
||||
old_count = Event.objects.filter(action=EventAction.PASSWORD_SET, user__pk=user.pk).count()
|
||||
|
||||
user.set_password_from_hash(make_password(generate_id()))
|
||||
user.save()
|
||||
|
||||
new_count = Event.objects.filter(action=EventAction.PASSWORD_SET, user__pk=user.pk).count()
|
||||
self.assertEqual(new_count, old_count + 1)
|
||||
|
||||
@@ -29,6 +29,7 @@ class RefreshOtherFlowsAfterAuthentication(Flag[bool], key="flows_refresh_others
|
||||
default = False
|
||||
visibility = "public"
|
||||
description = _("Refresh other tabs after successful authentication.")
|
||||
deprecated = True
|
||||
|
||||
|
||||
class ContinuousLogin(Flag[bool], key="flows_continuous_login"):
|
||||
|
||||
@@ -23,7 +23,7 @@
|
||||
height: 100%;
|
||||
}
|
||||
body {
|
||||
background-image: url("{{ flow_background_url }}");
|
||||
background-image: url("{{ flow_background_url|iriencode|safe }}");
|
||||
background-repeat: no-repeat;
|
||||
background-size: cover;
|
||||
}
|
||||
|
||||
@@ -39,7 +39,7 @@
|
||||
<script src="{% versioned_script 'dist/flow/FlowInterface-%v.js' %}" type="module"></script>
|
||||
<style data-id="flow-css">
|
||||
:root {
|
||||
--ak-global--background-image: url("{{ flow_background_url }}");
|
||||
--ak-global--background-image: url("{{ flow_background_url|iriencode|safe }}");
|
||||
}
|
||||
</style>
|
||||
{% endblock %}
|
||||
|
||||
@@ -1,12 +1,14 @@
|
||||
"""stage view tests"""
|
||||
|
||||
from collections.abc import Callable
|
||||
from unittest.mock import patch
|
||||
|
||||
from django.test import RequestFactory, TestCase
|
||||
from django.urls import reverse
|
||||
|
||||
from authentik.core.tests.utils import RequestFactory as AuthentikRequestFactory
|
||||
from authentik.core.tests.utils import create_test_flow
|
||||
from authentik.flows.models import FlowStageBinding
|
||||
from authentik.flows.models import Flow, FlowStageBinding
|
||||
from authentik.flows.stage import StageView
|
||||
from authentik.flows.views.executor import FlowExecutorView
|
||||
from authentik.lib.utils.reflection import all_subclasses
|
||||
@@ -42,6 +44,46 @@ class TestViews(TestCase):
|
||||
"/static/dist/assets/images/flow_background.jpg",
|
||||
)
|
||||
|
||||
def test_flow_interface_css_background_preserves_presigned_url_query(self):
|
||||
"""Test flow CSS keeps signed URL query separators intact."""
|
||||
flow = create_test_flow()
|
||||
background_url = (
|
||||
"https://s3.ca-central-1.amazonaws.com/example/media/public/background.png"
|
||||
"?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=credential"
|
||||
"&X-Amz-Signature=signature"
|
||||
)
|
||||
|
||||
with patch.object(Flow, "background_url", return_value=background_url):
|
||||
response = self.client.get(
|
||||
reverse("authentik_core:if-flow", kwargs={"flow_slug": flow.slug})
|
||||
)
|
||||
|
||||
self.assertContains(
|
||||
response,
|
||||
f'--ak-global--background-image: url("{background_url}");',
|
||||
html=False,
|
||||
)
|
||||
|
||||
def test_flow_sfe_css_background_preserves_presigned_url_query(self):
|
||||
"""Test SFE flow CSS keeps signed URL query separators intact."""
|
||||
flow = create_test_flow()
|
||||
background_url = (
|
||||
"https://s3.ca-central-1.amazonaws.com/example/media/public/background.png"
|
||||
"?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=credential"
|
||||
"&X-Amz-Signature=signature"
|
||||
)
|
||||
|
||||
with patch.object(Flow, "background_url", return_value=background_url):
|
||||
response = self.client.get(
|
||||
reverse("authentik_core:if-flow", kwargs={"flow_slug": flow.slug}) + "?sfe"
|
||||
)
|
||||
|
||||
self.assertContains(
|
||||
response,
|
||||
f'background-image: url("{background_url}");',
|
||||
html=False,
|
||||
)
|
||||
|
||||
|
||||
def view_tester_factory(view_class: type[StageView]) -> Callable:
|
||||
"""Test a form"""
|
||||
|
||||
@@ -53,6 +53,16 @@ class TestEndSessionView(OAuthTestCase):
|
||||
self.brand.flow_invalidation = self.invalidation_flow
|
||||
self.brand.save()
|
||||
|
||||
def _id_token_hint(self, host: str) -> str:
|
||||
"""Issue a valid id_token_hint for the test provider under the given host."""
|
||||
return self.provider.encode(
|
||||
{
|
||||
"iss": f"http://{host}/application/o/{self.app.slug}/",
|
||||
"aud": self.provider.client_id,
|
||||
"sub": str(self.user.pk),
|
||||
}
|
||||
)
|
||||
|
||||
def test_post_logout_redirect_uri_strict_match(self):
|
||||
"""Test strict URI matching redirects to flow"""
|
||||
self.client.force_login(self.user)
|
||||
@@ -61,7 +71,10 @@ class TestEndSessionView(OAuthTestCase):
|
||||
"authentik_providers_oauth2:end-session",
|
||||
kwargs={"application_slug": self.app.slug},
|
||||
),
|
||||
{"post_logout_redirect_uri": "http://testserver/logout"},
|
||||
{
|
||||
"post_logout_redirect_uri": "http://testserver/logout",
|
||||
"id_token_hint": self._id_token_hint(self.brand.domain),
|
||||
},
|
||||
HTTP_HOST=self.brand.domain,
|
||||
)
|
||||
# Should redirect to the invalidation flow
|
||||
@@ -69,7 +82,12 @@ class TestEndSessionView(OAuthTestCase):
|
||||
self.assertIn(self.invalidation_flow.slug, response.url)
|
||||
|
||||
def test_post_logout_redirect_uri_strict_no_match(self):
|
||||
"""Test strict URI not matching still proceeds with flow (no redirect URI in context)"""
|
||||
"""Test strict URI not matching returns an error and does not start logout flow.
|
||||
|
||||
Required by OIDC RP-Initiated Logout 1.0: on an unregistered
|
||||
post_logout_redirect_uri, the OP MUST NOT redirect and MUST NOT proceed with
|
||||
logout that targets the RP.
|
||||
"""
|
||||
self.client.force_login(self.user)
|
||||
invalid_uri = "http://testserver/other"
|
||||
response = self.client.get(
|
||||
@@ -77,12 +95,14 @@ class TestEndSessionView(OAuthTestCase):
|
||||
"authentik_providers_oauth2:end-session",
|
||||
kwargs={"application_slug": self.app.slug},
|
||||
),
|
||||
{"post_logout_redirect_uri": invalid_uri},
|
||||
{
|
||||
"post_logout_redirect_uri": invalid_uri,
|
||||
"id_token_hint": self._id_token_hint(self.brand.domain),
|
||||
},
|
||||
HTTP_HOST=self.brand.domain,
|
||||
)
|
||||
# Should still redirect to flow, but invalid URI should not be in response
|
||||
self.assertEqual(response.status_code, 302)
|
||||
self.assertNotIn(invalid_uri, response.url)
|
||||
self.assertEqual(response.status_code, 400)
|
||||
self.assertNotIn(invalid_uri, response.content.decode())
|
||||
|
||||
def test_post_logout_redirect_uri_regex_match(self):
|
||||
"""Test regex URI matching redirects to flow"""
|
||||
@@ -92,7 +112,10 @@ class TestEndSessionView(OAuthTestCase):
|
||||
"authentik_providers_oauth2:end-session",
|
||||
kwargs={"application_slug": self.app.slug},
|
||||
),
|
||||
{"post_logout_redirect_uri": "https://app.example.com/logout"},
|
||||
{
|
||||
"post_logout_redirect_uri": "https://app.example.com/logout",
|
||||
"id_token_hint": self._id_token_hint(self.brand.domain),
|
||||
},
|
||||
HTTP_HOST=self.brand.domain,
|
||||
)
|
||||
# Should redirect to the invalidation flow
|
||||
@@ -100,7 +123,7 @@ class TestEndSessionView(OAuthTestCase):
|
||||
self.assertIn(self.invalidation_flow.slug, response.url)
|
||||
|
||||
def test_post_logout_redirect_uri_regex_no_match(self):
|
||||
"""Test regex URI not matching"""
|
||||
"""Test regex URI not matching returns an error and does not start logout flow."""
|
||||
self.client.force_login(self.user)
|
||||
invalid_uri = "https://malicious.com/logout"
|
||||
response = self.client.get(
|
||||
@@ -108,12 +131,14 @@ class TestEndSessionView(OAuthTestCase):
|
||||
"authentik_providers_oauth2:end-session",
|
||||
kwargs={"application_slug": self.app.slug},
|
||||
),
|
||||
{"post_logout_redirect_uri": invalid_uri},
|
||||
{
|
||||
"post_logout_redirect_uri": invalid_uri,
|
||||
"id_token_hint": self._id_token_hint(self.brand.domain),
|
||||
},
|
||||
HTTP_HOST=self.brand.domain,
|
||||
)
|
||||
# Should still proceed to flow, but invalid URI should not be in response
|
||||
self.assertEqual(response.status_code, 302)
|
||||
self.assertNotIn(invalid_uri, response.url)
|
||||
self.assertEqual(response.status_code, 400)
|
||||
self.assertNotIn(invalid_uri, response.content.decode())
|
||||
|
||||
def test_state_parameter_appended_to_uri(self):
|
||||
"""Test state parameter is appended to validated redirect URI"""
|
||||
@@ -123,6 +148,7 @@ class TestEndSessionView(OAuthTestCase):
|
||||
{
|
||||
"post_logout_redirect_uri": "http://testserver/logout",
|
||||
"state": "test-state-123",
|
||||
"id_token_hint": self._id_token_hint("testserver"),
|
||||
},
|
||||
)
|
||||
request.user = self.user
|
||||
@@ -132,6 +158,7 @@ class TestEndSessionView(OAuthTestCase):
|
||||
view.request = request
|
||||
view.kwargs = {"application_slug": self.app.slug}
|
||||
view.resolve_provider_application()
|
||||
view.validate()
|
||||
|
||||
self.assertIn("state=test-state-123", view.post_logout_redirect_uri)
|
||||
|
||||
@@ -146,6 +173,7 @@ class TestEndSessionView(OAuthTestCase):
|
||||
{
|
||||
"post_logout_redirect_uri": "http://testserver/logout",
|
||||
"state": "xyz789",
|
||||
"id_token_hint": self._id_token_hint(self.brand.domain),
|
||||
},
|
||||
HTTP_HOST=self.brand.domain,
|
||||
)
|
||||
|
||||
@@ -5,6 +5,8 @@ from urllib.parse import quote, urlparse
|
||||
|
||||
from django.http import Http404, HttpRequest, HttpResponse
|
||||
from django.shortcuts import get_object_or_404
|
||||
from jwt import PyJWTError
|
||||
from jwt import decode as jwt_decode
|
||||
|
||||
from authentik.common.oauth.constants import (
|
||||
FORBIDDEN_URI_SCHEMES,
|
||||
@@ -21,11 +23,14 @@ from authentik.flows.planner import (
|
||||
from authentik.flows.stage import SessionEndStage
|
||||
from authentik.flows.views.executor import SESSION_KEY_PLAN
|
||||
from authentik.lib.views import bad_request_message
|
||||
from authentik.policies.views import PolicyAccessView, RequestValidationError
|
||||
from authentik.policies.views import PolicyAccessView
|
||||
from authentik.providers.iframe_logout import IframeLogoutStageView
|
||||
from authentik.providers.oauth2.errors import TokenError
|
||||
from authentik.providers.oauth2.models import (
|
||||
AccessToken,
|
||||
JWTAlgorithms,
|
||||
OAuth2LogoutMethod,
|
||||
OAuth2Provider,
|
||||
RedirectURIMatchingMode,
|
||||
)
|
||||
from authentik.providers.oauth2.tasks import send_backchannel_logout_request
|
||||
@@ -47,21 +52,45 @@ class EndSessionView(PolicyAccessView):
|
||||
if not self.flow:
|
||||
raise Http404
|
||||
|
||||
def validate(self):
|
||||
# Parse end session parameters
|
||||
query_dict = self.request.POST if self.request.method == "POST" else self.request.GET
|
||||
state = query_dict.get("state")
|
||||
request_redirect_uri = query_dict.get("post_logout_redirect_uri")
|
||||
id_token_hint = query_dict.get("id_token_hint")
|
||||
self.post_logout_redirect_uri = None
|
||||
|
||||
# OIDC Certification: Verify id_token_hint. If invalid or missing, throw an error
|
||||
if id_token_hint:
|
||||
# Load a fresh provider instance that's not part of the flow
|
||||
# since it'll have the cryptography Certificate that can't be pickled
|
||||
provider = OAuth2Provider.objects.get(pk=self.provider.pk)
|
||||
key, alg = provider.jwt_key
|
||||
if alg != JWTAlgorithms.HS256:
|
||||
key = provider.signing_key.public_key
|
||||
try:
|
||||
jwt_decode(
|
||||
id_token_hint,
|
||||
key,
|
||||
algorithms=[alg],
|
||||
audience=provider.client_id,
|
||||
issuer=provider.get_issuer(self.request),
|
||||
# ID Tokens are short-lived; a logout request arriving
|
||||
# after expiry is still legitimate and must succeed.
|
||||
options={"verify_exp": False},
|
||||
)
|
||||
except PyJWTError:
|
||||
raise TokenError("invalid_request").with_cause(
|
||||
"id_token_hint_decode_failed"
|
||||
) from None
|
||||
|
||||
# Validate post_logout_redirect_uri against registered URIs
|
||||
if request_redirect_uri:
|
||||
# OIDC Certification: id_token_hint required with post_logout_redirect_uri
|
||||
if not id_token_hint:
|
||||
raise TokenError("invalid_request").with_cause("id_token_hint_missing")
|
||||
if urlparse(request_redirect_uri).scheme in FORBIDDEN_URI_SCHEMES:
|
||||
raise RequestValidationError(
|
||||
bad_request_message(
|
||||
self.request,
|
||||
"Forbidden URI scheme in post_logout_redirect_uri",
|
||||
)
|
||||
)
|
||||
raise TokenError("invalid_request").with_cause("post_logout_redirect_uri")
|
||||
for allowed in self.provider.post_logout_redirect_uris:
|
||||
if allowed.matching_mode == RedirectURIMatchingMode.STRICT:
|
||||
if request_redirect_uri == allowed.url:
|
||||
@@ -71,6 +100,10 @@ class EndSessionView(PolicyAccessView):
|
||||
if fullmatch(allowed.url, request_redirect_uri):
|
||||
self.post_logout_redirect_uri = request_redirect_uri
|
||||
break
|
||||
# OIDC Certification: OP MUST NOT perform post-logout redirection
|
||||
# if the supplied URI does not exactly match a registered one
|
||||
if self.post_logout_redirect_uri is None:
|
||||
raise TokenError("invalid_request").with_cause("invalid_post_logout_redirect_uri")
|
||||
|
||||
# Append state to the redirect URI if both are present
|
||||
if self.post_logout_redirect_uri and state:
|
||||
@@ -91,50 +124,43 @@ class EndSessionView(PolicyAccessView):
|
||||
"<html><body>Logout successful</body></html>", content_type="text/html", status=200
|
||||
)
|
||||
|
||||
# Otherwise, continue with normal policy checks
|
||||
return super().dispatch(request, *args, **kwargs)
|
||||
|
||||
def get(self, request: HttpRequest, *args, **kwargs) -> HttpResponse:
|
||||
"""Dispatch the flow planner for the invalidation flow"""
|
||||
try:
|
||||
self.validate()
|
||||
except TokenError as exc:
|
||||
return bad_request_message(
|
||||
self.request,
|
||||
exc.description,
|
||||
)
|
||||
planner = FlowPlanner(self.flow)
|
||||
planner.allow_empty_flows = True
|
||||
|
||||
# Build flow context with logout parameters
|
||||
context = {
|
||||
PLAN_CONTEXT_APPLICATION: self.application,
|
||||
}
|
||||
|
||||
# Get session info for logout notifications and token invalidation
|
||||
auth_session = AuthenticatedSession.from_request(request, request.user)
|
||||
|
||||
# Add validated redirect URI (with state appended) to context if available
|
||||
if self.post_logout_redirect_uri:
|
||||
context[PLAN_CONTEXT_POST_LOGOUT_REDIRECT_URI] = self.post_logout_redirect_uri
|
||||
# Invalidate tokens for this provider/session (RP-initiated logout:
|
||||
# user stays logged into authentik, only this provider's tokens are revoked)
|
||||
if request.user.is_authenticated and auth_session:
|
||||
AccessToken.objects.filter(
|
||||
user=request.user,
|
||||
provider=self.provider,
|
||||
session=auth_session,
|
||||
).delete()
|
||||
|
||||
session_key = (
|
||||
auth_session.session.session_key if auth_session and auth_session.session else None
|
||||
)
|
||||
|
||||
# Handle frontchannel logout
|
||||
frontchannel_logout_url = None
|
||||
if self.provider.logout_method == OAuth2LogoutMethod.FRONTCHANNEL:
|
||||
frontchannel_logout_url = build_frontchannel_logout_url(
|
||||
self.provider, request, session_key
|
||||
)
|
||||
|
||||
# Handle backchannel logout
|
||||
if (
|
||||
self.provider.logout_method == OAuth2LogoutMethod.BACKCHANNEL
|
||||
and self.provider.logout_uri
|
||||
):
|
||||
# Find access token to get iss and sub for the logout token
|
||||
access_token = AccessToken.objects.filter(
|
||||
user=request.user,
|
||||
provider=self.provider,
|
||||
@@ -163,9 +189,16 @@ class EndSessionView(PolicyAccessView):
|
||||
}
|
||||
]
|
||||
|
||||
access_tokens = AccessToken.objects.filter(
|
||||
user=request.user,
|
||||
provider=self.provider,
|
||||
)
|
||||
if auth_session:
|
||||
access_tokens = access_tokens.filter(session=auth_session)
|
||||
access_tokens.delete()
|
||||
|
||||
plan = planner.plan(request, context)
|
||||
|
||||
# Inject iframe logout stage if frontchannel logout is configured
|
||||
if frontchannel_logout_url:
|
||||
plan.insert_stage(in_memory_stage(IframeLogoutStageView))
|
||||
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
"""RAC Signals"""
|
||||
|
||||
from asgiref.sync import async_to_sync
|
||||
from channels.layers import get_channel_layer
|
||||
from django.core.cache import cache
|
||||
from django.db.models.signals import post_delete, post_save, pre_delete
|
||||
@@ -18,7 +17,7 @@ from authentik.providers.rac.models import ConnectionToken, Endpoint
|
||||
@receiver(pre_delete, sender=AuthenticatedSession)
|
||||
def user_session_deleted(sender, instance: AuthenticatedSession, **_):
|
||||
layer = get_channel_layer()
|
||||
async_to_sync(layer.group_send)(
|
||||
layer.group_send_blocking(
|
||||
build_rac_client_group_session(instance.session.session_key),
|
||||
{"type": "event.disconnect", "reason": "session_logout"},
|
||||
)
|
||||
@@ -28,7 +27,7 @@ def user_session_deleted(sender, instance: AuthenticatedSession, **_):
|
||||
def pre_delete_connection_token_disconnect(sender, instance: ConnectionToken, **_):
|
||||
"""Disconnect session when connection token is deleted"""
|
||||
layer = get_channel_layer()
|
||||
async_to_sync(layer.group_send)(
|
||||
layer.group_send_blocking(
|
||||
build_rac_client_group_token(instance.token),
|
||||
{"type": "event.disconnect", "reason": "token_delete"},
|
||||
)
|
||||
|
||||
@@ -24,7 +24,11 @@ from rest_framework.viewsets import ModelViewSet
|
||||
from structlog.stdlib import get_logger
|
||||
|
||||
from authentik.api.validation import validate
|
||||
from authentik.common.saml.constants import SAML_BINDING_POST, SAML_BINDING_REDIRECT
|
||||
from authentik.common.saml.constants import (
|
||||
DEFAULT_ISSUER,
|
||||
SAML_BINDING_POST,
|
||||
SAML_BINDING_REDIRECT,
|
||||
)
|
||||
from authentik.core.api.providers import ProviderSerializer
|
||||
from authentik.core.api.used_by import UsedByMixin
|
||||
from authentik.core.api.utils import PassiveSerializer, PropertyMappingPreviewSerializer
|
||||
@@ -55,6 +59,7 @@ class SAMLProviderSerializer(ProviderSerializer):
|
||||
"""SAMLProvider Serializer"""
|
||||
|
||||
url_download_metadata = SerializerMethodField()
|
||||
url_issuer = SerializerMethodField()
|
||||
|
||||
url_sso_post = SerializerMethodField()
|
||||
url_sso_redirect = SerializerMethodField()
|
||||
@@ -85,6 +90,23 @@ class SAMLProviderSerializer(ProviderSerializer):
|
||||
+ "?download"
|
||||
)
|
||||
|
||||
def get_url_issuer(self, instance: SAMLProvider) -> str:
|
||||
"""Get Issuer/EntityID URL"""
|
||||
if instance.issuer_override:
|
||||
return instance.issuer_override
|
||||
if "request" not in self._context:
|
||||
return DEFAULT_ISSUER
|
||||
request: HttpRequest = self._context["request"]._request
|
||||
try:
|
||||
return request.build_absolute_uri(
|
||||
reverse(
|
||||
"authentik_providers_saml:base",
|
||||
kwargs={"application_slug": instance.application.slug},
|
||||
)
|
||||
)
|
||||
except Provider.application.RelatedObjectDoesNotExist:
|
||||
return DEFAULT_ISSUER
|
||||
|
||||
def get_url_sso_post(self, instance: SAMLProvider) -> str:
|
||||
"""Get SSO Post URL"""
|
||||
if "request" not in self._context:
|
||||
@@ -198,7 +220,7 @@ class SAMLProviderSerializer(ProviderSerializer):
|
||||
"acs_url",
|
||||
"sls_url",
|
||||
"audience",
|
||||
"issuer",
|
||||
"issuer_override",
|
||||
"assertion_valid_not_before",
|
||||
"assertion_valid_not_on_or_after",
|
||||
"session_valid_not_on_or_after",
|
||||
@@ -220,6 +242,7 @@ class SAMLProviderSerializer(ProviderSerializer):
|
||||
"default_relay_state",
|
||||
"default_name_id_policy",
|
||||
"url_download_metadata",
|
||||
"url_issuer",
|
||||
"url_sso_post",
|
||||
"url_sso_redirect",
|
||||
"url_sso_init",
|
||||
|
||||
@@ -0,0 +1,34 @@
|
||||
# Generated by Django 5.2.11 on 2026-02-24 06:03
|
||||
|
||||
from django.db import migrations, models
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
|
||||
dependencies = [
|
||||
("authentik_providers_saml", "0021_samlprovider_sign_logout_response"),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.RenameField(
|
||||
model_name="samlprovider",
|
||||
old_name="issuer",
|
||||
new_name="issuer_override",
|
||||
),
|
||||
migrations.AlterField(
|
||||
model_name="samlprovider",
|
||||
name="issuer_override",
|
||||
field=models.TextField(
|
||||
blank=True,
|
||||
default="",
|
||||
help_text="Also known as EntityID. Providing a value overrides the default issuer generated by authentik.",
|
||||
),
|
||||
),
|
||||
migrations.AddField(
|
||||
model_name="samlsession",
|
||||
name="issuer",
|
||||
field=models.TextField(
|
||||
default=None, help_text="SAML Issuer used for this session", null=True
|
||||
),
|
||||
),
|
||||
]
|
||||
@@ -77,7 +77,14 @@ class SAMLProvider(Provider):
|
||||
"no audience restriction will be added."
|
||||
),
|
||||
)
|
||||
issuer = models.TextField(help_text=_("Also known as EntityID"), default="authentik")
|
||||
issuer_override = models.TextField(
|
||||
blank=True,
|
||||
default="",
|
||||
help_text=_(
|
||||
"Also known as EntityID. Providing a value overrides the default issuer "
|
||||
"generated by authentik."
|
||||
),
|
||||
)
|
||||
sls_url = models.TextField(
|
||||
blank=True,
|
||||
validators=[DomainlessURLValidator(schemes=("http", "https"))],
|
||||
@@ -318,6 +325,9 @@ class SAMLSession(InternallyManagedMixin, SerializerModel, ExpiringModel):
|
||||
session_index = models.TextField(help_text=_("SAML SessionIndex for this session"))
|
||||
name_id = models.TextField(help_text=_("SAML NameID value for this session"))
|
||||
name_id_format = models.TextField(default="", blank=True, help_text=_("SAML NameID format"))
|
||||
issuer = models.TextField(
|
||||
default=None, null=True, help_text=_("SAML Issuer used for this session")
|
||||
)
|
||||
created = models.DateTimeField(auto_now_add=True)
|
||||
|
||||
@property
|
||||
|
||||
@@ -6,6 +6,7 @@ from types import GeneratorType
|
||||
|
||||
import xmlsec
|
||||
from django.http import HttpRequest
|
||||
from django.urls import reverse
|
||||
from django.utils.timezone import now
|
||||
from lxml import etree # nosec
|
||||
from lxml.etree import Element, SubElement, _Element # nosec
|
||||
@@ -63,6 +64,7 @@ class AssertionProcessor:
|
||||
session_index: str
|
||||
name_id: str
|
||||
name_id_format: str
|
||||
issuer: str
|
||||
session_not_on_or_after_datetime: datetime
|
||||
|
||||
def __init__(self, provider: SAMLProvider, request: HttpRequest, auth_n_request: AuthNRequest):
|
||||
@@ -137,10 +139,24 @@ class AssertionProcessor:
|
||||
continue
|
||||
return attribute_statement
|
||||
|
||||
def _get_issuer_value(self) -> str:
|
||||
"""Get issuer value, with fallback to generated URL if empty"""
|
||||
# If user has set an override issuer, use it
|
||||
if self.provider.issuer_override:
|
||||
return self.provider.issuer_override
|
||||
|
||||
return self.http_request.build_absolute_uri(
|
||||
reverse(
|
||||
"authentik_providers_saml:base",
|
||||
kwargs={"application_slug": self.provider.application.slug},
|
||||
)
|
||||
)
|
||||
|
||||
def get_issuer(self) -> Element:
|
||||
"""Get Issuer Element"""
|
||||
issuer = Element(f"{{{NS_SAML_ASSERTION}}}Issuer", nsmap=NS_MAP)
|
||||
issuer.text = self.provider.issuer
|
||||
self.issuer = self._get_issuer_value()
|
||||
issuer.text = self.issuer
|
||||
return issuer
|
||||
|
||||
def get_assertion_auth_n_statement(self) -> Element:
|
||||
|
||||
@@ -8,6 +8,7 @@ from lxml import etree # nosec
|
||||
from lxml.etree import Element, _Element
|
||||
|
||||
from authentik.common.saml.constants import (
|
||||
DEFAULT_ISSUER,
|
||||
DIGEST_ALGORITHM_TRANSLATION_MAP,
|
||||
NS_MAP,
|
||||
NS_SAML_ASSERTION,
|
||||
@@ -33,11 +34,12 @@ class LogoutRequestProcessor:
|
||||
name_id_format: str
|
||||
session_index: str | None
|
||||
relay_state: str | None
|
||||
issuer: str | None
|
||||
|
||||
_issue_instant: str
|
||||
_request_id: str
|
||||
|
||||
def __init__(
|
||||
def __init__( # noqa: PLR0913
|
||||
self,
|
||||
provider: SAMLProvider,
|
||||
user: User | None,
|
||||
@@ -46,6 +48,7 @@ class LogoutRequestProcessor:
|
||||
name_id_format: str = SAML_NAME_ID_FORMAT_EMAIL,
|
||||
session_index: str | None = None,
|
||||
relay_state: str | None = None,
|
||||
issuer: str | None = None,
|
||||
):
|
||||
self.provider = provider
|
||||
self.user = user
|
||||
@@ -54,14 +57,23 @@ class LogoutRequestProcessor:
|
||||
self.name_id_format = name_id_format
|
||||
self.session_index = session_index
|
||||
self.relay_state = relay_state
|
||||
self.issuer = issuer
|
||||
|
||||
self._issue_instant = get_time_string()
|
||||
self._request_id = get_random_id()
|
||||
|
||||
def _get_issuer_value(self) -> str:
|
||||
"""Get issuer value from session, with fallback to provider"""
|
||||
if self.issuer:
|
||||
return self.issuer
|
||||
if self.provider.issuer_override:
|
||||
return self.provider.issuer_override
|
||||
return DEFAULT_ISSUER
|
||||
|
||||
def get_issuer(self) -> Element:
|
||||
"""Get Issuer element"""
|
||||
issuer = Element(f"{{{NS_SAML_ASSERTION}}}Issuer")
|
||||
issuer.text = self.provider.issuer
|
||||
issuer.text = self._get_issuer_value()
|
||||
return issuer
|
||||
|
||||
def get_name_id(self) -> Element:
|
||||
|
||||
@@ -8,6 +8,7 @@ from lxml import etree
|
||||
from lxml.etree import Element, SubElement
|
||||
|
||||
from authentik.common.saml.constants import (
|
||||
DEFAULT_ISSUER,
|
||||
DIGEST_ALGORITHM_TRANSLATION_MAP,
|
||||
NS_MAP,
|
||||
NS_SAML_ASSERTION,
|
||||
@@ -28,27 +29,38 @@ class LogoutResponseProcessor:
|
||||
logout_request: LogoutRequest
|
||||
destination: str | None
|
||||
relay_state: str | None
|
||||
issuer: str | None
|
||||
_issue_instant: str
|
||||
_response_id: str
|
||||
|
||||
def __init__(
|
||||
def __init__( # noqa: PLR0913
|
||||
self,
|
||||
provider: SAMLProvider,
|
||||
logout_request: LogoutRequest,
|
||||
destination: str | None = None,
|
||||
relay_state: str | None = None,
|
||||
issuer: str | None = None,
|
||||
):
|
||||
self.provider = provider
|
||||
self.logout_request = logout_request
|
||||
self.destination = destination
|
||||
self.relay_state = relay_state or (logout_request.relay_state if logout_request else None)
|
||||
self.issuer = issuer
|
||||
self._issue_instant = get_time_string()
|
||||
self._response_id = get_random_id()
|
||||
|
||||
def _get_issuer_value(self) -> str:
|
||||
"""Get issuer value from session, with fallback to provider"""
|
||||
if self.issuer:
|
||||
return self.issuer
|
||||
if self.provider.issuer_override:
|
||||
return self.provider.issuer_override
|
||||
return DEFAULT_ISSUER
|
||||
|
||||
def get_issuer(self) -> Element:
|
||||
"""Get Issuer element"""
|
||||
issuer = Element(f"{{{NS_SAML_ASSERTION}}}Issuer")
|
||||
issuer.text = self.provider.issuer
|
||||
issuer.text = self._get_issuer_value()
|
||||
return issuer
|
||||
|
||||
def build(self, status: str = "Success") -> Element:
|
||||
|
||||
@@ -40,6 +40,19 @@ class MetadataProcessor:
|
||||
self.force_binding = None
|
||||
self.xml_id = "_" + sha256(f"{provider.name}-{provider.pk}".encode("ascii")).hexdigest()
|
||||
|
||||
def _get_issuer_value(self) -> str:
|
||||
"""Get issuer value, with fallback to generated URL if empty"""
|
||||
# If user has set an override issuer, use it
|
||||
if self.provider.issuer_override:
|
||||
return self.provider.issuer_override
|
||||
|
||||
return self.http_request.build_absolute_uri(
|
||||
reverse(
|
||||
"authentik_providers_saml:base",
|
||||
kwargs={"application_slug": self.provider.application.slug},
|
||||
)
|
||||
)
|
||||
|
||||
# Using type unions doesn't work with cython types (which is what lxml is)
|
||||
def get_signing_key_descriptor(self) -> Element | None:
|
||||
"""Get Signing KeyDescriptor, if enabled for the provider"""
|
||||
@@ -189,7 +202,7 @@ class MetadataProcessor:
|
||||
"""Build full EntityDescriptor"""
|
||||
entity_descriptor = Element(f"{{{NS_SAML_METADATA}}}EntityDescriptor", nsmap=NS_MAP)
|
||||
entity_descriptor.attrib["ID"] = self.xml_id
|
||||
entity_descriptor.attrib["entityID"] = self.provider.issuer
|
||||
entity_descriptor.attrib["entityID"] = self._get_issuer_value()
|
||||
|
||||
if self.provider.signing_kp:
|
||||
self._prepare_signature(entity_descriptor)
|
||||
|
||||
@@ -51,7 +51,6 @@ class ServiceProviderMetadata:
|
||||
provider = SAMLProvider.objects.create(
|
||||
name=name, authorization_flow=authorization_flow, invalidation_flow=invalidation_flow
|
||||
)
|
||||
provider.issuer = self.entity_id
|
||||
provider.sp_binding = self.acs_binding
|
||||
provider.acs_url = self.acs_location
|
||||
provider.default_name_id_policy = self.name_id_policy
|
||||
|
||||
@@ -75,6 +75,7 @@ def handle_saml_iframe_pre_user_logout(
|
||||
name_id_format=session.name_id_format,
|
||||
session_index=session.session_index,
|
||||
relay_state=relay_state,
|
||||
issuer=session.issuer,
|
||||
)
|
||||
|
||||
if session.provider.sls_binding == SAMLBindings.POST:
|
||||
@@ -163,6 +164,7 @@ def handle_flow_pre_user_logout(
|
||||
name_id_format=session.name_id_format,
|
||||
session_index=session.session_index,
|
||||
relay_state=relay_state,
|
||||
issuer=session.issuer,
|
||||
)
|
||||
|
||||
if session.provider.sls_binding == SAMLBindings.POST:
|
||||
@@ -224,6 +226,7 @@ def user_session_deleted_saml_logout(sender, instance: AuthenticatedSession, **_
|
||||
name_id=saml_session.name_id,
|
||||
name_id_format=saml_session.name_id_format,
|
||||
session_index=saml_session.session_index,
|
||||
issuer=saml_session.issuer,
|
||||
)
|
||||
|
||||
|
||||
@@ -257,4 +260,5 @@ def user_deactivated_saml_logout(sender, instance: User, **kwargs):
|
||||
name_id=saml_session.name_id,
|
||||
name_id_format=saml_session.name_id_format,
|
||||
session_index=saml_session.session_index,
|
||||
issuer=saml_session.issuer,
|
||||
)
|
||||
|
||||
@@ -22,6 +22,7 @@ def send_saml_logout_request(
|
||||
name_id: str,
|
||||
name_id_format: str,
|
||||
session_index: str,
|
||||
issuer: str,
|
||||
):
|
||||
"""Send SAML LogoutRequest to a Service Provider using session data"""
|
||||
provider = SAMLProvider.objects.filter(pk=provider_pk).first()
|
||||
@@ -47,6 +48,7 @@ def send_saml_logout_request(
|
||||
name_id=name_id,
|
||||
name_id_format=name_id_format,
|
||||
session_index=session_index,
|
||||
issuer=issuer,
|
||||
)
|
||||
|
||||
return send_post_logout_request(provider, processor)
|
||||
@@ -89,6 +91,7 @@ def send_saml_logout_response(
|
||||
sls_url: str,
|
||||
logout_request_id: str | None = None,
|
||||
relay_state: str | None = None,
|
||||
issuer: str | None = None,
|
||||
):
|
||||
"""Send SAML LogoutResponse to a Service Provider using backchannel (server-to-server)"""
|
||||
provider = SAMLProvider.objects.filter(pk=provider_pk).first()
|
||||
@@ -119,6 +122,7 @@ def send_saml_logout_response(
|
||||
logout_request=logout_request,
|
||||
destination=sls_url,
|
||||
relay_state=relay_state,
|
||||
issuer=issuer,
|
||||
)
|
||||
|
||||
encoded_response = processor.encode_post()
|
||||
|
||||
@@ -15,6 +15,7 @@ from authentik.common.saml.constants import (
|
||||
SAML_NAME_ID_FORMAT_EMAIL,
|
||||
SAML_NAME_ID_FORMAT_UNSPECIFIED,
|
||||
)
|
||||
from authentik.core.models import Application
|
||||
from authentik.core.tests.utils import (
|
||||
RequestFactory,
|
||||
create_test_admin_user,
|
||||
@@ -97,6 +98,11 @@ class TestAuthNRequest(TestCase):
|
||||
)
|
||||
self.provider.property_mappings.set(SAMLPropertyMapping.objects.all())
|
||||
self.provider.save()
|
||||
Application.objects.create(
|
||||
name="test-app",
|
||||
slug="test-app",
|
||||
provider=self.provider,
|
||||
)
|
||||
self.source = SAMLSource.objects.create(
|
||||
slug="provider",
|
||||
issuer="authentik",
|
||||
@@ -526,7 +532,7 @@ class TestAuthNRequest(TestCase):
|
||||
authorization_flow=create_test_flow(),
|
||||
acs_url="https://10.120.20.200/saml-sp/SAML2/POST",
|
||||
audience="https://10.120.20.200/saml-sp/SAML2/POST",
|
||||
issuer="https://10.120.20.200/saml-sp/SAML2/POST",
|
||||
issuer_override="https://10.120.20.200/saml-sp/SAML2/POST",
|
||||
signing_kp=static_keypair,
|
||||
verification_kp=static_keypair,
|
||||
)
|
||||
@@ -547,7 +553,7 @@ class TestAuthNRequest(TestCase):
|
||||
"saml/acs/2d737f96-55fb-4035-953e-5e24134eb778"
|
||||
),
|
||||
audience="https://10.120.20.200/saml-sp/SAML2/POST",
|
||||
issuer="https://10.120.20.200/saml-sp/SAML2/POST",
|
||||
issuer_override="https://10.120.20.200/saml-sp/SAML2/POST",
|
||||
signing_kp=create_test_cert(),
|
||||
)
|
||||
parsed_request = AuthNRequestParser(provider).parse(POST_REQUEST)
|
||||
|
||||
@@ -47,7 +47,7 @@ class TestNativeLogoutStageView(TestCase):
|
||||
authorization_flow=self.flow,
|
||||
acs_url="https://sp1.example.com/acs",
|
||||
sls_url="https://sp1.example.com/sls",
|
||||
issuer="https://idp.example.com",
|
||||
issuer_override="https://idp.example.com",
|
||||
sp_binding="redirect",
|
||||
sls_binding="redirect",
|
||||
logout_method=SAMLLogoutMethods.FRONTCHANNEL_NATIVE,
|
||||
@@ -58,7 +58,7 @@ class TestNativeLogoutStageView(TestCase):
|
||||
authorization_flow=self.flow,
|
||||
acs_url="https://sp2.example.com/acs",
|
||||
sls_url="https://sp2.example.com/sls",
|
||||
issuer="https://idp.example.com",
|
||||
issuer_override="https://idp.example.com",
|
||||
sp_binding="post",
|
||||
sls_binding="post",
|
||||
logout_method=SAMLLogoutMethods.FRONTCHANNEL_NATIVE,
|
||||
@@ -218,7 +218,7 @@ class TestIframeLogoutStageView(TestCase):
|
||||
authorization_flow=self.flow,
|
||||
acs_url="https://sp1.example.com/acs",
|
||||
sls_url="https://sp1.example.com/sls",
|
||||
issuer="https://idp.example.com",
|
||||
issuer_override="https://idp.example.com",
|
||||
sp_binding="redirect",
|
||||
sls_binding="redirect",
|
||||
logout_method="frontchannel_iframe",
|
||||
@@ -229,7 +229,7 @@ class TestIframeLogoutStageView(TestCase):
|
||||
authorization_flow=self.flow,
|
||||
acs_url="https://sp2.example.com/acs",
|
||||
sls_url="https://sp2.example.com/sls",
|
||||
issuer="https://idp.example.com",
|
||||
issuer_override="https://idp.example.com",
|
||||
sp_binding="post",
|
||||
sls_binding="post",
|
||||
logout_method="frontchannel_iframe",
|
||||
@@ -372,7 +372,7 @@ class TestIdPLogoutIntegration(FlowTestCase):
|
||||
authorization_flow=self.flow,
|
||||
acs_url="https://sp.example.com/acs",
|
||||
sls_url="https://sp.example.com/sls",
|
||||
issuer="https://idp.example.com",
|
||||
issuer_override="https://idp.example.com",
|
||||
sp_binding="redirect",
|
||||
sls_binding="redirect",
|
||||
signing_kp=self.keypair,
|
||||
|
||||
@@ -28,7 +28,7 @@ class TestLogoutIntegration(TestCase):
|
||||
authorization_flow=self.flow,
|
||||
acs_url="https://sp.example.com/acs",
|
||||
sls_url="https://sp.example.com/sls",
|
||||
issuer="https://idp.example.com",
|
||||
issuer_override="https://idp.example.com",
|
||||
sp_binding="redirect",
|
||||
sls_binding="redirect",
|
||||
signature_algorithm=RSA_SHA256,
|
||||
@@ -57,7 +57,7 @@ class TestLogoutIntegration(TestCase):
|
||||
parsed = self.parser.parse(encoded)
|
||||
|
||||
# Verify all fields match
|
||||
self.assertEqual(parsed.issuer, self.provider.issuer)
|
||||
self.assertEqual(parsed.issuer, self.provider.issuer_override)
|
||||
self.assertEqual(parsed.name_id, "test@example.com")
|
||||
self.assertEqual(parsed.name_id_format, SAML_NAME_ID_FORMAT_EMAIL)
|
||||
self.assertEqual(parsed.session_index, "test-session-123")
|
||||
@@ -72,7 +72,7 @@ class TestLogoutIntegration(TestCase):
|
||||
parsed = self.parser.parse_detached(encoded)
|
||||
|
||||
# Verify all fields match
|
||||
self.assertEqual(parsed.issuer, self.provider.issuer)
|
||||
self.assertEqual(parsed.issuer, self.provider.issuer_override)
|
||||
self.assertEqual(parsed.name_id, "test@example.com")
|
||||
self.assertEqual(parsed.name_id_format, SAML_NAME_ID_FORMAT_EMAIL)
|
||||
self.assertEqual(parsed.session_index, "test-session-123")
|
||||
@@ -106,7 +106,7 @@ class TestLogoutIntegration(TestCase):
|
||||
parsed = parser.parse(encoded)
|
||||
|
||||
# Verify all fields match
|
||||
self.assertEqual(parsed.issuer, self.provider.issuer)
|
||||
self.assertEqual(parsed.issuer, self.provider.issuer_override)
|
||||
self.assertEqual(parsed.name_id, "signed@example.com")
|
||||
self.assertEqual(parsed.name_id_format, SAML_NAME_ID_FORMAT_EMAIL)
|
||||
self.assertEqual(parsed.session_index, "signed-session-456")
|
||||
@@ -125,7 +125,7 @@ class TestLogoutIntegration(TestCase):
|
||||
parsed = self.parser.parse_detached(saml_request)
|
||||
|
||||
# Verify parsing succeeded
|
||||
self.assertEqual(parsed.issuer, self.provider.issuer)
|
||||
self.assertEqual(parsed.issuer, self.provider.issuer_override)
|
||||
self.assertEqual(parsed.name_id, "test@example.com")
|
||||
self.assertEqual(parsed.name_id_format, SAML_NAME_ID_FORMAT_EMAIL)
|
||||
|
||||
@@ -164,7 +164,7 @@ class TestLogoutIntegration(TestCase):
|
||||
|
||||
# Parse the SAMLRequest (unsigned XML)
|
||||
parsed = self.parser.parse_detached(params["SAMLRequest"][0])
|
||||
self.assertEqual(parsed.issuer, self.provider.issuer)
|
||||
self.assertEqual(parsed.issuer, self.provider.issuer_override)
|
||||
|
||||
def test_form_data_can_be_parsed(self):
|
||||
"""Test that form data generates parseable POST request"""
|
||||
@@ -175,7 +175,7 @@ class TestLogoutIntegration(TestCase):
|
||||
parsed = self.parser.parse(form_data["SAMLRequest"])
|
||||
|
||||
# Verify parsing succeeded
|
||||
self.assertEqual(parsed.issuer, self.provider.issuer)
|
||||
self.assertEqual(parsed.issuer, self.provider.issuer_override)
|
||||
self.assertEqual(parsed.name_id, "test@example.com")
|
||||
self.assertEqual(parsed.name_id_format, SAML_NAME_ID_FORMAT_EMAIL)
|
||||
self.assertEqual(parsed.session_index, "test-session-123")
|
||||
@@ -244,4 +244,4 @@ class TestLogoutIntegration(TestCase):
|
||||
|
||||
# But same issuer
|
||||
self.assertEqual(parsed1.issuer, parsed2.issuer)
|
||||
self.assertEqual(parsed1.issuer, self.provider.issuer)
|
||||
self.assertEqual(parsed1.issuer, self.provider.issuer_override)
|
||||
|
||||
@@ -35,7 +35,7 @@ class TestLogoutRequestProcessor(TestCase):
|
||||
authorization_flow=self.flow,
|
||||
acs_url="https://sp.example.com/acs",
|
||||
sls_url="https://sp.example.com/sls",
|
||||
issuer="https://idp.example.com",
|
||||
issuer_override="https://idp.example.com",
|
||||
sp_binding="redirect",
|
||||
sls_binding="redirect",
|
||||
signature_algorithm=RSA_SHA256,
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
"""logout response tests"""
|
||||
|
||||
from defusedxml import ElementTree
|
||||
from django.test import TestCase
|
||||
from django.test import RequestFactory, TestCase
|
||||
|
||||
from authentik.blueprints.tests import apply_blueprint
|
||||
from authentik.common.saml.constants import (
|
||||
@@ -9,10 +9,13 @@ from authentik.common.saml.constants import (
|
||||
NS_SAML_PROTOCOL,
|
||||
NS_SIGNATURE,
|
||||
)
|
||||
from authentik.core.models import Application
|
||||
from authentik.core.tests.utils import create_test_cert, create_test_flow
|
||||
from authentik.lib.generators import generate_id
|
||||
from authentik.providers.saml.models import SAMLPropertyMapping, SAMLProvider
|
||||
from authentik.providers.saml.processors.logout_request_parser import LogoutRequest
|
||||
from authentik.providers.saml.processors.logout_response_processor import LogoutResponseProcessor
|
||||
from authentik.providers.saml.processors.metadata import MetadataProcessor
|
||||
|
||||
|
||||
class TestLogoutResponse(TestCase):
|
||||
@@ -21,6 +24,7 @@ class TestLogoutResponse(TestCase):
|
||||
@apply_blueprint("system/providers-saml.yaml")
|
||||
def setUp(self):
|
||||
cert = create_test_cert()
|
||||
self.factory = RequestFactory()
|
||||
self.provider: SAMLProvider = SAMLProvider.objects.create(
|
||||
authorization_flow=create_test_flow(),
|
||||
acs_url="http://testserver/source/saml/provider/acs/",
|
||||
@@ -30,17 +34,31 @@ class TestLogoutResponse(TestCase):
|
||||
)
|
||||
self.provider.property_mappings.set(SAMLPropertyMapping.objects.all())
|
||||
self.provider.save()
|
||||
self.application = Application.objects.create(
|
||||
name=generate_id(),
|
||||
slug=generate_id(),
|
||||
provider=self.provider,
|
||||
)
|
||||
|
||||
def test_build_response(self):
|
||||
"""Test building a LogoutResponse"""
|
||||
"""Test building a LogoutResponse uses the generated issuer from the assertion"""
|
||||
# Generate the issuer the same way the assertion/metadata processors would
|
||||
request = self.factory.get("/")
|
||||
metadata_processor = MetadataProcessor(self.provider, request)
|
||||
generated_issuer = metadata_processor._get_issuer_value()
|
||||
|
||||
logout_request = LogoutRequest(
|
||||
id="test-request-id",
|
||||
issuer="test-sp",
|
||||
relay_state="test-relay-state",
|
||||
)
|
||||
|
||||
# Pass the generated issuer as if it came from SAMLSession.issuer
|
||||
processor = LogoutResponseProcessor(
|
||||
self.provider, logout_request, destination=self.provider.sls_url
|
||||
self.provider,
|
||||
logout_request,
|
||||
destination=self.provider.sls_url,
|
||||
issuer=generated_issuer,
|
||||
)
|
||||
response_xml = processor.build_response(status="Success")
|
||||
|
||||
@@ -51,9 +69,9 @@ class TestLogoutResponse(TestCase):
|
||||
self.assertEqual(root.attrib["Destination"], self.provider.sls_url)
|
||||
self.assertEqual(root.attrib["InResponseTo"], "test-request-id")
|
||||
|
||||
# Check Issuer
|
||||
# Check Issuer matches the generated issuer from the assertion processor
|
||||
issuer = root.find(f"{{{NS_SAML_ASSERTION}}}Issuer")
|
||||
self.assertEqual(issuer.text, self.provider.issuer)
|
||||
self.assertEqual(issuer.text, generated_issuer)
|
||||
|
||||
# Check Status
|
||||
status = root.find(f".//{{{NS_SAML_PROTOCOL}}}StatusCode")
|
||||
|
||||
@@ -85,7 +85,6 @@ class TestServiceProviderMetadataParser(TestCase):
|
||||
metadata = ServiceProviderMetadataParser().parse(load_fixture("fixtures/simple.xml"))
|
||||
provider = metadata.to_provider("test", self.flow, self.flow)
|
||||
self.assertEqual(provider.acs_url, "http://localhost:8080/saml/acs")
|
||||
self.assertEqual(provider.issuer, "http://localhost:8080/saml/metadata")
|
||||
self.assertEqual(provider.sp_binding, SAMLBindings.POST)
|
||||
self.assertEqual(provider.default_name_id_policy, SAMLNameIDPolicy.EMAIL)
|
||||
self.assertEqual(
|
||||
@@ -99,7 +98,6 @@ class TestServiceProviderMetadataParser(TestCase):
|
||||
metadata = ServiceProviderMetadataParser().parse(load_fixture("fixtures/cert.xml"))
|
||||
provider = metadata.to_provider("test", self.flow, self.flow)
|
||||
self.assertEqual(provider.acs_url, "http://localhost:8080/apps/user_saml/saml/acs")
|
||||
self.assertEqual(provider.issuer, "http://localhost:8080/apps/user_saml/saml/metadata")
|
||||
self.assertEqual(provider.sp_binding, SAMLBindings.POST)
|
||||
self.assertEqual(
|
||||
provider.verification_kp.certificate_data, load_fixture("fixtures/cert.pem")
|
||||
|
||||
@@ -32,7 +32,7 @@ class TestSAMLSessionModel(TestCase):
|
||||
name="test-provider",
|
||||
authorization_flow=self.flow,
|
||||
acs_url="https://sp.example.com/acs",
|
||||
issuer="https://idp.example.com",
|
||||
issuer_override="https://idp.example.com",
|
||||
)
|
||||
|
||||
# Create another provider for testing
|
||||
@@ -40,7 +40,7 @@ class TestSAMLSessionModel(TestCase):
|
||||
name="test-provider-2",
|
||||
authorization_flow=self.flow,
|
||||
acs_url="https://sp2.example.com/acs",
|
||||
issuer="https://idp2.example.com",
|
||||
issuer_override="https://idp2.example.com",
|
||||
)
|
||||
|
||||
# Create a session first (using authentik's custom Session model)
|
||||
@@ -72,6 +72,7 @@ class TestSAMLSessionModel(TestCase):
|
||||
name_id_format=self.name_id_format,
|
||||
expires=self.expires,
|
||||
expiring=True,
|
||||
issuer="authentik",
|
||||
)
|
||||
|
||||
# Verify the session was created
|
||||
@@ -100,6 +101,7 @@ class TestSAMLSessionModel(TestCase):
|
||||
name_id_format=self.name_id_format,
|
||||
expires=self.expires,
|
||||
expiring=True,
|
||||
issuer="authentik",
|
||||
)
|
||||
|
||||
# Try to create another session with same session_index and provider
|
||||
@@ -113,6 +115,7 @@ class TestSAMLSessionModel(TestCase):
|
||||
name_id_format=self.name_id_format,
|
||||
expires=self.expires,
|
||||
expiring=True,
|
||||
issuer="authentik",
|
||||
)
|
||||
|
||||
def test_cascade_deletion_user(self):
|
||||
@@ -127,6 +130,7 @@ class TestSAMLSessionModel(TestCase):
|
||||
name_id_format=self.name_id_format,
|
||||
expires=self.expires,
|
||||
expiring=True,
|
||||
issuer="authentik",
|
||||
)
|
||||
|
||||
# Verify session exists
|
||||
@@ -150,6 +154,7 @@ class TestSAMLSessionModel(TestCase):
|
||||
name_id_format=self.name_id_format,
|
||||
expires=self.expires,
|
||||
expiring=True,
|
||||
issuer="authentik",
|
||||
)
|
||||
|
||||
# Verify session exists
|
||||
@@ -173,6 +178,7 @@ class TestSAMLSessionModel(TestCase):
|
||||
name_id_format=self.name_id_format,
|
||||
expires=self.expires,
|
||||
expiring=True,
|
||||
issuer="authentik",
|
||||
)
|
||||
|
||||
# Verify session exists
|
||||
@@ -196,6 +202,7 @@ class TestSAMLSessionModel(TestCase):
|
||||
name_id_format=self.name_id_format,
|
||||
expires=self.expires,
|
||||
expiring=True,
|
||||
issuer="authentik",
|
||||
)
|
||||
|
||||
# Create second session with different provider
|
||||
@@ -208,6 +215,7 @@ class TestSAMLSessionModel(TestCase):
|
||||
name_id_format=self.name_id_format,
|
||||
expires=self.expires,
|
||||
expiring=True,
|
||||
issuer="authentik",
|
||||
)
|
||||
|
||||
# Verify both sessions exist
|
||||
@@ -229,6 +237,7 @@ class TestSAMLSessionModel(TestCase):
|
||||
name_id_format=self.name_id_format,
|
||||
expires=future_time,
|
||||
expiring=True,
|
||||
issuer="authentik",
|
||||
)
|
||||
|
||||
# Verify expiry time
|
||||
@@ -248,6 +257,7 @@ class TestSAMLSessionModel(TestCase):
|
||||
name_id_format=self.name_id_format,
|
||||
expires=past_time,
|
||||
expiring=True,
|
||||
issuer="authentik",
|
||||
)
|
||||
|
||||
# Check if marked as expired
|
||||
@@ -265,6 +275,7 @@ class TestSAMLSessionModel(TestCase):
|
||||
name_id_format="", # Blank format
|
||||
expires=self.expires,
|
||||
expiring=True,
|
||||
issuer="authentik",
|
||||
)
|
||||
|
||||
# Verify it was created successfully
|
||||
@@ -283,6 +294,7 @@ class TestSAMLSessionModel(TestCase):
|
||||
name_id_format=self.name_id_format,
|
||||
expires=self.expires,
|
||||
expiring=True,
|
||||
issuer="authentik",
|
||||
)
|
||||
|
||||
session2 = SAMLSession.objects.create(
|
||||
@@ -294,6 +306,7 @@ class TestSAMLSessionModel(TestCase):
|
||||
name_id_format=self.name_id_format,
|
||||
expires=self.expires,
|
||||
expiring=True,
|
||||
issuer="authentik",
|
||||
)
|
||||
|
||||
# Query by provider
|
||||
@@ -316,6 +329,7 @@ class TestSAMLSessionModel(TestCase):
|
||||
name_id_format=self.name_id_format,
|
||||
expires=self.expires,
|
||||
expiring=True,
|
||||
issuer="authentik",
|
||||
)
|
||||
|
||||
# Check serializer property
|
||||
@@ -334,6 +348,7 @@ class TestSAMLSessionModel(TestCase):
|
||||
name_id_format=self.name_id_format,
|
||||
expires=self.expires,
|
||||
expiring=True,
|
||||
issuer="authentik",
|
||||
)
|
||||
|
||||
# Verify sessions exist
|
||||
|
||||
@@ -7,6 +7,7 @@ from guardian.shortcuts import get_anonymous_user
|
||||
from lxml import etree # nosec
|
||||
|
||||
from authentik.blueprints.tests import apply_blueprint
|
||||
from authentik.core.models import Application
|
||||
from authentik.core.tests.utils import RequestFactory, create_test_cert, create_test_flow
|
||||
from authentik.lib.xml import lxml_from_string
|
||||
from authentik.providers.saml.models import SAMLPropertyMapping, SAMLProvider
|
||||
@@ -30,6 +31,11 @@ class TestSchema(TestCase):
|
||||
)
|
||||
self.provider.property_mappings.set(SAMLPropertyMapping.objects.all())
|
||||
self.provider.save()
|
||||
Application.objects.create(
|
||||
name="test-app",
|
||||
slug="test-app",
|
||||
provider=self.provider,
|
||||
)
|
||||
self.source = SAMLSource.objects.create(
|
||||
slug="provider",
|
||||
issuer="authentik",
|
||||
|
||||
@@ -28,7 +28,7 @@ class TestSendSamlLogoutResponse(TestCase):
|
||||
authorization_flow=self.flow,
|
||||
acs_url="https://sp.example.com/acs",
|
||||
sls_url="https://sp.example.com/sls",
|
||||
issuer="https://idp.example.com",
|
||||
issuer_override="https://idp.example.com",
|
||||
signing_kp=self.cert,
|
||||
)
|
||||
|
||||
@@ -137,7 +137,7 @@ class TestSendSamlLogoutRequest(TestCase):
|
||||
authorization_flow=self.flow,
|
||||
acs_url="https://sp.example.com/acs",
|
||||
sls_url="https://sp.example.com/sls",
|
||||
issuer="https://idp.example.com",
|
||||
issuer_override="https://idp.example.com",
|
||||
signing_kp=self.cert,
|
||||
)
|
||||
|
||||
@@ -155,6 +155,7 @@ class TestSendSamlLogoutRequest(TestCase):
|
||||
name_id="test@example.com",
|
||||
name_id_format=SAML_NAME_ID_FORMAT_EMAIL,
|
||||
session_index="test-session-123",
|
||||
issuer="https://idp.example.com",
|
||||
)
|
||||
|
||||
self.assertTrue(result)
|
||||
@@ -179,6 +180,7 @@ class TestSendSamlLogoutRequest(TestCase):
|
||||
name_id="test@example.com",
|
||||
name_id_format=SAML_NAME_ID_FORMAT_EMAIL,
|
||||
session_index="test-session-123",
|
||||
issuer="https://idp.example.com",
|
||||
)
|
||||
|
||||
self.assertFalse(result)
|
||||
@@ -198,6 +200,7 @@ class TestSendSamlLogoutRequest(TestCase):
|
||||
name_id="test@example.com",
|
||||
name_id_format=SAML_NAME_ID_FORMAT_EMAIL,
|
||||
session_index="test-session-123",
|
||||
issuer="https://idp.example.com",
|
||||
)
|
||||
|
||||
|
||||
@@ -214,7 +217,7 @@ class TestSendPostLogoutRequest(TestCase):
|
||||
authorization_flow=self.flow,
|
||||
acs_url="https://sp.example.com/acs",
|
||||
sls_url="https://sp.example.com/sls",
|
||||
issuer="https://idp.example.com",
|
||||
issuer_override="https://idp.example.com",
|
||||
signing_kp=self.cert,
|
||||
)
|
||||
|
||||
|
||||
@@ -40,7 +40,7 @@ class TestSPInitiatedSLOViews(TestCase):
|
||||
invalidation_flow=self.invalidation_flow,
|
||||
acs_url="https://sp.example.com/acs",
|
||||
sls_url="https://sp.example.com/sls",
|
||||
issuer="https://idp.example.com",
|
||||
issuer_override="https://idp.example.com",
|
||||
sp_binding="redirect",
|
||||
sls_binding="redirect",
|
||||
)
|
||||
@@ -90,7 +90,7 @@ class TestSPInitiatedSLOViews(TestCase):
|
||||
# Verify logout request was stored in plan context
|
||||
self.assertIn("authentik/providers/saml/logout_request", view.plan_context)
|
||||
logout_request = view.plan_context["authentik/providers/saml/logout_request"]
|
||||
self.assertEqual(logout_request.issuer, self.provider.issuer)
|
||||
self.assertEqual(logout_request.issuer, self.provider.issuer_override)
|
||||
self.assertEqual(logout_request.session_index, "test-session-123")
|
||||
|
||||
def test_redirect_view_handles_logout_response_with_plan_context(self):
|
||||
@@ -228,7 +228,7 @@ class TestSPInitiatedSLOViews(TestCase):
|
||||
# Verify logout request was stored in plan context
|
||||
self.assertIn("authentik/providers/saml/logout_request", view.plan_context)
|
||||
logout_request = view.plan_context["authentik/providers/saml/logout_request"]
|
||||
self.assertEqual(logout_request.issuer, self.provider.issuer)
|
||||
self.assertEqual(logout_request.issuer, self.provider.issuer_override)
|
||||
self.assertEqual(logout_request.session_index, "test-session-123")
|
||||
|
||||
def test_post_view_handles_logout_response_with_plan_context(self):
|
||||
@@ -396,7 +396,7 @@ class TestSPInitiatedSLOViews(TestCase):
|
||||
authorization_flow=self.flow,
|
||||
acs_url="https://sp2.example.com/acs",
|
||||
sls_url="https://sp2.example.com/sls",
|
||||
issuer="https://idp2.example.com",
|
||||
issuer_override="https://idp2.example.com",
|
||||
invalidation_flow=None, # No invalidation flow
|
||||
)
|
||||
|
||||
@@ -524,7 +524,7 @@ class TestSPInitiatedSLOLogoutMethods(TestCase):
|
||||
invalidation_flow=self.invalidation_flow,
|
||||
acs_url="https://sp.example.com/acs",
|
||||
sls_url="https://sp.example.com/sls",
|
||||
issuer="https://idp.example.com",
|
||||
issuer_override="https://idp.example.com",
|
||||
sp_binding="redirect",
|
||||
sls_binding="redirect",
|
||||
signing_kp=self.cert,
|
||||
@@ -714,7 +714,7 @@ class TestSPInitiatedSLOLogoutMethods(TestCase):
|
||||
invalidation_flow=self.invalidation_flow,
|
||||
acs_url="https://sp.example.com/acs",
|
||||
sls_url="", # No SLS URL
|
||||
issuer="https://idp.example.com",
|
||||
issuer_override="https://idp.example.com",
|
||||
)
|
||||
|
||||
app_no_sls = Application.objects.create(
|
||||
|
||||
@@ -11,6 +11,12 @@ from authentik.providers.saml.views.sp_slo import (
|
||||
)
|
||||
|
||||
urlpatterns = [
|
||||
# Base path for Issuer/Entity ID
|
||||
path(
|
||||
"<slug:application_slug>/",
|
||||
sso.SAMLSSOBindingRedirectView.as_view(),
|
||||
name="base",
|
||||
),
|
||||
# SSO Bindings
|
||||
path(
|
||||
"<slug:application_slug>/sso/binding/redirect/",
|
||||
|
||||
@@ -81,6 +81,7 @@ class SAMLFlowFinalView(ChallengeStageView):
|
||||
"session": auth_session,
|
||||
"name_id": processor.name_id,
|
||||
"name_id_format": processor.name_id_format,
|
||||
"issuer": processor.issuer,
|
||||
"expires": processor.session_not_on_or_after_datetime,
|
||||
"expiring": True,
|
||||
},
|
||||
|
||||
@@ -107,12 +107,25 @@ class SPInitiatedSLOView(PolicyAccessView):
|
||||
# Store relay state for the logout response
|
||||
plan.context[PLAN_CONTEXT_SAML_RELAY_STATE] = relay_state
|
||||
|
||||
# Look up the session issuer to use in the logout response
|
||||
auth_session = AuthenticatedSession.from_request(request, request.user)
|
||||
session_issuer = None
|
||||
if auth_session:
|
||||
saml_session = SAMLSession.objects.filter(
|
||||
session=auth_session,
|
||||
user=request.user,
|
||||
provider=self.provider,
|
||||
).first()
|
||||
if saml_session:
|
||||
session_issuer = saml_session.issuer
|
||||
|
||||
if self.provider.logout_method == SAMLLogoutMethods.FRONTCHANNEL_NATIVE:
|
||||
# Native mode - user will be redirected/posted away from authentik
|
||||
processor = LogoutResponseProcessor(
|
||||
self.provider,
|
||||
logout_request,
|
||||
destination=self.provider.sls_url,
|
||||
issuer=session_issuer,
|
||||
)
|
||||
|
||||
if self.provider.sls_binding == SAMLBindings.POST:
|
||||
@@ -152,6 +165,7 @@ class SPInitiatedSLOView(PolicyAccessView):
|
||||
sls_url=self.provider.sls_url,
|
||||
logout_request_id=logout_request.id if logout_request else None,
|
||||
relay_state=relay_state,
|
||||
issuer=session_issuer,
|
||||
)
|
||||
|
||||
LOGGER.debug(
|
||||
@@ -168,6 +182,7 @@ class SPInitiatedSLOView(PolicyAccessView):
|
||||
self.provider,
|
||||
logout_request,
|
||||
destination=self.provider.sls_url,
|
||||
issuer=session_issuer,
|
||||
)
|
||||
|
||||
logout_response = processor.build_response()
|
||||
|
||||
@@ -6,6 +6,7 @@ from django.db import migrations, models
|
||||
class Migration(migrations.Migration):
|
||||
|
||||
dependencies = [
|
||||
("authentik_core", "0056_user_roles"), # must run before group field is removed
|
||||
("authentik_rbac", "0009_remove_initialpermissions_mode"),
|
||||
]
|
||||
|
||||
|
||||
@@ -172,6 +172,7 @@ SPECTACULAR_SETTINGS = {
|
||||
},
|
||||
"ENUM_NAME_OVERRIDES": {
|
||||
"AppEnum": "authentik.lib.api.Apps",
|
||||
"AuthenticationEnum": "authentik.flows.models.FlowAuthenticationRequirement",
|
||||
"ConsentModeEnum": "authentik.stages.consent.models.ConsentMode",
|
||||
"CountryCodeEnum": "django_countries.countries",
|
||||
"DeviceClassesEnum": "authentik.stages.authenticator_validate.models.DeviceClasses",
|
||||
@@ -186,6 +187,7 @@ SPECTACULAR_SETTINGS = {
|
||||
"PolicyEngineMode": "authentik.policies.models.PolicyEngineMode",
|
||||
"PromptTypeEnum": "authentik.stages.prompt.models.FieldTypes",
|
||||
"ProxyMode": "authentik.providers.proxy.models.ProxyMode",
|
||||
"RedirectURITypeEnum": "authentik.providers.oauth2.models.RedirectURIType",
|
||||
"SAMLBindingsEnum": "authentik.providers.saml.models.SAMLBindings",
|
||||
"SAMLLogoutMethods": "authentik.providers.saml.models.SAMLLogoutMethods",
|
||||
"SAMLNameIDPolicyEnum": "authentik.sources.saml.models.SAMLNameIDPolicy",
|
||||
|
||||
File diff suppressed because one or more lines are too long
64
authentik/stages/prompt/migrations/0012_alter_prompt_type.py
Normal file
64
authentik/stages/prompt/migrations/0012_alter_prompt_type.py
Normal file
@@ -0,0 +1,64 @@
|
||||
# Generated by Django 5.2.12 on 2026-03-14 02:58
|
||||
|
||||
from django.db import migrations, models
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
|
||||
dependencies = [
|
||||
(
|
||||
"authentik_stages_prompt",
|
||||
"0011_prompt_initial_value_prompt_initial_value_expression_and_more",
|
||||
),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.AlterField(
|
||||
model_name="prompt",
|
||||
name="type",
|
||||
field=models.CharField(
|
||||
choices=[
|
||||
("text", "Text: Simple Text input"),
|
||||
("text_area", "Text area: Multiline Text Input."),
|
||||
(
|
||||
"text_read_only",
|
||||
"Text (read-only): Simple Text input, but cannot be edited.",
|
||||
),
|
||||
(
|
||||
"text_area_read_only",
|
||||
"Text area (read-only): Multiline Text input, but cannot be edited.",
|
||||
),
|
||||
(
|
||||
"username",
|
||||
"Username: Same as Text input, but checks for and prevents duplicate usernames.",
|
||||
),
|
||||
("email", "Email: Text field with Email type."),
|
||||
(
|
||||
"password",
|
||||
"Password: Masked input, multiple inputs of this type on the same prompt need to be identical.",
|
||||
),
|
||||
("number", "Number"),
|
||||
("checkbox", "Checkbox"),
|
||||
(
|
||||
"radio-button-group",
|
||||
"Fixed choice field rendered as a group of radio buttons.",
|
||||
),
|
||||
("dropdown", "Fixed choice field rendered as a dropdown."),
|
||||
("date", "Date"),
|
||||
("date-time", "Date Time"),
|
||||
(
|
||||
"file",
|
||||
"File: File upload for arbitrary files. File content will be available in flow context as data-URI",
|
||||
),
|
||||
("separator", "Separator: Static Separator Line"),
|
||||
("hidden", "Hidden: Hidden field, can be used to insert data into form."),
|
||||
("static", "Static: Static value, displayed as-is."),
|
||||
("alert_info", "Alert (Info): Static alert box with info styling"),
|
||||
("alert_warning", "Alert (Warning): Static alert box with warning styling"),
|
||||
("alert_danger", "Alert (Danger): Static alert box with danger styling"),
|
||||
("ak-locale", "authentik: Selection of locales authentik supports"),
|
||||
],
|
||||
max_length=100,
|
||||
),
|
||||
),
|
||||
]
|
||||
@@ -87,6 +87,11 @@ class FieldTypes(models.TextChoices):
|
||||
HIDDEN = "hidden", _("Hidden: Hidden field, can be used to insert data into form.")
|
||||
STATIC = "static", _("Static: Static value, displayed as-is.")
|
||||
|
||||
# Alert box types for displaying styled messages
|
||||
ALERT_INFO = "alert_info", _("Alert (Info): Static alert box with info styling")
|
||||
ALERT_WARNING = "alert_warning", _("Alert (Warning): Static alert box with warning styling")
|
||||
ALERT_DANGER = "alert_danger", _("Alert (Danger): Static alert box with danger styling")
|
||||
|
||||
AK_LOCALE = "ak-locale", _("authentik: Selection of locales authentik supports")
|
||||
|
||||
|
||||
@@ -299,7 +304,12 @@ class Prompt(SerializerModel):
|
||||
field_class = HiddenField
|
||||
kwargs["required"] = False
|
||||
kwargs["default"] = self.placeholder
|
||||
case FieldTypes.STATIC:
|
||||
case (
|
||||
FieldTypes.STATIC
|
||||
| FieldTypes.ALERT_INFO
|
||||
| FieldTypes.ALERT_WARNING
|
||||
| FieldTypes.ALERT_DANGER
|
||||
):
|
||||
kwargs["default"] = self.placeholder
|
||||
kwargs["required"] = False
|
||||
kwargs["label"] = ""
|
||||
|
||||
@@ -124,6 +124,9 @@ class PromptChallengeResponse(ChallengeResponse):
|
||||
type__in=[
|
||||
FieldTypes.HIDDEN,
|
||||
FieldTypes.STATIC,
|
||||
FieldTypes.ALERT_INFO,
|
||||
FieldTypes.ALERT_WARNING,
|
||||
FieldTypes.ALERT_DANGER,
|
||||
FieldTypes.TEXT_READ_ONLY,
|
||||
FieldTypes.TEXT_AREA_READ_ONLY,
|
||||
]
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user