mirror of
https://github.com/goauthentik/authentik
synced 2026-04-25 17:15:26 +02:00
Compare commits
180 Commits
hide_apps
...
version-20
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
99a93fa8a2 | ||
|
|
bd2a0e1d7d | ||
|
|
c4d455dd3a | ||
|
|
508dba6a04 | ||
|
|
aa921dcdca | ||
|
|
e5d873c129 | ||
|
|
f0a14d380f | ||
|
|
1da15a549e | ||
|
|
eaf1c45ea6 | ||
|
|
f0f42668c4 | ||
|
|
123fbd26bb | ||
|
|
b94d93b6c4 | ||
|
|
d0b25bf648 | ||
|
|
d4db4e50b4 | ||
|
|
c5e726d7eb | ||
|
|
203a7e0c61 | ||
|
|
2feaeff5db | ||
|
|
8fcc47e047 | ||
|
|
7a6408cc67 | ||
|
|
2da88028da | ||
|
|
fa91404895 | ||
|
|
460fce7279 | ||
|
|
995128955c | ||
|
|
85536abbcf | ||
|
|
5249546862 | ||
|
|
bf91348c05 | ||
|
|
63136f0180 | ||
|
|
faffabf938 | ||
|
|
0b180b15a2 | ||
|
|
07af6de74f | ||
|
|
ddfef91ea5 | ||
|
|
cefbf5e6ae | ||
|
|
e53d3d2486 | ||
|
|
32a3eed521 | ||
|
|
f05cc6e75a | ||
|
|
c68c36fdeb | ||
|
|
888f969fc7 | ||
|
|
82535e4671 | ||
|
|
ed2957e4e6 | ||
|
|
a5abe85148 | ||
|
|
8d2c31fa25 | ||
|
|
2637ce2474 | ||
|
|
319008dec8 | ||
|
|
8beb2fac18 | ||
|
|
ac7b28d0b0 | ||
|
|
073acf92c2 | ||
|
|
ad107c19af | ||
|
|
d285fcd8a7 | ||
|
|
84066cab48 | ||
|
|
e623d93ff5 | ||
|
|
1d0628dfbe | ||
|
|
996645105c | ||
|
|
63d7ca6ef0 | ||
|
|
5b24f4ad80 | ||
|
|
ed2e6cfb9c | ||
|
|
a1431ea48e | ||
|
|
b30e77b363 | ||
|
|
2f50cdd9fe | ||
|
|
494bdcaa09 | ||
|
|
e36ce1789e | ||
|
|
5a72ed83e0 | ||
|
|
f72d257e43 | ||
|
|
cbedb16cc4 | ||
|
|
6fc1b5ce90 | ||
|
|
57b0fa48c1 | ||
|
|
84a344ed87 | ||
|
|
f864cb56ab | ||
|
|
692735f9e1 | ||
|
|
e24fb300b1 | ||
|
|
f0e90d6873 | ||
|
|
0cf45835a0 | ||
|
|
69d35c1d26 | ||
|
|
ac803b210d | ||
|
|
c9728b4607 | ||
|
|
6e45584563 | ||
|
|
59a2e84b35 | ||
|
|
6025dbb9c9 | ||
|
|
d07bcd5025 | ||
|
|
e80655d285 | ||
|
|
e0d3d4d38c | ||
|
|
62112404ee | ||
|
|
1c9e12fcd9 | ||
|
|
42c6c257ec | ||
|
|
41bd9d7913 | ||
|
|
2c84935732 | ||
|
|
819c13a9bc | ||
|
|
0d8f366af8 | ||
|
|
093e60c753 | ||
|
|
af646f32d2 | ||
|
|
de4afc7322 | ||
|
|
bc1983106f | ||
|
|
8c2c1474f1 | ||
|
|
0dccbd4193 | ||
|
|
6a70894e01 | ||
|
|
2f5eb9b2e4 | ||
|
|
12aedb3a9e | ||
|
|
303dc93514 | ||
|
|
fbb217db57 | ||
|
|
4de253653f | ||
|
|
4154c06831 | ||
|
|
4750ed5e2a | ||
|
|
361017127d | ||
|
|
0ca5a54307 | ||
|
|
ef1aad5dbb | ||
|
|
29d880920e | ||
|
|
fc6f8374e6 | ||
|
|
a8668bbac4 | ||
|
|
d686932166 | ||
|
|
feceb220b1 | ||
|
|
937df6e07f | ||
|
|
48e6b968a6 | ||
|
|
cd89c45e75 | ||
|
|
e53995e2c1 | ||
|
|
33d5f11f0e | ||
|
|
565e16eca7 | ||
|
|
9a0164b722 | ||
|
|
8af491630b | ||
|
|
8e25e7a213 | ||
|
|
4d183657da | ||
|
|
be89b6052d | ||
|
|
ad5d2bb611 | ||
|
|
8d30fb3d25 | ||
|
|
cea3fbfa9b | ||
|
|
151d889ff4 | ||
|
|
58ca3ecbd5 | ||
|
|
1a6c7082a3 | ||
|
|
1dc60276f9 | ||
|
|
de045c6d7b | ||
|
|
850728e9bb | ||
|
|
84a605a4ba | ||
|
|
1780bb0cf0 | ||
|
|
cd75fe235d | ||
|
|
e6e62e9de1 | ||
|
|
ac7a4f8a22 | ||
|
|
0290ed3342 | ||
|
|
e367525794 | ||
|
|
93c319baee | ||
|
|
1d02ee7d74 | ||
|
|
93439b5742 | ||
|
|
6682a6664e | ||
|
|
0b5bac74e9 | ||
|
|
062823f1b2 | ||
|
|
a17fe58971 | ||
|
|
422ea893b1 | ||
|
|
15c9f93851 | ||
|
|
e2202d498b | ||
|
|
9ea9a86ad3 | ||
|
|
4bac1edd61 | ||
|
|
24726be3c9 | ||
|
|
411f06756f | ||
|
|
4bdcab48c3 | ||
|
|
00dbd377a7 | ||
|
|
a01c0575db | ||
|
|
6e51d044bb | ||
|
|
6d1b168dc4 | ||
|
|
43675c2b22 | ||
|
|
8645273eaf | ||
|
|
eb6f4712fe | ||
|
|
7b9505242e | ||
|
|
3dda20ebc7 | ||
|
|
dfd2bc5c3c | ||
|
|
06a270913c | ||
|
|
430507fc72 | ||
|
|
847af7f9ea | ||
|
|
8f1cb636e8 | ||
|
|
e61c876002 | ||
|
|
33c0d3df0a | ||
|
|
3a03e1ebfd | ||
|
|
1e41b77761 | ||
|
|
6c1662f99f | ||
|
|
bb5bc5c8da | ||
|
|
30670c9070 | ||
|
|
fdbf9ffedc | ||
|
|
2ec433d724 | ||
|
|
55297b9e6a | ||
|
|
f9dda6582c | ||
|
|
3394c17bfd | ||
|
|
a37d101b10 | ||
|
|
4774b4db87 | ||
|
|
fdb52c9394 |
23
.github/actions/cherry-pick/action.yml
vendored
23
.github/actions/cherry-pick/action.yml
vendored
@@ -115,20 +115,13 @@ runs:
|
||||
shell: bash
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ inputs.token }}
|
||||
PR_NUMBER: ${{ steps.should_run.outputs.pr_number }}
|
||||
REASON: ${{ steps.should_run.outputs.reason }}
|
||||
run: |
|
||||
set -e -o pipefail
|
||||
PR_NUMBER="${{ steps.should_run.outputs.pr_number }}"
|
||||
|
||||
# Get PR details
|
||||
PR_DATA=$(gh api repos/${{ github.repository }}/pulls/$PR_NUMBER)
|
||||
PR_TITLE=$(echo "$PR_DATA" | jq -r '.title')
|
||||
PR_AUTHOR=$(echo "$PR_DATA" | jq -r '.user.login')
|
||||
|
||||
echo "pr_title=$PR_TITLE" >> $GITHUB_OUTPUT
|
||||
echo "pr_author=$PR_AUTHOR" >> $GITHUB_OUTPUT
|
||||
|
||||
# Determine which labels to process
|
||||
if [ "${{ steps.should_run.outputs.reason }}" = "label_added_to_merged_pr" ]; then
|
||||
if [ "${REASON}" = "label_added_to_merged_pr" ]; then
|
||||
# Only process the specific label that was just added
|
||||
if [ "${{ github.event_name }}" = "issues" ]; then
|
||||
LABEL_NAME="${{ github.event.label.name }}"
|
||||
@@ -152,13 +145,13 @@ runs:
|
||||
shell: bash
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ inputs.token }}
|
||||
PR_NUMBER: '${{ steps.should_run.outputs.pr_number }}'
|
||||
COMMIT_SHA: '${{ steps.should_run.outputs.merge_commit_sha }}'
|
||||
PR_TITLE: ${{ github.event.pull_request.title }}
|
||||
PR_AUTHOR: ${{ github.event.pull_request.user.login }}
|
||||
LABELS: '${{ steps.pr_details.outputs.labels }}'
|
||||
run: |
|
||||
set -e -o pipefail
|
||||
PR_NUMBER='${{ steps.should_run.outputs.pr_number }}'
|
||||
COMMIT_SHA='${{ steps.should_run.outputs.merge_commit_sha }}'
|
||||
PR_TITLE='${{ steps.pr_details.outputs.pr_title }}'
|
||||
PR_AUTHOR='${{ steps.pr_details.outputs.pr_author }}'
|
||||
LABELS='${{ steps.pr_details.outputs.labels }}'
|
||||
|
||||
echo "Processing PR #$PR_NUMBER (reason: ${{ steps.should_run.outputs.reason }})"
|
||||
echo "Found backport labels: $LABELS"
|
||||
|
||||
@@ -89,6 +89,8 @@ if should_push:
|
||||
_cache_tag = "buildcache"
|
||||
if image_arch:
|
||||
_cache_tag += f"-{image_arch}"
|
||||
if is_release:
|
||||
_cache_tag += f"-{version_family}"
|
||||
cache_to = f"type=registry,ref={get_attest_image_names(image_tags)}:{_cache_tag},mode=max"
|
||||
|
||||
|
||||
|
||||
38
.github/actions/setup/action.yml
vendored
38
.github/actions/setup/action.yml
vendored
@@ -8,18 +8,33 @@ inputs:
|
||||
postgresql_version:
|
||||
description: "Optional postgresql image tag"
|
||||
default: "16"
|
||||
working-directory:
|
||||
description: |
|
||||
Optional working directory if this repo isn't in the root of the actions workspace.
|
||||
When set, needs to contain a trailing slash
|
||||
default: ""
|
||||
|
||||
runs:
|
||||
using: "composite"
|
||||
steps:
|
||||
- name: Install apt deps & cleanup
|
||||
- name: Cleanup apt
|
||||
if: ${{ contains(inputs.dependencies, 'system') || contains(inputs.dependencies, 'python') }}
|
||||
shell: bash
|
||||
run: sudo apt-get remove --purge man-db
|
||||
- name: Install apt deps
|
||||
if: ${{ contains(inputs.dependencies, 'system') || contains(inputs.dependencies, 'python') }}
|
||||
uses: gerlero/apt-install@f4fa5265092af9e750549565d28c99aec7189639
|
||||
with:
|
||||
packages: libpq-dev openssl libxmlsec1-dev pkg-config gettext krb5-multidev libkrb5-dev heimdal-multidev libclang-dev krb5-kdc krb5-user krb5-admin-server
|
||||
update: true
|
||||
upgrade: false
|
||||
install-recommends: false
|
||||
- name: Make space on disk
|
||||
if: ${{ contains(inputs.dependencies, 'system') || contains(inputs.dependencies, 'python') }}
|
||||
shell: bash
|
||||
run: |
|
||||
sudo apt-get remove --purge man-db
|
||||
sudo apt-get update
|
||||
sudo apt-get install --no-install-recommends -y libpq-dev openssl libxmlsec1-dev pkg-config gettext krb5-multidev libkrb5-dev heimdal-multidev libclang-dev krb5-kdc krb5-user krb5-admin-server
|
||||
sudo rm -rf /usr/local/lib/android
|
||||
sudo mkdir -p /tmp/empty/
|
||||
sudo rsync -a --delete /tmp/empty/ /usr/local/lib/android/
|
||||
- name: Install uv
|
||||
if: ${{ contains(inputs.dependencies, 'python') }}
|
||||
uses: astral-sh/setup-uv@eac588ad8def6316056a12d4907a9d4d84ff7a3b # v5
|
||||
@@ -29,24 +44,25 @@ runs:
|
||||
if: ${{ contains(inputs.dependencies, 'python') }}
|
||||
uses: actions/setup-python@a309ff8b426b58ec0e2a45f0f869d46889d02405 # v5
|
||||
with:
|
||||
python-version-file: "pyproject.toml"
|
||||
python-version-file: "${{ inputs.working-directory }}pyproject.toml"
|
||||
- name: Install Python deps
|
||||
if: ${{ contains(inputs.dependencies, 'python') }}
|
||||
shell: bash
|
||||
working-directory: ${{ inputs.working-directory }}
|
||||
run: uv sync --all-extras --dev --frozen
|
||||
- name: Setup node
|
||||
if: ${{ contains(inputs.dependencies, 'node') }}
|
||||
uses: actions/setup-node@6044e13b5dc448c55e2357c09f80417699197238 # v4
|
||||
with:
|
||||
node-version-file: web/package.json
|
||||
node-version-file: ${{ inputs.working-directory }}web/package.json
|
||||
cache: "npm"
|
||||
cache-dependency-path: web/package-lock.json
|
||||
cache-dependency-path: ${{ inputs.working-directory }}web/package-lock.json
|
||||
registry-url: 'https://registry.npmjs.org'
|
||||
- name: Setup go
|
||||
if: ${{ contains(inputs.dependencies, 'go') }}
|
||||
uses: actions/setup-go@7a3fe6cf4cb3a834922a1244abfce67bcef6a0c5 # v5
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
go-version-file: "${{ inputs.working-directory }}go.mod"
|
||||
- name: Setup docker cache
|
||||
if: ${{ contains(inputs.dependencies, 'runtime') }}
|
||||
uses: AndreKurait/docker-cache@0fe76702a40db986d9663c24954fc14c6a6031b7
|
||||
@@ -55,13 +71,15 @@ runs:
|
||||
- name: Setup dependencies
|
||||
if: ${{ contains(inputs.dependencies, 'runtime') }}
|
||||
shell: bash
|
||||
working-directory: ${{ inputs.working-directory }}
|
||||
run: |
|
||||
export PSQL_TAG=${{ inputs.postgresql_version }}
|
||||
docker compose -f .github/actions/setup/compose.yml up -d
|
||||
cd web && npm i
|
||||
cd web && npm ci
|
||||
- name: Generate config
|
||||
if: ${{ contains(inputs.dependencies, 'python') }}
|
||||
shell: uv run python {0}
|
||||
working-directory: ${{ inputs.working-directory }}
|
||||
run: |
|
||||
from authentik.lib.generators import generate_id
|
||||
from yaml import safe_dump
|
||||
|
||||
2
.github/actions/setup/compose.yml
vendored
2
.github/actions/setup/compose.yml
vendored
@@ -2,7 +2,7 @@ services:
|
||||
postgresql:
|
||||
image: docker.io/library/postgres:${PSQL_TAG:-16}
|
||||
volumes:
|
||||
- db-data:/var/lib/postgresql/data
|
||||
- db-data:/var/lib/postgresql
|
||||
command: "-c log_statement=all"
|
||||
environment:
|
||||
POSTGRES_USER: authentik
|
||||
|
||||
2
.github/workflows/api-ts-publish.yml
vendored
2
.github/workflows/api-ts-publish.yml
vendored
@@ -21,7 +21,7 @@ jobs:
|
||||
uses: actions/create-github-app-token@29824e69f54612133e76f7eaac726eef6c875baf # v2
|
||||
with:
|
||||
app-id: ${{ secrets.GH_APP_ID }}
|
||||
private-key: ${{ secrets.GH_APP_PRIVATE_KEY }}
|
||||
private-key: ${{ secrets.GH_APP_PRIV_KEY }}
|
||||
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v5
|
||||
with:
|
||||
token: ${{ steps.generate_token.outputs.token }}
|
||||
|
||||
5
.github/workflows/ci-main.yml
vendored
5
.github/workflows/ci-main.yml
vendored
@@ -95,7 +95,10 @@ jobs:
|
||||
with:
|
||||
postgresql_version: ${{ matrix.psql }}
|
||||
- name: run migrations to stable
|
||||
run: uv run python -m lifecycle.migrate
|
||||
run: |
|
||||
docker ps
|
||||
docker logs setup-postgresql-1
|
||||
uv run python -m lifecycle.migrate
|
||||
- name: checkout current code
|
||||
run: |
|
||||
set -x
|
||||
|
||||
2
.github/workflows/gen-image-compress.yml
vendored
2
.github/workflows/gen-image-compress.yml
vendored
@@ -32,7 +32,7 @@ jobs:
|
||||
uses: actions/create-github-app-token@29824e69f54612133e76f7eaac726eef6c875baf # v2
|
||||
with:
|
||||
app-id: ${{ secrets.GH_APP_ID }}
|
||||
private-key: ${{ secrets.GH_APP_PRIVATE_KEY }}
|
||||
private-key: ${{ secrets.GH_APP_PRIV_KEY }}
|
||||
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v5
|
||||
with:
|
||||
token: ${{ steps.generate_token.outputs.token }}
|
||||
|
||||
@@ -19,7 +19,7 @@ jobs:
|
||||
uses: actions/create-github-app-token@29824e69f54612133e76f7eaac726eef6c875baf # v2
|
||||
with:
|
||||
app-id: ${{ secrets.GH_APP_ID }}
|
||||
private-key: ${{ secrets.GH_APP_PRIVATE_KEY }}
|
||||
private-key: ${{ secrets.GH_APP_PRIV_KEY }}
|
||||
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v5
|
||||
with:
|
||||
token: ${{ steps.generate_token.outputs.token }}
|
||||
|
||||
2
.github/workflows/gh-cherry-pick.yml
vendored
2
.github/workflows/gh-cherry-pick.yml
vendored
@@ -14,7 +14,7 @@ jobs:
|
||||
if: ${{ env.GH_APP_ID != '' }}
|
||||
with:
|
||||
app-id: ${{ secrets.GH_APP_ID }}
|
||||
private-key: ${{ secrets.GH_APP_PRIVATE_KEY }}
|
||||
private-key: ${{ secrets.GH_APP_PRIV_KEY }}
|
||||
env:
|
||||
GH_APP_ID: ${{ secrets.GH_APP_ID }}
|
||||
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v5
|
||||
|
||||
2
.github/workflows/gh-ghcr-retention.yml
vendored
2
.github/workflows/gh-ghcr-retention.yml
vendored
@@ -19,7 +19,7 @@ jobs:
|
||||
uses: actions/create-github-app-token@29824e69f54612133e76f7eaac726eef6c875baf # v2
|
||||
with:
|
||||
app-id: ${{ secrets.GH_APP_ID }}
|
||||
private-key: ${{ secrets.GH_APP_PRIVATE_KEY }}
|
||||
private-key: ${{ secrets.GH_APP_PRIV_KEY }}
|
||||
- name: Delete 'dev' containers older than a week
|
||||
uses: snok/container-retention-policy@3b0972b2276b171b212f8c4efbca59ebba26eceb # v3.0.1
|
||||
with:
|
||||
|
||||
4
.github/workflows/release-branch-off.yml
vendored
4
.github/workflows/release-branch-off.yml
vendored
@@ -32,7 +32,7 @@ jobs:
|
||||
uses: actions/create-github-app-token@29824e69f54612133e76f7eaac726eef6c875baf # v2
|
||||
with:
|
||||
app-id: ${{ secrets.GH_APP_ID }}
|
||||
private-key: ${{ secrets.GH_APP_PRIVATE_KEY }}
|
||||
private-key: ${{ secrets.GH_APP_PRIV_KEY }}
|
||||
- name: Checkout main
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v5
|
||||
with:
|
||||
@@ -60,7 +60,7 @@ jobs:
|
||||
uses: actions/create-github-app-token@29824e69f54612133e76f7eaac726eef6c875baf # v2
|
||||
with:
|
||||
app-id: ${{ secrets.GH_APP_ID }}
|
||||
private-key: ${{ secrets.GH_APP_PRIVATE_KEY }}
|
||||
private-key: ${{ secrets.GH_APP_PRIV_KEY }}
|
||||
- name: Checkout main
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v5
|
||||
with:
|
||||
|
||||
21
.github/workflows/release-publish.yml
vendored
21
.github/workflows/release-publish.yml
vendored
@@ -160,10 +160,17 @@ jobs:
|
||||
node-version-file: web/package.json
|
||||
cache: "npm"
|
||||
cache-dependency-path: web/package-lock.json
|
||||
- name: Build web
|
||||
- name: Install web dependencies
|
||||
working-directory: web/
|
||||
run: |
|
||||
npm ci
|
||||
- name: Generate API Clients
|
||||
run: |
|
||||
make gen-client-ts
|
||||
make gen-client-go
|
||||
- name: Build web
|
||||
working-directory: web/
|
||||
run: |
|
||||
npm run build-proxy
|
||||
- name: Build outpost
|
||||
run: |
|
||||
@@ -210,12 +217,12 @@ jobs:
|
||||
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v5
|
||||
- name: Run test suite in final docker images
|
||||
run: |
|
||||
echo "PG_PASS=$(openssl rand 32 | base64 -w 0)" >> .env
|
||||
echo "AUTHENTIK_SECRET_KEY=$(openssl rand 32 | base64 -w 0)" >> .env
|
||||
docker compose pull -q
|
||||
docker compose up --no-start
|
||||
docker compose start postgresql
|
||||
docker compose run -u root server test-all
|
||||
echo "PG_PASS=$(openssl rand 32 | base64 -w 0)" >> lifecycle/container/.env
|
||||
echo "AUTHENTIK_SECRET_KEY=$(openssl rand 32 | base64 -w 0)" >> lifecycle/container/.env
|
||||
docker compose -f lifecycle/container/compose.yml pull -q
|
||||
docker compose -f lifecycle/container/compose.yml up --no-start
|
||||
docker compose -f lifecycle/container/compose.yml start postgresql
|
||||
docker compose -f lifecycle/container/compose.yml run -u root server test-all
|
||||
sentry-release:
|
||||
needs:
|
||||
- build-server
|
||||
|
||||
15
.github/workflows/release-tag.yml
vendored
15
.github/workflows/release-tag.yml
vendored
@@ -70,7 +70,7 @@ jobs:
|
||||
uses: actions/create-github-app-token@29824e69f54612133e76f7eaac726eef6c875baf # v2
|
||||
with:
|
||||
app-id: ${{ secrets.GH_APP_ID }}
|
||||
private-key: ${{ secrets.GH_APP_PRIVATE_KEY }}
|
||||
private-key: ${{ secrets.GH_APP_PRIV_KEY }}
|
||||
- id: get-user-id
|
||||
name: Get GitHub app user ID
|
||||
run: echo "user-id=$(gh api "/users/${{ steps.app-token.outputs.app-slug }}[bot]" --jq .id)" >> "$GITHUB_OUTPUT"
|
||||
@@ -91,6 +91,7 @@ jobs:
|
||||
# ID from https://api.github.com/users/authentik-automation[bot]
|
||||
git config --global user.name '${{ steps.app-token.outputs.app-slug }}[bot]'
|
||||
git config --global user.email '${{ steps.get-user-id.outputs.user-id }}+${{ steps.app-token.outputs.app-slug }}[bot]@users.noreply.github.com'
|
||||
git pull
|
||||
git commit -a -m "release: ${{ inputs.version }}" --allow-empty
|
||||
git tag "version/${{ inputs.version }}" HEAD -m "version/${{ inputs.version }}"
|
||||
git push --follow-tags
|
||||
@@ -117,7 +118,7 @@ jobs:
|
||||
uses: actions/create-github-app-token@29824e69f54612133e76f7eaac726eef6c875baf # v2
|
||||
with:
|
||||
app-id: ${{ secrets.GH_APP_ID }}
|
||||
private-key: ${{ secrets.GH_APP_PRIVATE_KEY }}
|
||||
private-key: ${{ secrets.GH_APP_PRIV_KEY }}
|
||||
repositories: helm
|
||||
- id: get-user-id
|
||||
name: Get GitHub app user ID
|
||||
@@ -159,7 +160,7 @@ jobs:
|
||||
uses: actions/create-github-app-token@29824e69f54612133e76f7eaac726eef6c875baf # v2
|
||||
with:
|
||||
app-id: ${{ secrets.GH_APP_ID }}
|
||||
private-key: ${{ secrets.GH_APP_PRIVATE_KEY }}
|
||||
private-key: ${{ secrets.GH_APP_PRIV_KEY }}
|
||||
repositories: version
|
||||
- id: get-user-id
|
||||
name: Get GitHub app user ID
|
||||
@@ -174,21 +175,25 @@ jobs:
|
||||
if: "${{ inputs.release_reason == 'feature' }}"
|
||||
run: |
|
||||
changelog_url="https://docs.goauthentik.io/docs/releases/${{ needs.check-inputs.outputs.major_version }}"
|
||||
reason="${{ inputs.release_reason }}"
|
||||
jq \
|
||||
--arg version "${{ inputs.version }}" \
|
||||
--arg changelog "See ${changelog_url}" \
|
||||
--arg changelog_url "${changelog_url}" \
|
||||
'.stable.version = $version | .stable.changelog = $changelog | .stable.changelog_url = $changelog_url' version.json > version.new.json
|
||||
--arg reason "${reason}" \
|
||||
'.stable.version = $version | .stable.changelog = $changelog | .stable.changelog_url = $changelog_url | .stable.reason = $reason' version.json > version.new.json
|
||||
mv version.new.json version.json
|
||||
- name: Bump version
|
||||
if: "${{ inputs.release_reason != 'feature' }}"
|
||||
run: |
|
||||
changelog_url="https://docs.goauthentik.io/docs/releases/${{ needs.check-inputs.outputs.major_version }}#fixed-in-$(echo -n ${{ inputs.version}} | sed 's/\.//g')"
|
||||
reason="${{ inputs.release_reason }}"
|
||||
jq \
|
||||
--arg version "${{ inputs.version }}" \
|
||||
--arg changelog "See ${changelog_url}" \
|
||||
--arg changelog_url "${changelog_url}" \
|
||||
'.stable.version = $version | .stable.changelog = $changelog | .stable.changelog_url = $changelog_url' version.json > version.new.json
|
||||
--arg reason "${reason}" \
|
||||
'.stable.version = $version | .stable.changelog = $changelog | .stable.changelog_url = $changelog_url | .stable.reason = $reason' version.json > version.new.json
|
||||
mv version.new.json version.json
|
||||
- name: Create pull request
|
||||
uses: peter-evans/create-pull-request@c0f553fe549906ede9cf27b5156039d195d2ece0 # v7
|
||||
|
||||
2
.github/workflows/repo-stale.yml
vendored
2
.github/workflows/repo-stale.yml
vendored
@@ -18,7 +18,7 @@ jobs:
|
||||
uses: actions/create-github-app-token@29824e69f54612133e76f7eaac726eef6c875baf # v2
|
||||
with:
|
||||
app-id: ${{ secrets.GH_APP_ID }}
|
||||
private-key: ${{ secrets.GH_APP_PRIVATE_KEY }}
|
||||
private-key: ${{ secrets.GH_APP_PRIV_KEY }}
|
||||
- uses: actions/stale@997185467fa4f803885201cee163a9f38240193d # v10
|
||||
with:
|
||||
repo-token: ${{ steps.generate_token.outputs.token }}
|
||||
|
||||
@@ -24,7 +24,7 @@ jobs:
|
||||
uses: actions/create-github-app-token@29824e69f54612133e76f7eaac726eef6c875baf # v2
|
||||
with:
|
||||
app-id: ${{ secrets.GH_APP_ID }}
|
||||
private-key: ${{ secrets.GH_APP_PRIVATE_KEY }}
|
||||
private-key: ${{ secrets.GH_APP_PRIV_KEY }}
|
||||
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v5
|
||||
if: ${{ github.event_name != 'pull_request' }}
|
||||
with:
|
||||
|
||||
8
Makefile
8
Makefile
@@ -148,11 +148,11 @@ bump: ## Bump authentik version. Usage: make bump version=20xx.xx.xx
|
||||
ifndef version
|
||||
$(error Usage: make bump version=20xx.xx.xx )
|
||||
endif
|
||||
$(SED_INPLACE) 's/^version = ".*"/version = "$(version)"/' pyproject.toml
|
||||
$(SED_INPLACE) 's/^VERSION = ".*"/VERSION = "$(version)"/' authentik/__init__.py
|
||||
$(eval current_version := $(shell cat ${PWD}/internal/constants/VERSION))
|
||||
$(SED_INPLACE) 's/^version = ".*"/version = "$(version)"/' ${PWD}/pyproject.toml
|
||||
$(SED_INPLACE) 's/^VERSION = ".*"/VERSION = "$(version)"/' ${PWD}/authentik/__init__.py
|
||||
$(MAKE) gen-build gen-compose aws-cfn
|
||||
npm version --no-git-tag-version --allow-same-version $(version)
|
||||
cd ${PWD}/web && npm version --no-git-tag-version --allow-same-version $(version)
|
||||
$(SED_INPLACE) "s/\"${current_version}\"/\"$(version)\"/" ${PWD}/package.json ${PWD}/package-lock.json ${PWD}/web/package.json ${PWD}/web/package-lock.json
|
||||
echo -n $(version) > ${PWD}/internal/constants/VERSION
|
||||
|
||||
#########################
|
||||
|
||||
42
SECURITY.md
42
SECURITY.md
@@ -18,10 +18,10 @@ Even if the issue is not a CVE, we still greatly appreciate your help in hardeni
|
||||
|
||||
(.x being the latest patch release for each version)
|
||||
|
||||
| Version | Supported |
|
||||
| ---------- | ---------- |
|
||||
| 2025.10.x | ✅ |
|
||||
| 2025.12.x | ✅ |
|
||||
| Version | Supported |
|
||||
| --------- | --------- |
|
||||
| 2025.12.x | ✅ |
|
||||
| 2026.2.x | ✅ |
|
||||
|
||||
## Reporting a Vulnerability
|
||||
|
||||
@@ -60,6 +60,40 @@ authentik reserves the right to reclassify CVSS as necessary. To determine sever
|
||||
| 7.0 – 8.9 | High |
|
||||
| 9.0 – 10.0 | Critical |
|
||||
|
||||
## Intended functionality
|
||||
|
||||
The following capabilities are part of intentional system design and should not be reported as security vulnerabilities:
|
||||
|
||||
- Expressions (property mappings/policies/prompts) can execute arbitrary Python code without safeguards.
|
||||
|
||||
This is expected behavior. Any user with permission to create or modify objects containing expression fields can write code that is executed within authentik. If a vulnerability allows a user without the required permissions to write or modify code and have it executed, that would be a valid security report.
|
||||
|
||||
However, the fact that expressions are executed as part of normal operations is not considered a privilege escalation or security vulnerability.
|
||||
|
||||
- Blueprints can access all files on the filesystem.
|
||||
|
||||
This access is intentional to allow legitimate configuration and deployment tasks. It does not represent a security problem by itself.
|
||||
|
||||
- Importing blueprints allows arbitrary modification of application objects.
|
||||
|
||||
This is intended functionality. This behavior reflects the privileged design of blueprint imports. It is "exploitable" when importing blueprints from untrusted sources without reviewing the blueprint beforehand. However, any method to create, modify or execute blueprints without the required permissions would be a valid security report.
|
||||
|
||||
- Flow imports may contain objects other than flows (such as policies, users, groups, etc.)
|
||||
|
||||
This is expected behavior as flow imports are blueprint files.
|
||||
|
||||
- Prompt HTML is not escaped.
|
||||
|
||||
Prompts intentionally allow raw HTML, including script tags, so they can be used to create interactive or customized user interface elements. Because of this, scripts within prompts may affect or interact with the surrounding page as designed.
|
||||
|
||||
- Open redirects that do not include tokens or other sensitive information are not considered a security vulnerability.
|
||||
|
||||
Redirects that only change navigation flow and do not expose session tokens, API keys, or other confidential data are considered acceptable and do not require reporting.
|
||||
|
||||
- Outgoing network requests are not filtered.
|
||||
|
||||
The destinations of outgoing network requests (HTTP, TCP, etc.) made by authentik to configurable endpoints through objects such as OAuth Sources, SSO Providers, and others are not validated. Depending on your threat model, these requests should be restricted at the network level using appropriate firewall or network policies.
|
||||
|
||||
## Disclosure process
|
||||
|
||||
1. Report from Github or Issue is reported via Email as listed above.
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
from functools import lru_cache
|
||||
from os import environ
|
||||
|
||||
VERSION = "2026.2.0-rc1"
|
||||
VERSION = "2026.2.3-rc1"
|
||||
ENV_GIT_HASH_KEY = "GIT_BUILD_HASH"
|
||||
|
||||
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
"""Schema generation tests"""
|
||||
|
||||
from pathlib import Path
|
||||
from tempfile import gettempdir
|
||||
from uuid import uuid4
|
||||
|
||||
from django.core.management import call_command
|
||||
from django.urls import reverse
|
||||
@@ -29,15 +31,14 @@ class TestSchemaGeneration(APITestCase):
|
||||
|
||||
def test_build_schema(self):
|
||||
"""Test schema build command"""
|
||||
blueprint_file = Path("blueprints/schema.json")
|
||||
api_file = Path("schema.yml")
|
||||
blueprint_file.unlink()
|
||||
api_file.unlink()
|
||||
tmp = Path(gettempdir())
|
||||
blueprint_file = tmp / f"{str(uuid4())}.json"
|
||||
api_file = tmp / f"{str(uuid4())}.yml"
|
||||
with (
|
||||
CONFIG.patch("debug", True),
|
||||
CONFIG.patch("tenants.enabled", True),
|
||||
CONFIG.patch("outposts.disable_embedded_outpost", True),
|
||||
):
|
||||
call_command("build_schema")
|
||||
call_command("build_schema", blueprint_file=blueprint_file, api_file=api_file)
|
||||
self.assertTrue(blueprint_file.exists())
|
||||
self.assertTrue(api_file.exists())
|
||||
|
||||
@@ -3,7 +3,6 @@
|
||||
import traceback
|
||||
from collections.abc import Callable
|
||||
from importlib import import_module
|
||||
from inspect import ismethod
|
||||
|
||||
from django.apps import AppConfig
|
||||
from django.conf import settings
|
||||
@@ -72,12 +71,19 @@ class ManagedAppConfig(AppConfig):
|
||||
|
||||
def _reconcile(self, prefix: str) -> None:
|
||||
for meth_name in dir(self):
|
||||
meth = getattr(self, meth_name)
|
||||
if not ismethod(meth):
|
||||
# Check the attribute on the class to avoid evaluating @property descriptors.
|
||||
# Using getattr(self, ...) on a @property would evaluate it, which can trigger
|
||||
# expensive side effects (e.g. tenant_schedule_specs iterating all providers
|
||||
# and running PolicyEngine queries for every user).
|
||||
class_attr = getattr(type(self), meth_name, None)
|
||||
if class_attr is None or isinstance(class_attr, property):
|
||||
continue
|
||||
category = getattr(meth, "_authentik_managed_reconcile", None)
|
||||
if not callable(class_attr):
|
||||
continue
|
||||
category = getattr(class_attr, "_authentik_managed_reconcile", None)
|
||||
if category != prefix:
|
||||
continue
|
||||
meth = getattr(self, meth_name)
|
||||
name = meth_name.replace(prefix, "")
|
||||
try:
|
||||
self.logger.debug("Starting reconciler", name=name)
|
||||
|
||||
@@ -47,7 +47,12 @@ class ApplicationSerializer(ModelSerializer):
|
||||
"""Application Serializer"""
|
||||
|
||||
launch_url = SerializerMethodField()
|
||||
provider_obj = ProviderSerializer(source="get_provider", required=False, read_only=True)
|
||||
provider_obj = ProviderSerializer(
|
||||
source="get_provider",
|
||||
required=False,
|
||||
read_only=True,
|
||||
allow_null=True,
|
||||
)
|
||||
backchannel_providers_obj = ProviderSerializer(
|
||||
source="backchannel_providers", required=False, read_only=True, many=True
|
||||
)
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
"""authentik core models"""
|
||||
|
||||
import re
|
||||
import traceback
|
||||
from datetime import datetime, timedelta
|
||||
from enum import StrEnum
|
||||
from hashlib import sha256
|
||||
@@ -15,7 +17,6 @@ from django.contrib.sessions.base_session import AbstractBaseSession
|
||||
from django.core.validators import validate_slug
|
||||
from django.db import models
|
||||
from django.db.models import Q, QuerySet, options
|
||||
from django.db.models.constants import LOOKUP_SEP
|
||||
from django.http import HttpRequest
|
||||
from django.utils.functional import cached_property
|
||||
from django.utils.timezone import now
|
||||
@@ -43,6 +44,7 @@ from authentik.lib.models import (
|
||||
DomainlessFormattedURLValidator,
|
||||
SerializerModel,
|
||||
)
|
||||
from authentik.lib.utils.inheritance import get_deepest_child
|
||||
from authentik.lib.utils.time import timedelta_from_string
|
||||
from authentik.policies.models import PolicyBindingModel
|
||||
from authentik.rbac.models import Role
|
||||
@@ -528,23 +530,35 @@ class User(SerializerModel, AttributesMixin, AbstractUser):
|
||||
"default: in 30 days). See authentik logs for every will invocation of this "
|
||||
"deprecation."
|
||||
)
|
||||
stacktrace = traceback.format_stack()
|
||||
# The last line is this function, the next-to-last line is its caller
|
||||
cause = stacktrace[-2] if len(stacktrace) > 1 else "Unknown, see stacktrace in logs"
|
||||
if search := re.search(r'"(.*?)"', cause):
|
||||
cause = f"Property mapping or Expression policy named {search.group(1)}"
|
||||
|
||||
LOGGER.warning(
|
||||
"deprecation used",
|
||||
message=message_logger,
|
||||
deprecation=deprecation,
|
||||
replacement=replacement,
|
||||
cause=cause,
|
||||
stacktrace=stacktrace,
|
||||
)
|
||||
if not Event.filter_not_expired(
|
||||
action=EventAction.CONFIGURATION_WARNING, context__deprecation=deprecation
|
||||
action=EventAction.CONFIGURATION_WARNING,
|
||||
context__deprecation=deprecation,
|
||||
context__cause=cause,
|
||||
).exists():
|
||||
event = Event.new(
|
||||
EventAction.CONFIGURATION_WARNING,
|
||||
deprecation=deprecation,
|
||||
replacement=replacement,
|
||||
message=message_event,
|
||||
cause=cause,
|
||||
)
|
||||
event.expires = datetime.now() + timedelta(days=30)
|
||||
event.save()
|
||||
|
||||
return self.groups
|
||||
|
||||
def set_password(self, raw_password, signal=True, sender=None, request=None):
|
||||
@@ -789,25 +803,7 @@ class Application(SerializerModel, PolicyBindingModel):
|
||||
"""Get casted provider instance. Needs Application queryset with_provider"""
|
||||
if not self.provider:
|
||||
return None
|
||||
|
||||
candidates = []
|
||||
base_class = Provider
|
||||
for subclass in base_class.objects.get_queryset()._get_subclasses_recurse(base_class):
|
||||
parent = self.provider
|
||||
for level in subclass.split(LOOKUP_SEP):
|
||||
try:
|
||||
parent = getattr(parent, level)
|
||||
except AttributeError:
|
||||
break
|
||||
if parent in candidates:
|
||||
continue
|
||||
idx = subclass.count(LOOKUP_SEP)
|
||||
if type(parent) is not base_class:
|
||||
idx += 1
|
||||
candidates.insert(idx, parent)
|
||||
if not candidates:
|
||||
return None
|
||||
return candidates[-1]
|
||||
return get_deepest_child(self.provider)
|
||||
|
||||
def backchannel_provider_for[T: Provider](self, provider_type: type[T], **kwargs) -> T | None:
|
||||
"""Get Backchannel provider for a specific type"""
|
||||
@@ -1119,7 +1115,11 @@ class ExpiringModel(models.Model):
|
||||
default the object is deleted. This is less efficient compared
|
||||
to bulk deleting objects, but classes like Token() need to change
|
||||
values instead of being deleted."""
|
||||
return self.delete(*args, **kwargs)
|
||||
try:
|
||||
return self.delete(*args, **kwargs)
|
||||
except self.DoesNotExist:
|
||||
# Object has already been deleted, so this should be fine
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def filter_not_expired(cls, **kwargs) -> QuerySet[Self]:
|
||||
|
||||
@@ -24,7 +24,8 @@ from authentik.root.ws.consumer import build_device_group
|
||||
|
||||
# Arguments: user: User, password: str
|
||||
password_changed = Signal()
|
||||
# Arguments: credentials: dict[str, any], request: HttpRequest, stage: Stage
|
||||
# Arguments: credentials: dict[str, any], request: HttpRequest,
|
||||
# stage: Stage, context: dict[str, any]
|
||||
login_failed = Signal()
|
||||
|
||||
LOGGER = get_logger()
|
||||
|
||||
@@ -44,19 +44,24 @@
|
||||
{% endblock %}
|
||||
</div>
|
||||
</main>
|
||||
<footer aria-label="Site footer" class="pf-c-login__footer pf-m-dark">
|
||||
<ul class="pf-c-list pf-m-inline">
|
||||
{% for link in footer_links %}
|
||||
<li>
|
||||
<a href="{{ link.href }}">{{ link.name }}</a>
|
||||
</li>
|
||||
{% endfor %}
|
||||
<li>
|
||||
<span>
|
||||
{% trans 'Powered by authentik' %}
|
||||
</span>
|
||||
</li>
|
||||
</ul>
|
||||
<footer
|
||||
name="site-footer"
|
||||
aria-label="{% trans 'Site footer' %}"
|
||||
class="pf-c-login__footer pf-m-dark">
|
||||
<div name="flow-links" aria-label="{% trans 'Flow links' %}">
|
||||
<ul class="pf-c-list pf-m-inline" part="list">
|
||||
{% for link in footer_links %}
|
||||
<li part="list-item">
|
||||
<a part="list-item-link" href="{{ link.href }}">{{ link.name }}</a>
|
||||
</li>
|
||||
{% endfor %}
|
||||
<li part="list-item">
|
||||
<span>
|
||||
{% trans 'Powered by authentik' %}
|
||||
</span>
|
||||
</li>
|
||||
</ul>
|
||||
</div>
|
||||
</footer>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
@@ -78,7 +78,7 @@ def generate_key_id_legacy(key_data: str) -> str:
|
||||
"""Generate Key ID using MD5 (legacy format for backwards compatibility)."""
|
||||
if not key_data:
|
||||
return ""
|
||||
return md5(key_data.encode("utf-8")).hexdigest() # nosec
|
||||
return md5(key_data.encode("utf-8"), usedforsecurity=False).hexdigest() # nosec
|
||||
|
||||
|
||||
class CertificateKeyPair(SerializerModel, ManagedModel, CreatedUpdatedModel):
|
||||
|
||||
@@ -1,8 +1,11 @@
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
from rest_framework.exceptions import ValidationError
|
||||
from rest_framework.viewsets import ModelViewSet
|
||||
|
||||
from authentik.core.api.used_by import UsedByMixin
|
||||
from authentik.endpoints.api.connectors import ConnectorSerializer
|
||||
from authentik.endpoints.models import EndpointStage
|
||||
from authentik.endpoints.controller import Capabilities
|
||||
from authentik.endpoints.models import Connector, EndpointStage
|
||||
from authentik.flows.api.stages import StageSerializer
|
||||
|
||||
|
||||
@@ -11,6 +14,13 @@ class EndpointStageSerializer(StageSerializer):
|
||||
|
||||
connector_obj = ConnectorSerializer(source="connector", read_only=True)
|
||||
|
||||
def validate_connector(self, connector: Connector) -> Connector:
|
||||
conn: Connector = Connector.objects.get_subclass(pk=connector.pk)
|
||||
controller = conn.controller(conn)
|
||||
if Capabilities.STAGE_ENDPOINTS not in controller.capabilities():
|
||||
raise ValidationError(_("Selected connector is not compatible with this stage."))
|
||||
return connector
|
||||
|
||||
class Meta:
|
||||
model = EndpointStage
|
||||
fields = StageSerializer.Meta.fields + [
|
||||
|
||||
@@ -18,7 +18,10 @@ from authentik.rbac.decorators import permission_required
|
||||
class EnrollmentTokenSerializer(ModelSerializer):
|
||||
|
||||
device_group_obj = DeviceAccessGroupSerializer(
|
||||
source="device_group", read_only=True, required=False
|
||||
source="device_group",
|
||||
read_only=True,
|
||||
required=False,
|
||||
allow_null=True,
|
||||
)
|
||||
|
||||
def __init__(self, *args, **kwargs) -> None:
|
||||
|
||||
@@ -37,6 +37,8 @@ class AgentEnrollmentAuth(BaseAuthentication):
|
||||
token = EnrollmentToken.filter_not_expired(key=key).first()
|
||||
if not token:
|
||||
raise PermissionDenied()
|
||||
if not token.connector.enabled:
|
||||
raise PermissionDenied()
|
||||
CTX_AUTH_VIA.set("endpoint_token_enrollment")
|
||||
return (DeviceUser(), token)
|
||||
|
||||
@@ -51,6 +53,8 @@ class AgentAuth(BaseAuthentication):
|
||||
device_token = DeviceToken.filter_not_expired(key=key).first()
|
||||
if not device_token:
|
||||
raise PermissionDenied()
|
||||
if not device_token.device.connector.enabled:
|
||||
raise PermissionDenied()
|
||||
if device_token.device.device.is_expired:
|
||||
raise PermissionDenied()
|
||||
CTX_AUTH_VIA.set("endpoint_token")
|
||||
|
||||
@@ -8,7 +8,7 @@ from rest_framework.fields import CharField
|
||||
|
||||
from authentik.core.api.utils import PassiveSerializer
|
||||
from authentik.endpoints.connectors.agent.models import AgentConnector, EnrollmentToken
|
||||
from authentik.endpoints.controller import BaseController
|
||||
from authentik.endpoints.controller import BaseController, Capabilities
|
||||
from authentik.endpoints.facts import OSFamily
|
||||
|
||||
|
||||
@@ -48,8 +48,8 @@ class AgentConnectorController(BaseController[AgentConnector]):
|
||||
def vendor_identifier() -> str:
|
||||
return "goauthentik.io/platform"
|
||||
|
||||
def supported_enrollment_methods(self):
|
||||
return []
|
||||
def capabilities(self) -> list[Capabilities]:
|
||||
return [Capabilities.STAGE_ENDPOINTS]
|
||||
|
||||
def generate_mdm_config(
|
||||
self, target_platform: OSFamily, request: HttpRequest, token: EnrollmentToken
|
||||
|
||||
@@ -58,6 +58,16 @@ class TestAgentAPI(APITestCase):
|
||||
)
|
||||
self.assertEqual(response.status_code, 200)
|
||||
|
||||
def test_enroll_disabled(self):
|
||||
self.connector.enabled = False
|
||||
self.connector.save()
|
||||
response = self.client.post(
|
||||
reverse("authentik_api:agentconnector-enroll"),
|
||||
data={"device_serial": generate_id(), "device_name": "bar"},
|
||||
HTTP_AUTHORIZATION=f"Bearer {self.token.key}",
|
||||
)
|
||||
self.assertEqual(response.status_code, 403)
|
||||
|
||||
def test_enroll_token_delete(self):
|
||||
response = self.client.post(
|
||||
reverse("authentik_api:agentconnector-enroll"),
|
||||
@@ -104,6 +114,16 @@ class TestAgentAPI(APITestCase):
|
||||
)
|
||||
self.assertEqual(response.status_code, 200)
|
||||
|
||||
@reconcile_app("authentik_crypto")
|
||||
def test_config_disabled(self):
|
||||
self.connector.enabled = False
|
||||
self.connector.save()
|
||||
response = self.client.get(
|
||||
reverse("authentik_api:agentconnector-agent-config"),
|
||||
HTTP_AUTHORIZATION=f"Bearer+agent {self.device_token.key}",
|
||||
)
|
||||
self.assertEqual(response.status_code, 403)
|
||||
|
||||
def test_check_in(self):
|
||||
response = self.client.post(
|
||||
reverse("authentik_api:agentconnector-check-in"),
|
||||
@@ -112,6 +132,16 @@ class TestAgentAPI(APITestCase):
|
||||
)
|
||||
self.assertEqual(response.status_code, 204)
|
||||
|
||||
def test_check_in_disabled(self):
|
||||
self.connector.enabled = False
|
||||
self.connector.save()
|
||||
response = self.client.post(
|
||||
reverse("authentik_api:agentconnector-check-in"),
|
||||
data=CHECK_IN_DATA_VALID,
|
||||
HTTP_AUTHORIZATION=f"Bearer+agent {self.device_token.key}",
|
||||
)
|
||||
self.assertEqual(response.status_code, 403)
|
||||
|
||||
def test_check_in_token_expired(self):
|
||||
self.device_token.expiring = True
|
||||
self.device_token.expires = now() - timedelta(hours=1)
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from hashlib import sha256
|
||||
from json import loads
|
||||
from unittest.mock import PropertyMock, patch
|
||||
|
||||
from django.urls import reverse
|
||||
from jwt import encode
|
||||
@@ -232,3 +233,43 @@ class TestEndpointStage(FlowTestCase):
|
||||
plan = plan()
|
||||
self.assertNotIn(PLAN_CONTEXT_AGENT_ENDPOINT_CHALLENGE, plan.context)
|
||||
self.assertEqual(plan.context[PLAN_CONTEXT_DEVICE], self.device)
|
||||
|
||||
def test_endpoint_stage_connector_no_stage_optional(self):
|
||||
flow = create_test_flow()
|
||||
stage = EndpointStage.objects.create(connector=self.connector, mode=StageMode.OPTIONAL)
|
||||
FlowStageBinding.objects.create(stage=stage, target=flow, order=0)
|
||||
|
||||
with patch(
|
||||
"authentik.endpoints.connectors.agent.models.AgentConnector.stage",
|
||||
PropertyMock(return_value=None),
|
||||
):
|
||||
with self.assertFlowFinishes() as plan:
|
||||
res = self.client.get(
|
||||
reverse("authentik_api:flow-executor", kwargs={"flow_slug": flow.slug}),
|
||||
)
|
||||
self.assertStageRedirects(res, reverse("authentik_core:root-redirect"))
|
||||
plan = plan()
|
||||
self.assertNotIn(PLAN_CONTEXT_AGENT_ENDPOINT_CHALLENGE, plan.context)
|
||||
self.assertNotIn(PLAN_CONTEXT_DEVICE, plan.context)
|
||||
|
||||
def test_endpoint_stage_connector_no_stage_required(self):
|
||||
flow = create_test_flow()
|
||||
stage = EndpointStage.objects.create(connector=self.connector, mode=StageMode.REQUIRED)
|
||||
FlowStageBinding.objects.create(stage=stage, target=flow, order=0)
|
||||
|
||||
with patch(
|
||||
"authentik.endpoints.connectors.agent.models.AgentConnector.stage",
|
||||
PropertyMock(return_value=None),
|
||||
):
|
||||
with self.assertFlowFinishes() as plan:
|
||||
res = self.client.get(
|
||||
reverse("authentik_api:flow-executor", kwargs={"flow_slug": flow.slug}),
|
||||
)
|
||||
self.assertStageResponse(
|
||||
res,
|
||||
component="ak-stage-access-denied",
|
||||
error_message="Invalid stage configuration",
|
||||
)
|
||||
plan = plan()
|
||||
self.assertNotIn(PLAN_CONTEXT_AGENT_ENDPOINT_CHALLENGE, plan.context)
|
||||
self.assertNotIn(PLAN_CONTEXT_DEVICE, plan.context)
|
||||
|
||||
@@ -8,13 +8,15 @@ from authentik.lib.sentry import SentryIgnoredException
|
||||
MERGED_VENDOR = "goauthentik.io/@merged"
|
||||
|
||||
|
||||
class EnrollmentMethods(models.TextChoices):
|
||||
class Capabilities(models.TextChoices):
|
||||
# Automatically enrolled through user action
|
||||
AUTOMATIC_USER = "automatic_user"
|
||||
ENROLL_AUTOMATIC_USER = "enroll_automatic_user"
|
||||
# Automatically enrolled through connector integration
|
||||
AUTOMATIC_API = "automatic_api"
|
||||
ENROLL_AUTOMATIC_API = "enroll_automatic_api"
|
||||
# Manually enrolled with user interaction (user scanning a QR code for example)
|
||||
MANUAL_USER = "manual_user"
|
||||
ENROLL_MANUAL_USER = "enroll_manual_user"
|
||||
# Supported for use with Endpoints stage
|
||||
STAGE_ENDPOINTS = "stage_endpoints"
|
||||
|
||||
|
||||
class ConnectorSyncException(SentryIgnoredException):
|
||||
@@ -34,7 +36,7 @@ class BaseController[T: "Connector"]:
|
||||
def vendor_identifier() -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
def supported_enrollment_methods(self) -> list[EnrollmentMethods]:
|
||||
def capabilities(self) -> list[Capabilities]:
|
||||
return []
|
||||
|
||||
def stage_view_enrollment(self) -> StageView | None:
|
||||
@@ -42,3 +44,6 @@ class BaseController[T: "Connector"]:
|
||||
|
||||
def stage_view_authentication(self) -> StageView | None:
|
||||
return None
|
||||
|
||||
def sync_endpoints(self):
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -162,8 +162,11 @@ class Connector(ScheduledModel, SerializerModel):
|
||||
|
||||
@property
|
||||
def schedule_specs(self) -> list[ScheduleSpec]:
|
||||
from authentik.endpoints.controller import Capabilities
|
||||
from authentik.endpoints.tasks import endpoints_sync
|
||||
|
||||
if Capabilities.ENROLL_AUTOMATIC_API not in self.controller(self).capabilities():
|
||||
return []
|
||||
return [
|
||||
ScheduleSpec(
|
||||
actor=endpoints_sync,
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from authentik.endpoints.models import EndpointStage
|
||||
from authentik.endpoints.models import Connector, EndpointStage, StageMode
|
||||
from authentik.flows.stage import StageView
|
||||
|
||||
PLAN_CONTEXT_ENDPOINT_CONNECTOR = "endpoint_connector"
|
||||
@@ -6,15 +6,27 @@ PLAN_CONTEXT_ENDPOINT_CONNECTOR = "endpoint_connector"
|
||||
|
||||
class EndpointStageView(StageView):
|
||||
|
||||
def _get_inner(self):
|
||||
def _get_inner(self) -> StageView | None:
|
||||
stage: EndpointStage = self.executor.current_stage
|
||||
inner_stage: type[StageView] | None = stage.connector.stage
|
||||
connector: Connector = stage.connector
|
||||
if not connector.enabled:
|
||||
return None
|
||||
inner_stage: type[StageView] | None = connector.stage
|
||||
if not inner_stage:
|
||||
return self.executor.stage_ok()
|
||||
return None
|
||||
return inner_stage(self.executor, request=self.request)
|
||||
|
||||
def dispatch(self, request, *args, **kwargs):
|
||||
return self._get_inner().dispatch(request, *args, **kwargs)
|
||||
inner = self._get_inner()
|
||||
if inner is None:
|
||||
stage: EndpointStage = self.executor.current_stage
|
||||
if stage.mode == StageMode.OPTIONAL:
|
||||
return self.executor.stage_ok()
|
||||
else:
|
||||
return self.executor.stage_invalid("Invalid stage configuration")
|
||||
return inner.dispatch(request, *args, **kwargs)
|
||||
|
||||
def cleanup(self):
|
||||
return self._get_inner().cleanup()
|
||||
inner = self._get_inner()
|
||||
if inner is not None:
|
||||
return inner.cleanup()
|
||||
|
||||
@@ -6,7 +6,7 @@ from django.utils.translation import gettext_lazy as _
|
||||
from dramatiq.actor import actor
|
||||
from structlog.stdlib import get_logger
|
||||
|
||||
from authentik.endpoints.controller import EnrollmentMethods
|
||||
from authentik.endpoints.controller import Capabilities
|
||||
from authentik.endpoints.models import Connector
|
||||
|
||||
LOGGER = get_logger()
|
||||
@@ -17,11 +17,11 @@ def endpoints_sync(connector_pk: Any):
|
||||
connector: Connector | None = (
|
||||
Connector.objects.filter(pk=connector_pk).select_subclasses().first()
|
||||
)
|
||||
if not connector:
|
||||
if not connector or not connector.enabled:
|
||||
return
|
||||
controller = connector.controller
|
||||
ctrl = controller(connector)
|
||||
if EnrollmentMethods.AUTOMATIC_API not in ctrl.supported_enrollment_methods():
|
||||
if Capabilities.ENROLL_AUTOMATIC_API not in ctrl.capabilities():
|
||||
return
|
||||
LOGGER.info("Syncing connector", connector=connector.name)
|
||||
ctrl.sync_endpoints()
|
||||
|
||||
41
authentik/endpoints/tests/test_api.py
Normal file
41
authentik/endpoints/tests/test_api.py
Normal file
@@ -0,0 +1,41 @@
|
||||
from django.urls import reverse
|
||||
from rest_framework.test import APITestCase
|
||||
|
||||
from authentik.core.tests.utils import create_test_admin_user
|
||||
from authentik.endpoints.connectors.agent.models import AgentConnector
|
||||
from authentik.endpoints.models import StageMode
|
||||
from authentik.enterprise.endpoints.connectors.fleet.models import FleetConnector
|
||||
from authentik.lib.generators import generate_id
|
||||
|
||||
|
||||
class TestAPI(APITestCase):
|
||||
def setUp(self):
|
||||
self.user = create_test_admin_user()
|
||||
self.client.force_login(self.user)
|
||||
|
||||
def test_endpoint_stage_agent(self):
|
||||
connector = AgentConnector.objects.create(name=generate_id())
|
||||
res = self.client.post(
|
||||
reverse("authentik_api:stages-endpoint-list"),
|
||||
data={
|
||||
"name": generate_id(),
|
||||
"connector": str(connector.pk),
|
||||
"mode": StageMode.REQUIRED,
|
||||
},
|
||||
)
|
||||
self.assertEqual(res.status_code, 201)
|
||||
|
||||
def test_endpoint_stage_fleet(self):
|
||||
connector = FleetConnector.objects.create(name=generate_id())
|
||||
res = self.client.post(
|
||||
reverse("authentik_api:stages-endpoint-list"),
|
||||
data={
|
||||
"name": generate_id(),
|
||||
"connector": str(connector.pk),
|
||||
"mode": StageMode.REQUIRED,
|
||||
},
|
||||
)
|
||||
self.assertEqual(res.status_code, 400)
|
||||
self.assertJSONEqual(
|
||||
res.content, {"connector": ["Selected connector is not compatible with this stage."]}
|
||||
)
|
||||
35
authentik/endpoints/tests/test_tasks.py
Normal file
35
authentik/endpoints/tests/test_tasks.py
Normal file
@@ -0,0 +1,35 @@
|
||||
from unittest.mock import PropertyMock, patch
|
||||
|
||||
from rest_framework.test import APITestCase
|
||||
|
||||
from authentik.endpoints.controller import BaseController, Capabilities
|
||||
from authentik.endpoints.models import Connector
|
||||
from authentik.endpoints.tasks import endpoints_sync
|
||||
from authentik.lib.generators import generate_id
|
||||
|
||||
|
||||
class TestEndpointTasks(APITestCase):
|
||||
def test_agent_sync(self):
|
||||
class controller(BaseController):
|
||||
def capabilities(self):
|
||||
return [Capabilities.ENROLL_AUTOMATIC_API]
|
||||
|
||||
def sync_endpoints(self):
|
||||
pass
|
||||
|
||||
with patch.object(Connector, "controller", PropertyMock(return_value=controller)):
|
||||
connector = Connector.objects.create(name=generate_id())
|
||||
self.assertEqual(len(connector.schedule_specs), 1)
|
||||
|
||||
endpoints_sync.send(connector.pk).get_result(block=True)
|
||||
|
||||
def test_agent_no_sync(self):
|
||||
class controller(BaseController):
|
||||
def capabilities(self):
|
||||
return []
|
||||
|
||||
with patch.object(Connector, "controller", PropertyMock(return_value=controller)):
|
||||
connector = Connector.objects.create(name=generate_id())
|
||||
self.assertEqual(len(connector.schedule_specs), 0)
|
||||
|
||||
endpoints_sync.send(connector.pk).get_result(block=True)
|
||||
@@ -3,6 +3,7 @@ from hmac import compare_digest
|
||||
|
||||
from django.http import Http404, HttpRequest, HttpResponse, HttpResponseBadRequest, QueryDict
|
||||
|
||||
from authentik.common.oauth.constants import QS_LOGIN_HINT
|
||||
from authentik.endpoints.connectors.agent.auth import (
|
||||
agent_auth_issue_token,
|
||||
check_device_policies,
|
||||
@@ -14,7 +15,7 @@ from authentik.enterprise.policy import EnterprisePolicyAccessView
|
||||
from authentik.flows.exceptions import FlowNonApplicableException
|
||||
from authentik.flows.models import in_memory_stage
|
||||
from authentik.flows.planner import PLAN_CONTEXT_DEVICE, FlowPlanner
|
||||
from authentik.flows.stage import StageView
|
||||
from authentik.flows.stage import PLAN_CONTEXT_PENDING_USER_IDENTIFIER, StageView
|
||||
from authentik.providers.oauth2.utils import HttpResponseRedirectScheme
|
||||
|
||||
QS_AGENT_IA_TOKEN = "ak-auth-ia-token" # nosec
|
||||
@@ -64,14 +65,14 @@ class AgentInteractiveAuth(EnterprisePolicyAccessView):
|
||||
|
||||
planner = FlowPlanner(self.connector.authorization_flow)
|
||||
planner.allow_empty_flows = True
|
||||
context = {
|
||||
PLAN_CONTEXT_DEVICE: self.device,
|
||||
PLAN_CONTEXT_DEVICE_AUTH_TOKEN: self.auth_token,
|
||||
}
|
||||
if QS_LOGIN_HINT in request.GET:
|
||||
context[PLAN_CONTEXT_PENDING_USER_IDENTIFIER] = request.GET[QS_LOGIN_HINT]
|
||||
try:
|
||||
plan = planner.plan(
|
||||
self.request,
|
||||
{
|
||||
PLAN_CONTEXT_DEVICE: self.device,
|
||||
PLAN_CONTEXT_DEVICE_AUTH_TOKEN: self.auth_token,
|
||||
},
|
||||
)
|
||||
plan = planner.plan(self.request, context)
|
||||
except FlowNonApplicableException:
|
||||
return self.handle_no_permission_authenticated()
|
||||
plan.append_stage(in_memory_stage(AgentAuthFulfillmentStage))
|
||||
@@ -84,7 +85,6 @@ class AgentInteractiveAuth(EnterprisePolicyAccessView):
|
||||
|
||||
|
||||
class AgentAuthFulfillmentStage(StageView):
|
||||
|
||||
def get(self, request: HttpRequest, *args, **kwargs) -> HttpResponse:
|
||||
device: Device = self.executor.plan.context.pop(PLAN_CONTEXT_DEVICE)
|
||||
auth_token: DeviceAuthenticationToken = self.executor.plan.context.pop(
|
||||
|
||||
@@ -6,7 +6,7 @@ from requests import RequestException
|
||||
from rest_framework.exceptions import ValidationError
|
||||
|
||||
from authentik.core.models import User
|
||||
from authentik.endpoints.controller import BaseController, ConnectorSyncException, EnrollmentMethods
|
||||
from authentik.endpoints.controller import BaseController, Capabilities, ConnectorSyncException
|
||||
from authentik.endpoints.facts import (
|
||||
DeviceFacts,
|
||||
OSFamily,
|
||||
@@ -43,8 +43,8 @@ class FleetController(BaseController[DBC]):
|
||||
def vendor_identifier() -> str:
|
||||
return "fleetdm.com"
|
||||
|
||||
def supported_enrollment_methods(self) -> list[EnrollmentMethods]:
|
||||
return [EnrollmentMethods.AUTOMATIC_API]
|
||||
def capabilities(self) -> list[Capabilities]:
|
||||
return [Capabilities.ENROLL_AUTOMATIC_API]
|
||||
|
||||
def _url(self, path: str) -> str:
|
||||
return f"{self.connector.url}{path}"
|
||||
|
||||
@@ -15,6 +15,7 @@ from django.core.cache import cache
|
||||
from django.db.models.query import QuerySet
|
||||
from django.utils.timezone import now
|
||||
from jwt import PyJWTError, decode, get_unverified_header
|
||||
from jwt.algorithms import ECAlgorithm
|
||||
from rest_framework.exceptions import ValidationError
|
||||
from rest_framework.fields import (
|
||||
ChoiceField,
|
||||
@@ -109,13 +110,20 @@ class LicenseKey:
|
||||
intermediate.verify_directly_issued_by(get_licensing_key())
|
||||
except InvalidSignature, TypeError, ValueError, Error:
|
||||
raise ValidationError("Unable to verify license") from None
|
||||
_validate_curve_original = ECAlgorithm._validate_curve
|
||||
try:
|
||||
# authentik's license are generated with `algorithm="ES512"` and signed with
|
||||
# a key of curve `secp384r1`. Starting with version 2.11.0, pyjwt enforces the spec, see
|
||||
# https://github.com/jpadilla/pyjwt/commit/5b8622773358e56d3d3c0a9acf404809ff34433a
|
||||
# authentik will change its license generation to `algorithm="ES384"` in 2026.
|
||||
# TODO: remove this when the last incompatible license runs out.
|
||||
ECAlgorithm._validate_curve = lambda *_: True
|
||||
body = from_dict(
|
||||
LicenseKey,
|
||||
decode(
|
||||
jwt,
|
||||
our_cert.public_key(),
|
||||
algorithms=["ES512"],
|
||||
algorithms=["ES384", "ES512"],
|
||||
audience=get_license_aud(),
|
||||
options={"verify_exp": check_expiry, "verify_signature": check_expiry},
|
||||
),
|
||||
@@ -125,6 +133,8 @@ class LicenseKey:
|
||||
if unverified["aud"] != get_license_aud():
|
||||
raise ValidationError("Invalid Install ID in license") from None
|
||||
raise ValidationError("Unable to verify license") from None
|
||||
finally:
|
||||
ECAlgorithm._validate_curve = _validate_curve_original
|
||||
return body
|
||||
|
||||
@staticmethod
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
from datetime import date
|
||||
from datetime import datetime
|
||||
|
||||
from django.db.models import BooleanField as ModelBooleanField
|
||||
from django.db.models import Case, Q, Value, When
|
||||
from django_filters.rest_framework import BooleanFilter, FilterSet
|
||||
from drf_spectacular.utils import extend_schema, extend_schema_field
|
||||
from drf_spectacular.utils import extend_schema
|
||||
from rest_framework.decorators import action
|
||||
from rest_framework.fields import DateField, IntegerField, SerializerMethodField
|
||||
from rest_framework.fields import IntegerField, SerializerMethodField
|
||||
from rest_framework.mixins import CreateModelMixin
|
||||
from rest_framework.request import Request
|
||||
from rest_framework.response import Response
|
||||
@@ -21,6 +21,7 @@ from authentik.enterprise.lifecycle.utils import (
|
||||
ReviewerUserSerializer,
|
||||
admin_link_for_model,
|
||||
parse_content_type,
|
||||
start_of_day,
|
||||
)
|
||||
from authentik.lib.utils.time import timedelta_from_string
|
||||
|
||||
@@ -67,13 +68,13 @@ class LifecycleIterationSerializer(EnterpriseRequiredMixin, ModelSerializer):
|
||||
def get_object_admin_url(self, iteration: LifecycleIteration) -> str:
|
||||
return admin_link_for_model(iteration.object)
|
||||
|
||||
@extend_schema_field(DateField())
|
||||
def get_grace_period_end(self, iteration: LifecycleIteration) -> date:
|
||||
return iteration.opened_on + timedelta_from_string(iteration.rule.grace_period)
|
||||
def get_grace_period_end(self, iteration: LifecycleIteration) -> datetime:
|
||||
return start_of_day(
|
||||
iteration.opened_on + timedelta_from_string(iteration.rule.grace_period)
|
||||
)
|
||||
|
||||
@extend_schema_field(DateField())
|
||||
def get_next_review_date(self, iteration: LifecycleIteration):
|
||||
return iteration.opened_on + timedelta_from_string(iteration.rule.interval)
|
||||
def get_next_review_date(self, iteration: LifecycleIteration) -> datetime:
|
||||
return start_of_day(iteration.opened_on + timedelta_from_string(iteration.rule.interval))
|
||||
|
||||
def get_user_can_review(self, iteration: LifecycleIteration) -> bool:
|
||||
return iteration.user_can_review(self.context["request"].user)
|
||||
@@ -102,7 +103,7 @@ class IterationViewSet(EnterpriseRequiredMixin, CreateModelMixin, GenericViewSet
|
||||
default=Value(False),
|
||||
output_field=ModelBooleanField(),
|
||||
)
|
||||
)
|
||||
).distinct()
|
||||
|
||||
@action(
|
||||
detail=False,
|
||||
|
||||
@@ -0,0 +1,18 @@
|
||||
# Generated by Django 5.2.11 on 2026-02-13 09:33
|
||||
|
||||
from django.db import migrations, models
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
|
||||
dependencies = [
|
||||
("authentik_lifecycle", "0001_initial"),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.AlterField(
|
||||
model_name="lifecycleiteration",
|
||||
name="opened_on",
|
||||
field=models.DateTimeField(auto_now_add=True),
|
||||
),
|
||||
]
|
||||
@@ -1,3 +1,4 @@
|
||||
from datetime import timedelta
|
||||
from uuid import uuid4
|
||||
|
||||
from django.contrib.contenttypes.fields import GenericForeignKey
|
||||
@@ -13,7 +14,7 @@ from rest_framework.serializers import BaseSerializer
|
||||
|
||||
from authentik.blueprints.models import ManagedModel
|
||||
from authentik.core.models import Group, User
|
||||
from authentik.enterprise.lifecycle.utils import link_for_model
|
||||
from authentik.enterprise.lifecycle.utils import link_for_model, start_of_day
|
||||
from authentik.events.models import Event, EventAction, NotificationSeverity, NotificationTransport
|
||||
from authentik.lib.models import SerializerModel
|
||||
from authentik.lib.utils.time import timedelta_from_string, timedelta_string_validator
|
||||
@@ -98,7 +99,9 @@ class LifecycleRule(SerializerModel):
|
||||
|
||||
def _get_newly_overdue_iterations(self) -> QuerySet[LifecycleIteration]:
|
||||
return self.lifecycleiteration_set.filter(
|
||||
opened_on__lte=timezone.now() - timedelta_from_string(self.grace_period),
|
||||
opened_on__lt=start_of_day(
|
||||
timezone.now() + timedelta(days=1) - timedelta_from_string(self.grace_period)
|
||||
),
|
||||
state=ReviewState.PENDING,
|
||||
)
|
||||
|
||||
@@ -106,7 +109,9 @@ class LifecycleRule(SerializerModel):
|
||||
recent_iteration_ids = LifecycleIteration.objects.filter(
|
||||
content_type=self.content_type,
|
||||
object_id__isnull=False,
|
||||
opened_on__gte=timezone.now() - timedelta_from_string(self.interval),
|
||||
opened_on__gte=start_of_day(
|
||||
timezone.now() + timedelta(days=1) - timedelta_from_string(self.interval)
|
||||
),
|
||||
).values_list(Cast("object_id", output_field=self._get_pk_field()), flat=True)
|
||||
|
||||
return self.get_objects().exclude(pk__in=recent_iteration_ids)
|
||||
@@ -186,7 +191,7 @@ class LifecycleIteration(SerializerModel, ManagedModel):
|
||||
rule = models.ForeignKey(LifecycleRule, null=True, on_delete=models.SET_NULL)
|
||||
|
||||
state = models.CharField(max_length=10, choices=ReviewState, default=ReviewState.PENDING)
|
||||
opened_on = models.DateField(auto_now_add=True)
|
||||
opened_on = models.DateTimeField(auto_now_add=True)
|
||||
|
||||
class Meta:
|
||||
indexes = [models.Index(fields=["content_type", "opened_on"])]
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import datetime as dt
|
||||
from datetime import timedelta
|
||||
from unittest.mock import patch
|
||||
|
||||
@@ -319,7 +320,7 @@ class TestLifecycleModels(TestCase):
|
||||
content_type=content_type, object_id=str(app_one.pk), rule=rule_overdue
|
||||
)
|
||||
LifecycleIteration.objects.filter(pk=iteration.pk).update(
|
||||
opened_on=(timezone.now().date() - timedelta(days=20))
|
||||
opened_on=(timezone.now() - timedelta(days=20))
|
||||
)
|
||||
|
||||
# Apply again to trigger overdue logic
|
||||
@@ -383,7 +384,7 @@ class TestLifecycleModels(TestCase):
|
||||
content_type=content_type, object_id=str(app_overdue.pk), rule=rule_overdue
|
||||
)
|
||||
LifecycleIteration.objects.filter(pk=overdue_iteration.pk).update(
|
||||
opened_on=(timezone.now().date() - timedelta(days=20))
|
||||
opened_on=(timezone.now() - timedelta(days=20))
|
||||
)
|
||||
|
||||
# Apply overdue rule to mark iteration as overdue
|
||||
@@ -667,3 +668,178 @@ class TestLifecycleModels(TestCase):
|
||||
reviewers = list(rule.get_reviewers())
|
||||
self.assertIn(explicit_reviewer, reviewers)
|
||||
self.assertIn(group_member, reviewers)
|
||||
|
||||
|
||||
class TestLifecycleDateBoundaries(TestCase):
|
||||
"""Verify that start_of_day normalization ensures correct overdue/due
|
||||
detection regardless of exact task execution time within a day.
|
||||
|
||||
The daily task may run at any point during the day. The start_of_day
|
||||
normalization in _get_newly_overdue_iterations and _get_newly_due_objects
|
||||
ensures that the boundary is always at midnight, so millisecond variations
|
||||
in task execution time do not affect results."""
|
||||
|
||||
def _create_rule_and_iteration(self, grace_period="days=1", interval="days=365"):
|
||||
app = Application.objects.create(name=generate_id(), slug=generate_id())
|
||||
content_type = ContentType.objects.get_for_model(Application)
|
||||
rule = LifecycleRule.objects.create(
|
||||
name=generate_id(),
|
||||
content_type=content_type,
|
||||
object_id=str(app.pk),
|
||||
interval=interval,
|
||||
grace_period=grace_period,
|
||||
)
|
||||
iteration = LifecycleIteration.objects.get(
|
||||
content_type=content_type, object_id=str(app.pk), rule=rule
|
||||
)
|
||||
return app, rule, iteration
|
||||
|
||||
def test_overdue_iteration_opened_yesterday(self):
|
||||
"""grace_period=1 day: iteration opened yesterday at any time is overdue today."""
|
||||
_, rule, iteration = self._create_rule_and_iteration(grace_period="days=1")
|
||||
fixed_now = dt.datetime(2025, 6, 15, 14, 30, 0, tzinfo=dt.UTC)
|
||||
for opened_on in [
|
||||
dt.datetime(2025, 6, 14, 0, 0, 0, tzinfo=dt.UTC),
|
||||
dt.datetime(2025, 6, 14, 12, 0, 0, tzinfo=dt.UTC),
|
||||
dt.datetime(2025, 6, 14, 23, 59, 59, 999999, tzinfo=dt.UTC),
|
||||
]:
|
||||
with self.subTest(opened_on=opened_on):
|
||||
LifecycleIteration.objects.filter(pk=iteration.pk).update(
|
||||
opened_on=opened_on, state=ReviewState.PENDING
|
||||
)
|
||||
with patch("django.utils.timezone.now", return_value=fixed_now):
|
||||
self.assertIn(iteration, list(rule._get_newly_overdue_iterations()))
|
||||
|
||||
def test_not_overdue_iteration_opened_today(self):
|
||||
"""grace_period=1 day: iteration opened today at any time is NOT overdue."""
|
||||
_, rule, iteration = self._create_rule_and_iteration(grace_period="days=1")
|
||||
fixed_now = dt.datetime(2025, 6, 15, 14, 30, 0, tzinfo=dt.UTC)
|
||||
for opened_on in [
|
||||
dt.datetime(2025, 6, 15, 0, 0, 0, tzinfo=dt.UTC),
|
||||
dt.datetime(2025, 6, 15, 14, 30, 0, tzinfo=dt.UTC),
|
||||
dt.datetime(2025, 6, 15, 23, 59, 59, 999999, tzinfo=dt.UTC),
|
||||
]:
|
||||
with self.subTest(opened_on=opened_on):
|
||||
LifecycleIteration.objects.filter(pk=iteration.pk).update(
|
||||
opened_on=opened_on, state=ReviewState.PENDING
|
||||
)
|
||||
with patch("django.utils.timezone.now", return_value=fixed_now):
|
||||
self.assertNotIn(iteration, list(rule._get_newly_overdue_iterations()))
|
||||
|
||||
def test_overdue_independent_of_task_execution_time(self):
|
||||
"""Overdue detection gives the same result whether the task runs at 00:00:01 or 23:59:59."""
|
||||
_, rule, iteration = self._create_rule_and_iteration(grace_period="days=1")
|
||||
opened_on = dt.datetime(2025, 6, 14, 18, 0, 0, tzinfo=dt.UTC)
|
||||
LifecycleIteration.objects.filter(pk=iteration.pk).update(
|
||||
opened_on=opened_on, state=ReviewState.PENDING
|
||||
)
|
||||
for task_time in [
|
||||
dt.datetime(2025, 6, 15, 0, 0, 1, tzinfo=dt.UTC),
|
||||
dt.datetime(2025, 6, 15, 12, 0, 0, tzinfo=dt.UTC),
|
||||
dt.datetime(2025, 6, 15, 23, 59, 59, tzinfo=dt.UTC),
|
||||
]:
|
||||
with self.subTest(task_time=task_time):
|
||||
with patch("django.utils.timezone.now", return_value=task_time):
|
||||
self.assertIn(iteration, list(rule._get_newly_overdue_iterations()))
|
||||
|
||||
def test_overdue_boundary_multi_day_grace_period(self):
|
||||
"""grace_period=30 days: overdue after 30 full days, not after 29."""
|
||||
_, rule, iteration = self._create_rule_and_iteration(grace_period="days=30")
|
||||
fixed_now = dt.datetime(2025, 6, 15, 14, 30, 0, tzinfo=dt.UTC)
|
||||
|
||||
# Opened 30 days ago (May 16), should go overdue
|
||||
LifecycleIteration.objects.filter(pk=iteration.pk).update(
|
||||
opened_on=dt.datetime(2025, 5, 16, 12, 0, 0, tzinfo=dt.UTC),
|
||||
state=ReviewState.PENDING,
|
||||
)
|
||||
with patch("django.utils.timezone.now", return_value=fixed_now):
|
||||
self.assertIn(iteration, list(rule._get_newly_overdue_iterations()))
|
||||
|
||||
# Opened 29 days ago (May 17), should NOT go overdue
|
||||
LifecycleIteration.objects.filter(pk=iteration.pk).update(
|
||||
opened_on=dt.datetime(2025, 5, 17, 12, 0, 0, tzinfo=dt.UTC),
|
||||
state=ReviewState.PENDING,
|
||||
)
|
||||
with patch("django.utils.timezone.now", return_value=fixed_now):
|
||||
self.assertNotIn(iteration, list(rule._get_newly_overdue_iterations()))
|
||||
|
||||
def test_due_object_iteration_opened_yesterday(self):
|
||||
"""interval=1 day: object with iteration opened yesterday is due for a new review."""
|
||||
app, rule, iteration = self._create_rule_and_iteration(interval="days=1")
|
||||
fixed_now = dt.datetime(2025, 6, 15, 14, 30, 0, tzinfo=dt.UTC)
|
||||
for opened_on in [
|
||||
dt.datetime(2025, 6, 14, 0, 0, 0, tzinfo=dt.UTC),
|
||||
dt.datetime(2025, 6, 14, 12, 0, 0, tzinfo=dt.UTC),
|
||||
dt.datetime(2025, 6, 14, 23, 59, 59, 999999, tzinfo=dt.UTC),
|
||||
]:
|
||||
with self.subTest(opened_on=opened_on):
|
||||
LifecycleIteration.objects.filter(pk=iteration.pk).update(opened_on=opened_on)
|
||||
with patch("django.utils.timezone.now", return_value=fixed_now):
|
||||
self.assertIn(app, list(rule._get_newly_due_objects()))
|
||||
|
||||
def test_not_due_object_iteration_opened_today(self):
|
||||
"""interval=1 day: object with iteration opened today is NOT due."""
|
||||
app, rule, iteration = self._create_rule_and_iteration(interval="days=1")
|
||||
fixed_now = dt.datetime(2025, 6, 15, 14, 30, 0, tzinfo=dt.UTC)
|
||||
for opened_on in [
|
||||
dt.datetime(2025, 6, 15, 0, 0, 0, tzinfo=dt.UTC),
|
||||
dt.datetime(2025, 6, 15, 14, 30, 0, tzinfo=dt.UTC),
|
||||
dt.datetime(2025, 6, 15, 23, 59, 59, 999999, tzinfo=dt.UTC),
|
||||
]:
|
||||
with self.subTest(opened_on=opened_on):
|
||||
LifecycleIteration.objects.filter(pk=iteration.pk).update(opened_on=opened_on)
|
||||
with patch("django.utils.timezone.now", return_value=fixed_now):
|
||||
self.assertNotIn(app, list(rule._get_newly_due_objects()))
|
||||
|
||||
def test_due_independent_of_task_execution_time(self):
|
||||
"""Due detection gives the same result whether the task runs at 00:00:01 or 23:59:59."""
|
||||
app, rule, iteration = self._create_rule_and_iteration(interval="days=1")
|
||||
opened_on = dt.datetime(2025, 6, 14, 18, 0, 0, tzinfo=dt.UTC)
|
||||
LifecycleIteration.objects.filter(pk=iteration.pk).update(opened_on=opened_on)
|
||||
for task_time in [
|
||||
dt.datetime(2025, 6, 15, 0, 0, 1, tzinfo=dt.UTC),
|
||||
dt.datetime(2025, 6, 15, 12, 0, 0, tzinfo=dt.UTC),
|
||||
dt.datetime(2025, 6, 15, 23, 59, 59, tzinfo=dt.UTC),
|
||||
]:
|
||||
with self.subTest(task_time=task_time):
|
||||
with patch("django.utils.timezone.now", return_value=task_time):
|
||||
self.assertIn(app, list(rule._get_newly_due_objects()))
|
||||
|
||||
def test_due_boundary_multi_day_interval(self):
|
||||
"""interval=30 days: due after 30 full days, not after 29."""
|
||||
app, rule, iteration = self._create_rule_and_iteration(interval="days=30")
|
||||
fixed_now = dt.datetime(2025, 6, 15, 14, 30, 0, tzinfo=dt.UTC)
|
||||
|
||||
# Previous review opened 30 days ago (May 16), review is due for the object
|
||||
LifecycleIteration.objects.filter(pk=iteration.pk).update(
|
||||
opened_on=dt.datetime(2025, 5, 16, 12, 0, 0, tzinfo=dt.UTC)
|
||||
)
|
||||
with patch("django.utils.timezone.now", return_value=fixed_now):
|
||||
self.assertIn(app, list(rule._get_newly_due_objects()))
|
||||
|
||||
# Previous review opened 29 days ago (May 17), new review is NOT due
|
||||
LifecycleIteration.objects.filter(pk=iteration.pk).update(
|
||||
opened_on=dt.datetime(2025, 5, 17, 12, 0, 0, tzinfo=dt.UTC)
|
||||
)
|
||||
with patch("django.utils.timezone.now", return_value=fixed_now):
|
||||
self.assertNotIn(app, list(rule._get_newly_due_objects()))
|
||||
|
||||
def test_apply_overdue_at_boundary(self):
|
||||
"""apply() marks iteration overdue when grace period just expired,
|
||||
regardless of what time the daily task runs."""
|
||||
_, rule, iteration = self._create_rule_and_iteration(
|
||||
grace_period="days=1", interval="days=365"
|
||||
)
|
||||
opened_on = dt.datetime(2025, 6, 14, 20, 0, 0, tzinfo=dt.UTC)
|
||||
for task_time in [
|
||||
dt.datetime(2025, 6, 15, 0, 0, 1, tzinfo=dt.UTC),
|
||||
dt.datetime(2025, 6, 15, 23, 59, 59, tzinfo=dt.UTC),
|
||||
]:
|
||||
with self.subTest(task_time=task_time):
|
||||
LifecycleIteration.objects.filter(pk=iteration.pk).update(
|
||||
opened_on=opened_on, state=ReviewState.PENDING
|
||||
)
|
||||
with patch("django.utils.timezone.now", return_value=task_time):
|
||||
rule.apply()
|
||||
iteration.refresh_from_db()
|
||||
self.assertEqual(iteration.state, ReviewState.OVERDUE)
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
from datetime import datetime
|
||||
from urllib import parse
|
||||
|
||||
from django.contrib.contenttypes.models import ContentType
|
||||
@@ -39,6 +40,10 @@ def link_for_model(model: Model) -> str:
|
||||
return f"{reverse("authentik_core:if-admin")}#{admin_link_for_model(model)}"
|
||||
|
||||
|
||||
def start_of_day(dt: datetime) -> datetime:
|
||||
return dt.replace(hour=0, minute=0, second=0, microsecond=0)
|
||||
|
||||
|
||||
class ContentTypeField(ChoiceField):
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(choices=model_choices(), **kwargs)
|
||||
|
||||
@@ -331,7 +331,7 @@ class GoogleWorkspaceGroupTests(TestCase):
|
||||
).exists()
|
||||
)
|
||||
self.assertFalse(Event.objects.filter(action=EventAction.SYSTEM_EXCEPTION).exists())
|
||||
self.assertEqual(len(http.requests()), 5)
|
||||
self.assertEqual(len(http.requests()), 7)
|
||||
|
||||
def test_sync_discover_multiple(self):
|
||||
"""Test group discovery"""
|
||||
@@ -372,7 +372,7 @@ class GoogleWorkspaceGroupTests(TestCase):
|
||||
).exists()
|
||||
)
|
||||
self.assertFalse(Event.objects.filter(action=EventAction.SYSTEM_EXCEPTION).exists())
|
||||
self.assertEqual(len(http.requests()), 5)
|
||||
self.assertEqual(len(http.requests()), 7)
|
||||
# Change response to trigger update
|
||||
http.add_response(
|
||||
f"https://admin.googleapis.com/admin/directory/v1/groups?customer=my_customer&maxResults=500&orderBy=email&key={self.api_key}&alt=json",
|
||||
|
||||
@@ -309,7 +309,7 @@ class GoogleWorkspaceUserTests(TestCase):
|
||||
).exists()
|
||||
)
|
||||
self.assertFalse(Event.objects.filter(action=EventAction.SYSTEM_EXCEPTION).exists())
|
||||
self.assertEqual(len(http.requests()), 5)
|
||||
self.assertEqual(len(http.requests()), 7)
|
||||
|
||||
def test_sync_discover_multiple(self):
|
||||
"""Test user discovery, running multiple times"""
|
||||
@@ -352,7 +352,7 @@ class GoogleWorkspaceUserTests(TestCase):
|
||||
).exists()
|
||||
)
|
||||
self.assertFalse(Event.objects.filter(action=EventAction.SYSTEM_EXCEPTION).exists())
|
||||
self.assertEqual(len(http.requests()), 5)
|
||||
self.assertEqual(len(http.requests()), 7)
|
||||
# Change response, which will trigger a discovery update
|
||||
http.add_response(
|
||||
f"https://admin.googleapis.com/admin/directory/v1/users?customer=my_customer&maxResults=500&orderBy=email&key={self.api_key}&alt=json",
|
||||
|
||||
@@ -78,7 +78,8 @@ class MicrosoftEntraUserClient(MicrosoftEntraSyncClient[User, MicrosoftEntraProv
|
||||
def create(self, user: User):
|
||||
"""Create user from scratch and create a connection object"""
|
||||
microsoft_user = self.to_schema(user, None)
|
||||
self.check_email_valid(microsoft_user.user_principal_name)
|
||||
if microsoft_user.user_principal_name:
|
||||
self.check_email_valid(microsoft_user.user_principal_name)
|
||||
with transaction.atomic():
|
||||
try:
|
||||
response = self._request(self.client.users.post(microsoft_user))
|
||||
@@ -118,7 +119,8 @@ class MicrosoftEntraUserClient(MicrosoftEntraSyncClient[User, MicrosoftEntraProv
|
||||
def update(self, user: User, connection: MicrosoftEntraProviderUser):
|
||||
"""Update existing user"""
|
||||
microsoft_user = self.to_schema(user, connection)
|
||||
self.check_email_valid(microsoft_user.user_principal_name)
|
||||
if microsoft_user.user_principal_name:
|
||||
self.check_email_valid(microsoft_user.user_principal_name)
|
||||
response = self._request(
|
||||
self.client.users.by_user_id(connection.microsoft_id).patch(microsoft_user)
|
||||
)
|
||||
|
||||
@@ -5,6 +5,7 @@ from django.urls import reverse
|
||||
from rest_framework.fields import CharField, SerializerMethodField, URLField
|
||||
|
||||
from authentik.core.api.providers import ProviderSerializer
|
||||
from authentik.core.models import Provider
|
||||
from authentik.enterprise.api import EnterpriseRequiredMixin
|
||||
from authentik.enterprise.providers.ws_federation.models import WSFederationProvider
|
||||
from authentik.enterprise.providers.ws_federation.processors.metadata import MetadataProcessor
|
||||
@@ -18,6 +19,29 @@ class WSFederationProviderSerializer(EnterpriseRequiredMixin, SAMLProviderSerial
|
||||
wtrealm = CharField(source="audience")
|
||||
url_wsfed = SerializerMethodField()
|
||||
|
||||
def get_url_download_metadata(self, instance: WSFederationProvider) -> str:
|
||||
"""Get metadata download URL"""
|
||||
if "request" not in self._context:
|
||||
return ""
|
||||
request: HttpRequest = self._context["request"]._request
|
||||
try:
|
||||
return request.build_absolute_uri(
|
||||
reverse(
|
||||
"authentik_providers_ws_federation:metadata-download",
|
||||
kwargs={"application_slug": instance.application.slug},
|
||||
)
|
||||
)
|
||||
except Provider.application.RelatedObjectDoesNotExist:
|
||||
return request.build_absolute_uri(
|
||||
reverse(
|
||||
"authentik_api:wsfederationprovider-metadata",
|
||||
kwargs={
|
||||
"pk": instance.pk,
|
||||
},
|
||||
)
|
||||
+ "?download"
|
||||
)
|
||||
|
||||
def get_url_wsfed(self, instance: WSFederationProvider) -> str:
|
||||
"""Get WS-Fed url"""
|
||||
if "request" not in self._context:
|
||||
|
||||
@@ -81,6 +81,8 @@ class SignInProcessor:
|
||||
self.sign_in_request = sign_in_request
|
||||
self.saml_processor = AssertionProcessor(self.provider, self.request, AuthNRequest())
|
||||
self.saml_processor.provider.audience = self.sign_in_request.wtrealm
|
||||
if self.provider.signing_kp:
|
||||
self.saml_processor.provider.sign_assertion = True
|
||||
|
||||
def create_response_token(self):
|
||||
root = Element(f"{{{NS_WS_FED_TRUST}}}RequestSecurityTokenResponse", nsmap=NS_MAP)
|
||||
@@ -148,7 +150,8 @@ class SignInProcessor:
|
||||
def response(self) -> dict[str, str]:
|
||||
root = self.create_response_token()
|
||||
assertion = root.xpath("//saml:Assertion", namespaces=NS_MAP)[0]
|
||||
self.saml_processor._sign(assertion)
|
||||
if self.provider.signing_kp:
|
||||
self.saml_processor._sign(assertion)
|
||||
str_token = etree.tostring(root).decode("utf-8") # nosec
|
||||
return delete_none_values(
|
||||
{
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
from django.urls import path
|
||||
|
||||
from authentik.enterprise.providers.ws_federation.api.providers import WSFederationProviderViewSet
|
||||
from authentik.enterprise.providers.ws_federation.views import WSFedEntryView
|
||||
from authentik.enterprise.providers.ws_federation.views import MetadataDownload, WSFedEntryView
|
||||
|
||||
urlpatterns = [
|
||||
path(
|
||||
@@ -11,6 +11,12 @@ urlpatterns = [
|
||||
WSFedEntryView.as_view(),
|
||||
name="wsfed",
|
||||
),
|
||||
# Metadata
|
||||
path(
|
||||
"<slug:application_slug>/metadata/",
|
||||
MetadataDownload.as_view(),
|
||||
name="metadata-download",
|
||||
),
|
||||
]
|
||||
|
||||
api_urlpatterns = [
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
from django.http import Http404, HttpRequest, HttpResponse
|
||||
from django.shortcuts import get_object_or_404
|
||||
from django.shortcuts import get_object_or_404, redirect
|
||||
from django.urls import reverse
|
||||
from django.utils.translation import gettext as _
|
||||
from django.views import View
|
||||
from structlog.stdlib import get_logger
|
||||
|
||||
from authentik.core.models import Application, AuthenticatedSession
|
||||
@@ -160,3 +162,24 @@ class WSFedFlowFinalView(ChallengeStageView):
|
||||
"attrs": response,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
class MetadataDownload(View):
|
||||
"""Redirect to metadata download"""
|
||||
|
||||
def dispatch(self, request: HttpRequest, application_slug: str) -> HttpResponse:
|
||||
app = Application.objects.filter(slug=application_slug).with_provider().first()
|
||||
if not app:
|
||||
raise Http404
|
||||
provider = app.get_provider()
|
||||
if not provider:
|
||||
raise Http404
|
||||
return redirect(
|
||||
reverse(
|
||||
"authentik_api:wsfederationprovider-metadata",
|
||||
kwargs={
|
||||
"pk": provider.pk,
|
||||
},
|
||||
)
|
||||
+ "?download"
|
||||
)
|
||||
|
||||
@@ -93,11 +93,13 @@ def on_login_failed(
|
||||
credentials: dict[str, str],
|
||||
request: HttpRequest,
|
||||
stage: Stage | None = None,
|
||||
context: dict[str, Any] | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""Failed Login, authentik custom event"""
|
||||
user = User.objects.filter(username=credentials.get("username")).first()
|
||||
Event.new(EventAction.LOGIN_FAILED, **credentials, stage=stage, **kwargs).from_http(
|
||||
context = context or {}
|
||||
Event.new(EventAction.LOGIN_FAILED, **credentials, stage=stage, **context).from_http(
|
||||
request, user
|
||||
)
|
||||
|
||||
|
||||
@@ -207,3 +207,9 @@ class TestEvents(TestCase):
|
||||
"username": user.username,
|
||||
},
|
||||
)
|
||||
|
||||
def test_invalid_string(self):
|
||||
"""Test creating an event with invalid unicode string data"""
|
||||
event = Event.new("unittest", foo="foo bar \u0000 baz")
|
||||
event.save()
|
||||
self.assertEqual(event.context["foo"], "foo bar baz")
|
||||
|
||||
@@ -36,6 +36,10 @@ ALLOWED_SPECIAL_KEYS = re.compile(
|
||||
)
|
||||
|
||||
|
||||
def cleanse_str(raw: Any) -> str:
|
||||
return str(raw).replace("\u0000", "")
|
||||
|
||||
|
||||
def cleanse_item(key: str, value: Any) -> Any:
|
||||
"""Cleanse a single item"""
|
||||
if isinstance(value, dict):
|
||||
@@ -66,7 +70,7 @@ def cleanse_dict(source: dict[Any, Any]) -> dict[Any, Any]:
|
||||
|
||||
def model_to_dict(model: Model) -> dict[str, Any]:
|
||||
"""Convert model to dict"""
|
||||
name = str(model)
|
||||
name = cleanse_str(model)
|
||||
if hasattr(model, "name"):
|
||||
name = model.name
|
||||
return {
|
||||
@@ -133,11 +137,11 @@ def sanitize_item(value: Any) -> Any: # noqa: PLR0911, PLR0912
|
||||
if isinstance(value, ASN):
|
||||
return ASN_CONTEXT_PROCESSOR.asn_to_dict(value)
|
||||
if isinstance(value, Path):
|
||||
return str(value)
|
||||
return cleanse_str(value)
|
||||
if isinstance(value, Exception):
|
||||
return str(value)
|
||||
return cleanse_str(value)
|
||||
if isinstance(value, YAMLTag):
|
||||
return str(value)
|
||||
return cleanse_str(value)
|
||||
if isinstance(value, Enum):
|
||||
return value.value
|
||||
if isinstance(value, type):
|
||||
@@ -161,7 +165,7 @@ def sanitize_item(value: Any) -> Any: # noqa: PLR0911, PLR0912
|
||||
raise ValueError("JSON can't represent timezone-aware times.")
|
||||
return value.isoformat()
|
||||
if isinstance(value, timedelta):
|
||||
return str(value.total_seconds())
|
||||
return cleanse_str(value.total_seconds())
|
||||
if callable(value):
|
||||
return {
|
||||
"type": "callable",
|
||||
@@ -174,8 +178,8 @@ def sanitize_item(value: Any) -> Any: # noqa: PLR0911, PLR0912
|
||||
try:
|
||||
return DjangoJSONEncoder().default(value)
|
||||
except TypeError:
|
||||
return str(value)
|
||||
return str(value)
|
||||
return cleanse_str(value)
|
||||
return cleanse_str(value)
|
||||
|
||||
|
||||
def sanitize_dict(source: dict[Any, Any]) -> dict[Any, Any]:
|
||||
|
||||
@@ -29,6 +29,12 @@ class RefreshOtherFlowsAfterAuthentication(Flag[bool], key="flows_refresh_others
|
||||
visibility = "public"
|
||||
|
||||
|
||||
class ContinuousLogin(Flag[bool], key="flows_continuous_login"):
|
||||
|
||||
default = False
|
||||
visibility = "public"
|
||||
|
||||
|
||||
class AuthentikFlowsConfig(ManagedAppConfig):
|
||||
"""authentik flows app config"""
|
||||
|
||||
|
||||
@@ -9,7 +9,15 @@
|
||||
{{ block.super }}
|
||||
<link rel="prefetch" href="{{ flow_background_url }}" />
|
||||
{% if flow.compatibility_mode and not inspector %}
|
||||
<script data-id="shady-dom">ShadyDOM = { force: true };</script>
|
||||
{% comment %}
|
||||
@see {@link web/types/webcomponents.d.ts} for type definitions.
|
||||
{% endcomment %}
|
||||
<script data-id="shady-dom">
|
||||
"use strict";
|
||||
|
||||
window.ShadyDOM = window.ShadyDOM || {}
|
||||
window.ShadyDOM.force = true
|
||||
</script>
|
||||
{% endif %}
|
||||
{% include "base/header_js.html" %}
|
||||
<script data-id="flow-config">
|
||||
@@ -45,16 +53,11 @@
|
||||
slug="{{ flow.slug }}"
|
||||
class="pf-c-login"
|
||||
data-layout="{{ flow.layout|default:'stacked' }}"
|
||||
loading
|
||||
>
|
||||
{% include "base/placeholder.html" %}
|
||||
|
||||
<ak-brand-links
|
||||
slot="footer"
|
||||
exportparts="list:brand-links-list, list-item:brand-links-list-item"
|
||||
role="contentinfo"
|
||||
aria-label="{% trans 'Site footer' %}"
|
||||
class="pf-c-login__footer {% if flow.layout == 'stacked' %}pf-m-dark{% endif %}"
|
||||
></ak-brand-links>
|
||||
<ak-brand-links name="flow-links" slot="footer"></ak-brand-links>
|
||||
</ak-flow-executor>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
@@ -141,6 +141,10 @@ web:
|
||||
# workers: 2
|
||||
threads: 4
|
||||
path: /
|
||||
timeout_http_read_header: 5s
|
||||
timeout_http_read: 30s
|
||||
timeout_http_write: 60s
|
||||
timeout_http_idle: 120s
|
||||
|
||||
worker:
|
||||
processes: 1
|
||||
@@ -178,3 +182,5 @@ storage:
|
||||
# backend: file # or s3
|
||||
# file: {}
|
||||
# s3: {}
|
||||
|
||||
skip_migrations: false
|
||||
|
||||
@@ -42,7 +42,7 @@ ARG_SANITIZE = re.compile(r"[:.-]")
|
||||
|
||||
|
||||
def sanitize_arg(arg_name: str) -> str:
|
||||
return re.sub(ARG_SANITIZE, "_", arg_name)
|
||||
return re.sub(ARG_SANITIZE, "_", slugify(arg_name))
|
||||
|
||||
|
||||
class BaseEvaluator:
|
||||
@@ -311,7 +311,9 @@ class BaseEvaluator:
|
||||
|
||||
def wrap_expression(self, expression: str) -> str:
|
||||
"""Wrap expression in a function, call it, and save the result as `result`"""
|
||||
handler_signature = ",".join(sanitize_arg(x) for x in self._context.keys())
|
||||
handler_signature = ",".join(
|
||||
[x for x in [sanitize_arg(x) for x in self._context.keys()] if x]
|
||||
)
|
||||
full_expression = ""
|
||||
full_expression += f"def handler({handler_signature}):\n"
|
||||
full_expression += indent(expression, " ")
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import math
|
||||
from typing import Any, Self
|
||||
|
||||
import pglock
|
||||
@@ -68,7 +69,12 @@ class OutgoingSyncProvider(ScheduledModel, Model):
|
||||
return Paginator(self.get_object_qs(type), self.sync_page_size)
|
||||
|
||||
def get_object_sync_time_limit_ms[T: User | Group](self, type: type[T]) -> int:
|
||||
num_pages: int = self.get_paginator(type).num_pages
|
||||
# Use a simple COUNT(*) on the model instead of materializing get_object_qs(),
|
||||
# which for some providers (e.g. SCIM) runs PolicyEngine per-user and is
|
||||
# extremely expensive. The time limit is an upper-bound estimate, so using
|
||||
# the total count (without policy filtering) is a safe overestimate.
|
||||
total_count = type.objects.count()
|
||||
num_pages = math.ceil(total_count / self.sync_page_size) if total_count > 0 else 1
|
||||
page_timeout_ms = timedelta_from_string(self.sync_page_timeout).total_seconds() * 1000
|
||||
return int(num_pages * page_timeout_ms * 1.5)
|
||||
|
||||
|
||||
@@ -103,6 +103,7 @@ class SyncTasks:
|
||||
)
|
||||
users_tasks.run().wait(timeout=provider.get_object_sync_time_limit_ms(User))
|
||||
group_tasks.run().wait(timeout=provider.get_object_sync_time_limit_ms(Group))
|
||||
self._sync_cleanup(provider, task)
|
||||
except TransientSyncException as exc:
|
||||
self.logger.warning("transient sync exception", exc=exc)
|
||||
task.warning("Sync encountered a transient exception. Retrying", exc=exc)
|
||||
@@ -111,6 +112,35 @@ class SyncTasks:
|
||||
task.error(exc)
|
||||
return
|
||||
|
||||
def _sync_cleanup(self, provider: OutgoingSyncProvider, task: Task):
|
||||
"""Delete remote objects that are no longer in scope"""
|
||||
for object_type in (User, Group):
|
||||
try:
|
||||
client = provider.client_for_model(object_type)
|
||||
except TransientSyncException:
|
||||
continue
|
||||
in_scope_pks = set(provider.get_object_qs(object_type).values_list("pk", flat=True))
|
||||
stale = client.connection_type.objects.filter(provider=provider).exclude(
|
||||
**{f"{client.connection_type_query}__pk__in": in_scope_pks}
|
||||
)
|
||||
for connection in stale:
|
||||
try:
|
||||
client.delete(connection.scim_id)
|
||||
task.info(
|
||||
f"Deleted out-of-scope {object_type._meta.verbose_name}",
|
||||
scim_id=connection.scim_id,
|
||||
)
|
||||
except NotFoundSyncException:
|
||||
pass
|
||||
except TransientSyncException as exc:
|
||||
self.logger.warning("transient error during cleanup", exc=exc)
|
||||
self.logger.warning(
|
||||
"Cleanup encountered a transient exception. Retrying", exc=exc
|
||||
)
|
||||
raise Retry() from exc
|
||||
except DryRunRejected as exc:
|
||||
self.logger.info("Rejected dry-run cleanup event", exc=exc)
|
||||
|
||||
def sync_objects(
|
||||
self,
|
||||
object_type: str,
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
"""Test Evaluator base functions"""
|
||||
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
from django.test import RequestFactory, TestCase
|
||||
@@ -353,3 +354,18 @@ class TestEvaluator(TestCase):
|
||||
self.assertEqual(message.to, ["to@example.com"])
|
||||
self.assertEqual(message.cc, ["cc1@example.com", "cc2@example.com"])
|
||||
self.assertEqual(message.bcc, ["bcc1@example.com", "bcc2@example.com"])
|
||||
|
||||
def test_expr_arg_escape(self):
|
||||
"""Test escaping of arguments"""
|
||||
eval = BaseEvaluator()
|
||||
eval._context = {
|
||||
'z=getattr(getattr(__import__("os"), "popen")("id > /tmp/test"), "read")()': "bar",
|
||||
"@@": "baz",
|
||||
"{{": "baz",
|
||||
"aa@@": "baz",
|
||||
}
|
||||
res = eval.evaluate("return locals()")
|
||||
self.assertEqual(
|
||||
res, {"zgetattrgetattr__import__os_popenid_tmptest_read": "bar", "aa": "baz"}
|
||||
)
|
||||
self.assertFalse(Path("/tmp/test").exists())
|
||||
|
||||
119
authentik/lib/tests/test_utils_inheritance.py
Normal file
119
authentik/lib/tests/test_utils_inheritance.py
Normal file
@@ -0,0 +1,119 @@
|
||||
"""Tests for inheritance helpers."""
|
||||
|
||||
from contextlib import contextmanager
|
||||
|
||||
from django.db import connection, models
|
||||
from django.test import TransactionTestCase
|
||||
from django.test.utils import isolate_apps
|
||||
|
||||
from authentik.lib.utils.inheritance import get_deepest_child
|
||||
|
||||
|
||||
@contextmanager
|
||||
def temporary_inheritance_models():
|
||||
"""Create a temporary multi-table inheritance graph for testing."""
|
||||
with isolate_apps("authentik.lib.tests"):
|
||||
|
||||
class GrandParent(models.Model):
|
||||
class Meta:
|
||||
app_label = "tests"
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"GrandParent({self.pk})"
|
||||
|
||||
class Parent(GrandParent):
|
||||
class Meta:
|
||||
app_label = "tests"
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"Parent({self.pk})"
|
||||
|
||||
class Child(Parent):
|
||||
class Meta:
|
||||
app_label = "tests"
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"Child({self.pk})"
|
||||
|
||||
class GrandChild(Child):
|
||||
class Meta:
|
||||
app_label = "tests"
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"GrandChild({self.pk})"
|
||||
|
||||
with connection.schema_editor() as schema_editor:
|
||||
schema_editor.create_model(GrandParent)
|
||||
schema_editor.create_model(Parent)
|
||||
schema_editor.create_model(Child)
|
||||
schema_editor.create_model(GrandChild)
|
||||
|
||||
try:
|
||||
yield GrandParent, Parent, Child, GrandChild
|
||||
finally:
|
||||
with connection.schema_editor() as schema_editor:
|
||||
schema_editor.delete_model(GrandChild)
|
||||
schema_editor.delete_model(Child)
|
||||
schema_editor.delete_model(Parent)
|
||||
schema_editor.delete_model(GrandParent)
|
||||
|
||||
|
||||
class TestInheritanceUtils(TransactionTestCase):
|
||||
"""Tests for helper functions in authentik.lib.utils.inheritance."""
|
||||
|
||||
def test_get_deepest_child_grandparent_to_parent(self):
|
||||
"""GrandParent -> Parent."""
|
||||
with temporary_inheritance_models() as (GrandParent, Parent, _Child, _GrandChild):
|
||||
parent = Parent.objects.create()
|
||||
grandparent = GrandParent.objects.get(pk=parent.pk)
|
||||
|
||||
resolved = get_deepest_child(grandparent)
|
||||
|
||||
self.assertIsInstance(resolved, Parent)
|
||||
self.assertEqual(resolved.pk, parent.pk)
|
||||
|
||||
def test_get_deepest_child_grandparent_to_child(self):
|
||||
"""GrandParent -> Child."""
|
||||
with temporary_inheritance_models() as (GrandParent, _Parent, Child, _GrandChild):
|
||||
child = Child.objects.create()
|
||||
grandparent = GrandParent.objects.get(pk=child.pk)
|
||||
|
||||
resolved = get_deepest_child(grandparent)
|
||||
|
||||
self.assertIsInstance(resolved, Child)
|
||||
self.assertEqual(resolved.pk, child.pk)
|
||||
|
||||
def test_get_deepest_child_grandparent_to_grandchild(self):
|
||||
"""GrandParent -> GrandChild."""
|
||||
with temporary_inheritance_models() as (GrandParent, _Parent, _Child, GrandChild):
|
||||
grandchild = GrandChild.objects.create()
|
||||
grandparent = GrandParent.objects.get(pk=grandchild.pk)
|
||||
|
||||
resolved = get_deepest_child(grandparent)
|
||||
|
||||
self.assertIsInstance(resolved, GrandChild)
|
||||
self.assertEqual(resolved.pk, grandchild.pk)
|
||||
|
||||
def test_get_deepest_child_parent_to_child(self):
|
||||
"""Parent -> Child (start from non-root)."""
|
||||
with temporary_inheritance_models() as (_GrandParent, Parent, Child, _GrandChild):
|
||||
child = Child.objects.create()
|
||||
parent = Parent.objects.get(pk=child.pk)
|
||||
|
||||
resolved = get_deepest_child(parent)
|
||||
|
||||
self.assertIsInstance(resolved, Child)
|
||||
self.assertEqual(resolved.pk, child.pk)
|
||||
|
||||
def test_get_deepest_child_no_queries_with_preloaded_relations(self):
|
||||
"""No extra queries when the inheritance chain is fully select_related."""
|
||||
with temporary_inheritance_models() as (GrandParent, _Parent, _Child, GrandChild):
|
||||
grandchild = GrandChild.objects.create()
|
||||
grandparent = GrandParent.objects.select_related("parent__child__grandchild").get(
|
||||
pk=grandchild.pk
|
||||
)
|
||||
|
||||
with self.assertNumQueries(0):
|
||||
resolved = get_deepest_child(grandparent)
|
||||
|
||||
self.assertIsInstance(resolved, GrandChild)
|
||||
41
authentik/lib/utils/inheritance.py
Normal file
41
authentik/lib/utils/inheritance.py
Normal file
@@ -0,0 +1,41 @@
|
||||
from django.db.models import Model, OneToOneField, OneToOneRel
|
||||
|
||||
|
||||
def get_deepest_child(parent: Model) -> Model:
|
||||
"""
|
||||
In multiple table inheritance, given any ancestor object, get the deepest child object.
|
||||
See https://docs.djangoproject.com/en/dev/topics/db/models/#multi-table-inheritance
|
||||
|
||||
This function does not query the database if `select_related` has been performed on all
|
||||
subclasses of `parent`'s model.
|
||||
"""
|
||||
|
||||
# Almost verbatim copy from django-model-utils, see
|
||||
# https://github.com/jazzband/django-model-utils/blob/5.0.0/model_utils/managers.py#L132
|
||||
one_to_one_rels = [
|
||||
field for field in parent._meta.get_fields() if isinstance(field, OneToOneRel)
|
||||
]
|
||||
|
||||
submodel_fields = [
|
||||
rel
|
||||
for rel in one_to_one_rels
|
||||
if isinstance(rel.field, OneToOneField)
|
||||
and issubclass(rel.field.model, parent._meta.model)
|
||||
and parent._meta.model is not rel.field.model
|
||||
and rel.parent_link
|
||||
]
|
||||
|
||||
submodel_accessors = [submodel_field.get_accessor_name() for submodel_field in submodel_fields]
|
||||
# End Copy
|
||||
|
||||
child = None
|
||||
for submodel in submodel_accessors:
|
||||
try:
|
||||
child = getattr(parent, submodel)
|
||||
break
|
||||
except AttributeError:
|
||||
continue
|
||||
|
||||
if not child:
|
||||
return parent
|
||||
return get_deepest_child(child)
|
||||
@@ -185,8 +185,10 @@ class KubernetesObjectReconciler[T]:
|
||||
|
||||
patch = self.get_patch()
|
||||
if patch is not None:
|
||||
current_json = ApiClient().sanitize_for_serialization(current)
|
||||
|
||||
try:
|
||||
current_json = ApiClient().sanitize_for_serialization(current)
|
||||
except AttributeError:
|
||||
current_json = asdict(current)
|
||||
try:
|
||||
if apply_patch(current_json, patch) != current_json:
|
||||
raise NeedsUpdate()
|
||||
|
||||
@@ -163,4 +163,5 @@ def outpost_pre_delete_cleanup(sender, instance: Outpost, **_):
|
||||
@receiver(pre_delete, sender=AuthenticatedSession)
|
||||
def outpost_logout_revoke(sender: type[AuthenticatedSession], instance: AuthenticatedSession, **_):
|
||||
"""Catch logout by expiring sessions being deleted"""
|
||||
outpost_session_end.send(instance.session.session_key)
|
||||
if Outpost.objects.exists():
|
||||
outpost_session_end.send(instance.session.session_key)
|
||||
|
||||
@@ -7,7 +7,6 @@ from socket import gethostname
|
||||
from typing import Any
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from asgiref.sync import async_to_sync
|
||||
from channels.layers import get_channel_layer
|
||||
from django.core.cache import cache
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
@@ -159,7 +158,7 @@ def outpost_send_update(pk: Any):
|
||||
layer = get_channel_layer()
|
||||
group = build_outpost_group(outpost.pk)
|
||||
LOGGER.debug("sending update", channel=group, outpost=outpost)
|
||||
async_to_sync(layer.group_send)(group, {"type": "event.update"})
|
||||
layer.group_send_blocking(group, {"type": "event.update"})
|
||||
|
||||
|
||||
@actor(description=_("Checks the local environment and create Service connections."))
|
||||
@@ -210,7 +209,7 @@ def outpost_session_end(session_id: str):
|
||||
for outpost in Outpost.objects.all():
|
||||
LOGGER.info("Sending session end signal to outpost", outpost=outpost)
|
||||
group = build_outpost_group(outpost.pk)
|
||||
async_to_sync(layer.group_send)(
|
||||
layer.group_send_blocking(
|
||||
group,
|
||||
{
|
||||
"type": "event.session.end",
|
||||
|
||||
@@ -57,9 +57,11 @@ class PolicyBindingSerializer(ModelSerializer):
|
||||
required=True,
|
||||
)
|
||||
|
||||
policy_obj = PolicySerializer(required=False, read_only=True, source="policy")
|
||||
group_obj = PartialGroupSerializer(required=False, read_only=True, source="group")
|
||||
user_obj = PartialUserSerializer(required=False, read_only=True, source="user")
|
||||
policy_obj = PolicySerializer(required=False, allow_null=True, read_only=True, source="policy")
|
||||
group_obj = PartialGroupSerializer(
|
||||
required=False, allow_null=True, read_only=True, source="group"
|
||||
)
|
||||
user_obj = PartialUserSerializer(required=False, allow_null=True, read_only=True, source="user")
|
||||
|
||||
class Meta:
|
||||
model = PolicyBinding
|
||||
|
||||
@@ -132,9 +132,14 @@ class PolicyEngine:
|
||||
# If we didn't find any static bindings, do nothing
|
||||
return
|
||||
self.logger.debug("P_ENG: Found static bindings", **matched_bindings)
|
||||
if matched_bindings.get("passing", 0) > 0:
|
||||
# Any passing static binding -> passing
|
||||
passing = True
|
||||
if self.mode == PolicyEngineMode.MODE_ANY:
|
||||
if matched_bindings.get("passing", 0) > 0:
|
||||
# Any passing static binding -> passing
|
||||
passing = True
|
||||
elif self.mode == PolicyEngineMode.MODE_ALL:
|
||||
if matched_bindings.get("passing", 0) == matched_bindings["total"]:
|
||||
# All static bindings are passing -> passing
|
||||
passing = True
|
||||
elif matched_bindings["total"] > 0 and matched_bindings.get("passing", 0) < 1:
|
||||
# No matching static bindings but at least one is configured -> not passing
|
||||
passing = False
|
||||
@@ -185,6 +190,16 @@ class PolicyEngine:
|
||||
# Only call .recv() if no result is saved, otherwise we just deadlock here
|
||||
if not proc_info.result:
|
||||
proc_info.result = proc_info.connection.recv()
|
||||
if proc_info.result and proc_info.result._exec_time:
|
||||
HIST_POLICIES_EXECUTION_TIME.labels(
|
||||
binding_order=proc_info.binding.order,
|
||||
binding_target_type=proc_info.binding.target_type,
|
||||
binding_target_name=proc_info.binding.target_name,
|
||||
object_type=(
|
||||
class_to_path(self.request.obj.__class__) if self.request.obj else ""
|
||||
),
|
||||
mode="execute_process",
|
||||
).observe(proc_info.result._exec_time)
|
||||
return self
|
||||
|
||||
@property
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
|
||||
from multiprocessing import get_context
|
||||
from multiprocessing.connection import Connection
|
||||
from time import perf_counter
|
||||
|
||||
from django.core.cache import cache
|
||||
from sentry_sdk import start_span
|
||||
@@ -11,8 +12,6 @@ from structlog.stdlib import get_logger
|
||||
from authentik.events.models import Event, EventAction
|
||||
from authentik.lib.config import CONFIG
|
||||
from authentik.lib.utils.errors import exception_to_dict
|
||||
from authentik.lib.utils.reflection import class_to_path
|
||||
from authentik.policies.apps import HIST_POLICIES_EXECUTION_TIME
|
||||
from authentik.policies.exceptions import PolicyException
|
||||
from authentik.policies.models import PolicyBinding
|
||||
from authentik.policies.types import CACHE_PREFIX, PolicyRequest, PolicyResult
|
||||
@@ -123,18 +122,9 @@ class PolicyProcess(PROCESS_CLASS):
|
||||
|
||||
def profiling_wrapper(self):
|
||||
"""Run with profiling enabled"""
|
||||
with (
|
||||
start_span(
|
||||
op="authentik.policy.process.execute",
|
||||
) as span,
|
||||
HIST_POLICIES_EXECUTION_TIME.labels(
|
||||
binding_order=self.binding.order,
|
||||
binding_target_type=self.binding.target_type,
|
||||
binding_target_name=self.binding.target_name,
|
||||
object_type=class_to_path(self.request.obj.__class__) if self.request.obj else "",
|
||||
mode="execute_process",
|
||||
).time(),
|
||||
):
|
||||
with start_span(
|
||||
op="authentik.policy.process.execute",
|
||||
) as span:
|
||||
span: Span
|
||||
span.set_data("policy", self.binding.policy)
|
||||
span.set_data("request", self.request)
|
||||
@@ -142,8 +132,14 @@ class PolicyProcess(PROCESS_CLASS):
|
||||
|
||||
def run(self): # pragma: no cover
|
||||
"""Task wrapper to run policy checking"""
|
||||
result = None
|
||||
try:
|
||||
self.connection.send(self.profiling_wrapper())
|
||||
start = perf_counter()
|
||||
result = self.profiling_wrapper()
|
||||
end = perf_counter()
|
||||
result._exec_time = max((end - start), 0)
|
||||
except Exception as exc: # noqa
|
||||
LOGGER.warning("Policy failed to run", exc=exc)
|
||||
self.connection.send(PolicyResult(False, str(exc)))
|
||||
result = PolicyResult(False, str(exc))
|
||||
finally:
|
||||
self.connection.send(result)
|
||||
|
||||
@@ -33,6 +33,9 @@ class TestPolicyEngine(TestCase):
|
||||
self.policy_raises = ExpressionPolicy.objects.create(
|
||||
name=generate_id(), expression="{{ 0/0 }}"
|
||||
)
|
||||
self.group_member = Group.objects.create(name=generate_id())
|
||||
self.user.groups.add(self.group_member)
|
||||
self.group_non_member = Group.objects.create(name=generate_id())
|
||||
|
||||
def test_engine_empty(self):
|
||||
"""Ensure empty policy list passes"""
|
||||
@@ -51,7 +54,7 @@ class TestPolicyEngine(TestCase):
|
||||
self.assertEqual(result.passing, True)
|
||||
self.assertEqual(result.messages, ("dummy",))
|
||||
|
||||
def test_engine_mode_all(self):
|
||||
def test_engine_mode_all_dyn(self):
|
||||
"""Ensure all policies passes with AND mode (false and true -> false)"""
|
||||
pbm = PolicyBindingModel.objects.create(policy_engine_mode=PolicyEngineMode.MODE_ALL)
|
||||
PolicyBinding.objects.create(target=pbm, policy=self.policy_false, order=0)
|
||||
@@ -67,7 +70,7 @@ class TestPolicyEngine(TestCase):
|
||||
),
|
||||
)
|
||||
|
||||
def test_engine_mode_any(self):
|
||||
def test_engine_mode_any_dyn(self):
|
||||
"""Ensure all policies passes with OR mode (false and true -> true)"""
|
||||
pbm = PolicyBindingModel.objects.create(policy_engine_mode=PolicyEngineMode.MODE_ANY)
|
||||
PolicyBinding.objects.create(target=pbm, policy=self.policy_false, order=0)
|
||||
@@ -83,6 +86,26 @@ class TestPolicyEngine(TestCase):
|
||||
),
|
||||
)
|
||||
|
||||
def test_engine_mode_all_static(self):
|
||||
"""Ensure all policies passes with OR mode (false and true -> true)"""
|
||||
pbm = PolicyBindingModel.objects.create(policy_engine_mode=PolicyEngineMode.MODE_ALL)
|
||||
PolicyBinding.objects.create(target=pbm, group=self.group_member, order=0)
|
||||
PolicyBinding.objects.create(target=pbm, group=self.group_non_member, order=1)
|
||||
engine = PolicyEngine(pbm, self.user)
|
||||
result = engine.build().result
|
||||
self.assertEqual(result.passing, False)
|
||||
self.assertEqual(result.messages, ())
|
||||
|
||||
def test_engine_mode_any_static(self):
|
||||
"""Ensure all policies passes with OR mode (false and true -> true)"""
|
||||
pbm = PolicyBindingModel.objects.create(policy_engine_mode=PolicyEngineMode.MODE_ANY)
|
||||
PolicyBinding.objects.create(target=pbm, group=self.group_member, order=0)
|
||||
PolicyBinding.objects.create(target=pbm, group=self.group_non_member, order=1)
|
||||
engine = PolicyEngine(pbm, self.user)
|
||||
result = engine.build().result
|
||||
self.assertEqual(result.passing, True)
|
||||
self.assertEqual(result.messages, ())
|
||||
|
||||
def test_engine_negate(self):
|
||||
"""Test negate flag"""
|
||||
pbm = PolicyBindingModel.objects.create()
|
||||
|
||||
@@ -77,6 +77,8 @@ class PolicyResult:
|
||||
|
||||
log_messages: list[LogEvent] | None
|
||||
|
||||
_exec_time: int | None
|
||||
|
||||
def __init__(self, passing: bool, *messages: str):
|
||||
self.passing = passing
|
||||
self.messages = messages
|
||||
@@ -84,6 +86,7 @@ class PolicyResult:
|
||||
self.source_binding = None
|
||||
self.source_results = []
|
||||
self.log_messages = []
|
||||
self._exec_time = None
|
||||
|
||||
def __repr__(self):
|
||||
return self.__str__()
|
||||
|
||||
@@ -68,6 +68,8 @@ class IDToken:
|
||||
at_hash: str | None = None
|
||||
# Session ID, https://openid.net/specs/openid-connect-frontchannel-1_0.html#ClaimsContents
|
||||
sid: str | None = None
|
||||
# JWT ID, https://www.rfc-editor.org/rfc/rfc7519.html#section-4.1.7
|
||||
jti: str | None = None
|
||||
|
||||
claims: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
@@ -81,6 +83,7 @@ class IDToken:
|
||||
(token.expires if token.expires is not None else default_token_duration()).timestamp()
|
||||
)
|
||||
id_token.iss = provider.get_issuer(request)
|
||||
id_token.jti = generate_id()
|
||||
id_token.aud = provider.client_id
|
||||
id_token.claims = {}
|
||||
|
||||
|
||||
@@ -5,6 +5,7 @@ from urllib.parse import parse_qs, urlparse
|
||||
|
||||
from django.test import RequestFactory
|
||||
from django.urls import reverse
|
||||
from django.utils import translation
|
||||
from django.utils.timezone import now
|
||||
|
||||
from authentik.blueprints.tests import apply_blueprint
|
||||
@@ -140,26 +141,6 @@ class TestAuthorize(OAuthTestCase):
|
||||
OAuthAuthorizationParams.from_request(request)
|
||||
self.assertEqual(cm.exception.cause, "redirect_uri_forbidden_scheme")
|
||||
|
||||
def test_invalid_redirect_uri_empty(self):
|
||||
"""test missing/invalid redirect URI"""
|
||||
provider = OAuth2Provider.objects.create(
|
||||
name=generate_id(),
|
||||
client_id="test",
|
||||
authorization_flow=create_test_flow(),
|
||||
redirect_uris=[],
|
||||
)
|
||||
request = self.factory.get(
|
||||
"/",
|
||||
data={
|
||||
"response_type": "code",
|
||||
"client_id": "test",
|
||||
"redirect_uri": "+",
|
||||
},
|
||||
)
|
||||
OAuthAuthorizationParams.from_request(request)
|
||||
provider.refresh_from_db()
|
||||
self.assertEqual(provider.redirect_uris, [RedirectURI(RedirectURIMatchingMode.STRICT, "+")])
|
||||
|
||||
def test_invalid_redirect_uri_regex(self):
|
||||
"""test missing/invalid redirect URI"""
|
||||
OAuth2Provider.objects.create(
|
||||
@@ -393,7 +374,7 @@ class TestAuthorize(OAuthTestCase):
|
||||
"nonce": generate_id(),
|
||||
},
|
||||
)
|
||||
token: AccessToken = AccessToken.objects.filter(user=user).first()
|
||||
token = AccessToken.objects.filter(user=user).first()
|
||||
expires = timedelta_from_string(provider.access_token_validity).total_seconds()
|
||||
self.assertEqual(
|
||||
response.url,
|
||||
@@ -465,7 +446,7 @@ class TestAuthorize(OAuthTestCase):
|
||||
},
|
||||
)
|
||||
self.assertEqual(response.status_code, 302)
|
||||
token: AccessToken = AccessToken.objects.filter(user=user).first()
|
||||
token = AccessToken.objects.filter(user=user).first()
|
||||
expires = timedelta_from_string(provider.access_token_validity).total_seconds()
|
||||
jwt = self.validate_jwe(token, provider)
|
||||
self.assertEqual(jwt["amr"], ["pwd"])
|
||||
@@ -564,7 +545,7 @@ class TestAuthorize(OAuthTestCase):
|
||||
response = self.client.get(
|
||||
reverse("authentik_api:flow-executor", kwargs={"flow_slug": flow.slug}),
|
||||
)
|
||||
token: AccessToken = AccessToken.objects.filter(user=user).first()
|
||||
token = AccessToken.objects.filter(user=user).first()
|
||||
self.assertIsNotNone(token)
|
||||
self.assertJSONEqual(
|
||||
response.content.decode(),
|
||||
@@ -690,18 +671,21 @@ class TestAuthorize(OAuthTestCase):
|
||||
Application.objects.create(name="app", slug="app", provider=provider)
|
||||
state = generate_id()
|
||||
self.client.logout()
|
||||
response = self.client.get(
|
||||
reverse("authentik_providers_oauth2:authorize"),
|
||||
data={
|
||||
"response_type": "code",
|
||||
"client_id": "test",
|
||||
"state": state,
|
||||
"redirect_uri": "foo://localhost",
|
||||
"ui_locales": "invalid fr",
|
||||
},
|
||||
)
|
||||
parsed = parse_qs(urlparse(response.url).query)
|
||||
self.assertEqual(parsed["locale"], ["fr"])
|
||||
try:
|
||||
response = self.client.get(
|
||||
reverse("authentik_providers_oauth2:authorize"),
|
||||
data={
|
||||
"response_type": "code",
|
||||
"client_id": "test",
|
||||
"state": state,
|
||||
"redirect_uri": "foo://localhost",
|
||||
"ui_locales": "invalid fr",
|
||||
},
|
||||
)
|
||||
parsed = parse_qs(urlparse(response.url).query)
|
||||
self.assertEqual(parsed["locale"], ["fr"])
|
||||
finally:
|
||||
translation.deactivate()
|
||||
|
||||
@apply_blueprint("default/flow-default-authentication-flow.yaml")
|
||||
def test_ui_locales_invalid(self):
|
||||
|
||||
@@ -4,22 +4,19 @@ from unittest.mock import Mock, patch
|
||||
|
||||
import jwt
|
||||
from django.test import RequestFactory
|
||||
from django.utils import timezone
|
||||
from dramatiq.results.errors import ResultFailure
|
||||
from requests import Response
|
||||
from requests.exceptions import HTTPError, Timeout
|
||||
|
||||
from authentik.core.models import Application, AuthenticatedSession, Session
|
||||
from authentik.core.models import Application
|
||||
from authentik.core.tests.utils import create_test_admin_user, create_test_flow
|
||||
from authentik.lib.generators import generate_id
|
||||
from authentik.providers.oauth2.id_token import hash_session_key
|
||||
from authentik.providers.oauth2.models import (
|
||||
AccessToken,
|
||||
OAuth2LogoutMethod,
|
||||
OAuth2Provider,
|
||||
RedirectURI,
|
||||
RedirectURIMatchingMode,
|
||||
RefreshToken,
|
||||
)
|
||||
from authentik.providers.oauth2.tasks import send_backchannel_logout_request
|
||||
from authentik.providers.oauth2.tests.utils import OAuthTestCase
|
||||
@@ -45,52 +42,6 @@ class TestBackChannelLogout(OAuthTestCase):
|
||||
self.app.provider = self.provider
|
||||
self.app.save()
|
||||
|
||||
def _create_session(self, session_key=None):
|
||||
"""Create a session with the given key or a generated one"""
|
||||
session_key = session_key or f"session-{generate_id()}"
|
||||
session = Session.objects.create(
|
||||
session_key=session_key,
|
||||
expires=timezone.now() + timezone.timedelta(hours=1),
|
||||
last_ip="255.255.255.255",
|
||||
)
|
||||
auth_session = AuthenticatedSession.objects.create(
|
||||
session=session,
|
||||
user=self.user,
|
||||
)
|
||||
return auth_session
|
||||
|
||||
def _create_token(
|
||||
self, provider, user, session=None, token_type="access", token_id=None
|
||||
): # nosec
|
||||
"""Create a token of the specified type"""
|
||||
token_id = token_id or f"{token_type}-token-{generate_id()}"
|
||||
kwargs = {
|
||||
"provider": provider,
|
||||
"user": user,
|
||||
"session": session,
|
||||
"token": token_id,
|
||||
"_id_token": "{}",
|
||||
"auth_time": timezone.now(),
|
||||
}
|
||||
|
||||
if token_type == "access": # nosec
|
||||
return AccessToken.objects.create(**kwargs)
|
||||
else: # refresh
|
||||
return RefreshToken.objects.create(**kwargs)
|
||||
|
||||
def _create_provider(self, name=None):
|
||||
"""Create an OAuth2 provider"""
|
||||
name = name or f"provider-{generate_id()}"
|
||||
provider = OAuth2Provider.objects.create(
|
||||
name=name,
|
||||
authorization_flow=create_test_flow(),
|
||||
redirect_uris=[
|
||||
RedirectURI(RedirectURIMatchingMode.STRICT, f"http://{name}/callback"),
|
||||
],
|
||||
signing_key=self.keypair,
|
||||
)
|
||||
return provider
|
||||
|
||||
def _create_logout_token(
|
||||
self,
|
||||
provider: OAuth2Provider | None = None,
|
||||
|
||||
@@ -1,13 +1,16 @@
|
||||
"""Device backchannel tests"""
|
||||
|
||||
from base64 import b64encode
|
||||
from json import loads
|
||||
from urllib.parse import quote
|
||||
|
||||
from django.urls import reverse
|
||||
|
||||
from authentik.blueprints.tests import apply_blueprint
|
||||
from authentik.core.models import Application
|
||||
from authentik.core.tests.utils import create_test_flow
|
||||
from authentik.lib.generators import generate_id
|
||||
from authentik.providers.oauth2.models import OAuth2Provider
|
||||
from authentik.providers.oauth2.models import DeviceToken, OAuth2Provider, ScopeMapping
|
||||
from authentik.providers.oauth2.tests.utils import OAuthTestCase
|
||||
|
||||
|
||||
@@ -26,7 +29,7 @@ class TesOAuth2DeviceBackchannel(OAuthTestCase):
|
||||
provider=self.provider,
|
||||
)
|
||||
|
||||
def test_backchannel_invalid(self):
|
||||
def test_backchannel_invalid_client_id_via_post_body(self):
|
||||
"""Test backchannel"""
|
||||
res = self.client.post(
|
||||
reverse("authentik_providers_oauth2:device"),
|
||||
@@ -50,7 +53,7 @@ class TesOAuth2DeviceBackchannel(OAuthTestCase):
|
||||
)
|
||||
self.assertEqual(res.status_code, 400)
|
||||
|
||||
def test_backchannel(self):
|
||||
def test_backchannel_client_id_via_post_body(self):
|
||||
"""Test backchannel"""
|
||||
res = self.client.post(
|
||||
reverse("authentik_providers_oauth2:device"),
|
||||
@@ -61,3 +64,104 @@ class TesOAuth2DeviceBackchannel(OAuthTestCase):
|
||||
self.assertEqual(res.status_code, 200)
|
||||
body = loads(res.content.decode())
|
||||
self.assertEqual(body["expires_in"], 60)
|
||||
|
||||
def test_backchannel_invalid_client_id_via_auth_header(self):
|
||||
"""Test backchannel"""
|
||||
creds = b64encode(b"foo:").decode()
|
||||
res = self.client.post(
|
||||
reverse("authentik_providers_oauth2:device"),
|
||||
HTTP_AUTHORIZATION=f"Basic {creds}",
|
||||
)
|
||||
self.assertEqual(res.status_code, 400)
|
||||
res = self.client.post(
|
||||
reverse("authentik_providers_oauth2:device"),
|
||||
)
|
||||
self.assertEqual(res.status_code, 400)
|
||||
# test without application
|
||||
self.application.provider = None
|
||||
self.application.save()
|
||||
res = self.client.post(
|
||||
reverse("authentik_providers_oauth2:device"),
|
||||
data={
|
||||
"client_id": "test",
|
||||
},
|
||||
)
|
||||
self.assertEqual(res.status_code, 400)
|
||||
|
||||
def test_backchannel_client_id_via_auth_header(self):
|
||||
"""Test backchannel"""
|
||||
creds = b64encode(f"{self.provider.client_id}:".encode()).decode()
|
||||
res = self.client.post(
|
||||
reverse("authentik_providers_oauth2:device"),
|
||||
HTTP_AUTHORIZATION=f"Basic {creds}",
|
||||
)
|
||||
self.assertEqual(res.status_code, 200)
|
||||
body = loads(res.content.decode())
|
||||
self.assertEqual(body["expires_in"], 60)
|
||||
|
||||
def test_backchannel_client_id_via_auth_header_urlencoded(self):
|
||||
"""Test URL-encoded client IDs in Basic auth"""
|
||||
self.provider.client_id = "test/client+id"
|
||||
self.provider.save()
|
||||
creds = b64encode(f"{quote(self.provider.client_id, safe='')}:".encode()).decode()
|
||||
res = self.client.post(
|
||||
reverse("authentik_providers_oauth2:device"),
|
||||
HTTP_AUTHORIZATION=f"Basic {creds}",
|
||||
)
|
||||
self.assertEqual(res.status_code, 200)
|
||||
body = loads(res.content.decode())
|
||||
self.assertEqual(body["expires_in"], 60)
|
||||
|
||||
@apply_blueprint("system/providers-oauth2.yaml")
|
||||
def test_backchannel_scopes(self):
|
||||
"""Test backchannel"""
|
||||
self.provider.property_mappings.set(
|
||||
ScopeMapping.objects.filter(
|
||||
managed__in=[
|
||||
"goauthentik.io/providers/oauth2/scope-openid",
|
||||
"goauthentik.io/providers/oauth2/scope-email",
|
||||
"goauthentik.io/providers/oauth2/scope-profile",
|
||||
]
|
||||
)
|
||||
)
|
||||
creds = b64encode(f"{self.provider.client_id}:".encode()).decode()
|
||||
res = self.client.post(
|
||||
reverse("authentik_providers_oauth2:device"),
|
||||
HTTP_AUTHORIZATION=f"Basic {creds}",
|
||||
data={"scope": "openid email"},
|
||||
)
|
||||
self.assertEqual(res.status_code, 200)
|
||||
body = loads(res.content.decode())
|
||||
self.assertEqual(body["expires_in"], 60)
|
||||
token = DeviceToken.objects.filter(device_code=body["device_code"]).first()
|
||||
self.assertIsNotNone(token)
|
||||
self.assertEqual(len(token.scope), 2)
|
||||
self.assertIn("openid", token.scope)
|
||||
self.assertIn("email", token.scope)
|
||||
|
||||
@apply_blueprint("system/providers-oauth2.yaml")
|
||||
def test_backchannel_scopes_extra(self):
|
||||
"""Test backchannel"""
|
||||
self.provider.property_mappings.set(
|
||||
ScopeMapping.objects.filter(
|
||||
managed__in=[
|
||||
"goauthentik.io/providers/oauth2/scope-openid",
|
||||
"goauthentik.io/providers/oauth2/scope-email",
|
||||
"goauthentik.io/providers/oauth2/scope-profile",
|
||||
]
|
||||
)
|
||||
)
|
||||
creds = b64encode(f"{self.provider.client_id}:".encode()).decode()
|
||||
res = self.client.post(
|
||||
reverse("authentik_providers_oauth2:device"),
|
||||
HTTP_AUTHORIZATION=f"Basic {creds}",
|
||||
data={"scope": "openid email foo"},
|
||||
)
|
||||
self.assertEqual(res.status_code, 200)
|
||||
body = loads(res.content.decode())
|
||||
self.assertEqual(body["expires_in"], 60)
|
||||
token = DeviceToken.objects.filter(device_code=body["device_code"]).first()
|
||||
self.assertIsNotNone(token)
|
||||
self.assertEqual(len(token.scope), 2)
|
||||
self.assertIn("openid", token.scope)
|
||||
self.assertIn("email", token.scope)
|
||||
|
||||
@@ -14,6 +14,7 @@ from authentik.lib.generators import generate_id
|
||||
from authentik.providers.oauth2.id_token import IDToken
|
||||
from authentik.providers.oauth2.models import (
|
||||
AccessToken,
|
||||
ClientTypes,
|
||||
OAuth2Provider,
|
||||
RedirectURI,
|
||||
RedirectURIMatchingMode,
|
||||
@@ -43,7 +44,7 @@ class TesOAuth2Introspection(OAuthTestCase):
|
||||
|
||||
def test_introspect_refresh(self):
|
||||
"""Test introspect"""
|
||||
token: RefreshToken = RefreshToken.objects.create(
|
||||
token = RefreshToken.objects.create(
|
||||
provider=self.provider,
|
||||
user=self.user,
|
||||
token=generate_id(),
|
||||
@@ -75,7 +76,7 @@ class TesOAuth2Introspection(OAuthTestCase):
|
||||
|
||||
def test_introspect_access(self):
|
||||
"""Test introspect"""
|
||||
token: AccessToken = AccessToken.objects.create(
|
||||
token = AccessToken.objects.create(
|
||||
provider=self.provider,
|
||||
user=self.user,
|
||||
token=generate_id(),
|
||||
@@ -130,7 +131,7 @@ class TesOAuth2Introspection(OAuthTestCase):
|
||||
)
|
||||
auth = b64encode(f"{provider.client_id}:{provider.client_secret}".encode()).decode()
|
||||
|
||||
token: AccessToken = AccessToken.objects.create(
|
||||
token = AccessToken.objects.create(
|
||||
provider=self.provider,
|
||||
user=self.user,
|
||||
token=generate_id(),
|
||||
@@ -169,3 +170,76 @@ class TesOAuth2Introspection(OAuthTestCase):
|
||||
"active": False,
|
||||
},
|
||||
)
|
||||
|
||||
def test_introspect_provider_public(self):
|
||||
"""Test introspect"""
|
||||
self.provider.client_type = ClientTypes.PUBLIC
|
||||
self.provider.save()
|
||||
token = AccessToken.objects.create(
|
||||
provider=self.provider,
|
||||
user=self.user,
|
||||
token=generate_id(),
|
||||
auth_time=timezone.now(),
|
||||
_scope="openid user profile",
|
||||
_id_token=json.dumps(
|
||||
asdict(
|
||||
IDToken("foo", "bar"),
|
||||
)
|
||||
),
|
||||
)
|
||||
res = self.client.post(
|
||||
reverse("authentik_providers_oauth2:token-introspection"),
|
||||
HTTP_AUTHORIZATION=f"Basic {self.auth}",
|
||||
data={"token": token.token},
|
||||
)
|
||||
self.assertEqual(res.status_code, 200)
|
||||
self.assertJSONEqual(
|
||||
res.content.decode(),
|
||||
{
|
||||
"active": False,
|
||||
},
|
||||
)
|
||||
|
||||
def test_introspect_provider_fed(self):
|
||||
"""Test introspect with federation. self.provider is a confidential
|
||||
client and other_provider is a public client."""
|
||||
other_provider = OAuth2Provider.objects.create(
|
||||
name=generate_id(),
|
||||
authorization_flow=create_test_flow(),
|
||||
redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "")],
|
||||
signing_key=create_test_cert(),
|
||||
client_type=ClientTypes.PUBLIC,
|
||||
)
|
||||
Application.objects.create(name=generate_id(), slug=generate_id(), provider=other_provider)
|
||||
|
||||
other_provider.jwt_federation_providers.add(self.provider)
|
||||
|
||||
token = AccessToken.objects.create(
|
||||
provider=other_provider,
|
||||
user=self.user,
|
||||
token=generate_id(),
|
||||
auth_time=timezone.now(),
|
||||
_scope="openid user profile",
|
||||
_id_token=json.dumps(
|
||||
asdict(
|
||||
IDToken("foo", "bar"),
|
||||
)
|
||||
),
|
||||
)
|
||||
res = self.client.post(
|
||||
reverse("authentik_providers_oauth2:token-introspection"),
|
||||
HTTP_AUTHORIZATION=f"Basic {self.auth}",
|
||||
data={"token": token.token},
|
||||
)
|
||||
self.assertEqual(res.status_code, 200)
|
||||
self.assertJSONEqual(
|
||||
res.content.decode(),
|
||||
{
|
||||
"acr": ACR_AUTHENTIK_DEFAULT,
|
||||
"sub": "bar",
|
||||
"iss": "foo",
|
||||
"active": True,
|
||||
"client_id": other_provider.client_id,
|
||||
"scope": " ".join(token.scope),
|
||||
},
|
||||
)
|
||||
|
||||
@@ -46,7 +46,7 @@ class TesOAuth2Revoke(OAuthTestCase):
|
||||
|
||||
def test_revoke_refresh(self):
|
||||
"""Test revoke"""
|
||||
token: RefreshToken = RefreshToken.objects.create(
|
||||
token = RefreshToken.objects.create(
|
||||
provider=self.provider,
|
||||
user=self.user,
|
||||
token=generate_id(),
|
||||
@@ -69,7 +69,7 @@ class TesOAuth2Revoke(OAuthTestCase):
|
||||
|
||||
def test_revoke_access(self):
|
||||
"""Test revoke"""
|
||||
token: AccessToken = AccessToken.objects.create(
|
||||
token = AccessToken.objects.create(
|
||||
provider=self.provider,
|
||||
user=self.user,
|
||||
token=generate_id(),
|
||||
@@ -105,7 +105,19 @@ class TesOAuth2Revoke(OAuthTestCase):
|
||||
"""Test revoke (invalid auth)"""
|
||||
res = self.client.post(
|
||||
reverse("authentik_providers_oauth2:token-revoke"),
|
||||
HTTP_AUTHORIZATION="Basic fqewr",
|
||||
HTTP_AUTHORIZATION="Basic aaa",
|
||||
data={
|
||||
"token": generate_id(),
|
||||
},
|
||||
)
|
||||
self.assertEqual(res.status_code, 401)
|
||||
|
||||
def test_revoke_invalid_auth_secret(self):
|
||||
"""Test revoke (invalid secret)"""
|
||||
invalid_auth = b64encode(f"{self.provider.client_id}:aaa".encode()).decode()
|
||||
res = self.client.post(
|
||||
reverse("authentik_providers_oauth2:token-revoke"),
|
||||
HTTP_AUTHORIZATION=f"Basic {invalid_auth}",
|
||||
data={
|
||||
"token": generate_id(),
|
||||
},
|
||||
@@ -116,7 +128,7 @@ class TesOAuth2Revoke(OAuthTestCase):
|
||||
"""Test revoke public client"""
|
||||
self.provider.client_type = ClientTypes.PUBLIC
|
||||
self.provider.save()
|
||||
token: AccessToken = AccessToken.objects.create(
|
||||
token = AccessToken.objects.create(
|
||||
provider=self.provider,
|
||||
user=self.user,
|
||||
token=generate_id(),
|
||||
@@ -220,3 +232,74 @@ class TesOAuth2Revoke(OAuthTestCase):
|
||||
self.assertEqual(AccessToken.objects.all().count(), 0)
|
||||
self.assertEqual(RefreshToken.objects.all().count(), 0)
|
||||
self.assertEqual(DeviceToken.objects.all().count(), 0)
|
||||
|
||||
def test_revoke_provider_fed(self):
|
||||
"""Test revoke with federation. self.provider is a confidential
|
||||
client and other_provider is a public client."""
|
||||
other_provider = OAuth2Provider.objects.create(
|
||||
name=generate_id(),
|
||||
authorization_flow=create_test_flow(),
|
||||
redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "")],
|
||||
signing_key=create_test_cert(),
|
||||
client_type=ClientTypes.PUBLIC,
|
||||
)
|
||||
Application.objects.create(name=generate_id(), slug=generate_id(), provider=other_provider)
|
||||
|
||||
other_provider.jwt_federation_providers.add(self.provider)
|
||||
|
||||
token = AccessToken.objects.create(
|
||||
provider=other_provider,
|
||||
user=self.user,
|
||||
token=generate_id(),
|
||||
auth_time=timezone.now(),
|
||||
_scope="openid user profile",
|
||||
_id_token=json.dumps(
|
||||
asdict(
|
||||
IDToken("foo", "bar"),
|
||||
)
|
||||
),
|
||||
)
|
||||
res = self.client.post(
|
||||
reverse("authentik_providers_oauth2:token-revoke"),
|
||||
HTTP_AUTHORIZATION=f"Basic {self.auth}",
|
||||
data={"token": token.token},
|
||||
)
|
||||
self.assertEqual(res.status_code, 200)
|
||||
self.assertJSONEqual(res.content.decode(), {})
|
||||
|
||||
def test_revoke_provider_fed_public(self):
|
||||
"""Test revoke with federation. self.provider is a public
|
||||
client and other_provider is a public client."""
|
||||
self.provider.client_type = ClientTypes.PUBLIC
|
||||
self.provider.save()
|
||||
other_provider = OAuth2Provider.objects.create(
|
||||
name=generate_id(),
|
||||
authorization_flow=create_test_flow(),
|
||||
redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "")],
|
||||
signing_key=create_test_cert(),
|
||||
client_type=ClientTypes.PUBLIC,
|
||||
)
|
||||
Application.objects.create(name=generate_id(), slug=generate_id(), provider=other_provider)
|
||||
|
||||
other_provider.jwt_federation_providers.add(self.provider)
|
||||
|
||||
token = AccessToken.objects.create(
|
||||
provider=other_provider,
|
||||
user=self.user,
|
||||
token=generate_id(),
|
||||
auth_time=timezone.now(),
|
||||
_scope="openid user profile",
|
||||
_id_token=json.dumps(
|
||||
asdict(
|
||||
IDToken("foo", "bar"),
|
||||
)
|
||||
),
|
||||
)
|
||||
auth_public = b64encode(f"{self.provider.client_id}:{generate_id()}".encode()).decode()
|
||||
res = self.client.post(
|
||||
reverse("authentik_providers_oauth2:token-revoke"),
|
||||
HTTP_AUTHORIZATION=f"Basic {auth_public}",
|
||||
data={"token": token.token},
|
||||
)
|
||||
self.assertEqual(res.status_code, 200)
|
||||
self.assertTrue(AccessToken.objects.filter(token=token.token).exists())
|
||||
|
||||
@@ -1,11 +1,15 @@
|
||||
"""Test token view"""
|
||||
|
||||
from base64 import b64encode
|
||||
from datetime import timedelta
|
||||
from json import dumps
|
||||
from urllib.parse import quote
|
||||
|
||||
from django.test import RequestFactory
|
||||
from django.urls import reverse
|
||||
from django.utils import timezone
|
||||
from django.utils.timezone import now
|
||||
from freezegun import freeze_time
|
||||
|
||||
from authentik.blueprints.tests import apply_blueprint
|
||||
from authentik.common.oauth.constants import (
|
||||
@@ -28,6 +32,7 @@ from authentik.providers.oauth2.models import (
|
||||
ScopeMapping,
|
||||
)
|
||||
from authentik.providers.oauth2.tests.utils import OAuthTestCase
|
||||
from authentik.providers.oauth2.utils import extract_client_auth
|
||||
from authentik.providers.oauth2.views.token import TokenParams
|
||||
|
||||
|
||||
@@ -97,7 +102,7 @@ class TestToken(OAuthTestCase):
|
||||
)
|
||||
header = b64encode(f"{provider.client_id}:{provider.client_secret}".encode()).decode()
|
||||
user = create_test_admin_user()
|
||||
token: RefreshToken = RefreshToken.objects.create(
|
||||
token = RefreshToken.objects.create(
|
||||
provider=provider,
|
||||
user=user,
|
||||
token=generate_id(),
|
||||
@@ -115,6 +120,20 @@ class TestToken(OAuthTestCase):
|
||||
params = TokenParams.parse(request, provider, provider.client_id, provider.client_secret)
|
||||
self.assertEqual(params.provider, provider)
|
||||
|
||||
def test_extract_client_auth_basic_auth_percent_decodes(self):
|
||||
"""test percent-decoding of client credentials in Basic auth"""
|
||||
header = b64encode(
|
||||
f"{quote('client/id', safe='')}:{quote('secret+/==', safe='')}".encode()
|
||||
).decode()
|
||||
request = self.factory.post("/", HTTP_AUTHORIZATION=f"Basic {header}")
|
||||
self.assertEqual(extract_client_auth(request), ("client/id", "secret+/=="))
|
||||
|
||||
def test_extract_client_auth_basic_auth_preserves_raw_plus(self):
|
||||
"""test compatibility with clients that still send raw plus characters"""
|
||||
header = b64encode(b"client:secret+plus").decode()
|
||||
request = self.factory.post("/", HTTP_AUTHORIZATION=f"Basic {header}")
|
||||
self.assertEqual(extract_client_auth(request), ("client", "secret+plus"))
|
||||
|
||||
def test_auth_code_view(self):
|
||||
"""test request param"""
|
||||
provider = OAuth2Provider.objects.create(
|
||||
@@ -140,7 +159,7 @@ class TestToken(OAuthTestCase):
|
||||
},
|
||||
HTTP_AUTHORIZATION=f"Basic {header}",
|
||||
)
|
||||
access: AccessToken = AccessToken.objects.filter(user=user, provider=provider).first()
|
||||
access = AccessToken.objects.filter(user=user, provider=provider).first()
|
||||
self.assertJSONEqual(
|
||||
response.content.decode(),
|
||||
{
|
||||
@@ -182,7 +201,7 @@ class TestToken(OAuthTestCase):
|
||||
HTTP_AUTHORIZATION=f"Basic {header}",
|
||||
)
|
||||
self.assertEqual(response.status_code, 200)
|
||||
access: AccessToken = AccessToken.objects.filter(user=user, provider=provider).first()
|
||||
access = AccessToken.objects.filter(user=user, provider=provider).first()
|
||||
self.validate_jwe(access, provider)
|
||||
|
||||
@apply_blueprint("system/providers-oauth2.yaml")
|
||||
@@ -209,7 +228,7 @@ class TestToken(OAuthTestCase):
|
||||
self.app.save()
|
||||
header = b64encode(f"{provider.client_id}:{provider.client_secret}".encode()).decode()
|
||||
user = create_test_admin_user()
|
||||
token: RefreshToken = RefreshToken.objects.create(
|
||||
token = RefreshToken.objects.create(
|
||||
provider=provider,
|
||||
user=user,
|
||||
token=generate_id(),
|
||||
@@ -229,10 +248,8 @@ class TestToken(OAuthTestCase):
|
||||
)
|
||||
self.assertEqual(response["Access-Control-Allow-Credentials"], "true")
|
||||
self.assertEqual(response["Access-Control-Allow-Origin"], "http://local.invalid")
|
||||
access: AccessToken = AccessToken.objects.filter(user=user, provider=provider).first()
|
||||
refresh: RefreshToken = RefreshToken.objects.filter(
|
||||
user=user, provider=provider, revoked=False
|
||||
).first()
|
||||
access = AccessToken.objects.filter(user=user, provider=provider).first()
|
||||
refresh = RefreshToken.objects.filter(user=user, provider=provider, revoked=False).first()
|
||||
self.assertJSONEqual(
|
||||
response.content.decode(),
|
||||
{
|
||||
@@ -269,7 +286,7 @@ class TestToken(OAuthTestCase):
|
||||
)
|
||||
header = b64encode(f"{provider.client_id}:{provider.client_secret}".encode()).decode()
|
||||
user = create_test_admin_user()
|
||||
token: RefreshToken = RefreshToken.objects.create(
|
||||
token = RefreshToken.objects.create(
|
||||
provider=provider,
|
||||
user=user,
|
||||
token=generate_id(),
|
||||
@@ -287,10 +304,8 @@ class TestToken(OAuthTestCase):
|
||||
HTTP_AUTHORIZATION=f"Basic {header}",
|
||||
HTTP_ORIGIN="http://another.invalid",
|
||||
)
|
||||
access: AccessToken = AccessToken.objects.filter(user=user, provider=provider).first()
|
||||
refresh: RefreshToken = RefreshToken.objects.filter(
|
||||
user=user, provider=provider, revoked=False
|
||||
).first()
|
||||
access = AccessToken.objects.filter(user=user, provider=provider).first()
|
||||
refresh = RefreshToken.objects.filter(user=user, provider=provider, revoked=False).first()
|
||||
self.assertNotIn("Access-Control-Allow-Credentials", response)
|
||||
self.assertNotIn("Access-Control-Allow-Origin", response)
|
||||
self.assertJSONEqual(
|
||||
@@ -331,7 +346,7 @@ class TestToken(OAuthTestCase):
|
||||
self.app.save()
|
||||
header = b64encode(f"{provider.client_id}:{provider.client_secret}".encode()).decode()
|
||||
user = create_test_admin_user()
|
||||
token: RefreshToken = RefreshToken.objects.create(
|
||||
token = RefreshToken.objects.create(
|
||||
provider=provider,
|
||||
user=user,
|
||||
token=generate_id(),
|
||||
@@ -349,9 +364,7 @@ class TestToken(OAuthTestCase):
|
||||
},
|
||||
HTTP_AUTHORIZATION=f"Basic {header}",
|
||||
)
|
||||
new_token: RefreshToken = (
|
||||
RefreshToken.objects.filter(user=user).exclude(pk=token.pk).first()
|
||||
)
|
||||
new_token = RefreshToken.objects.filter(user=user).exclude(pk=token.pk).first()
|
||||
# Post again with initial token -> get new refresh token
|
||||
# and revoke old one
|
||||
response = self.client.post(
|
||||
@@ -379,7 +392,11 @@ class TestToken(OAuthTestCase):
|
||||
|
||||
@apply_blueprint("system/providers-oauth2.yaml")
|
||||
def test_refresh_token_view_threshold(self):
|
||||
"""test request param"""
|
||||
"""refresh token threshold
|
||||
|
||||
threshold set to 1 hour, refresh token expires in 2 hours.
|
||||
First request should not return a new refresh token, second request
|
||||
has a fake time 1 hours in the future which should return a new access token"""
|
||||
provider = OAuth2Provider.objects.create(
|
||||
name=generate_id(),
|
||||
authorization_flow=create_test_flow(),
|
||||
@@ -402,13 +419,14 @@ class TestToken(OAuthTestCase):
|
||||
self.app.save()
|
||||
header = b64encode(f"{provider.client_id}:{provider.client_secret}".encode()).decode()
|
||||
user = create_test_admin_user()
|
||||
token: RefreshToken = RefreshToken.objects.create(
|
||||
token = RefreshToken.objects.create(
|
||||
provider=provider,
|
||||
user=user,
|
||||
token=generate_id(),
|
||||
_id_token=dumps({}),
|
||||
auth_time=timezone.now(),
|
||||
_scope="offline_access",
|
||||
expires=now() + timedelta(hours=2),
|
||||
)
|
||||
response = self.client.post(
|
||||
reverse("authentik_providers_oauth2:token"),
|
||||
@@ -420,9 +438,7 @@ class TestToken(OAuthTestCase):
|
||||
HTTP_AUTHORIZATION=f"Basic {header}",
|
||||
HTTP_ORIGIN="http://local.invalid",
|
||||
)
|
||||
self.assertEqual(response["Access-Control-Allow-Credentials"], "true")
|
||||
self.assertEqual(response["Access-Control-Allow-Origin"], "http://local.invalid")
|
||||
access: AccessToken = AccessToken.objects.filter(user=user, provider=provider).first()
|
||||
access = AccessToken.objects.filter(user=user, provider=provider).first()
|
||||
self.assertJSONEqual(
|
||||
response.content.decode(),
|
||||
{
|
||||
@@ -437,6 +453,42 @@ class TestToken(OAuthTestCase):
|
||||
)
|
||||
self.validate_jwt(access, provider)
|
||||
|
||||
with freeze_time(now() + timedelta(hours=1, minutes=10)):
|
||||
response = self.client.post(
|
||||
reverse("authentik_providers_oauth2:token"),
|
||||
data={
|
||||
"grant_type": GRANT_TYPE_REFRESH_TOKEN,
|
||||
"refresh_token": token.token,
|
||||
"redirect_uri": "http://local.invalid",
|
||||
},
|
||||
HTTP_AUTHORIZATION=f"Basic {header}",
|
||||
HTTP_ORIGIN="http://local.invalid",
|
||||
)
|
||||
access = (
|
||||
AccessToken.objects.filter(user=user, provider=provider)
|
||||
.exclude(pk=access.pk)
|
||||
.first()
|
||||
)
|
||||
refresh = (
|
||||
RefreshToken.objects.filter(user=user, provider=provider)
|
||||
.exclude(pk=token.pk)
|
||||
.first()
|
||||
)
|
||||
self.assertJSONEqual(
|
||||
response.content.decode(),
|
||||
{
|
||||
"access_token": access.token,
|
||||
"token_type": TOKEN_TYPE,
|
||||
"expires_in": 3600,
|
||||
"id_token": provider.encode(
|
||||
access.id_token.to_dict(),
|
||||
),
|
||||
"scope": "offline_access",
|
||||
"refresh_token": refresh.token,
|
||||
},
|
||||
)
|
||||
self.validate_jwt(access, provider)
|
||||
|
||||
@apply_blueprint("system/providers-oauth2.yaml")
|
||||
def test_scope_claim_override_via_property_mapping(self):
|
||||
"""Test that property mappings can override the scope claim in access tokens.
|
||||
@@ -484,7 +536,7 @@ class TestToken(OAuthTestCase):
|
||||
)
|
||||
self.assertEqual(response.status_code, 200)
|
||||
|
||||
access: AccessToken = AccessToken.objects.filter(user=user, provider=provider).first()
|
||||
access = AccessToken.objects.filter(user=user, provider=provider).first()
|
||||
jwt_data = self.validate_jwt(access, provider)
|
||||
|
||||
# The scope should be the custom value from the property mapping,
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
|
||||
from base64 import b64encode
|
||||
from json import loads
|
||||
from urllib.parse import quote
|
||||
|
||||
from django.test import RequestFactory
|
||||
from django.urls import reverse
|
||||
@@ -178,6 +179,41 @@ class TestTokenClientCredentialsStandardCompat(OAuthTestCase):
|
||||
self.assertEqual(jwt["given_name"], self.user.name)
|
||||
self.assertEqual(jwt["preferred_username"], self.user.username)
|
||||
|
||||
def test_successful_basic_auth_urlencoded_client_secret(self):
|
||||
"""test successful with URL-encoded Basic auth credentials"""
|
||||
client_secret = b64encode(f"sa:{self.token.key}".encode()).decode()
|
||||
header = b64encode(
|
||||
f"{quote(self.provider.client_id, safe='')}:{quote(client_secret, safe='')}".encode()
|
||||
).decode()
|
||||
response = self.client.post(
|
||||
reverse("authentik_providers_oauth2:token"),
|
||||
{
|
||||
"grant_type": GRANT_TYPE_CLIENT_CREDENTIALS,
|
||||
"scope": f"{SCOPE_OPENID} {SCOPE_OPENID_EMAIL} {SCOPE_OPENID_PROFILE}",
|
||||
},
|
||||
HTTP_AUTHORIZATION=f"Basic {header}",
|
||||
)
|
||||
self.assertEqual(response.status_code, 200)
|
||||
body = loads(response.content.decode())
|
||||
self.assertEqual(body["token_type"], TOKEN_TYPE)
|
||||
_, alg = self.provider.jwt_key
|
||||
jwt = decode(
|
||||
body["access_token"],
|
||||
key=self.provider.signing_key.public_key,
|
||||
algorithms=[alg],
|
||||
audience=self.provider.client_id,
|
||||
)
|
||||
self.assertEqual(jwt["given_name"], self.user.name)
|
||||
self.assertEqual(jwt["preferred_username"], self.user.username)
|
||||
jwt = decode(
|
||||
body["id_token"],
|
||||
key=self.provider.signing_key.public_key,
|
||||
algorithms=[alg],
|
||||
audience=self.provider.client_id,
|
||||
)
|
||||
self.assertEqual(jwt["given_name"], self.user.name)
|
||||
self.assertEqual(jwt["preferred_username"], self.user.username)
|
||||
|
||||
def test_successful_password(self):
|
||||
"""test successful (password grant)"""
|
||||
response = self.client.post(
|
||||
|
||||
@@ -40,7 +40,7 @@ class TestUserinfo(OAuthTestCase):
|
||||
self.app.provider = self.provider
|
||||
self.app.save()
|
||||
self.user = create_test_admin_user()
|
||||
self.token: AccessToken = AccessToken.objects.create(
|
||||
self.token = AccessToken.objects.create(
|
||||
provider=self.provider,
|
||||
user=self.user,
|
||||
token=generate_id(),
|
||||
|
||||
@@ -7,7 +7,7 @@ from binascii import Error
|
||||
from hashlib import sha256
|
||||
from hmac import compare_digest
|
||||
from typing import Any
|
||||
from urllib.parse import urlparse
|
||||
from urllib.parse import unquote, urlparse
|
||||
|
||||
from django.http import HttpRequest, HttpResponse, JsonResponse
|
||||
from django.http.response import HttpResponseRedirect
|
||||
@@ -122,6 +122,10 @@ def extract_client_auth(request: HttpRequest) -> tuple[str, str]:
|
||||
try:
|
||||
user_pass = b64decode(b64_user_pass).decode("utf-8").partition(":")
|
||||
client_id, _, client_secret = user_pass
|
||||
# RFC 6749 requires client credentials in Basic auth to be form-encoded first.
|
||||
# We only percent-decode here so raw `+` characters keep their previous meaning.
|
||||
client_id = unquote(client_id)
|
||||
client_secret = unquote(client_secret)
|
||||
except ValueError, Error:
|
||||
client_id = client_secret = "" # nosec
|
||||
else:
|
||||
|
||||
@@ -58,7 +58,6 @@ from authentik.providers.oauth2.models import (
|
||||
AuthorizationCode,
|
||||
GrantTypes,
|
||||
OAuth2Provider,
|
||||
RedirectURI,
|
||||
RedirectURIMatchingMode,
|
||||
ResponseMode,
|
||||
ResponseTypes,
|
||||
@@ -196,14 +195,6 @@ class OAuthAuthorizationParams:
|
||||
LOGGER.warning("Missing redirect uri.")
|
||||
raise RedirectUriError("", allowed_redirect_urls).with_cause("redirect_uri_missing")
|
||||
|
||||
if len(allowed_redirect_urls) < 1:
|
||||
LOGGER.info("Setting redirect for blank redirect_uris", redirect=self.redirect_uri)
|
||||
self.provider.redirect_uris = [
|
||||
RedirectURI(RedirectURIMatchingMode.STRICT, self.redirect_uri)
|
||||
]
|
||||
self.provider.save()
|
||||
allowed_redirect_urls = self.provider.redirect_uris
|
||||
|
||||
match_found = False
|
||||
for allowed in allowed_redirect_urls:
|
||||
if allowed.matching_mode == RedirectURIMatchingMode.STRICT:
|
||||
|
||||
@@ -15,8 +15,8 @@ from authentik.core.models import Application
|
||||
from authentik.lib.config import CONFIG
|
||||
from authentik.lib.utils.time import timedelta_from_string
|
||||
from authentik.providers.oauth2.errors import DeviceCodeError
|
||||
from authentik.providers.oauth2.models import DeviceToken, OAuth2Provider
|
||||
from authentik.providers.oauth2.utils import TokenResponse
|
||||
from authentik.providers.oauth2.models import DeviceToken, OAuth2Provider, ScopeMapping
|
||||
from authentik.providers.oauth2.utils import TokenResponse, extract_client_auth
|
||||
from authentik.providers.oauth2.views.device_init import QS_KEY_CODE
|
||||
|
||||
LOGGER = get_logger()
|
||||
@@ -28,11 +28,11 @@ class DeviceView(View):
|
||||
|
||||
client_id: str
|
||||
provider: OAuth2Provider
|
||||
scopes: list[str] = []
|
||||
scopes: set[str] = []
|
||||
|
||||
def parse_request(self):
|
||||
"""Parse incoming request"""
|
||||
client_id = self.request.POST.get("client_id", None)
|
||||
client_id, _ = extract_client_auth(self.request)
|
||||
if not client_id:
|
||||
raise DeviceCodeError("invalid_client")
|
||||
provider = OAuth2Provider.objects.filter(client_id=client_id).first()
|
||||
@@ -44,7 +44,21 @@ class DeviceView(View):
|
||||
raise DeviceCodeError("invalid_client") from None
|
||||
self.provider = provider
|
||||
self.client_id = client_id
|
||||
self.scopes = self.request.POST.get("scope", "").split(" ")
|
||||
|
||||
scopes_to_check = set(self.request.POST.get("scope", "").split())
|
||||
default_scope_names = set(
|
||||
ScopeMapping.objects.filter(provider__in=[self.provider]).values_list(
|
||||
"scope_name", flat=True
|
||||
)
|
||||
)
|
||||
self.scopes = scopes_to_check
|
||||
if not scopes_to_check.issubset(default_scope_names):
|
||||
LOGGER.info(
|
||||
"Application requested scopes not configured, setting to overlap",
|
||||
scope_allowed=default_scope_names,
|
||||
scope_given=self.scopes,
|
||||
)
|
||||
self.scopes = self.scopes.intersection(default_scope_names)
|
||||
|
||||
def dispatch(self, request: HttpRequest, *args, **kwargs) -> HttpResponse:
|
||||
throttle = AnonRateThrottle()
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from django.db.models import Q
|
||||
from django.http import HttpRequest, HttpResponse
|
||||
from django.utils.decorators import method_decorator
|
||||
from django.views import View
|
||||
@@ -10,7 +11,7 @@ from structlog.stdlib import get_logger
|
||||
|
||||
from authentik.providers.oauth2.errors import TokenIntrospectionError
|
||||
from authentik.providers.oauth2.id_token import IDToken
|
||||
from authentik.providers.oauth2.models import AccessToken, OAuth2Provider, RefreshToken
|
||||
from authentik.providers.oauth2.models import AccessToken, ClientTypes, OAuth2Provider, RefreshToken
|
||||
from authentik.providers.oauth2.utils import TokenResponse, authenticate_provider
|
||||
|
||||
LOGGER = get_logger()
|
||||
@@ -33,10 +34,7 @@ class TokenIntrospectionParams:
|
||||
self.id_token = self.token.id_token
|
||||
|
||||
if not self.token.id_token:
|
||||
LOGGER.debug(
|
||||
"token not an authentication token",
|
||||
token=self.token,
|
||||
)
|
||||
LOGGER.debug("token not an authentication token", token=self.token)
|
||||
raise TokenIntrospectionError()
|
||||
|
||||
@staticmethod
|
||||
@@ -45,14 +43,23 @@ class TokenIntrospectionParams:
|
||||
raw_token = request.POST.get("token")
|
||||
provider = authenticate_provider(request)
|
||||
if not provider:
|
||||
LOGGER.info("Failed to authenticate introspection request")
|
||||
raise TokenIntrospectionError
|
||||
if provider.client_type != ClientTypes.CONFIDENTIAL:
|
||||
LOGGER.info("Introspection request from public provider, denying.")
|
||||
raise TokenIntrospectionError
|
||||
|
||||
access_token = AccessToken.objects.filter(token=raw_token, provider=provider).first()
|
||||
query = Q(
|
||||
Q(provider=provider) | Q(provider__jwt_federation_providers__in=[provider]),
|
||||
token=raw_token,
|
||||
)
|
||||
|
||||
access_token = AccessToken.objects.filter(query).first()
|
||||
if access_token:
|
||||
return TokenIntrospectionParams(access_token, provider)
|
||||
refresh_token = RefreshToken.objects.filter(token=raw_token, provider=provider).first()
|
||||
return TokenIntrospectionParams(access_token, access_token.provider)
|
||||
refresh_token = RefreshToken.objects.filter(query).first()
|
||||
if refresh_token:
|
||||
return TokenIntrospectionParams(refresh_token, provider)
|
||||
return TokenIntrospectionParams(refresh_token, refresh_token.provider)
|
||||
LOGGER.debug("Token does not exist", token=raw_token)
|
||||
raise TokenIntrospectionError()
|
||||
|
||||
|
||||
@@ -704,7 +704,7 @@ class TokenView(View):
|
||||
refresh_token_threshold = timedelta_from_string(self.provider.refresh_token_threshold)
|
||||
if (
|
||||
refresh_token_threshold.total_seconds() == 0
|
||||
or (now - self.params.refresh_token.expires) > refresh_token_threshold
|
||||
or (self.params.refresh_token.expires - now) < refresh_token_threshold
|
||||
):
|
||||
refresh_token_expiry = now + timedelta_from_string(self.provider.refresh_token_validity)
|
||||
refresh_token = RefreshToken(
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
from django.db.models import Q
|
||||
from django.http import Http404, HttpRequest, HttpResponse
|
||||
from django.utils.decorators import method_decorator
|
||||
from django.views import View
|
||||
@@ -32,15 +33,25 @@ class TokenRevocationParams:
|
||||
raw_token = request.POST.get("token")
|
||||
|
||||
provider, _, _ = provider_from_request(request)
|
||||
if provider and provider.client_type == ClientTypes.CONFIDENTIAL:
|
||||
provider = authenticate_provider(request)
|
||||
if not provider:
|
||||
raise TokenRevocationError("invalid_client")
|
||||
# By default clients can only revoke their own tokens
|
||||
query = Q(provider=provider, token=raw_token)
|
||||
if provider.client_type == ClientTypes.CONFIDENTIAL:
|
||||
provider = authenticate_provider(request)
|
||||
if not provider:
|
||||
raise TokenRevocationError("invalid_client")
|
||||
# If the request is authenticated by a confidential provider, it can also
|
||||
# revoke federated tokens
|
||||
query = Q(
|
||||
Q(provider=provider) | Q(provider__jwt_federation_providers__in=[provider]),
|
||||
token=raw_token,
|
||||
)
|
||||
|
||||
access_token = AccessToken.objects.filter(token=raw_token).first()
|
||||
access_token = AccessToken.objects.filter(query).first()
|
||||
if access_token:
|
||||
return TokenRevocationParams(access_token, provider)
|
||||
refresh_token = RefreshToken.objects.filter(token=raw_token).first()
|
||||
refresh_token = RefreshToken.objects.filter(query).first()
|
||||
if refresh_token:
|
||||
return TokenRevocationParams(refresh_token, provider)
|
||||
LOGGER.debug("Token does not exist", token=raw_token)
|
||||
|
||||
@@ -27,6 +27,8 @@ class TraefikMiddlewareSpecForwardAuth:
|
||||
|
||||
trustForwardHeader: bool = field(default=True)
|
||||
|
||||
maxResponseBodySize: int = field(default=1024 * 1024 * 4)
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class TraefikMiddlewareSpec:
|
||||
@@ -140,6 +142,7 @@ class Traefik3MiddlewareReconciler(KubernetesObjectReconciler[TraefikMiddleware]
|
||||
],
|
||||
authResponseHeadersRegex="",
|
||||
trustForwardHeader=True,
|
||||
maxResponseBodySize=1024 * 1024 * 4,
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
@@ -1,13 +0,0 @@
|
||||
"""Proxy provider signals"""
|
||||
|
||||
from django.db.models.signals import pre_delete
|
||||
from django.dispatch import receiver
|
||||
|
||||
from authentik.core.models import AuthenticatedSession
|
||||
from authentik.providers.proxy.tasks import proxy_on_logout
|
||||
|
||||
|
||||
@receiver(pre_delete, sender=AuthenticatedSession)
|
||||
def logout_proxy_revoke(sender: type[AuthenticatedSession], instance: AuthenticatedSession, **_):
|
||||
"""Catch logout by expiring sessions being deleted"""
|
||||
proxy_on_logout.send(instance.session.session_key)
|
||||
@@ -1,26 +0,0 @@
|
||||
"""proxy provider tasks"""
|
||||
|
||||
from asgiref.sync import async_to_sync
|
||||
from channels.layers import get_channel_layer
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
from dramatiq.actor import actor
|
||||
|
||||
from authentik.outposts.consumer import build_outpost_group
|
||||
from authentik.outposts.models import Outpost, OutpostType
|
||||
from authentik.providers.oauth2.id_token import hash_session_key
|
||||
|
||||
|
||||
@actor(description=_("Terminate session on Proxy outpost."))
|
||||
def proxy_on_logout(session_id: str):
|
||||
layer = get_channel_layer()
|
||||
hashed_session_id = hash_session_key(session_id)
|
||||
for outpost in Outpost.objects.filter(type=OutpostType.PROXY):
|
||||
group = build_outpost_group(outpost.pk)
|
||||
async_to_sync(layer.group_send)(
|
||||
group,
|
||||
{
|
||||
"type": "event.provider.specific",
|
||||
"sub_type": "logout",
|
||||
"session_id": hashed_session_id,
|
||||
},
|
||||
)
|
||||
@@ -231,8 +231,8 @@ class SAMLProviderSerializer(ProviderSerializer):
|
||||
class SAMLMetadataSerializer(PassiveSerializer):
|
||||
"""SAML Provider Metadata serializer"""
|
||||
|
||||
metadata = CharField(read_only=True)
|
||||
download_url = CharField(read_only=True, required=False)
|
||||
metadata = CharField()
|
||||
download_url = CharField(required=False, allow_null=True)
|
||||
|
||||
|
||||
class SAMLProviderImportSerializer(PassiveSerializer):
|
||||
@@ -314,7 +314,7 @@ class SAMLProviderViewSet(UsedByMixin, ModelViewSet):
|
||||
return response
|
||||
return Response({"metadata": metadata}, content_type="application/json")
|
||||
except Provider.application.RelatedObjectDoesNotExist:
|
||||
return Response({"metadata": ""}, content_type="application/json")
|
||||
raise Http404 from None
|
||||
|
||||
@permission_required(
|
||||
None,
|
||||
|
||||
@@ -157,7 +157,7 @@ class TestSAMLProviderAPI(APITestCase):
|
||||
response = self.client.get(
|
||||
reverse("authentik_api:samlprovider-metadata", kwargs={"pk": provider.pk}),
|
||||
)
|
||||
self.assertEqual(200, response.status_code)
|
||||
self.assertEqual(404, response.status_code)
|
||||
response = self.client.get(
|
||||
reverse("authentik_api:samlprovider-metadata", kwargs={"pk": "abc"}),
|
||||
)
|
||||
|
||||
@@ -8,11 +8,11 @@ from django.urls import reverse
|
||||
|
||||
from authentik.common.saml.constants import SAML_NAME_ID_FORMAT_EMAIL
|
||||
from authentik.core.models import Application
|
||||
from authentik.core.tests.utils import create_test_brand, create_test_flow
|
||||
from authentik.core.tests.utils import create_test_brand, create_test_cert, create_test_flow
|
||||
from authentik.flows.planner import FlowPlan
|
||||
from authentik.flows.views.executor import SESSION_KEY_PLAN
|
||||
from authentik.providers.saml.exceptions import CannotHandleAssertion
|
||||
from authentik.providers.saml.models import SAMLProvider
|
||||
from authentik.providers.saml.models import SAMLBindings, SAMLLogoutMethods, SAMLProvider
|
||||
from authentik.providers.saml.processors.logout_request import LogoutRequestProcessor
|
||||
from authentik.providers.saml.views.flows import PLAN_CONTEXT_SAML_RELAY_STATE
|
||||
from authentik.providers.saml.views.sp_slo import (
|
||||
@@ -91,32 +91,33 @@ class TestSPInitiatedSLOViews(TestCase):
|
||||
self.assertEqual(logout_request.issuer, self.provider.issuer)
|
||||
self.assertEqual(logout_request.session_index, "test-session-123")
|
||||
|
||||
def test_redirect_view_handles_logout_response_with_relay_state(self):
|
||||
"""Test that redirect view handles logout response with RelayState"""
|
||||
# Use raw URL (no encoding needed)
|
||||
relay_state = "https://idp.example.com/flow/return"
|
||||
def test_redirect_view_handles_logout_response_with_plan_context(self):
|
||||
"""Test that redirect view always redirects to plan context URL, ignoring RelayState"""
|
||||
plan_relay_state = "https://idp.example.com/flow/return"
|
||||
|
||||
# Create request with SAML logout response
|
||||
request = self.factory.get(
|
||||
f"/slo/redirect/{self.application.slug}/",
|
||||
{
|
||||
"SAMLResponse": "dummy-response",
|
||||
"RelayState": relay_state,
|
||||
"RelayState": "https://somewhere-else.example.com/return",
|
||||
},
|
||||
)
|
||||
request.session = {}
|
||||
plan = FlowPlan(flow_pk="test-flow")
|
||||
plan.context[PLAN_CONTEXT_SAML_RELAY_STATE] = plan_relay_state
|
||||
request.session = {SESSION_KEY_PLAN: plan}
|
||||
request.brand = self.brand
|
||||
|
||||
view = SPInitiatedSLOBindingRedirectView()
|
||||
view.setup(request, application_slug=self.application.slug)
|
||||
response = view.dispatch(request, application_slug=self.application.slug)
|
||||
|
||||
# Should redirect to relay state URL
|
||||
# Should redirect to plan context URL, not the request's RelayState
|
||||
self.assertEqual(response.status_code, 302)
|
||||
self.assertEqual(response.url, relay_state)
|
||||
self.assertEqual(response.url, plan_relay_state)
|
||||
|
||||
def test_redirect_view_handles_logout_response_plain_relay_state(self):
|
||||
"""Test that redirect view handles logout response with plain RelayState"""
|
||||
def test_redirect_view_ignores_relay_state_without_plan(self):
|
||||
"""Test that redirect view ignores RelayState and falls back to root when no plan context"""
|
||||
relay_state = "https://sp.example.com/plain"
|
||||
|
||||
# Create request with SAML logout response
|
||||
@@ -134,9 +135,9 @@ class TestSPInitiatedSLOViews(TestCase):
|
||||
view.setup(request, application_slug=self.application.slug)
|
||||
response = view.dispatch(request, application_slug=self.application.slug)
|
||||
|
||||
# Should redirect to plain relay state
|
||||
# Should ignore relay_state and redirect to root (no plan context)
|
||||
self.assertEqual(response.status_code, 302)
|
||||
self.assertEqual(response.url, relay_state)
|
||||
self.assertEqual(response.url, reverse("authentik_core:root-redirect"))
|
||||
|
||||
def test_redirect_view_handles_logout_response_no_relay_state_with_plan_context(self):
|
||||
"""Test that redirect view uses plan context fallback when no RelayState"""
|
||||
@@ -228,29 +229,30 @@ class TestSPInitiatedSLOViews(TestCase):
|
||||
self.assertEqual(logout_request.issuer, self.provider.issuer)
|
||||
self.assertEqual(logout_request.session_index, "test-session-123")
|
||||
|
||||
def test_post_view_handles_logout_response_with_relay_state(self):
|
||||
"""Test that POST view handles logout response with RelayState"""
|
||||
# Use raw URL (no encoding needed)
|
||||
relay_state = "https://idp.example.com/flow/return"
|
||||
def test_post_view_handles_logout_response_with_plan_context(self):
|
||||
"""Test that POST view always redirects to plan context URL, ignoring RelayState"""
|
||||
plan_relay_state = "https://idp.example.com/flow/return"
|
||||
|
||||
# Create POST request with SAML logout response
|
||||
request = self.factory.post(
|
||||
f"/slo/post/{self.application.slug}/",
|
||||
{
|
||||
"SAMLResponse": "dummy-response",
|
||||
"RelayState": relay_state,
|
||||
"RelayState": "https://somewhere-else.example.com/return",
|
||||
},
|
||||
)
|
||||
request.session = {}
|
||||
plan = FlowPlan(flow_pk="test-flow")
|
||||
plan.context[PLAN_CONTEXT_SAML_RELAY_STATE] = plan_relay_state
|
||||
request.session = {SESSION_KEY_PLAN: plan}
|
||||
request.brand = self.brand
|
||||
|
||||
view = SPInitiatedSLOBindingPOSTView()
|
||||
view.setup(request, application_slug=self.application.slug)
|
||||
response = view.dispatch(request, application_slug=self.application.slug)
|
||||
|
||||
# Should redirect to relay state URL
|
||||
# Should redirect to plan context URL, not the request's RelayState
|
||||
self.assertEqual(response.status_code, 302)
|
||||
self.assertEqual(response.url, relay_state)
|
||||
self.assertEqual(response.url, plan_relay_state)
|
||||
|
||||
def test_post_view_handles_logout_response_no_relay_state_with_plan_context(self):
|
||||
"""Test that POST view uses plan context fallback when no RelayState"""
|
||||
@@ -417,7 +419,7 @@ class TestSPInitiatedSLOViews(TestCase):
|
||||
view.resolve_provider_application()
|
||||
|
||||
def test_relay_state_decoding_failure(self):
|
||||
"""Test handling of RelayState that's a path"""
|
||||
"""Test that arbitrary path RelayState is ignored and redirects to root"""
|
||||
# Create request with relay state that is a path
|
||||
request = self.factory.get(
|
||||
f"/slo/redirect/{self.application.slug}/",
|
||||
@@ -433,6 +435,357 @@ class TestSPInitiatedSLOViews(TestCase):
|
||||
view.setup(request, application_slug=self.application.slug)
|
||||
response = view.dispatch(request, application_slug=self.application.slug)
|
||||
|
||||
# Should treat it as plain URL and redirect to it
|
||||
# Should ignore relay_state and redirect to root (no plan context)
|
||||
self.assertEqual(response.status_code, 302)
|
||||
self.assertEqual(response.url, "/some/invalid/path")
|
||||
self.assertEqual(response.url, reverse("authentik_core:root-redirect"))
|
||||
|
||||
def test_redirect_view_blocks_external_relay_state(self):
|
||||
"""Test that redirect view ignores external malicious URL and redirects to root"""
|
||||
request = self.factory.get(
|
||||
f"/slo/redirect/{self.application.slug}/",
|
||||
{
|
||||
"SAMLResponse": "dummy-response",
|
||||
"RelayState": "https://evil.com/phishing",
|
||||
},
|
||||
)
|
||||
request.session = {}
|
||||
request.brand = self.brand
|
||||
|
||||
view = SPInitiatedSLOBindingRedirectView()
|
||||
view.setup(request, application_slug=self.application.slug)
|
||||
response = view.dispatch(request, application_slug=self.application.slug)
|
||||
|
||||
# Should ignore relay_state and redirect to root (no plan context)
|
||||
self.assertEqual(response.status_code, 302)
|
||||
self.assertEqual(response.url, reverse("authentik_core:root-redirect"))
|
||||
|
||||
def test_redirect_view_ignores_relay_state_uses_plan_context(self):
|
||||
"""Test that redirect view always uses plan context URL regardless of RelayState"""
|
||||
plan_relay_state = "https://authentik.example.com/if/flow/logout/"
|
||||
|
||||
request = self.factory.get(
|
||||
f"/slo/redirect/{self.application.slug}/",
|
||||
{
|
||||
"SAMLResponse": "dummy-response",
|
||||
"RelayState": "https://evil.com/phishing",
|
||||
},
|
||||
)
|
||||
plan = FlowPlan(flow_pk="test-flow")
|
||||
plan.context[PLAN_CONTEXT_SAML_RELAY_STATE] = plan_relay_state
|
||||
request.session = {SESSION_KEY_PLAN: plan}
|
||||
request.brand = self.brand
|
||||
|
||||
view = SPInitiatedSLOBindingRedirectView()
|
||||
view.setup(request, application_slug=self.application.slug)
|
||||
response = view.dispatch(request, application_slug=self.application.slug)
|
||||
|
||||
# Should always use plan context value, ignoring malicious RelayState
|
||||
self.assertEqual(response.status_code, 302)
|
||||
self.assertEqual(response.url, plan_relay_state)
|
||||
|
||||
def test_post_view_ignores_external_relay_state(self):
|
||||
"""Test that POST view ignores external RelayState and redirects to root"""
|
||||
request = self.factory.post(
|
||||
f"/slo/post/{self.application.slug}/",
|
||||
{
|
||||
"SAMLResponse": "dummy-response",
|
||||
"RelayState": "https://evil.com/phishing",
|
||||
},
|
||||
)
|
||||
request.session = {}
|
||||
request.brand = self.brand
|
||||
|
||||
view = SPInitiatedSLOBindingPOSTView()
|
||||
view.setup(request, application_slug=self.application.slug)
|
||||
response = view.dispatch(request, application_slug=self.application.slug)
|
||||
|
||||
# Should ignore relay_state and redirect to root (no plan context)
|
||||
self.assertEqual(response.status_code, 302)
|
||||
self.assertEqual(response.url, reverse("authentik_core:root-redirect"))
|
||||
|
||||
|
||||
class TestSPInitiatedSLOLogoutMethods(TestCase):
|
||||
"""Test SP-initiated SAML SLO logout method branching"""
|
||||
|
||||
def setUp(self):
|
||||
"""Set up test fixtures"""
|
||||
self.factory = RequestFactory()
|
||||
self.brand = create_test_brand()
|
||||
self.flow = create_test_flow()
|
||||
self.invalidation_flow = create_test_flow()
|
||||
self.cert = create_test_cert()
|
||||
|
||||
# Create provider with sls_url
|
||||
self.provider = SAMLProvider.objects.create(
|
||||
name="test-provider",
|
||||
authorization_flow=self.flow,
|
||||
invalidation_flow=self.invalidation_flow,
|
||||
acs_url="https://sp.example.com/acs",
|
||||
sls_url="https://sp.example.com/sls",
|
||||
issuer="https://idp.example.com",
|
||||
sp_binding="redirect",
|
||||
sls_binding="redirect",
|
||||
signing_kp=self.cert,
|
||||
)
|
||||
|
||||
# Create application
|
||||
self.application = Application.objects.create(
|
||||
name="test-app",
|
||||
slug="test-app-logout-methods",
|
||||
provider=self.provider,
|
||||
)
|
||||
|
||||
# Create logout request processor for generating test requests
|
||||
self.processor = LogoutRequestProcessor(
|
||||
provider=self.provider,
|
||||
user=None,
|
||||
destination="https://idp.example.com/sls",
|
||||
name_id="test@example.com",
|
||||
name_id_format=SAML_NAME_ID_FORMAT_EMAIL,
|
||||
session_index="test-session-123",
|
||||
relay_state="https://sp.example.com/return",
|
||||
)
|
||||
|
||||
@patch("authentik.providers.saml.views.sp_slo.AuthenticatedSession")
|
||||
def test_frontchannel_native_post_binding(self, mock_auth_session):
|
||||
"""Test FRONTCHANNEL_NATIVE with POST binding parses request correctly"""
|
||||
mock_auth_session.from_request.return_value = None
|
||||
|
||||
self.provider.logout_method = SAMLLogoutMethods.FRONTCHANNEL_NATIVE
|
||||
self.provider.sls_binding = SAMLBindings.POST
|
||||
self.provider.save()
|
||||
|
||||
encoded_request = self.processor.encode_redirect()
|
||||
|
||||
request = self.factory.get(
|
||||
f"/slo/redirect/{self.application.slug}/",
|
||||
{
|
||||
"SAMLRequest": encoded_request,
|
||||
"RelayState": "https://sp.example.com/return",
|
||||
},
|
||||
)
|
||||
request.session = {}
|
||||
request.brand = self.brand
|
||||
request.user = MagicMock()
|
||||
|
||||
view = SPInitiatedSLOBindingRedirectView()
|
||||
view.setup(request, application_slug=self.application.slug)
|
||||
view.resolve_provider_application()
|
||||
view.check_saml_request()
|
||||
|
||||
# Verify the logout request was parsed and provider is configured correctly
|
||||
self.assertIn("authentik/providers/saml/logout_request", view.plan_context)
|
||||
self.assertEqual(view.provider.logout_method, SAMLLogoutMethods.FRONTCHANNEL_NATIVE)
|
||||
self.assertEqual(view.provider.sls_binding, SAMLBindings.POST)
|
||||
|
||||
@patch("authentik.providers.saml.views.sp_slo.AuthenticatedSession")
|
||||
def test_frontchannel_native_redirect_binding(self, mock_auth_session):
|
||||
"""Test FRONTCHANNEL_NATIVE with REDIRECT binding creates redirect URL"""
|
||||
mock_auth_session.from_request.return_value = None
|
||||
|
||||
self.provider.logout_method = SAMLLogoutMethods.FRONTCHANNEL_NATIVE
|
||||
self.provider.sls_binding = SAMLBindings.REDIRECT
|
||||
self.provider.save()
|
||||
|
||||
encoded_request = self.processor.encode_redirect()
|
||||
|
||||
request = self.factory.get(
|
||||
f"/slo/redirect/{self.application.slug}/",
|
||||
{
|
||||
"SAMLRequest": encoded_request,
|
||||
"RelayState": "https://sp.example.com/return",
|
||||
},
|
||||
)
|
||||
request.session = {}
|
||||
request.brand = self.brand
|
||||
request.user = MagicMock()
|
||||
|
||||
view = SPInitiatedSLOBindingRedirectView()
|
||||
view.setup(request, application_slug=self.application.slug)
|
||||
view.resolve_provider_application()
|
||||
view.check_saml_request()
|
||||
|
||||
# Verify the logout request was parsed
|
||||
self.assertIn("authentik/providers/saml/logout_request", view.plan_context)
|
||||
|
||||
@patch("authentik.providers.saml.views.sp_slo.AuthenticatedSession")
|
||||
def test_frontchannel_iframe_post_binding(self, mock_auth_session):
|
||||
"""Test FRONTCHANNEL_IFRAME with POST binding creates IframeLogoutStageView"""
|
||||
mock_auth_session.from_request.return_value = None
|
||||
|
||||
self.provider.logout_method = SAMLLogoutMethods.FRONTCHANNEL_IFRAME
|
||||
self.provider.sls_binding = SAMLBindings.POST
|
||||
self.provider.save()
|
||||
|
||||
encoded_request = self.processor.encode_redirect()
|
||||
|
||||
request = self.factory.get(
|
||||
f"/slo/redirect/{self.application.slug}/",
|
||||
{
|
||||
"SAMLRequest": encoded_request,
|
||||
"RelayState": "https://sp.example.com/return",
|
||||
},
|
||||
)
|
||||
request.session = {}
|
||||
request.brand = self.brand
|
||||
request.user = MagicMock()
|
||||
|
||||
view = SPInitiatedSLOBindingRedirectView()
|
||||
view.setup(request, application_slug=self.application.slug)
|
||||
view.resolve_provider_application()
|
||||
view.check_saml_request()
|
||||
|
||||
# Verify the logout request was parsed
|
||||
self.assertIn("authentik/providers/saml/logout_request", view.plan_context)
|
||||
|
||||
@patch("authentik.providers.saml.views.sp_slo.AuthenticatedSession")
|
||||
def test_frontchannel_iframe_redirect_binding(self, mock_auth_session):
|
||||
"""Test FRONTCHANNEL_IFRAME with REDIRECT binding"""
|
||||
mock_auth_session.from_request.return_value = None
|
||||
|
||||
self.provider.logout_method = SAMLLogoutMethods.FRONTCHANNEL_IFRAME
|
||||
self.provider.sls_binding = SAMLBindings.REDIRECT
|
||||
self.provider.save()
|
||||
|
||||
encoded_request = self.processor.encode_redirect()
|
||||
|
||||
request = self.factory.get(
|
||||
f"/slo/redirect/{self.application.slug}/",
|
||||
{
|
||||
"SAMLRequest": encoded_request,
|
||||
"RelayState": "https://sp.example.com/return",
|
||||
},
|
||||
)
|
||||
request.session = {}
|
||||
request.brand = self.brand
|
||||
request.user = MagicMock()
|
||||
|
||||
view = SPInitiatedSLOBindingRedirectView()
|
||||
view.setup(request, application_slug=self.application.slug)
|
||||
view.resolve_provider_application()
|
||||
view.check_saml_request()
|
||||
|
||||
# Verify the logout request was parsed
|
||||
self.assertIn("authentik/providers/saml/logout_request", view.plan_context)
|
||||
|
||||
@patch("authentik.providers.saml.views.sp_slo.AuthenticatedSession")
|
||||
def test_backchannel_parses_request(self, mock_auth_session):
|
||||
"""Test BACKCHANNEL mode parses request correctly"""
|
||||
mock_auth_session.from_request.return_value = None
|
||||
|
||||
self.provider.logout_method = SAMLLogoutMethods.BACKCHANNEL
|
||||
self.provider.sls_binding = SAMLBindings.POST
|
||||
self.provider.save()
|
||||
|
||||
encoded_request = self.processor.encode_redirect()
|
||||
|
||||
request = self.factory.get(
|
||||
f"/slo/redirect/{self.application.slug}/",
|
||||
{
|
||||
"SAMLRequest": encoded_request,
|
||||
"RelayState": "https://sp.example.com/return",
|
||||
},
|
||||
)
|
||||
request.session = {}
|
||||
request.brand = self.brand
|
||||
request.user = MagicMock()
|
||||
|
||||
view = SPInitiatedSLOBindingRedirectView()
|
||||
view.setup(request, application_slug=self.application.slug)
|
||||
view.resolve_provider_application()
|
||||
view.check_saml_request()
|
||||
|
||||
# Verify the logout request was parsed and provider is configured correctly
|
||||
self.assertIn("authentik/providers/saml/logout_request", view.plan_context)
|
||||
self.assertEqual(view.provider.logout_method, SAMLLogoutMethods.BACKCHANNEL)
|
||||
self.assertEqual(view.provider.sls_binding, SAMLBindings.POST)
|
||||
|
||||
@patch("authentik.providers.saml.views.sp_slo.AuthenticatedSession")
|
||||
def test_no_sls_url_only_session_end(self, mock_auth_session):
|
||||
"""Test that only SessionEndStage is appended when sls_url is empty"""
|
||||
mock_auth_session.from_request.return_value = None
|
||||
|
||||
# Create provider without sls_url
|
||||
provider_no_sls = SAMLProvider.objects.create(
|
||||
name="no-sls-provider",
|
||||
authorization_flow=self.flow,
|
||||
invalidation_flow=self.invalidation_flow,
|
||||
acs_url="https://sp.example.com/acs",
|
||||
sls_url="", # No SLS URL
|
||||
issuer="https://idp.example.com",
|
||||
)
|
||||
|
||||
app_no_sls = Application.objects.create(
|
||||
name="no-sls-app",
|
||||
slug="no-sls-app",
|
||||
provider=provider_no_sls,
|
||||
)
|
||||
|
||||
processor = LogoutRequestProcessor(
|
||||
provider=provider_no_sls,
|
||||
user=None,
|
||||
destination="https://idp.example.com/sls",
|
||||
name_id="test@example.com",
|
||||
name_id_format=SAML_NAME_ID_FORMAT_EMAIL,
|
||||
session_index="test-session-123",
|
||||
)
|
||||
encoded_request = processor.encode_redirect()
|
||||
|
||||
request = self.factory.get(
|
||||
f"/slo/redirect/{app_no_sls.slug}/",
|
||||
{
|
||||
"SAMLRequest": encoded_request,
|
||||
},
|
||||
)
|
||||
request.session = {}
|
||||
request.brand = self.brand
|
||||
request.user = MagicMock()
|
||||
|
||||
view = SPInitiatedSLOBindingRedirectView()
|
||||
view.setup(request, application_slug=app_no_sls.slug)
|
||||
view.resolve_provider_application()
|
||||
view.check_saml_request()
|
||||
|
||||
# Verify the provider has no sls_url
|
||||
self.assertEqual(view.provider.sls_url, "")
|
||||
|
||||
@patch("authentik.providers.saml.views.sp_slo.AuthenticatedSession")
|
||||
def test_relay_state_propagation(self, mock_auth_session):
|
||||
"""Test that relay state from logout request is passed through to response"""
|
||||
mock_auth_session.from_request.return_value = None
|
||||
|
||||
self.provider.logout_method = SAMLLogoutMethods.FRONTCHANNEL_IFRAME
|
||||
self.provider.save()
|
||||
|
||||
expected_relay_state = "https://sp.example.com/custom-return"
|
||||
|
||||
processor = LogoutRequestProcessor(
|
||||
provider=self.provider,
|
||||
user=None,
|
||||
destination="https://idp.example.com/sls",
|
||||
name_id="test@example.com",
|
||||
name_id_format=SAML_NAME_ID_FORMAT_EMAIL,
|
||||
session_index="test-session-123",
|
||||
relay_state=expected_relay_state,
|
||||
)
|
||||
encoded_request = processor.encode_redirect()
|
||||
|
||||
request = self.factory.get(
|
||||
f"/slo/redirect/{self.application.slug}/",
|
||||
{
|
||||
"SAMLRequest": encoded_request,
|
||||
"RelayState": expected_relay_state,
|
||||
},
|
||||
)
|
||||
request.session = {}
|
||||
request.brand = self.brand
|
||||
request.user = MagicMock()
|
||||
|
||||
view = SPInitiatedSLOBindingRedirectView()
|
||||
view.setup(request, application_slug=self.application.slug)
|
||||
view.resolve_provider_application()
|
||||
view.check_saml_request()
|
||||
|
||||
# Verify relay state was captured
|
||||
logout_request = view.plan_context.get("authentik/providers/saml/logout_request")
|
||||
self.assertEqual(logout_request.relay_state, expected_relay_state)
|
||||
|
||||
@@ -29,6 +29,24 @@ from authentik.providers.saml.views.flows import (
|
||||
LOGGER = get_logger()
|
||||
|
||||
|
||||
def _get_redirect_url(request: HttpRequest, relay_state: str = "") -> str:
|
||||
"""Get the safe redirect URL from the plan context, logging a warning if the
|
||||
incoming relay_state doesn't match the stored value."""
|
||||
stored_relay_state = ""
|
||||
if SESSION_KEY_PLAN in request.session:
|
||||
plan: FlowPlan = request.session[SESSION_KEY_PLAN]
|
||||
stored_relay_state = plan.context.get(PLAN_CONTEXT_SAML_RELAY_STATE, "")
|
||||
|
||||
if relay_state and relay_state != stored_relay_state:
|
||||
LOGGER.warning(
|
||||
"SAML logout relay_state mismatch, possible open redirect attempt",
|
||||
received_relay_state=relay_state,
|
||||
stored_relay_state=stored_relay_state,
|
||||
)
|
||||
|
||||
return stored_relay_state
|
||||
|
||||
|
||||
class SPInitiatedSLOView(PolicyAccessView):
|
||||
"""Handle SP-initiated SAML Single Logout requests"""
|
||||
|
||||
@@ -96,17 +114,9 @@ class SPInitiatedSLOBindingRedirectView(SPInitiatedSLOView):
|
||||
# IDP SLO, so we want to redirect to our next provider
|
||||
if REQUEST_KEY_SAML_RESPONSE in request.GET:
|
||||
relay_state = request.GET.get(REQUEST_KEY_RELAY_STATE, "")
|
||||
if relay_state:
|
||||
return redirect(relay_state)
|
||||
|
||||
# No RelayState provided, try to get return URL from plan context
|
||||
if SESSION_KEY_PLAN in request.session:
|
||||
plan: FlowPlan = request.session[SESSION_KEY_PLAN]
|
||||
relay_state = plan.context.get(PLAN_CONTEXT_SAML_RELAY_STATE)
|
||||
if relay_state:
|
||||
return redirect(relay_state)
|
||||
|
||||
# No relay state and no plan context - redirect to root
|
||||
redirect_url = _get_redirect_url(request, relay_state)
|
||||
if redirect_url:
|
||||
return redirect(redirect_url)
|
||||
return redirect("authentik_core:root-redirect")
|
||||
|
||||
# For SAML logout requests, use the parent dispatch with auth checks
|
||||
@@ -147,17 +157,9 @@ class SPInitiatedSLOBindingPOSTView(SPInitiatedSLOView):
|
||||
# IDP SLO, so we want to redirect to our next provider
|
||||
if REQUEST_KEY_SAML_RESPONSE in request.POST:
|
||||
relay_state = request.POST.get(REQUEST_KEY_RELAY_STATE, "")
|
||||
if relay_state:
|
||||
return redirect(relay_state)
|
||||
|
||||
# No RelayState provided, try to get return URL from plan context
|
||||
if SESSION_KEY_PLAN in request.session:
|
||||
plan: FlowPlan = request.session[SESSION_KEY_PLAN]
|
||||
relay_state = plan.context.get(PLAN_CONTEXT_SAML_RELAY_STATE)
|
||||
if relay_state:
|
||||
return redirect(relay_state)
|
||||
|
||||
# No relay state and no plan context - redirect to root
|
||||
redirect_url = _get_redirect_url(request, relay_state)
|
||||
if redirect_url:
|
||||
return redirect(redirect_url)
|
||||
return redirect("authentik_core:root-redirect")
|
||||
|
||||
# For SAML logout requests, use the parent dispatch with auth checks
|
||||
|
||||
@@ -10,6 +10,7 @@ from authentik.blueprints.tests import apply_blueprint
|
||||
from authentik.core.models import Application, Group, User
|
||||
from authentik.lib.generators import generate_id
|
||||
from authentik.providers.scim.models import SCIMMapping, SCIMProvider, SCIMProviderGroup
|
||||
from authentik.providers.scim.tasks import scim_sync
|
||||
|
||||
|
||||
class SCIMGroupTests(TestCase):
|
||||
@@ -205,3 +206,80 @@ class SCIMGroupTests(TestCase):
|
||||
self.assertEqual(mock.request_history[1].method, "POST")
|
||||
self.assertEqual(mock.request_history[2].method, "GET")
|
||||
self.assertNotIn("PUT", [req.method for req in mock.request_history])
|
||||
|
||||
def _create_stale_provider_group(self, scim_id: str) -> Group:
|
||||
"""Create a group that is outside the provider's scope (via group_filters) with an
|
||||
existing SCIMProviderGroup, simulating a previously synced group now out of scope."""
|
||||
self.app.backchannel_providers.remove(self.provider)
|
||||
anchor = Group.objects.create(name=generate_id())
|
||||
stale = Group.objects.create(name=generate_id())
|
||||
self.app.backchannel_providers.add(self.provider)
|
||||
|
||||
self.provider.group_filters.set([anchor])
|
||||
SCIMProviderGroup.objects.create(provider=self.provider, group=stale, scim_id=scim_id)
|
||||
return stale
|
||||
|
||||
@Mocker()
|
||||
def test_sync_cleanup_stale_group_delete(self, mock: Mocker):
|
||||
"""Stale (out-of-scope) groups are deleted during full sync cleanup"""
|
||||
scim_id = generate_id()
|
||||
mock.get("https://localhost/ServiceProviderConfig", json={})
|
||||
|
||||
mock.post("https://localhost/Groups", json={"id": generate_id()})
|
||||
mock.delete(f"https://localhost/Groups/{scim_id}", status_code=204)
|
||||
self._create_stale_provider_group(scim_id)
|
||||
|
||||
scim_sync.send(self.provider.pk).get_result()
|
||||
|
||||
delete_reqs = [r for r in mock.request_history if r.method == "DELETE"]
|
||||
self.assertEqual(len(delete_reqs), 1)
|
||||
self.assertEqual(delete_reqs[0].url, f"https://localhost/Groups/{scim_id}")
|
||||
self.assertFalse(
|
||||
SCIMProviderGroup.objects.filter(provider=self.provider, scim_id=scim_id).exists()
|
||||
)
|
||||
|
||||
@Mocker()
|
||||
def test_sync_cleanup_stale_group_not_found(self, mock: Mocker):
|
||||
"""Stale group cleanup handles 404 from the remote gracefully"""
|
||||
scim_id = generate_id()
|
||||
mock.get("https://localhost/ServiceProviderConfig", json={})
|
||||
mock.post("https://localhost/Groups", json={"id": generate_id()})
|
||||
mock.delete(f"https://localhost/Groups/{scim_id}", status_code=404)
|
||||
self._create_stale_provider_group(scim_id)
|
||||
|
||||
scim_sync.send(self.provider.pk).get_result()
|
||||
|
||||
delete_reqs = [r for r in mock.request_history if r.method == "DELETE"]
|
||||
self.assertEqual(len(delete_reqs), 1)
|
||||
|
||||
self.assertFalse(
|
||||
SCIMProviderGroup.objects.filter(provider=self.provider, scim_id=scim_id).exists()
|
||||
)
|
||||
|
||||
@Mocker()
|
||||
def test_sync_cleanup_stale_group_transient_error(self, mock: Mocker):
|
||||
"""Stale group cleanup logs and retries on transient HTTP errors"""
|
||||
scim_id = generate_id()
|
||||
mock.get("https://localhost/ServiceProviderConfig", json={})
|
||||
mock.post("https://localhost/Groups", json={"id": generate_id()})
|
||||
mock.delete(f"https://localhost/Groups/{scim_id}", status_code=429)
|
||||
self._create_stale_provider_group(scim_id)
|
||||
|
||||
scim_sync.send(self.provider.pk)
|
||||
|
||||
delete_reqs = [r for r in mock.request_history if r.method == "DELETE"]
|
||||
self.assertEqual(len(delete_reqs), 1)
|
||||
|
||||
@Mocker()
|
||||
def test_sync_cleanup_stale_group_dry_run(self, mock: Mocker):
|
||||
"""Stale group cleanup skips HTTP DELETE in dry_run mode"""
|
||||
self.provider.dry_run = True
|
||||
self.provider.save()
|
||||
scim_id = generate_id()
|
||||
mock.get("https://localhost/ServiceProviderConfig", json={})
|
||||
self._create_stale_provider_group(scim_id)
|
||||
|
||||
scim_sync.send(self.provider.pk)
|
||||
|
||||
delete_reqs = [r for r in mock.request_history if r.method == "DELETE"]
|
||||
self.assertEqual(len(delete_reqs), 0)
|
||||
|
||||
@@ -1,17 +1,19 @@
|
||||
"""SCIM User tests"""
|
||||
|
||||
from json import loads
|
||||
from unittest.mock import patch
|
||||
|
||||
from django.test import TestCase
|
||||
from jsonschema import validate
|
||||
from requests_mock import Mocker
|
||||
|
||||
from authentik.blueprints.tests import apply_blueprint
|
||||
from authentik.core.models import Application, Group, User
|
||||
from authentik.core.models import Application, Group, User, UserTypes
|
||||
from authentik.lib.generators import generate_id
|
||||
from authentik.lib.sync.outgoing.base import SAFE_METHODS
|
||||
from authentik.lib.sync.outgoing.exceptions import TransientSyncException
|
||||
from authentik.providers.scim.models import SCIMMapping, SCIMProvider, SCIMProviderUser
|
||||
from authentik.providers.scim.tasks import scim_sync, scim_sync_objects
|
||||
from authentik.providers.scim.tasks import scim_sync, scim_sync_objects, sync_tasks
|
||||
from authentik.tasks.models import Task
|
||||
from authentik.tenants.models import Tenant
|
||||
|
||||
@@ -537,3 +539,104 @@ class SCIMUserTests(TestCase):
|
||||
self.assertEqual(mock.call_count, 2)
|
||||
self.assertEqual(mock.request_history[0].method, "GET")
|
||||
self.assertEqual(mock.request_history[1].method, "POST")
|
||||
|
||||
def _create_stale_provider_user(self, scim_id: str, uid: str) -> User:
|
||||
"""Create a service-account user (excluded from provider scope) with an existing
|
||||
SCIMProviderUser, simulating a previously synced user that is now out of scope."""
|
||||
user = User.objects.create(
|
||||
username=uid,
|
||||
name=f"{uid} {uid}",
|
||||
email=f"{uid}@goauthentik.io",
|
||||
type=UserTypes.SERVICE_ACCOUNT,
|
||||
)
|
||||
SCIMProviderUser.objects.create(provider=self.provider, user=user, scim_id=scim_id)
|
||||
return user
|
||||
|
||||
@Mocker()
|
||||
def test_sync_cleanup_stale_user_delete(self, mock: Mocker):
|
||||
"""Stale (out-of-scope) users are deleted during full sync cleanup"""
|
||||
scim_id = generate_id()
|
||||
uid = generate_id()
|
||||
mock.get("https://localhost/ServiceProviderConfig", json={})
|
||||
mock.delete(f"https://localhost/Users/{scim_id}", status_code=204)
|
||||
self._create_stale_provider_user(scim_id, uid)
|
||||
|
||||
scim_sync.send(self.provider.pk).get_result()
|
||||
|
||||
delete_reqs = [r for r in mock.request_history if r.method == "DELETE"]
|
||||
self.assertEqual(len(delete_reqs), 1)
|
||||
self.assertEqual(delete_reqs[0].url, f"https://localhost/Users/{scim_id}")
|
||||
self.assertFalse(
|
||||
SCIMProviderUser.objects.filter(provider=self.provider, scim_id=scim_id).exists()
|
||||
)
|
||||
|
||||
@Mocker()
|
||||
def test_sync_cleanup_stale_user_not_found(self, mock: Mocker):
|
||||
"""Stale user cleanup handles 404 from the remote gracefully"""
|
||||
scim_id = generate_id()
|
||||
uid = generate_id()
|
||||
mock.get("https://localhost/ServiceProviderConfig", json={})
|
||||
mock.delete(f"https://localhost/Users/{scim_id}", status_code=404)
|
||||
self._create_stale_provider_user(scim_id, uid)
|
||||
|
||||
scim_sync.send(self.provider.pk).get_result()
|
||||
|
||||
delete_reqs = [r for r in mock.request_history if r.method == "DELETE"]
|
||||
self.assertEqual(len(delete_reqs), 1)
|
||||
|
||||
self.assertFalse(
|
||||
SCIMProviderUser.objects.filter(provider=self.provider, scim_id=scim_id).exists()
|
||||
)
|
||||
|
||||
@Mocker()
|
||||
def test_sync_cleanup_stale_user_transient_error(self, mock: Mocker):
|
||||
"""Stale user cleanup logs and retries on transient HTTP errors"""
|
||||
scim_id = generate_id()
|
||||
uid = generate_id()
|
||||
mock.get("https://localhost/ServiceProviderConfig", json={})
|
||||
mock.delete(f"https://localhost/Users/{scim_id}", status_code=429)
|
||||
self._create_stale_provider_user(scim_id, uid)
|
||||
|
||||
scim_sync.send(self.provider.pk)
|
||||
|
||||
delete_reqs = [r for r in mock.request_history if r.method == "DELETE"]
|
||||
self.assertEqual(len(delete_reqs), 1)
|
||||
|
||||
@Mocker()
|
||||
def test_sync_cleanup_stale_user_dry_run(self, mock: Mocker):
|
||||
"""Stale user cleanup skips HTTP DELETE in dry_run mode"""
|
||||
self.provider.dry_run = True
|
||||
self.provider.save()
|
||||
scim_id = generate_id()
|
||||
uid = generate_id()
|
||||
mock.get("https://localhost/ServiceProviderConfig", json={})
|
||||
self._create_stale_provider_user(scim_id, uid)
|
||||
|
||||
scim_sync.send(self.provider.pk)
|
||||
|
||||
delete_reqs = [r for r in mock.request_history if r.method == "DELETE"]
|
||||
self.assertEqual(len(delete_reqs), 0)
|
||||
|
||||
def test_sync_cleanup_client_for_model_transient(self):
|
||||
"""Cleanup silently skips an object type when client_for_model raises
|
||||
TransientSyncException"""
|
||||
with Mocker() as mock:
|
||||
mock.get("https://localhost/ServiceProviderConfig", json={})
|
||||
with patch.object(
|
||||
SCIMProvider,
|
||||
"client_for_model",
|
||||
side_effect=TransientSyncException("connection failed"),
|
||||
):
|
||||
scim_sync.send(self.provider.pk).get_result()
|
||||
|
||||
def test_sync_transient_exception(self):
|
||||
"""TransientSyncException in _sync_cleanup is caught by sync() which then
|
||||
schedules a retry"""
|
||||
with Mocker() as mock:
|
||||
mock.get("https://localhost/ServiceProviderConfig", json={})
|
||||
with patch.object(
|
||||
sync_tasks,
|
||||
"_sync_cleanup",
|
||||
side_effect=TransientSyncException("connection failed"),
|
||||
):
|
||||
scim_sync.send(self.provider.pk)
|
||||
|
||||
@@ -60,11 +60,7 @@ class LDAPSourceSerializer(SourceSerializer):
|
||||
sources = sources.exclude(pk=self.instance.pk)
|
||||
if sources.exists():
|
||||
raise ValidationError(
|
||||
{
|
||||
"sync_users_password": _(
|
||||
"Only a single LDAP Source with password synchronization is allowed"
|
||||
)
|
||||
}
|
||||
_("Only a single LDAP Source with password synchronization is allowed")
|
||||
)
|
||||
return sync_users_password
|
||||
|
||||
@@ -221,7 +217,7 @@ class LDAPSourceViewSet(UsedByMixin, ModelViewSet):
|
||||
for sync_class in SYNC_CLASSES:
|
||||
class_name = sync_class.name()
|
||||
all_objects.setdefault(class_name, [])
|
||||
for page in sync_class(source).get_objects(size_limit=10):
|
||||
for page in sync_class(source, Task()).get_objects(size_limit=10):
|
||||
for obj in page:
|
||||
obj: dict
|
||||
obj.pop("raw_attributes", None)
|
||||
|
||||
@@ -14,6 +14,7 @@ from django.utils.translation import gettext_lazy as _
|
||||
from ldap3 import ALL, NONE, RANDOM, Connection, Server, ServerPool, Tls
|
||||
from ldap3.core.exceptions import LDAPException, LDAPInsufficientAccessRightsResult, LDAPSchemaError
|
||||
from rest_framework.serializers import Serializer
|
||||
from structlog.stdlib import get_logger
|
||||
|
||||
from authentik.core.models import (
|
||||
Group,
|
||||
@@ -31,6 +32,7 @@ from authentik.tasks.schedules.common import ScheduleSpec
|
||||
LDAP_TIMEOUT = 15
|
||||
LDAP_UNIQUENESS = "ldap_uniq"
|
||||
LDAP_DISTINGUISHED_NAME = "distinguishedName"
|
||||
LOGGER = get_logger()
|
||||
|
||||
|
||||
def flatten(value: Any) -> Any:
|
||||
@@ -268,6 +270,7 @@ class LDAPSource(IncomingSyncSource):
|
||||
)
|
||||
|
||||
if self.start_tls:
|
||||
LOGGER.debug("Connection StartTLS", source=self)
|
||||
conn.start_tls(read_server_info=False)
|
||||
try:
|
||||
successful = conn.bind()
|
||||
@@ -278,7 +281,9 @@ class LDAPSource(IncomingSyncSource):
|
||||
# See https://github.com/goauthentik/authentik/issues/4590
|
||||
# See also https://github.com/goauthentik/authentik/issues/3399
|
||||
if server_kwargs.get("get_info", ALL) == NONE:
|
||||
LOGGER.warning("Failed to connect after schema downgrade", source=self, exc=exc)
|
||||
raise exc
|
||||
LOGGER.warning("Downgrading connection to no schema info", source=self, exc=exc)
|
||||
server_kwargs["get_info"] = NONE
|
||||
return self.connection(server, server_kwargs, connection_kwargs)
|
||||
finally:
|
||||
|
||||
@@ -27,7 +27,7 @@ from authentik.tasks.middleware import CurrentTask
|
||||
from authentik.tasks.models import Task
|
||||
|
||||
LOGGER = get_logger()
|
||||
SYNC_CLASSES = [
|
||||
SYNC_CLASSES: list[type[BaseLDAPSynchronizer]] = [
|
||||
UserLDAPSynchronizer,
|
||||
GroupLDAPSynchronizer,
|
||||
MembershipLDAPSynchronizer,
|
||||
|
||||
@@ -1,10 +1,19 @@
|
||||
"""LDAP Source API tests"""
|
||||
|
||||
from json import loads
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from django.db.models import Q
|
||||
from django.urls import reverse
|
||||
from rest_framework.exceptions import ErrorDetail
|
||||
from rest_framework.test import APITestCase
|
||||
|
||||
from authentik.lib.generators import generate_key
|
||||
from authentik.blueprints.tests import apply_blueprint
|
||||
from authentik.core.tests.utils import create_test_admin_user
|
||||
from authentik.lib.generators import generate_id, generate_key
|
||||
from authentik.sources.ldap.api import LDAPSourceSerializer
|
||||
from authentik.sources.ldap.models import LDAPSource
|
||||
from authentik.sources.ldap.models import LDAPSource, LDAPSourcePropertyMapping
|
||||
from authentik.sources.ldap.tests.mock_ad import mock_ad_connection
|
||||
|
||||
LDAP_PASSWORD = generate_key()
|
||||
|
||||
@@ -26,12 +35,13 @@ class LDAPAPITests(APITestCase):
|
||||
}
|
||||
)
|
||||
self.assertTrue(serializer.is_valid())
|
||||
self.assertEqual(serializer.errors, {})
|
||||
|
||||
def test_sync_users_password_invalid(self):
|
||||
"""Ensure only a single source with password sync can be created"""
|
||||
LDAPSource.objects.create(
|
||||
name="foo",
|
||||
slug="foo",
|
||||
slug=generate_id(),
|
||||
server_uri="ldaps://1.2.3.4",
|
||||
bind_cn="",
|
||||
bind_password=LDAP_PASSWORD,
|
||||
@@ -41,15 +51,26 @@ class LDAPAPITests(APITestCase):
|
||||
serializer = LDAPSourceSerializer(
|
||||
data={
|
||||
"name": "foo",
|
||||
"slug": " foo",
|
||||
"slug": generate_id(),
|
||||
"server_uri": "ldaps://1.2.3.4",
|
||||
"bind_cn": "",
|
||||
"bind_password": LDAP_PASSWORD,
|
||||
"base_dn": "dc=foo",
|
||||
"sync_users_password": False,
|
||||
"sync_users_password": True,
|
||||
}
|
||||
)
|
||||
self.assertFalse(serializer.is_valid())
|
||||
self.assertEqual(
|
||||
serializer.errors,
|
||||
{
|
||||
"sync_users_password": [
|
||||
ErrorDetail(
|
||||
string="Only a single LDAP Source with password synchronization is allowed",
|
||||
code="invalid",
|
||||
)
|
||||
]
|
||||
},
|
||||
)
|
||||
|
||||
def test_sync_users_mapping_empty(self):
|
||||
"""Check that when sync_users is enabled, property mappings must be set"""
|
||||
@@ -82,3 +103,32 @@ class LDAPAPITests(APITestCase):
|
||||
}
|
||||
)
|
||||
self.assertFalse(serializer.is_valid())
|
||||
|
||||
@apply_blueprint("system/sources-ldap.yaml")
|
||||
def test_sync_debug(self):
|
||||
user = create_test_admin_user()
|
||||
self.client.force_login(user)
|
||||
|
||||
source: LDAPSource = LDAPSource.objects.create(
|
||||
name=generate_id(),
|
||||
slug=generate_id(),
|
||||
base_dn="dc=goauthentik,dc=io",
|
||||
additional_user_dn="ou=users",
|
||||
additional_group_dn="ou=groups",
|
||||
)
|
||||
source.user_property_mappings.set(
|
||||
LDAPSourcePropertyMapping.objects.filter(
|
||||
Q(managed__startswith="goauthentik.io/sources/ldap/default")
|
||||
| Q(managed__startswith="goauthentik.io/sources/ldap/ms")
|
||||
)
|
||||
)
|
||||
connection = MagicMock(return_value=mock_ad_connection(LDAP_PASSWORD))
|
||||
with patch("authentik.sources.ldap.models.LDAPSource.connection", connection):
|
||||
res = self.client.get(
|
||||
reverse("authentik_api:ldapsource-debug", kwargs={"slug": source.slug})
|
||||
)
|
||||
self.assertEqual(res.status_code, 200)
|
||||
body = loads(res.content.decode())
|
||||
self.assertIn("users", body)
|
||||
self.assertIn("groups", body)
|
||||
self.assertIn("membership", body)
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user