mirror of
https://github.com/goauthentik/authentik
synced 2026-05-05 22:52:42 +02:00
Compare commits
132 Commits
version/20
...
cherry-pic
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
548aef46d2 | ||
|
|
7af9e98079 | ||
|
|
51901c82ba | ||
|
|
ff653005e4 | ||
|
|
9b64d05e35 | ||
|
|
99a93fa8a2 | ||
|
|
bd2a0e1d7d | ||
|
|
c4d455dd3a | ||
|
|
508dba6a04 | ||
|
|
aa921dcdca | ||
|
|
e5d873c129 | ||
|
|
f0a14d380f | ||
|
|
1da15a549e | ||
|
|
eaf1c45ea6 | ||
|
|
f0f42668c4 | ||
|
|
123fbd26bb | ||
|
|
b94d93b6c4 | ||
|
|
d0b25bf648 | ||
|
|
d4db4e50b4 | ||
|
|
c5e726d7eb | ||
|
|
203a7e0c61 | ||
|
|
2feaeff5db | ||
|
|
8fcc47e047 | ||
|
|
7a6408cc67 | ||
|
|
2da88028da | ||
|
|
fa91404895 | ||
|
|
460fce7279 | ||
|
|
995128955c | ||
|
|
85536abbcf | ||
|
|
5249546862 | ||
|
|
bf91348c05 | ||
|
|
63136f0180 | ||
|
|
faffabf938 | ||
|
|
0b180b15a2 | ||
|
|
07af6de74f | ||
|
|
ddfef91ea5 | ||
|
|
cefbf5e6ae | ||
|
|
e53d3d2486 | ||
|
|
32a3eed521 | ||
|
|
f05cc6e75a | ||
|
|
c68c36fdeb | ||
|
|
888f969fc7 | ||
|
|
82535e4671 | ||
|
|
ed2957e4e6 | ||
|
|
a5abe85148 | ||
|
|
8d2c31fa25 | ||
|
|
2637ce2474 | ||
|
|
319008dec8 | ||
|
|
8beb2fac18 | ||
|
|
ac7b28d0b0 | ||
|
|
073acf92c2 | ||
|
|
ad107c19af | ||
|
|
d285fcd8a7 | ||
|
|
84066cab48 | ||
|
|
e623d93ff5 | ||
|
|
1d0628dfbe | ||
|
|
996645105c | ||
|
|
63d7ca6ef0 | ||
|
|
5b24f4ad80 | ||
|
|
ed2e6cfb9c | ||
|
|
a1431ea48e | ||
|
|
b30e77b363 | ||
|
|
2f50cdd9fe | ||
|
|
494bdcaa09 | ||
|
|
e36ce1789e | ||
|
|
5a72ed83e0 | ||
|
|
f72d257e43 | ||
|
|
cbedb16cc4 | ||
|
|
6fc1b5ce90 | ||
|
|
57b0fa48c1 | ||
|
|
84a344ed87 | ||
|
|
f864cb56ab | ||
|
|
692735f9e1 | ||
|
|
e24fb300b1 | ||
|
|
f0e90d6873 | ||
|
|
0cf45835a0 | ||
|
|
69d35c1d26 | ||
|
|
ac803b210d | ||
|
|
c9728b4607 | ||
|
|
6e45584563 | ||
|
|
59a2e84b35 | ||
|
|
6025dbb9c9 | ||
|
|
d07bcd5025 | ||
|
|
e80655d285 | ||
|
|
e0d3d4d38c | ||
|
|
62112404ee | ||
|
|
1c9e12fcd9 | ||
|
|
42c6c257ec | ||
|
|
41bd9d7913 | ||
|
|
2c84935732 | ||
|
|
819c13a9bc | ||
|
|
0d8f366af8 | ||
|
|
093e60c753 | ||
|
|
af646f32d2 | ||
|
|
de4afc7322 | ||
|
|
bc1983106f | ||
|
|
8c2c1474f1 | ||
|
|
0dccbd4193 | ||
|
|
6a70894e01 | ||
|
|
2f5eb9b2e4 | ||
|
|
12aedb3a9e | ||
|
|
303dc93514 | ||
|
|
fbb217db57 | ||
|
|
4de253653f | ||
|
|
4154c06831 | ||
|
|
4750ed5e2a | ||
|
|
361017127d | ||
|
|
0ca5a54307 | ||
|
|
ef1aad5dbb | ||
|
|
29d880920e | ||
|
|
fc6f8374e6 | ||
|
|
a8668bbac4 | ||
|
|
d686932166 | ||
|
|
feceb220b1 | ||
|
|
937df6e07f | ||
|
|
48e6b968a6 | ||
|
|
cd89c45e75 | ||
|
|
e53995e2c1 | ||
|
|
33d5f11f0e | ||
|
|
565e16eca7 | ||
|
|
9a0164b722 | ||
|
|
8af491630b | ||
|
|
8e25e7a213 | ||
|
|
4d183657da | ||
|
|
be89b6052d | ||
|
|
ad5d2bb611 | ||
|
|
8d30fb3d25 | ||
|
|
cea3fbfa9b | ||
|
|
151d889ff4 | ||
|
|
58ca3ecbd5 | ||
|
|
1a6c7082a3 | ||
|
|
1dc60276f9 |
23
.github/actions/cherry-pick/action.yml
vendored
23
.github/actions/cherry-pick/action.yml
vendored
@@ -115,20 +115,13 @@ runs:
|
||||
shell: bash
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ inputs.token }}
|
||||
PR_NUMBER: ${{ steps.should_run.outputs.pr_number }}
|
||||
REASON: ${{ steps.should_run.outputs.reason }}
|
||||
run: |
|
||||
set -e -o pipefail
|
||||
PR_NUMBER="${{ steps.should_run.outputs.pr_number }}"
|
||||
|
||||
# Get PR details
|
||||
PR_DATA=$(gh api repos/${{ github.repository }}/pulls/$PR_NUMBER)
|
||||
PR_TITLE=$(echo "$PR_DATA" | jq -r '.title')
|
||||
PR_AUTHOR=$(echo "$PR_DATA" | jq -r '.user.login')
|
||||
|
||||
echo "pr_title=$PR_TITLE" >> $GITHUB_OUTPUT
|
||||
echo "pr_author=$PR_AUTHOR" >> $GITHUB_OUTPUT
|
||||
|
||||
# Determine which labels to process
|
||||
if [ "${{ steps.should_run.outputs.reason }}" = "label_added_to_merged_pr" ]; then
|
||||
if [ "${REASON}" = "label_added_to_merged_pr" ]; then
|
||||
# Only process the specific label that was just added
|
||||
if [ "${{ github.event_name }}" = "issues" ]; then
|
||||
LABEL_NAME="${{ github.event.label.name }}"
|
||||
@@ -152,13 +145,13 @@ runs:
|
||||
shell: bash
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ inputs.token }}
|
||||
PR_NUMBER: '${{ steps.should_run.outputs.pr_number }}'
|
||||
COMMIT_SHA: '${{ steps.should_run.outputs.merge_commit_sha }}'
|
||||
PR_TITLE: ${{ github.event.pull_request.title }}
|
||||
PR_AUTHOR: ${{ github.event.pull_request.user.login }}
|
||||
LABELS: '${{ steps.pr_details.outputs.labels }}'
|
||||
run: |
|
||||
set -e -o pipefail
|
||||
PR_NUMBER='${{ steps.should_run.outputs.pr_number }}'
|
||||
COMMIT_SHA='${{ steps.should_run.outputs.merge_commit_sha }}'
|
||||
PR_TITLE='${{ steps.pr_details.outputs.pr_title }}'
|
||||
PR_AUTHOR='${{ steps.pr_details.outputs.pr_author }}'
|
||||
LABELS='${{ steps.pr_details.outputs.labels }}'
|
||||
|
||||
echo "Processing PR #$PR_NUMBER (reason: ${{ steps.should_run.outputs.reason }})"
|
||||
echo "Found backport labels: $LABELS"
|
||||
|
||||
@@ -89,6 +89,8 @@ if should_push:
|
||||
_cache_tag = "buildcache"
|
||||
if image_arch:
|
||||
_cache_tag += f"-{image_arch}"
|
||||
if is_release:
|
||||
_cache_tag += f"-{version_family}"
|
||||
cache_to = f"type=registry,ref={get_attest_image_names(image_tags)}:{_cache_tag},mode=max"
|
||||
|
||||
|
||||
|
||||
36
.github/actions/setup/action.yml
vendored
36
.github/actions/setup/action.yml
vendored
@@ -8,18 +8,33 @@ inputs:
|
||||
postgresql_version:
|
||||
description: "Optional postgresql image tag"
|
||||
default: "16"
|
||||
working-directory:
|
||||
description: |
|
||||
Optional working directory if this repo isn't in the root of the actions workspace.
|
||||
When set, needs to contain a trailing slash
|
||||
default: ""
|
||||
|
||||
runs:
|
||||
using: "composite"
|
||||
steps:
|
||||
- name: Install apt deps & cleanup
|
||||
- name: Cleanup apt
|
||||
if: ${{ contains(inputs.dependencies, 'system') || contains(inputs.dependencies, 'python') }}
|
||||
shell: bash
|
||||
run: sudo apt-get remove --purge man-db
|
||||
- name: Install apt deps
|
||||
if: ${{ contains(inputs.dependencies, 'system') || contains(inputs.dependencies, 'python') }}
|
||||
uses: gerlero/apt-install@f4fa5265092af9e750549565d28c99aec7189639
|
||||
with:
|
||||
packages: libpq-dev openssl libxmlsec1-dev pkg-config gettext krb5-multidev libkrb5-dev heimdal-multidev libclang-dev krb5-kdc krb5-user krb5-admin-server
|
||||
update: true
|
||||
upgrade: false
|
||||
install-recommends: false
|
||||
- name: Make space on disk
|
||||
if: ${{ contains(inputs.dependencies, 'system') || contains(inputs.dependencies, 'python') }}
|
||||
shell: bash
|
||||
run: |
|
||||
sudo apt-get remove --purge man-db
|
||||
sudo apt-get update
|
||||
sudo apt-get install --no-install-recommends -y libpq-dev openssl libxmlsec1-dev pkg-config gettext krb5-multidev libkrb5-dev heimdal-multidev libclang-dev krb5-kdc krb5-user krb5-admin-server
|
||||
sudo rm -rf /usr/local/lib/android
|
||||
sudo mkdir -p /tmp/empty/
|
||||
sudo rsync -a --delete /tmp/empty/ /usr/local/lib/android/
|
||||
- name: Install uv
|
||||
if: ${{ contains(inputs.dependencies, 'python') }}
|
||||
uses: astral-sh/setup-uv@eac588ad8def6316056a12d4907a9d4d84ff7a3b # v5
|
||||
@@ -29,24 +44,25 @@ runs:
|
||||
if: ${{ contains(inputs.dependencies, 'python') }}
|
||||
uses: actions/setup-python@a309ff8b426b58ec0e2a45f0f869d46889d02405 # v5
|
||||
with:
|
||||
python-version-file: "pyproject.toml"
|
||||
python-version-file: "${{ inputs.working-directory }}pyproject.toml"
|
||||
- name: Install Python deps
|
||||
if: ${{ contains(inputs.dependencies, 'python') }}
|
||||
shell: bash
|
||||
working-directory: ${{ inputs.working-directory }}
|
||||
run: uv sync --all-extras --dev --frozen
|
||||
- name: Setup node
|
||||
if: ${{ contains(inputs.dependencies, 'node') }}
|
||||
uses: actions/setup-node@6044e13b5dc448c55e2357c09f80417699197238 # v4
|
||||
with:
|
||||
node-version-file: web/package.json
|
||||
node-version-file: ${{ inputs.working-directory }}web/package.json
|
||||
cache: "npm"
|
||||
cache-dependency-path: web/package-lock.json
|
||||
cache-dependency-path: ${{ inputs.working-directory }}web/package-lock.json
|
||||
registry-url: 'https://registry.npmjs.org'
|
||||
- name: Setup go
|
||||
if: ${{ contains(inputs.dependencies, 'go') }}
|
||||
uses: actions/setup-go@7a3fe6cf4cb3a834922a1244abfce67bcef6a0c5 # v5
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
go-version-file: "${{ inputs.working-directory }}go.mod"
|
||||
- name: Setup docker cache
|
||||
if: ${{ contains(inputs.dependencies, 'runtime') }}
|
||||
uses: AndreKurait/docker-cache@0fe76702a40db986d9663c24954fc14c6a6031b7
|
||||
@@ -55,6 +71,7 @@ runs:
|
||||
- name: Setup dependencies
|
||||
if: ${{ contains(inputs.dependencies, 'runtime') }}
|
||||
shell: bash
|
||||
working-directory: ${{ inputs.working-directory }}
|
||||
run: |
|
||||
export PSQL_TAG=${{ inputs.postgresql_version }}
|
||||
docker compose -f .github/actions/setup/compose.yml up -d
|
||||
@@ -62,6 +79,7 @@ runs:
|
||||
- name: Generate config
|
||||
if: ${{ contains(inputs.dependencies, 'python') }}
|
||||
shell: uv run python {0}
|
||||
working-directory: ${{ inputs.working-directory }}
|
||||
run: |
|
||||
from authentik.lib.generators import generate_id
|
||||
from yaml import safe_dump
|
||||
|
||||
2
.github/actions/setup/compose.yml
vendored
2
.github/actions/setup/compose.yml
vendored
@@ -2,7 +2,7 @@ services:
|
||||
postgresql:
|
||||
image: docker.io/library/postgres:${PSQL_TAG:-16}
|
||||
volumes:
|
||||
- db-data:/var/lib/postgresql/data
|
||||
- db-data:/var/lib/postgresql
|
||||
command: "-c log_statement=all"
|
||||
environment:
|
||||
POSTGRES_USER: authentik
|
||||
|
||||
2
.github/workflows/api-ts-publish.yml
vendored
2
.github/workflows/api-ts-publish.yml
vendored
@@ -21,7 +21,7 @@ jobs:
|
||||
uses: actions/create-github-app-token@29824e69f54612133e76f7eaac726eef6c875baf # v2
|
||||
with:
|
||||
app-id: ${{ secrets.GH_APP_ID }}
|
||||
private-key: ${{ secrets.GH_APP_PRIVATE_KEY }}
|
||||
private-key: ${{ secrets.GH_APP_PRIV_KEY }}
|
||||
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v5
|
||||
with:
|
||||
token: ${{ steps.generate_token.outputs.token }}
|
||||
|
||||
5
.github/workflows/ci-main.yml
vendored
5
.github/workflows/ci-main.yml
vendored
@@ -95,7 +95,10 @@ jobs:
|
||||
with:
|
||||
postgresql_version: ${{ matrix.psql }}
|
||||
- name: run migrations to stable
|
||||
run: uv run python -m lifecycle.migrate
|
||||
run: |
|
||||
docker ps
|
||||
docker logs setup-postgresql-1
|
||||
uv run python -m lifecycle.migrate
|
||||
- name: checkout current code
|
||||
run: |
|
||||
set -x
|
||||
|
||||
2
.github/workflows/gen-image-compress.yml
vendored
2
.github/workflows/gen-image-compress.yml
vendored
@@ -32,7 +32,7 @@ jobs:
|
||||
uses: actions/create-github-app-token@29824e69f54612133e76f7eaac726eef6c875baf # v2
|
||||
with:
|
||||
app-id: ${{ secrets.GH_APP_ID }}
|
||||
private-key: ${{ secrets.GH_APP_PRIVATE_KEY }}
|
||||
private-key: ${{ secrets.GH_APP_PRIV_KEY }}
|
||||
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v5
|
||||
with:
|
||||
token: ${{ steps.generate_token.outputs.token }}
|
||||
|
||||
@@ -19,7 +19,7 @@ jobs:
|
||||
uses: actions/create-github-app-token@29824e69f54612133e76f7eaac726eef6c875baf # v2
|
||||
with:
|
||||
app-id: ${{ secrets.GH_APP_ID }}
|
||||
private-key: ${{ secrets.GH_APP_PRIVATE_KEY }}
|
||||
private-key: ${{ secrets.GH_APP_PRIV_KEY }}
|
||||
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v5
|
||||
with:
|
||||
token: ${{ steps.generate_token.outputs.token }}
|
||||
|
||||
2
.github/workflows/gh-cherry-pick.yml
vendored
2
.github/workflows/gh-cherry-pick.yml
vendored
@@ -14,7 +14,7 @@ jobs:
|
||||
if: ${{ env.GH_APP_ID != '' }}
|
||||
with:
|
||||
app-id: ${{ secrets.GH_APP_ID }}
|
||||
private-key: ${{ secrets.GH_APP_PRIVATE_KEY }}
|
||||
private-key: ${{ secrets.GH_APP_PRIV_KEY }}
|
||||
env:
|
||||
GH_APP_ID: ${{ secrets.GH_APP_ID }}
|
||||
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v5
|
||||
|
||||
2
.github/workflows/gh-ghcr-retention.yml
vendored
2
.github/workflows/gh-ghcr-retention.yml
vendored
@@ -19,7 +19,7 @@ jobs:
|
||||
uses: actions/create-github-app-token@29824e69f54612133e76f7eaac726eef6c875baf # v2
|
||||
with:
|
||||
app-id: ${{ secrets.GH_APP_ID }}
|
||||
private-key: ${{ secrets.GH_APP_PRIVATE_KEY }}
|
||||
private-key: ${{ secrets.GH_APP_PRIV_KEY }}
|
||||
- name: Delete 'dev' containers older than a week
|
||||
uses: snok/container-retention-policy@3b0972b2276b171b212f8c4efbca59ebba26eceb # v3.0.1
|
||||
with:
|
||||
|
||||
4
.github/workflows/release-branch-off.yml
vendored
4
.github/workflows/release-branch-off.yml
vendored
@@ -32,7 +32,7 @@ jobs:
|
||||
uses: actions/create-github-app-token@29824e69f54612133e76f7eaac726eef6c875baf # v2
|
||||
with:
|
||||
app-id: ${{ secrets.GH_APP_ID }}
|
||||
private-key: ${{ secrets.GH_APP_PRIVATE_KEY }}
|
||||
private-key: ${{ secrets.GH_APP_PRIV_KEY }}
|
||||
- name: Checkout main
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v5
|
||||
with:
|
||||
@@ -60,7 +60,7 @@ jobs:
|
||||
uses: actions/create-github-app-token@29824e69f54612133e76f7eaac726eef6c875baf # v2
|
||||
with:
|
||||
app-id: ${{ secrets.GH_APP_ID }}
|
||||
private-key: ${{ secrets.GH_APP_PRIVATE_KEY }}
|
||||
private-key: ${{ secrets.GH_APP_PRIV_KEY }}
|
||||
- name: Checkout main
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v5
|
||||
with:
|
||||
|
||||
14
.github/workflows/release-tag.yml
vendored
14
.github/workflows/release-tag.yml
vendored
@@ -70,7 +70,7 @@ jobs:
|
||||
uses: actions/create-github-app-token@29824e69f54612133e76f7eaac726eef6c875baf # v2
|
||||
with:
|
||||
app-id: ${{ secrets.GH_APP_ID }}
|
||||
private-key: ${{ secrets.GH_APP_PRIVATE_KEY }}
|
||||
private-key: ${{ secrets.GH_APP_PRIV_KEY }}
|
||||
- id: get-user-id
|
||||
name: Get GitHub app user ID
|
||||
run: echo "user-id=$(gh api "/users/${{ steps.app-token.outputs.app-slug }}[bot]" --jq .id)" >> "$GITHUB_OUTPUT"
|
||||
@@ -118,7 +118,7 @@ jobs:
|
||||
uses: actions/create-github-app-token@29824e69f54612133e76f7eaac726eef6c875baf # v2
|
||||
with:
|
||||
app-id: ${{ secrets.GH_APP_ID }}
|
||||
private-key: ${{ secrets.GH_APP_PRIVATE_KEY }}
|
||||
private-key: ${{ secrets.GH_APP_PRIV_KEY }}
|
||||
repositories: helm
|
||||
- id: get-user-id
|
||||
name: Get GitHub app user ID
|
||||
@@ -160,7 +160,7 @@ jobs:
|
||||
uses: actions/create-github-app-token@29824e69f54612133e76f7eaac726eef6c875baf # v2
|
||||
with:
|
||||
app-id: ${{ secrets.GH_APP_ID }}
|
||||
private-key: ${{ secrets.GH_APP_PRIVATE_KEY }}
|
||||
private-key: ${{ secrets.GH_APP_PRIV_KEY }}
|
||||
repositories: version
|
||||
- id: get-user-id
|
||||
name: Get GitHub app user ID
|
||||
@@ -175,21 +175,25 @@ jobs:
|
||||
if: "${{ inputs.release_reason == 'feature' }}"
|
||||
run: |
|
||||
changelog_url="https://docs.goauthentik.io/docs/releases/${{ needs.check-inputs.outputs.major_version }}"
|
||||
reason="${{ inputs.release_reason }}"
|
||||
jq \
|
||||
--arg version "${{ inputs.version }}" \
|
||||
--arg changelog "See ${changelog_url}" \
|
||||
--arg changelog_url "${changelog_url}" \
|
||||
'.stable.version = $version | .stable.changelog = $changelog | .stable.changelog_url = $changelog_url' version.json > version.new.json
|
||||
--arg reason "${reason}" \
|
||||
'.stable.version = $version | .stable.changelog = $changelog | .stable.changelog_url = $changelog_url | .stable.reason = $reason' version.json > version.new.json
|
||||
mv version.new.json version.json
|
||||
- name: Bump version
|
||||
if: "${{ inputs.release_reason != 'feature' }}"
|
||||
run: |
|
||||
changelog_url="https://docs.goauthentik.io/docs/releases/${{ needs.check-inputs.outputs.major_version }}#fixed-in-$(echo -n ${{ inputs.version}} | sed 's/\.//g')"
|
||||
reason="${{ inputs.release_reason }}"
|
||||
jq \
|
||||
--arg version "${{ inputs.version }}" \
|
||||
--arg changelog "See ${changelog_url}" \
|
||||
--arg changelog_url "${changelog_url}" \
|
||||
'.stable.version = $version | .stable.changelog = $changelog | .stable.changelog_url = $changelog_url' version.json > version.new.json
|
||||
--arg reason "${reason}" \
|
||||
'.stable.version = $version | .stable.changelog = $changelog | .stable.changelog_url = $changelog_url | .stable.reason = $reason' version.json > version.new.json
|
||||
mv version.new.json version.json
|
||||
- name: Create pull request
|
||||
uses: peter-evans/create-pull-request@c0f553fe549906ede9cf27b5156039d195d2ece0 # v7
|
||||
|
||||
2
.github/workflows/repo-stale.yml
vendored
2
.github/workflows/repo-stale.yml
vendored
@@ -18,7 +18,7 @@ jobs:
|
||||
uses: actions/create-github-app-token@29824e69f54612133e76f7eaac726eef6c875baf # v2
|
||||
with:
|
||||
app-id: ${{ secrets.GH_APP_ID }}
|
||||
private-key: ${{ secrets.GH_APP_PRIVATE_KEY }}
|
||||
private-key: ${{ secrets.GH_APP_PRIV_KEY }}
|
||||
- uses: actions/stale@997185467fa4f803885201cee163a9f38240193d # v10
|
||||
with:
|
||||
repo-token: ${{ steps.generate_token.outputs.token }}
|
||||
|
||||
@@ -24,7 +24,7 @@ jobs:
|
||||
uses: actions/create-github-app-token@29824e69f54612133e76f7eaac726eef6c875baf # v2
|
||||
with:
|
||||
app-id: ${{ secrets.GH_APP_ID }}
|
||||
private-key: ${{ secrets.GH_APP_PRIVATE_KEY }}
|
||||
private-key: ${{ secrets.GH_APP_PRIV_KEY }}
|
||||
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v5
|
||||
if: ${{ github.event_name != 'pull_request' }}
|
||||
with:
|
||||
|
||||
42
SECURITY.md
42
SECURITY.md
@@ -18,10 +18,10 @@ Even if the issue is not a CVE, we still greatly appreciate your help in hardeni
|
||||
|
||||
(.x being the latest patch release for each version)
|
||||
|
||||
| Version | Supported |
|
||||
| ---------- | ---------- |
|
||||
| 2025.10.x | ✅ |
|
||||
| 2025.12.x | ✅ |
|
||||
| Version | Supported |
|
||||
| --------- | --------- |
|
||||
| 2025.12.x | ✅ |
|
||||
| 2026.2.x | ✅ |
|
||||
|
||||
## Reporting a Vulnerability
|
||||
|
||||
@@ -60,6 +60,40 @@ authentik reserves the right to reclassify CVSS as necessary. To determine sever
|
||||
| 7.0 – 8.9 | High |
|
||||
| 9.0 – 10.0 | Critical |
|
||||
|
||||
## Intended functionality
|
||||
|
||||
The following capabilities are part of intentional system design and should not be reported as security vulnerabilities:
|
||||
|
||||
- Expressions (property mappings/policies/prompts) can execute arbitrary Python code without safeguards.
|
||||
|
||||
This is expected behavior. Any user with permission to create or modify objects containing expression fields can write code that is executed within authentik. If a vulnerability allows a user without the required permissions to write or modify code and have it executed, that would be a valid security report.
|
||||
|
||||
However, the fact that expressions are executed as part of normal operations is not considered a privilege escalation or security vulnerability.
|
||||
|
||||
- Blueprints can access all files on the filesystem.
|
||||
|
||||
This access is intentional to allow legitimate configuration and deployment tasks. It does not represent a security problem by itself.
|
||||
|
||||
- Importing blueprints allows arbitrary modification of application objects.
|
||||
|
||||
This is intended functionality. This behavior reflects the privileged design of blueprint imports. It is "exploitable" when importing blueprints from untrusted sources without reviewing the blueprint beforehand. However, any method to create, modify or execute blueprints without the required permissions would be a valid security report.
|
||||
|
||||
- Flow imports may contain objects other than flows (such as policies, users, groups, etc.)
|
||||
|
||||
This is expected behavior as flow imports are blueprint files.
|
||||
|
||||
- Prompt HTML is not escaped.
|
||||
|
||||
Prompts intentionally allow raw HTML, including script tags, so they can be used to create interactive or customized user interface elements. Because of this, scripts within prompts may affect or interact with the surrounding page as designed.
|
||||
|
||||
- Open redirects that do not include tokens or other sensitive information are not considered a security vulnerability.
|
||||
|
||||
Redirects that only change navigation flow and do not expose session tokens, API keys, or other confidential data are considered acceptable and do not require reporting.
|
||||
|
||||
- Outgoing network requests are not filtered.
|
||||
|
||||
The destinations of outgoing network requests (HTTP, TCP, etc.) made by authentik to configurable endpoints through objects such as OAuth Sources, SSO Providers, and others are not validated. Depending on your threat model, these requests should be restricted at the network level using appropriate firewall or network policies.
|
||||
|
||||
## Disclosure process
|
||||
|
||||
1. Report from Github or Issue is reported via Email as listed above.
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
from functools import lru_cache
|
||||
from os import environ
|
||||
|
||||
VERSION = "2026.2.0-rc5"
|
||||
VERSION = "2026.2.3-rc1"
|
||||
ENV_GIT_HASH_KEY = "GIT_BUILD_HASH"
|
||||
|
||||
|
||||
|
||||
@@ -1,31 +1,73 @@
|
||||
"""authentik API Modelviewset tests"""
|
||||
|
||||
from collections.abc import Callable
|
||||
from urllib.parse import urlencode
|
||||
|
||||
from django.test import TestCase
|
||||
from rest_framework.viewsets import ModelViewSet, ReadOnlyModelViewSet
|
||||
|
||||
from authentik.admin.api.version_history import VersionHistoryViewSet
|
||||
from authentik.api.v3.urls import router
|
||||
from authentik.core.tests.utils import RequestFactory, create_test_admin_user
|
||||
from authentik.lib.generators import generate_id
|
||||
from authentik.tenants.api.domains import DomainViewSet
|
||||
from authentik.tenants.api.tenants import TenantViewSet
|
||||
from authentik.tenants.utils import get_current_tenant
|
||||
|
||||
|
||||
class TestModelViewSets(TestCase):
|
||||
"""Test Viewset"""
|
||||
|
||||
def setUp(self):
|
||||
self.user = create_test_admin_user()
|
||||
self.factory = RequestFactory()
|
||||
|
||||
def viewset_tester_factory(test_viewset: type[ModelViewSet]) -> Callable:
|
||||
|
||||
def viewset_tester_factory(test_viewset: type[ModelViewSet], full=True) -> dict[str, Callable]:
|
||||
"""Test Viewset"""
|
||||
|
||||
def tester(self: TestModelViewSets):
|
||||
self.assertIsNotNone(getattr(test_viewset, "search_fields", None))
|
||||
def test_attrs(self: TestModelViewSets) -> None:
|
||||
"""Test attributes we require on all viewsets"""
|
||||
self.assertIsNotNone(getattr(test_viewset, "ordering", None))
|
||||
self.assertIsNotNone(getattr(test_viewset, "search_fields", None))
|
||||
filterset_class = getattr(test_viewset, "filterset_class", None)
|
||||
if not filterset_class:
|
||||
self.assertIsNotNone(getattr(test_viewset, "filterset_fields", None))
|
||||
|
||||
return tester
|
||||
def test_ordering(self: TestModelViewSets) -> None:
|
||||
"""Test that all ordering fields are correct"""
|
||||
view = test_viewset.as_view({"get": "list"})
|
||||
for ordering_field in test_viewset.ordering:
|
||||
with self.subTest(ordering_field):
|
||||
req = self.factory.get(
|
||||
f"/?{urlencode({'ordering': ordering_field}, doseq=True)}", user=self.user
|
||||
)
|
||||
req.tenant = get_current_tenant()
|
||||
res = view(req)
|
||||
self.assertEqual(res.status_code, 200)
|
||||
|
||||
def test_search(self: TestModelViewSets) -> None:
|
||||
"""Test that search fields are correct"""
|
||||
view = test_viewset.as_view({"get": "list"})
|
||||
req = self.factory.get(
|
||||
f"/?{urlencode({'search': generate_id()}, doseq=True)}", user=self.user
|
||||
)
|
||||
req.tenant = get_current_tenant()
|
||||
res = view(req)
|
||||
self.assertEqual(res.status_code, 200)
|
||||
|
||||
cases = {
|
||||
"attrs": test_attrs,
|
||||
}
|
||||
if full:
|
||||
cases["ordering"] = test_ordering
|
||||
cases["search"] = test_search
|
||||
return cases
|
||||
|
||||
|
||||
for _, viewset, _ in router.registry:
|
||||
if not issubclass(viewset, ModelViewSet | ReadOnlyModelViewSet):
|
||||
continue
|
||||
setattr(TestModelViewSets, f"test_viewset_{viewset.__name__}", viewset_tester_factory(viewset))
|
||||
full = viewset not in [VersionHistoryViewSet, DomainViewSet, TenantViewSet]
|
||||
for test, case in viewset_tester_factory(viewset, full=full).items():
|
||||
setattr(TestModelViewSets, f"test_viewset_{viewset.__name__}_{test}", case)
|
||||
|
||||
@@ -3,7 +3,6 @@
|
||||
import traceback
|
||||
from collections.abc import Callable
|
||||
from importlib import import_module
|
||||
from inspect import ismethod
|
||||
|
||||
from django.apps import AppConfig
|
||||
from django.conf import settings
|
||||
@@ -72,12 +71,19 @@ class ManagedAppConfig(AppConfig):
|
||||
|
||||
def _reconcile(self, prefix: str) -> None:
|
||||
for meth_name in dir(self):
|
||||
meth = getattr(self, meth_name)
|
||||
if not ismethod(meth):
|
||||
# Check the attribute on the class to avoid evaluating @property descriptors.
|
||||
# Using getattr(self, ...) on a @property would evaluate it, which can trigger
|
||||
# expensive side effects (e.g. tenant_schedule_specs iterating all providers
|
||||
# and running PolicyEngine queries for every user).
|
||||
class_attr = getattr(type(self), meth_name, None)
|
||||
if class_attr is None or isinstance(class_attr, property):
|
||||
continue
|
||||
category = getattr(meth, "_authentik_managed_reconcile", None)
|
||||
if not callable(class_attr):
|
||||
continue
|
||||
category = getattr(class_attr, "_authentik_managed_reconcile", None)
|
||||
if category != prefix:
|
||||
continue
|
||||
meth = getattr(self, meth_name)
|
||||
name = meth_name.replace(prefix, "")
|
||||
try:
|
||||
self.logger.debug("Starting reconciler", name=name)
|
||||
|
||||
@@ -47,7 +47,8 @@ class ApplicationEntitlementViewSet(UsedByMixin, ModelViewSet):
|
||||
search_fields = [
|
||||
"pbm_uuid",
|
||||
"name",
|
||||
"app",
|
||||
"app__name",
|
||||
"app__slug",
|
||||
"attributes",
|
||||
]
|
||||
filterset_fields = [
|
||||
|
||||
@@ -47,7 +47,12 @@ class ApplicationSerializer(ModelSerializer):
|
||||
"""Application Serializer"""
|
||||
|
||||
launch_url = SerializerMethodField()
|
||||
provider_obj = ProviderSerializer(source="get_provider", required=False, read_only=True)
|
||||
provider_obj = ProviderSerializer(
|
||||
source="get_provider",
|
||||
required=False,
|
||||
read_only=True,
|
||||
allow_null=True,
|
||||
)
|
||||
backchannel_providers_obj = ProviderSerializer(
|
||||
source="backchannel_providers", required=False, read_only=True, many=True
|
||||
)
|
||||
|
||||
@@ -32,19 +32,19 @@ from authentik.rbac.decorators import permission_required
|
||||
class UserAgentDeviceDict(TypedDict):
|
||||
"""User agent device"""
|
||||
|
||||
brand: str
|
||||
brand: str | None = None
|
||||
family: str
|
||||
model: str
|
||||
model: str | None = None
|
||||
|
||||
|
||||
class UserAgentOSDict(TypedDict):
|
||||
"""User agent os"""
|
||||
|
||||
family: str
|
||||
major: str
|
||||
minor: str
|
||||
patch: str
|
||||
patch_minor: str
|
||||
major: str | None = None
|
||||
minor: str | None = None
|
||||
patch: str | None = None
|
||||
patch_minor: str | None = None
|
||||
|
||||
|
||||
class UserAgentBrowserDict(TypedDict):
|
||||
|
||||
@@ -17,7 +17,6 @@ from django.contrib.sessions.base_session import AbstractBaseSession
|
||||
from django.core.validators import validate_slug
|
||||
from django.db import models
|
||||
from django.db.models import Q, QuerySet, options
|
||||
from django.db.models.constants import LOOKUP_SEP
|
||||
from django.http import HttpRequest
|
||||
from django.utils.functional import cached_property
|
||||
from django.utils.timezone import now
|
||||
@@ -45,6 +44,7 @@ from authentik.lib.models import (
|
||||
DomainlessFormattedURLValidator,
|
||||
SerializerModel,
|
||||
)
|
||||
from authentik.lib.utils.inheritance import get_deepest_child
|
||||
from authentik.lib.utils.time import timedelta_from_string
|
||||
from authentik.policies.models import PolicyBindingModel
|
||||
from authentik.rbac.models import Role
|
||||
@@ -803,25 +803,7 @@ class Application(SerializerModel, PolicyBindingModel):
|
||||
"""Get casted provider instance. Needs Application queryset with_provider"""
|
||||
if not self.provider:
|
||||
return None
|
||||
|
||||
candidates = []
|
||||
base_class = Provider
|
||||
for subclass in base_class.objects.get_queryset()._get_subclasses_recurse(base_class):
|
||||
parent = self.provider
|
||||
for level in subclass.split(LOOKUP_SEP):
|
||||
try:
|
||||
parent = getattr(parent, level)
|
||||
except AttributeError:
|
||||
break
|
||||
if parent in candidates:
|
||||
continue
|
||||
idx = subclass.count(LOOKUP_SEP)
|
||||
if type(parent) is not base_class:
|
||||
idx += 1
|
||||
candidates.insert(idx, parent)
|
||||
if not candidates:
|
||||
return None
|
||||
return candidates[-1]
|
||||
return get_deepest_child(self.provider)
|
||||
|
||||
def backchannel_provider_for[T: Provider](self, provider_type: type[T], **kwargs) -> T | None:
|
||||
"""Get Backchannel provider for a specific type"""
|
||||
@@ -1133,7 +1115,11 @@ class ExpiringModel(models.Model):
|
||||
default the object is deleted. This is less efficient compared
|
||||
to bulk deleting objects, but classes like Token() need to change
|
||||
values instead of being deleted."""
|
||||
return self.delete(*args, **kwargs)
|
||||
try:
|
||||
return self.delete(*args, **kwargs)
|
||||
except self.DoesNotExist:
|
||||
# Object has already been deleted, so this should be fine
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def filter_not_expired(cls, **kwargs) -> QuerySet[Self]:
|
||||
|
||||
@@ -24,7 +24,8 @@ from authentik.root.ws.consumer import build_device_group
|
||||
|
||||
# Arguments: user: User, password: str
|
||||
password_changed = Signal()
|
||||
# Arguments: credentials: dict[str, any], request: HttpRequest, stage: Stage
|
||||
# Arguments: credentials: dict[str, any], request: HttpRequest,
|
||||
# stage: Stage, context: dict[str, any]
|
||||
login_failed = Signal()
|
||||
|
||||
LOGGER = get_logger()
|
||||
|
||||
@@ -78,7 +78,7 @@ def generate_key_id_legacy(key_data: str) -> str:
|
||||
"""Generate Key ID using MD5 (legacy format for backwards compatibility)."""
|
||||
if not key_data:
|
||||
return ""
|
||||
return md5(key_data.encode("utf-8")).hexdigest() # nosec
|
||||
return md5(key_data.encode("utf-8"), usedforsecurity=False).hexdigest() # nosec
|
||||
|
||||
|
||||
class CertificateKeyPair(SerializerModel, ManagedModel, CreatedUpdatedModel):
|
||||
|
||||
@@ -1,8 +1,11 @@
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
from rest_framework.exceptions import ValidationError
|
||||
from rest_framework.viewsets import ModelViewSet
|
||||
|
||||
from authentik.core.api.used_by import UsedByMixin
|
||||
from authentik.endpoints.api.connectors import ConnectorSerializer
|
||||
from authentik.endpoints.models import EndpointStage
|
||||
from authentik.endpoints.controller import Capabilities
|
||||
from authentik.endpoints.models import Connector, EndpointStage
|
||||
from authentik.flows.api.stages import StageSerializer
|
||||
|
||||
|
||||
@@ -11,6 +14,13 @@ class EndpointStageSerializer(StageSerializer):
|
||||
|
||||
connector_obj = ConnectorSerializer(source="connector", read_only=True)
|
||||
|
||||
def validate_connector(self, connector: Connector) -> Connector:
|
||||
conn: Connector = Connector.objects.get_subclass(pk=connector.pk)
|
||||
controller = conn.controller(conn)
|
||||
if Capabilities.STAGE_ENDPOINTS not in controller.capabilities():
|
||||
raise ValidationError(_("Selected connector is not compatible with this stage."))
|
||||
return connector
|
||||
|
||||
class Meta:
|
||||
model = EndpointStage
|
||||
fields = StageSerializer.Meta.fields + [
|
||||
|
||||
@@ -18,7 +18,10 @@ from authentik.rbac.decorators import permission_required
|
||||
class EnrollmentTokenSerializer(ModelSerializer):
|
||||
|
||||
device_group_obj = DeviceAccessGroupSerializer(
|
||||
source="device_group", read_only=True, required=False
|
||||
source="device_group",
|
||||
read_only=True,
|
||||
required=False,
|
||||
allow_null=True,
|
||||
)
|
||||
|
||||
def __init__(self, *args, **kwargs) -> None:
|
||||
|
||||
@@ -37,6 +37,8 @@ class AgentEnrollmentAuth(BaseAuthentication):
|
||||
token = EnrollmentToken.filter_not_expired(key=key).first()
|
||||
if not token:
|
||||
raise PermissionDenied()
|
||||
if not token.connector.enabled:
|
||||
raise PermissionDenied()
|
||||
CTX_AUTH_VIA.set("endpoint_token_enrollment")
|
||||
return (DeviceUser(), token)
|
||||
|
||||
@@ -51,6 +53,8 @@ class AgentAuth(BaseAuthentication):
|
||||
device_token = DeviceToken.filter_not_expired(key=key).first()
|
||||
if not device_token:
|
||||
raise PermissionDenied()
|
||||
if not device_token.device.connector.enabled:
|
||||
raise PermissionDenied()
|
||||
if device_token.device.device.is_expired:
|
||||
raise PermissionDenied()
|
||||
CTX_AUTH_VIA.set("endpoint_token")
|
||||
|
||||
@@ -8,7 +8,7 @@ from rest_framework.fields import CharField
|
||||
|
||||
from authentik.core.api.utils import PassiveSerializer
|
||||
from authentik.endpoints.connectors.agent.models import AgentConnector, EnrollmentToken
|
||||
from authentik.endpoints.controller import BaseController
|
||||
from authentik.endpoints.controller import BaseController, Capabilities
|
||||
from authentik.endpoints.facts import OSFamily
|
||||
|
||||
|
||||
@@ -48,8 +48,8 @@ class AgentConnectorController(BaseController[AgentConnector]):
|
||||
def vendor_identifier() -> str:
|
||||
return "goauthentik.io/platform"
|
||||
|
||||
def supported_enrollment_methods(self):
|
||||
return []
|
||||
def capabilities(self) -> list[Capabilities]:
|
||||
return [Capabilities.STAGE_ENDPOINTS]
|
||||
|
||||
def generate_mdm_config(
|
||||
self, target_platform: OSFamily, request: HttpRequest, token: EnrollmentToken
|
||||
|
||||
@@ -58,6 +58,16 @@ class TestAgentAPI(APITestCase):
|
||||
)
|
||||
self.assertEqual(response.status_code, 200)
|
||||
|
||||
def test_enroll_disabled(self):
|
||||
self.connector.enabled = False
|
||||
self.connector.save()
|
||||
response = self.client.post(
|
||||
reverse("authentik_api:agentconnector-enroll"),
|
||||
data={"device_serial": generate_id(), "device_name": "bar"},
|
||||
HTTP_AUTHORIZATION=f"Bearer {self.token.key}",
|
||||
)
|
||||
self.assertEqual(response.status_code, 403)
|
||||
|
||||
def test_enroll_token_delete(self):
|
||||
response = self.client.post(
|
||||
reverse("authentik_api:agentconnector-enroll"),
|
||||
@@ -104,6 +114,16 @@ class TestAgentAPI(APITestCase):
|
||||
)
|
||||
self.assertEqual(response.status_code, 200)
|
||||
|
||||
@reconcile_app("authentik_crypto")
|
||||
def test_config_disabled(self):
|
||||
self.connector.enabled = False
|
||||
self.connector.save()
|
||||
response = self.client.get(
|
||||
reverse("authentik_api:agentconnector-agent-config"),
|
||||
HTTP_AUTHORIZATION=f"Bearer+agent {self.device_token.key}",
|
||||
)
|
||||
self.assertEqual(response.status_code, 403)
|
||||
|
||||
def test_check_in(self):
|
||||
response = self.client.post(
|
||||
reverse("authentik_api:agentconnector-check-in"),
|
||||
@@ -112,6 +132,16 @@ class TestAgentAPI(APITestCase):
|
||||
)
|
||||
self.assertEqual(response.status_code, 204)
|
||||
|
||||
def test_check_in_disabled(self):
|
||||
self.connector.enabled = False
|
||||
self.connector.save()
|
||||
response = self.client.post(
|
||||
reverse("authentik_api:agentconnector-check-in"),
|
||||
data=CHECK_IN_DATA_VALID,
|
||||
HTTP_AUTHORIZATION=f"Bearer+agent {self.device_token.key}",
|
||||
)
|
||||
self.assertEqual(response.status_code, 403)
|
||||
|
||||
def test_check_in_token_expired(self):
|
||||
self.device_token.expiring = True
|
||||
self.device_token.expires = now() - timedelta(hours=1)
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from hashlib import sha256
|
||||
from json import loads
|
||||
from unittest.mock import PropertyMock, patch
|
||||
|
||||
from django.urls import reverse
|
||||
from jwt import encode
|
||||
@@ -232,3 +233,43 @@ class TestEndpointStage(FlowTestCase):
|
||||
plan = plan()
|
||||
self.assertNotIn(PLAN_CONTEXT_AGENT_ENDPOINT_CHALLENGE, plan.context)
|
||||
self.assertEqual(plan.context[PLAN_CONTEXT_DEVICE], self.device)
|
||||
|
||||
def test_endpoint_stage_connector_no_stage_optional(self):
|
||||
flow = create_test_flow()
|
||||
stage = EndpointStage.objects.create(connector=self.connector, mode=StageMode.OPTIONAL)
|
||||
FlowStageBinding.objects.create(stage=stage, target=flow, order=0)
|
||||
|
||||
with patch(
|
||||
"authentik.endpoints.connectors.agent.models.AgentConnector.stage",
|
||||
PropertyMock(return_value=None),
|
||||
):
|
||||
with self.assertFlowFinishes() as plan:
|
||||
res = self.client.get(
|
||||
reverse("authentik_api:flow-executor", kwargs={"flow_slug": flow.slug}),
|
||||
)
|
||||
self.assertStageRedirects(res, reverse("authentik_core:root-redirect"))
|
||||
plan = plan()
|
||||
self.assertNotIn(PLAN_CONTEXT_AGENT_ENDPOINT_CHALLENGE, plan.context)
|
||||
self.assertNotIn(PLAN_CONTEXT_DEVICE, plan.context)
|
||||
|
||||
def test_endpoint_stage_connector_no_stage_required(self):
|
||||
flow = create_test_flow()
|
||||
stage = EndpointStage.objects.create(connector=self.connector, mode=StageMode.REQUIRED)
|
||||
FlowStageBinding.objects.create(stage=stage, target=flow, order=0)
|
||||
|
||||
with patch(
|
||||
"authentik.endpoints.connectors.agent.models.AgentConnector.stage",
|
||||
PropertyMock(return_value=None),
|
||||
):
|
||||
with self.assertFlowFinishes() as plan:
|
||||
res = self.client.get(
|
||||
reverse("authentik_api:flow-executor", kwargs={"flow_slug": flow.slug}),
|
||||
)
|
||||
self.assertStageResponse(
|
||||
res,
|
||||
component="ak-stage-access-denied",
|
||||
error_message="Invalid stage configuration",
|
||||
)
|
||||
plan = plan()
|
||||
self.assertNotIn(PLAN_CONTEXT_AGENT_ENDPOINT_CHALLENGE, plan.context)
|
||||
self.assertNotIn(PLAN_CONTEXT_DEVICE, plan.context)
|
||||
|
||||
@@ -8,13 +8,15 @@ from authentik.lib.sentry import SentryIgnoredException
|
||||
MERGED_VENDOR = "goauthentik.io/@merged"
|
||||
|
||||
|
||||
class EnrollmentMethods(models.TextChoices):
|
||||
class Capabilities(models.TextChoices):
|
||||
# Automatically enrolled through user action
|
||||
AUTOMATIC_USER = "automatic_user"
|
||||
ENROLL_AUTOMATIC_USER = "enroll_automatic_user"
|
||||
# Automatically enrolled through connector integration
|
||||
AUTOMATIC_API = "automatic_api"
|
||||
ENROLL_AUTOMATIC_API = "enroll_automatic_api"
|
||||
# Manually enrolled with user interaction (user scanning a QR code for example)
|
||||
MANUAL_USER = "manual_user"
|
||||
ENROLL_MANUAL_USER = "enroll_manual_user"
|
||||
# Supported for use with Endpoints stage
|
||||
STAGE_ENDPOINTS = "stage_endpoints"
|
||||
|
||||
|
||||
class ConnectorSyncException(SentryIgnoredException):
|
||||
@@ -34,7 +36,7 @@ class BaseController[T: "Connector"]:
|
||||
def vendor_identifier() -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
def supported_enrollment_methods(self) -> list[EnrollmentMethods]:
|
||||
def capabilities(self) -> list[Capabilities]:
|
||||
return []
|
||||
|
||||
def stage_view_enrollment(self) -> StageView | None:
|
||||
@@ -42,3 +44,6 @@ class BaseController[T: "Connector"]:
|
||||
|
||||
def stage_view_authentication(self) -> StageView | None:
|
||||
return None
|
||||
|
||||
def sync_endpoints(self):
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -162,8 +162,11 @@ class Connector(ScheduledModel, SerializerModel):
|
||||
|
||||
@property
|
||||
def schedule_specs(self) -> list[ScheduleSpec]:
|
||||
from authentik.endpoints.controller import Capabilities
|
||||
from authentik.endpoints.tasks import endpoints_sync
|
||||
|
||||
if Capabilities.ENROLL_AUTOMATIC_API not in self.controller(self).capabilities():
|
||||
return []
|
||||
return [
|
||||
ScheduleSpec(
|
||||
actor=endpoints_sync,
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from authentik.endpoints.models import EndpointStage
|
||||
from authentik.endpoints.models import Connector, EndpointStage, StageMode
|
||||
from authentik.flows.stage import StageView
|
||||
|
||||
PLAN_CONTEXT_ENDPOINT_CONNECTOR = "endpoint_connector"
|
||||
@@ -6,15 +6,27 @@ PLAN_CONTEXT_ENDPOINT_CONNECTOR = "endpoint_connector"
|
||||
|
||||
class EndpointStageView(StageView):
|
||||
|
||||
def _get_inner(self):
|
||||
def _get_inner(self) -> StageView | None:
|
||||
stage: EndpointStage = self.executor.current_stage
|
||||
inner_stage: type[StageView] | None = stage.connector.stage
|
||||
connector: Connector = stage.connector
|
||||
if not connector.enabled:
|
||||
return None
|
||||
inner_stage: type[StageView] | None = connector.stage
|
||||
if not inner_stage:
|
||||
return self.executor.stage_ok()
|
||||
return None
|
||||
return inner_stage(self.executor, request=self.request)
|
||||
|
||||
def dispatch(self, request, *args, **kwargs):
|
||||
return self._get_inner().dispatch(request, *args, **kwargs)
|
||||
inner = self._get_inner()
|
||||
if inner is None:
|
||||
stage: EndpointStage = self.executor.current_stage
|
||||
if stage.mode == StageMode.OPTIONAL:
|
||||
return self.executor.stage_ok()
|
||||
else:
|
||||
return self.executor.stage_invalid("Invalid stage configuration")
|
||||
return inner.dispatch(request, *args, **kwargs)
|
||||
|
||||
def cleanup(self):
|
||||
return self._get_inner().cleanup()
|
||||
inner = self._get_inner()
|
||||
if inner is not None:
|
||||
return inner.cleanup()
|
||||
|
||||
@@ -6,7 +6,7 @@ from django.utils.translation import gettext_lazy as _
|
||||
from dramatiq.actor import actor
|
||||
from structlog.stdlib import get_logger
|
||||
|
||||
from authentik.endpoints.controller import EnrollmentMethods
|
||||
from authentik.endpoints.controller import Capabilities
|
||||
from authentik.endpoints.models import Connector
|
||||
|
||||
LOGGER = get_logger()
|
||||
@@ -17,11 +17,11 @@ def endpoints_sync(connector_pk: Any):
|
||||
connector: Connector | None = (
|
||||
Connector.objects.filter(pk=connector_pk).select_subclasses().first()
|
||||
)
|
||||
if not connector:
|
||||
if not connector or not connector.enabled:
|
||||
return
|
||||
controller = connector.controller
|
||||
ctrl = controller(connector)
|
||||
if EnrollmentMethods.AUTOMATIC_API not in ctrl.supported_enrollment_methods():
|
||||
if Capabilities.ENROLL_AUTOMATIC_API not in ctrl.capabilities():
|
||||
return
|
||||
LOGGER.info("Syncing connector", connector=connector.name)
|
||||
ctrl.sync_endpoints()
|
||||
|
||||
41
authentik/endpoints/tests/test_api.py
Normal file
41
authentik/endpoints/tests/test_api.py
Normal file
@@ -0,0 +1,41 @@
|
||||
from django.urls import reverse
|
||||
from rest_framework.test import APITestCase
|
||||
|
||||
from authentik.core.tests.utils import create_test_admin_user
|
||||
from authentik.endpoints.connectors.agent.models import AgentConnector
|
||||
from authentik.endpoints.models import StageMode
|
||||
from authentik.enterprise.endpoints.connectors.fleet.models import FleetConnector
|
||||
from authentik.lib.generators import generate_id
|
||||
|
||||
|
||||
class TestAPI(APITestCase):
|
||||
def setUp(self):
|
||||
self.user = create_test_admin_user()
|
||||
self.client.force_login(self.user)
|
||||
|
||||
def test_endpoint_stage_agent(self):
|
||||
connector = AgentConnector.objects.create(name=generate_id())
|
||||
res = self.client.post(
|
||||
reverse("authentik_api:stages-endpoint-list"),
|
||||
data={
|
||||
"name": generate_id(),
|
||||
"connector": str(connector.pk),
|
||||
"mode": StageMode.REQUIRED,
|
||||
},
|
||||
)
|
||||
self.assertEqual(res.status_code, 201)
|
||||
|
||||
def test_endpoint_stage_fleet(self):
|
||||
connector = FleetConnector.objects.create(name=generate_id())
|
||||
res = self.client.post(
|
||||
reverse("authentik_api:stages-endpoint-list"),
|
||||
data={
|
||||
"name": generate_id(),
|
||||
"connector": str(connector.pk),
|
||||
"mode": StageMode.REQUIRED,
|
||||
},
|
||||
)
|
||||
self.assertEqual(res.status_code, 400)
|
||||
self.assertJSONEqual(
|
||||
res.content, {"connector": ["Selected connector is not compatible with this stage."]}
|
||||
)
|
||||
35
authentik/endpoints/tests/test_tasks.py
Normal file
35
authentik/endpoints/tests/test_tasks.py
Normal file
@@ -0,0 +1,35 @@
|
||||
from unittest.mock import PropertyMock, patch
|
||||
|
||||
from rest_framework.test import APITestCase
|
||||
|
||||
from authentik.endpoints.controller import BaseController, Capabilities
|
||||
from authentik.endpoints.models import Connector
|
||||
from authentik.endpoints.tasks import endpoints_sync
|
||||
from authentik.lib.generators import generate_id
|
||||
|
||||
|
||||
class TestEndpointTasks(APITestCase):
|
||||
def test_agent_sync(self):
|
||||
class controller(BaseController):
|
||||
def capabilities(self):
|
||||
return [Capabilities.ENROLL_AUTOMATIC_API]
|
||||
|
||||
def sync_endpoints(self):
|
||||
pass
|
||||
|
||||
with patch.object(Connector, "controller", PropertyMock(return_value=controller)):
|
||||
connector = Connector.objects.create(name=generate_id())
|
||||
self.assertEqual(len(connector.schedule_specs), 1)
|
||||
|
||||
endpoints_sync.send(connector.pk).get_result(block=True)
|
||||
|
||||
def test_agent_no_sync(self):
|
||||
class controller(BaseController):
|
||||
def capabilities(self):
|
||||
return []
|
||||
|
||||
with patch.object(Connector, "controller", PropertyMock(return_value=controller)):
|
||||
connector = Connector.objects.create(name=generate_id())
|
||||
self.assertEqual(len(connector.schedule_specs), 0)
|
||||
|
||||
endpoints_sync.send(connector.pk).get_result(block=True)
|
||||
@@ -3,6 +3,7 @@ from hmac import compare_digest
|
||||
|
||||
from django.http import Http404, HttpRequest, HttpResponse, HttpResponseBadRequest, QueryDict
|
||||
|
||||
from authentik.common.oauth.constants import QS_LOGIN_HINT
|
||||
from authentik.endpoints.connectors.agent.auth import (
|
||||
agent_auth_issue_token,
|
||||
check_device_policies,
|
||||
@@ -14,7 +15,7 @@ from authentik.enterprise.policy import EnterprisePolicyAccessView
|
||||
from authentik.flows.exceptions import FlowNonApplicableException
|
||||
from authentik.flows.models import in_memory_stage
|
||||
from authentik.flows.planner import PLAN_CONTEXT_DEVICE, FlowPlanner
|
||||
from authentik.flows.stage import StageView
|
||||
from authentik.flows.stage import PLAN_CONTEXT_PENDING_USER_IDENTIFIER, StageView
|
||||
from authentik.providers.oauth2.utils import HttpResponseRedirectScheme
|
||||
|
||||
QS_AGENT_IA_TOKEN = "ak-auth-ia-token" # nosec
|
||||
@@ -64,14 +65,14 @@ class AgentInteractiveAuth(EnterprisePolicyAccessView):
|
||||
|
||||
planner = FlowPlanner(self.connector.authorization_flow)
|
||||
planner.allow_empty_flows = True
|
||||
context = {
|
||||
PLAN_CONTEXT_DEVICE: self.device,
|
||||
PLAN_CONTEXT_DEVICE_AUTH_TOKEN: self.auth_token,
|
||||
}
|
||||
if QS_LOGIN_HINT in request.GET:
|
||||
context[PLAN_CONTEXT_PENDING_USER_IDENTIFIER] = request.GET[QS_LOGIN_HINT]
|
||||
try:
|
||||
plan = planner.plan(
|
||||
self.request,
|
||||
{
|
||||
PLAN_CONTEXT_DEVICE: self.device,
|
||||
PLAN_CONTEXT_DEVICE_AUTH_TOKEN: self.auth_token,
|
||||
},
|
||||
)
|
||||
plan = planner.plan(self.request, context)
|
||||
except FlowNonApplicableException:
|
||||
return self.handle_no_permission_authenticated()
|
||||
plan.append_stage(in_memory_stage(AgentAuthFulfillmentStage))
|
||||
@@ -84,7 +85,6 @@ class AgentInteractiveAuth(EnterprisePolicyAccessView):
|
||||
|
||||
|
||||
class AgentAuthFulfillmentStage(StageView):
|
||||
|
||||
def get(self, request: HttpRequest, *args, **kwargs) -> HttpResponse:
|
||||
device: Device = self.executor.plan.context.pop(PLAN_CONTEXT_DEVICE)
|
||||
auth_token: DeviceAuthenticationToken = self.executor.plan.context.pop(
|
||||
|
||||
@@ -6,7 +6,7 @@ from requests import RequestException
|
||||
from rest_framework.exceptions import ValidationError
|
||||
|
||||
from authentik.core.models import User
|
||||
from authentik.endpoints.controller import BaseController, ConnectorSyncException, EnrollmentMethods
|
||||
from authentik.endpoints.controller import BaseController, Capabilities, ConnectorSyncException
|
||||
from authentik.endpoints.facts import (
|
||||
DeviceFacts,
|
||||
OSFamily,
|
||||
@@ -43,8 +43,8 @@ class FleetController(BaseController[DBC]):
|
||||
def vendor_identifier() -> str:
|
||||
return "fleetdm.com"
|
||||
|
||||
def supported_enrollment_methods(self) -> list[EnrollmentMethods]:
|
||||
return [EnrollmentMethods.AUTOMATIC_API]
|
||||
def capabilities(self) -> list[Capabilities]:
|
||||
return [Capabilities.ENROLL_AUTOMATIC_API]
|
||||
|
||||
def _url(self, path: str) -> str:
|
||||
return f"{self.connector.url}{path}"
|
||||
|
||||
@@ -112,19 +112,18 @@ class LicenseKey:
|
||||
raise ValidationError("Unable to verify license") from None
|
||||
_validate_curve_original = ECAlgorithm._validate_curve
|
||||
try:
|
||||
# authentik's license used to be generated with `algorithm="ES512"` and signed with
|
||||
# authentik's license are generated with `algorithm="ES512"` and signed with
|
||||
# a key of curve `secp384r1`. Starting with version 2.11.0, pyjwt enforces the spec, see
|
||||
# https://github.com/jpadilla/pyjwt/commit/5b8622773358e56d3d3c0a9acf404809ff34433a
|
||||
# New licenses are generated with `algorithm="ES384"` and signed with `secp384r1`.
|
||||
# The last license will run out by March 2027.
|
||||
# TODO: remove this in March 2027.
|
||||
# authentik will change its license generation to `algorithm="ES384"` in 2026.
|
||||
# TODO: remove this when the last incompatible license runs out.
|
||||
ECAlgorithm._validate_curve = lambda *_: True
|
||||
body = from_dict(
|
||||
LicenseKey,
|
||||
decode(
|
||||
jwt,
|
||||
our_cert.public_key(),
|
||||
algorithms=["ES512"],
|
||||
algorithms=["ES384", "ES512"],
|
||||
audience=get_license_aud(),
|
||||
options={"verify_exp": check_expiry, "verify_signature": check_expiry},
|
||||
),
|
||||
|
||||
@@ -331,7 +331,7 @@ class GoogleWorkspaceGroupTests(TestCase):
|
||||
).exists()
|
||||
)
|
||||
self.assertFalse(Event.objects.filter(action=EventAction.SYSTEM_EXCEPTION).exists())
|
||||
self.assertEqual(len(http.requests()), 5)
|
||||
self.assertEqual(len(http.requests()), 7)
|
||||
|
||||
def test_sync_discover_multiple(self):
|
||||
"""Test group discovery"""
|
||||
@@ -372,7 +372,7 @@ class GoogleWorkspaceGroupTests(TestCase):
|
||||
).exists()
|
||||
)
|
||||
self.assertFalse(Event.objects.filter(action=EventAction.SYSTEM_EXCEPTION).exists())
|
||||
self.assertEqual(len(http.requests()), 5)
|
||||
self.assertEqual(len(http.requests()), 7)
|
||||
# Change response to trigger update
|
||||
http.add_response(
|
||||
f"https://admin.googleapis.com/admin/directory/v1/groups?customer=my_customer&maxResults=500&orderBy=email&key={self.api_key}&alt=json",
|
||||
|
||||
@@ -309,7 +309,7 @@ class GoogleWorkspaceUserTests(TestCase):
|
||||
).exists()
|
||||
)
|
||||
self.assertFalse(Event.objects.filter(action=EventAction.SYSTEM_EXCEPTION).exists())
|
||||
self.assertEqual(len(http.requests()), 5)
|
||||
self.assertEqual(len(http.requests()), 7)
|
||||
|
||||
def test_sync_discover_multiple(self):
|
||||
"""Test user discovery, running multiple times"""
|
||||
@@ -352,7 +352,7 @@ class GoogleWorkspaceUserTests(TestCase):
|
||||
).exists()
|
||||
)
|
||||
self.assertFalse(Event.objects.filter(action=EventAction.SYSTEM_EXCEPTION).exists())
|
||||
self.assertEqual(len(http.requests()), 5)
|
||||
self.assertEqual(len(http.requests()), 7)
|
||||
# Change response, which will trigger a discovery update
|
||||
http.add_response(
|
||||
f"https://admin.googleapis.com/admin/directory/v1/users?customer=my_customer&maxResults=500&orderBy=email&key={self.api_key}&alt=json",
|
||||
|
||||
@@ -81,6 +81,8 @@ class SignInProcessor:
|
||||
self.sign_in_request = sign_in_request
|
||||
self.saml_processor = AssertionProcessor(self.provider, self.request, AuthNRequest())
|
||||
self.saml_processor.provider.audience = self.sign_in_request.wtrealm
|
||||
if self.provider.signing_kp:
|
||||
self.saml_processor.provider.sign_assertion = True
|
||||
|
||||
def create_response_token(self):
|
||||
root = Element(f"{{{NS_WS_FED_TRUST}}}RequestSecurityTokenResponse", nsmap=NS_MAP)
|
||||
@@ -148,7 +150,8 @@ class SignInProcessor:
|
||||
def response(self) -> dict[str, str]:
|
||||
root = self.create_response_token()
|
||||
assertion = root.xpath("//saml:Assertion", namespaces=NS_MAP)[0]
|
||||
self.saml_processor._sign(assertion)
|
||||
if self.provider.signing_kp:
|
||||
self.saml_processor._sign(assertion)
|
||||
str_token = etree.tostring(root).decode("utf-8") # nosec
|
||||
return delete_none_values(
|
||||
{
|
||||
|
||||
@@ -3,8 +3,7 @@
|
||||
from django.urls import path
|
||||
|
||||
from authentik.enterprise.providers.ws_federation.api.providers import WSFederationProviderViewSet
|
||||
from authentik.enterprise.providers.ws_federation.views import WSFedEntryView
|
||||
from authentik.providers.saml.views.metadata import MetadataDownload
|
||||
from authentik.enterprise.providers.ws_federation.views import MetadataDownload, WSFedEntryView
|
||||
|
||||
urlpatterns = [
|
||||
path(
|
||||
|
||||
@@ -93,11 +93,13 @@ def on_login_failed(
|
||||
credentials: dict[str, str],
|
||||
request: HttpRequest,
|
||||
stage: Stage | None = None,
|
||||
context: dict[str, Any] | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""Failed Login, authentik custom event"""
|
||||
user = User.objects.filter(username=credentials.get("username")).first()
|
||||
Event.new(EventAction.LOGIN_FAILED, **credentials, stage=stage, **kwargs).from_http(
|
||||
context = context or {}
|
||||
Event.new(EventAction.LOGIN_FAILED, **credentials, stage=stage, **context).from_http(
|
||||
request, user
|
||||
)
|
||||
|
||||
|
||||
@@ -207,3 +207,9 @@ class TestEvents(TestCase):
|
||||
"username": user.username,
|
||||
},
|
||||
)
|
||||
|
||||
def test_invalid_string(self):
|
||||
"""Test creating an event with invalid unicode string data"""
|
||||
event = Event.new("unittest", foo="foo bar \u0000 baz")
|
||||
event.save()
|
||||
self.assertEqual(event.context["foo"], "foo bar baz")
|
||||
|
||||
@@ -36,6 +36,10 @@ ALLOWED_SPECIAL_KEYS = re.compile(
|
||||
)
|
||||
|
||||
|
||||
def cleanse_str(raw: Any) -> str:
|
||||
return str(raw).replace("\u0000", "")
|
||||
|
||||
|
||||
def cleanse_item(key: str, value: Any) -> Any:
|
||||
"""Cleanse a single item"""
|
||||
if isinstance(value, dict):
|
||||
@@ -66,7 +70,7 @@ def cleanse_dict(source: dict[Any, Any]) -> dict[Any, Any]:
|
||||
|
||||
def model_to_dict(model: Model) -> dict[str, Any]:
|
||||
"""Convert model to dict"""
|
||||
name = str(model)
|
||||
name = cleanse_str(model)
|
||||
if hasattr(model, "name"):
|
||||
name = model.name
|
||||
return {
|
||||
@@ -133,11 +137,11 @@ def sanitize_item(value: Any) -> Any: # noqa: PLR0911, PLR0912
|
||||
if isinstance(value, ASN):
|
||||
return ASN_CONTEXT_PROCESSOR.asn_to_dict(value)
|
||||
if isinstance(value, Path):
|
||||
return str(value)
|
||||
return cleanse_str(value)
|
||||
if isinstance(value, Exception):
|
||||
return str(value)
|
||||
return cleanse_str(value)
|
||||
if isinstance(value, YAMLTag):
|
||||
return str(value)
|
||||
return cleanse_str(value)
|
||||
if isinstance(value, Enum):
|
||||
return value.value
|
||||
if isinstance(value, type):
|
||||
@@ -161,7 +165,7 @@ def sanitize_item(value: Any) -> Any: # noqa: PLR0911, PLR0912
|
||||
raise ValueError("JSON can't represent timezone-aware times.")
|
||||
return value.isoformat()
|
||||
if isinstance(value, timedelta):
|
||||
return str(value.total_seconds())
|
||||
return cleanse_str(value.total_seconds())
|
||||
if callable(value):
|
||||
return {
|
||||
"type": "callable",
|
||||
@@ -174,8 +178,8 @@ def sanitize_item(value: Any) -> Any: # noqa: PLR0911, PLR0912
|
||||
try:
|
||||
return DjangoJSONEncoder().default(value)
|
||||
except TypeError:
|
||||
return str(value)
|
||||
return str(value)
|
||||
return cleanse_str(value)
|
||||
return cleanse_str(value)
|
||||
|
||||
|
||||
def sanitize_dict(source: dict[Any, Any]) -> dict[Any, Any]:
|
||||
|
||||
@@ -29,6 +29,12 @@ class RefreshOtherFlowsAfterAuthentication(Flag[bool], key="flows_refresh_others
|
||||
visibility = "public"
|
||||
|
||||
|
||||
class ContinuousLogin(Flag[bool], key="flows_continuous_login"):
|
||||
|
||||
default = False
|
||||
visibility = "public"
|
||||
|
||||
|
||||
class AuthentikFlowsConfig(ManagedAppConfig):
|
||||
"""authentik flows app config"""
|
||||
|
||||
|
||||
@@ -141,6 +141,10 @@ web:
|
||||
# workers: 2
|
||||
threads: 4
|
||||
path: /
|
||||
timeout_http_read_header: 5s
|
||||
timeout_http_read: 30s
|
||||
timeout_http_write: 60s
|
||||
timeout_http_idle: 120s
|
||||
|
||||
worker:
|
||||
processes: 1
|
||||
@@ -178,3 +182,5 @@ storage:
|
||||
# backend: file # or s3
|
||||
# file: {}
|
||||
# s3: {}
|
||||
|
||||
skip_migrations: false
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import math
|
||||
from typing import Any, Self
|
||||
|
||||
import pglock
|
||||
@@ -68,7 +69,12 @@ class OutgoingSyncProvider(ScheduledModel, Model):
|
||||
return Paginator(self.get_object_qs(type), self.sync_page_size)
|
||||
|
||||
def get_object_sync_time_limit_ms[T: User | Group](self, type: type[T]) -> int:
|
||||
num_pages: int = self.get_paginator(type).num_pages
|
||||
# Use a simple COUNT(*) on the model instead of materializing get_object_qs(),
|
||||
# which for some providers (e.g. SCIM) runs PolicyEngine per-user and is
|
||||
# extremely expensive. The time limit is an upper-bound estimate, so using
|
||||
# the total count (without policy filtering) is a safe overestimate.
|
||||
total_count = type.objects.count()
|
||||
num_pages = math.ceil(total_count / self.sync_page_size) if total_count > 0 else 1
|
||||
page_timeout_ms = timedelta_from_string(self.sync_page_timeout).total_seconds() * 1000
|
||||
return int(num_pages * page_timeout_ms * 1.5)
|
||||
|
||||
|
||||
@@ -103,6 +103,7 @@ class SyncTasks:
|
||||
)
|
||||
users_tasks.run().wait(timeout=provider.get_object_sync_time_limit_ms(User))
|
||||
group_tasks.run().wait(timeout=provider.get_object_sync_time_limit_ms(Group))
|
||||
self._sync_cleanup(provider, task)
|
||||
except TransientSyncException as exc:
|
||||
self.logger.warning("transient sync exception", exc=exc)
|
||||
task.warning("Sync encountered a transient exception. Retrying", exc=exc)
|
||||
@@ -111,6 +112,35 @@ class SyncTasks:
|
||||
task.error(exc)
|
||||
return
|
||||
|
||||
def _sync_cleanup(self, provider: OutgoingSyncProvider, task: Task):
|
||||
"""Delete remote objects that are no longer in scope"""
|
||||
for object_type in (User, Group):
|
||||
try:
|
||||
client = provider.client_for_model(object_type)
|
||||
except TransientSyncException:
|
||||
continue
|
||||
in_scope_pks = set(provider.get_object_qs(object_type).values_list("pk", flat=True))
|
||||
stale = client.connection_type.objects.filter(provider=provider).exclude(
|
||||
**{f"{client.connection_type_query}__pk__in": in_scope_pks}
|
||||
)
|
||||
for connection in stale:
|
||||
try:
|
||||
client.delete(connection.scim_id)
|
||||
task.info(
|
||||
f"Deleted out-of-scope {object_type._meta.verbose_name}",
|
||||
scim_id=connection.scim_id,
|
||||
)
|
||||
except NotFoundSyncException:
|
||||
pass
|
||||
except TransientSyncException as exc:
|
||||
self.logger.warning("transient error during cleanup", exc=exc)
|
||||
self.logger.warning(
|
||||
"Cleanup encountered a transient exception. Retrying", exc=exc
|
||||
)
|
||||
raise Retry() from exc
|
||||
except DryRunRejected as exc:
|
||||
self.logger.info("Rejected dry-run cleanup event", exc=exc)
|
||||
|
||||
def sync_objects(
|
||||
self,
|
||||
object_type: str,
|
||||
|
||||
119
authentik/lib/tests/test_utils_inheritance.py
Normal file
119
authentik/lib/tests/test_utils_inheritance.py
Normal file
@@ -0,0 +1,119 @@
|
||||
"""Tests for inheritance helpers."""
|
||||
|
||||
from contextlib import contextmanager
|
||||
|
||||
from django.db import connection, models
|
||||
from django.test import TransactionTestCase
|
||||
from django.test.utils import isolate_apps
|
||||
|
||||
from authentik.lib.utils.inheritance import get_deepest_child
|
||||
|
||||
|
||||
@contextmanager
|
||||
def temporary_inheritance_models():
|
||||
"""Create a temporary multi-table inheritance graph for testing."""
|
||||
with isolate_apps("authentik.lib.tests"):
|
||||
|
||||
class GrandParent(models.Model):
|
||||
class Meta:
|
||||
app_label = "tests"
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"GrandParent({self.pk})"
|
||||
|
||||
class Parent(GrandParent):
|
||||
class Meta:
|
||||
app_label = "tests"
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"Parent({self.pk})"
|
||||
|
||||
class Child(Parent):
|
||||
class Meta:
|
||||
app_label = "tests"
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"Child({self.pk})"
|
||||
|
||||
class GrandChild(Child):
|
||||
class Meta:
|
||||
app_label = "tests"
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"GrandChild({self.pk})"
|
||||
|
||||
with connection.schema_editor() as schema_editor:
|
||||
schema_editor.create_model(GrandParent)
|
||||
schema_editor.create_model(Parent)
|
||||
schema_editor.create_model(Child)
|
||||
schema_editor.create_model(GrandChild)
|
||||
|
||||
try:
|
||||
yield GrandParent, Parent, Child, GrandChild
|
||||
finally:
|
||||
with connection.schema_editor() as schema_editor:
|
||||
schema_editor.delete_model(GrandChild)
|
||||
schema_editor.delete_model(Child)
|
||||
schema_editor.delete_model(Parent)
|
||||
schema_editor.delete_model(GrandParent)
|
||||
|
||||
|
||||
class TestInheritanceUtils(TransactionTestCase):
|
||||
"""Tests for helper functions in authentik.lib.utils.inheritance."""
|
||||
|
||||
def test_get_deepest_child_grandparent_to_parent(self):
|
||||
"""GrandParent -> Parent."""
|
||||
with temporary_inheritance_models() as (GrandParent, Parent, _Child, _GrandChild):
|
||||
parent = Parent.objects.create()
|
||||
grandparent = GrandParent.objects.get(pk=parent.pk)
|
||||
|
||||
resolved = get_deepest_child(grandparent)
|
||||
|
||||
self.assertIsInstance(resolved, Parent)
|
||||
self.assertEqual(resolved.pk, parent.pk)
|
||||
|
||||
def test_get_deepest_child_grandparent_to_child(self):
|
||||
"""GrandParent -> Child."""
|
||||
with temporary_inheritance_models() as (GrandParent, _Parent, Child, _GrandChild):
|
||||
child = Child.objects.create()
|
||||
grandparent = GrandParent.objects.get(pk=child.pk)
|
||||
|
||||
resolved = get_deepest_child(grandparent)
|
||||
|
||||
self.assertIsInstance(resolved, Child)
|
||||
self.assertEqual(resolved.pk, child.pk)
|
||||
|
||||
def test_get_deepest_child_grandparent_to_grandchild(self):
|
||||
"""GrandParent -> GrandChild."""
|
||||
with temporary_inheritance_models() as (GrandParent, _Parent, _Child, GrandChild):
|
||||
grandchild = GrandChild.objects.create()
|
||||
grandparent = GrandParent.objects.get(pk=grandchild.pk)
|
||||
|
||||
resolved = get_deepest_child(grandparent)
|
||||
|
||||
self.assertIsInstance(resolved, GrandChild)
|
||||
self.assertEqual(resolved.pk, grandchild.pk)
|
||||
|
||||
def test_get_deepest_child_parent_to_child(self):
|
||||
"""Parent -> Child (start from non-root)."""
|
||||
with temporary_inheritance_models() as (_GrandParent, Parent, Child, _GrandChild):
|
||||
child = Child.objects.create()
|
||||
parent = Parent.objects.get(pk=child.pk)
|
||||
|
||||
resolved = get_deepest_child(parent)
|
||||
|
||||
self.assertIsInstance(resolved, Child)
|
||||
self.assertEqual(resolved.pk, child.pk)
|
||||
|
||||
def test_get_deepest_child_no_queries_with_preloaded_relations(self):
|
||||
"""No extra queries when the inheritance chain is fully select_related."""
|
||||
with temporary_inheritance_models() as (GrandParent, _Parent, _Child, GrandChild):
|
||||
grandchild = GrandChild.objects.create()
|
||||
grandparent = GrandParent.objects.select_related("parent__child__grandchild").get(
|
||||
pk=grandchild.pk
|
||||
)
|
||||
|
||||
with self.assertNumQueries(0):
|
||||
resolved = get_deepest_child(grandparent)
|
||||
|
||||
self.assertIsInstance(resolved, GrandChild)
|
||||
41
authentik/lib/utils/inheritance.py
Normal file
41
authentik/lib/utils/inheritance.py
Normal file
@@ -0,0 +1,41 @@
|
||||
from django.db.models import Model, OneToOneField, OneToOneRel
|
||||
|
||||
|
||||
def get_deepest_child(parent: Model) -> Model:
|
||||
"""
|
||||
In multiple table inheritance, given any ancestor object, get the deepest child object.
|
||||
See https://docs.djangoproject.com/en/dev/topics/db/models/#multi-table-inheritance
|
||||
|
||||
This function does not query the database if `select_related` has been performed on all
|
||||
subclasses of `parent`'s model.
|
||||
"""
|
||||
|
||||
# Almost verbatim copy from django-model-utils, see
|
||||
# https://github.com/jazzband/django-model-utils/blob/5.0.0/model_utils/managers.py#L132
|
||||
one_to_one_rels = [
|
||||
field for field in parent._meta.get_fields() if isinstance(field, OneToOneRel)
|
||||
]
|
||||
|
||||
submodel_fields = [
|
||||
rel
|
||||
for rel in one_to_one_rels
|
||||
if isinstance(rel.field, OneToOneField)
|
||||
and issubclass(rel.field.model, parent._meta.model)
|
||||
and parent._meta.model is not rel.field.model
|
||||
and rel.parent_link
|
||||
]
|
||||
|
||||
submodel_accessors = [submodel_field.get_accessor_name() for submodel_field in submodel_fields]
|
||||
# End Copy
|
||||
|
||||
child = None
|
||||
for submodel in submodel_accessors:
|
||||
try:
|
||||
child = getattr(parent, submodel)
|
||||
break
|
||||
except AttributeError:
|
||||
continue
|
||||
|
||||
if not child:
|
||||
return parent
|
||||
return get_deepest_child(child)
|
||||
@@ -185,8 +185,10 @@ class KubernetesObjectReconciler[T]:
|
||||
|
||||
patch = self.get_patch()
|
||||
if patch is not None:
|
||||
current_json = ApiClient().sanitize_for_serialization(current)
|
||||
|
||||
try:
|
||||
current_json = ApiClient().sanitize_for_serialization(current)
|
||||
except AttributeError:
|
||||
current_json = asdict(current)
|
||||
try:
|
||||
if apply_patch(current_json, patch) != current_json:
|
||||
raise NeedsUpdate()
|
||||
|
||||
@@ -163,4 +163,5 @@ def outpost_pre_delete_cleanup(sender, instance: Outpost, **_):
|
||||
@receiver(pre_delete, sender=AuthenticatedSession)
|
||||
def outpost_logout_revoke(sender: type[AuthenticatedSession], instance: AuthenticatedSession, **_):
|
||||
"""Catch logout by expiring sessions being deleted"""
|
||||
outpost_session_end.send(instance.session.session_key)
|
||||
if Outpost.objects.exists():
|
||||
outpost_session_end.send(instance.session.session_key)
|
||||
|
||||
@@ -7,7 +7,6 @@ from socket import gethostname
|
||||
from typing import Any
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from asgiref.sync import async_to_sync
|
||||
from channels.layers import get_channel_layer
|
||||
from django.core.cache import cache
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
@@ -159,7 +158,7 @@ def outpost_send_update(pk: Any):
|
||||
layer = get_channel_layer()
|
||||
group = build_outpost_group(outpost.pk)
|
||||
LOGGER.debug("sending update", channel=group, outpost=outpost)
|
||||
async_to_sync(layer.group_send)(group, {"type": "event.update"})
|
||||
layer.group_send_blocking(group, {"type": "event.update"})
|
||||
|
||||
|
||||
@actor(description=_("Checks the local environment and create Service connections."))
|
||||
@@ -210,7 +209,7 @@ def outpost_session_end(session_id: str):
|
||||
for outpost in Outpost.objects.all():
|
||||
LOGGER.info("Sending session end signal to outpost", outpost=outpost)
|
||||
group = build_outpost_group(outpost.pk)
|
||||
async_to_sync(layer.group_send)(
|
||||
layer.group_send_blocking(
|
||||
group,
|
||||
{
|
||||
"type": "event.session.end",
|
||||
|
||||
@@ -57,9 +57,11 @@ class PolicyBindingSerializer(ModelSerializer):
|
||||
required=True,
|
||||
)
|
||||
|
||||
policy_obj = PolicySerializer(required=False, read_only=True, source="policy")
|
||||
group_obj = PartialGroupSerializer(required=False, read_only=True, source="group")
|
||||
user_obj = PartialUserSerializer(required=False, read_only=True, source="user")
|
||||
policy_obj = PolicySerializer(required=False, allow_null=True, read_only=True, source="policy")
|
||||
group_obj = PartialGroupSerializer(
|
||||
required=False, allow_null=True, read_only=True, source="group"
|
||||
)
|
||||
user_obj = PartialUserSerializer(required=False, allow_null=True, read_only=True, source="user")
|
||||
|
||||
class Meta:
|
||||
model = PolicyBinding
|
||||
|
||||
@@ -132,9 +132,14 @@ class PolicyEngine:
|
||||
# If we didn't find any static bindings, do nothing
|
||||
return
|
||||
self.logger.debug("P_ENG: Found static bindings", **matched_bindings)
|
||||
if matched_bindings.get("passing", 0) > 0:
|
||||
# Any passing static binding -> passing
|
||||
passing = True
|
||||
if self.mode == PolicyEngineMode.MODE_ANY:
|
||||
if matched_bindings.get("passing", 0) > 0:
|
||||
# Any passing static binding -> passing
|
||||
passing = True
|
||||
elif self.mode == PolicyEngineMode.MODE_ALL:
|
||||
if matched_bindings.get("passing", 0) == matched_bindings["total"]:
|
||||
# All static bindings are passing -> passing
|
||||
passing = True
|
||||
elif matched_bindings["total"] > 0 and matched_bindings.get("passing", 0) < 1:
|
||||
# No matching static bindings but at least one is configured -> not passing
|
||||
passing = False
|
||||
|
||||
@@ -33,6 +33,9 @@ class TestPolicyEngine(TestCase):
|
||||
self.policy_raises = ExpressionPolicy.objects.create(
|
||||
name=generate_id(), expression="{{ 0/0 }}"
|
||||
)
|
||||
self.group_member = Group.objects.create(name=generate_id())
|
||||
self.user.groups.add(self.group_member)
|
||||
self.group_non_member = Group.objects.create(name=generate_id())
|
||||
|
||||
def test_engine_empty(self):
|
||||
"""Ensure empty policy list passes"""
|
||||
@@ -51,7 +54,7 @@ class TestPolicyEngine(TestCase):
|
||||
self.assertEqual(result.passing, True)
|
||||
self.assertEqual(result.messages, ("dummy",))
|
||||
|
||||
def test_engine_mode_all(self):
|
||||
def test_engine_mode_all_dyn(self):
|
||||
"""Ensure all policies passes with AND mode (false and true -> false)"""
|
||||
pbm = PolicyBindingModel.objects.create(policy_engine_mode=PolicyEngineMode.MODE_ALL)
|
||||
PolicyBinding.objects.create(target=pbm, policy=self.policy_false, order=0)
|
||||
@@ -67,7 +70,7 @@ class TestPolicyEngine(TestCase):
|
||||
),
|
||||
)
|
||||
|
||||
def test_engine_mode_any(self):
|
||||
def test_engine_mode_any_dyn(self):
|
||||
"""Ensure all policies passes with OR mode (false and true -> true)"""
|
||||
pbm = PolicyBindingModel.objects.create(policy_engine_mode=PolicyEngineMode.MODE_ANY)
|
||||
PolicyBinding.objects.create(target=pbm, policy=self.policy_false, order=0)
|
||||
@@ -83,6 +86,26 @@ class TestPolicyEngine(TestCase):
|
||||
),
|
||||
)
|
||||
|
||||
def test_engine_mode_all_static(self):
|
||||
"""Ensure all policies passes with OR mode (false and true -> true)"""
|
||||
pbm = PolicyBindingModel.objects.create(policy_engine_mode=PolicyEngineMode.MODE_ALL)
|
||||
PolicyBinding.objects.create(target=pbm, group=self.group_member, order=0)
|
||||
PolicyBinding.objects.create(target=pbm, group=self.group_non_member, order=1)
|
||||
engine = PolicyEngine(pbm, self.user)
|
||||
result = engine.build().result
|
||||
self.assertEqual(result.passing, False)
|
||||
self.assertEqual(result.messages, ())
|
||||
|
||||
def test_engine_mode_any_static(self):
|
||||
"""Ensure all policies passes with OR mode (false and true -> true)"""
|
||||
pbm = PolicyBindingModel.objects.create(policy_engine_mode=PolicyEngineMode.MODE_ANY)
|
||||
PolicyBinding.objects.create(target=pbm, group=self.group_member, order=0)
|
||||
PolicyBinding.objects.create(target=pbm, group=self.group_non_member, order=1)
|
||||
engine = PolicyEngine(pbm, self.user)
|
||||
result = engine.build().result
|
||||
self.assertEqual(result.passing, True)
|
||||
self.assertEqual(result.messages, ())
|
||||
|
||||
def test_engine_negate(self):
|
||||
"""Test negate flag"""
|
||||
pbm = PolicyBindingModel.objects.create()
|
||||
|
||||
@@ -68,6 +68,8 @@ class IDToken:
|
||||
at_hash: str | None = None
|
||||
# Session ID, https://openid.net/specs/openid-connect-frontchannel-1_0.html#ClaimsContents
|
||||
sid: str | None = None
|
||||
# JWT ID, https://www.rfc-editor.org/rfc/rfc7519.html#section-4.1.7
|
||||
jti: str | None = None
|
||||
|
||||
claims: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
@@ -81,6 +83,7 @@ class IDToken:
|
||||
(token.expires if token.expires is not None else default_token_duration()).timestamp()
|
||||
)
|
||||
id_token.iss = provider.get_issuer(request)
|
||||
id_token.jti = generate_id()
|
||||
id_token.aud = provider.client_id
|
||||
id_token.claims = {}
|
||||
|
||||
|
||||
@@ -5,6 +5,7 @@ from urllib.parse import parse_qs, urlparse
|
||||
|
||||
from django.test import RequestFactory
|
||||
from django.urls import reverse
|
||||
from django.utils import translation
|
||||
from django.utils.timezone import now
|
||||
|
||||
from authentik.blueprints.tests import apply_blueprint
|
||||
@@ -140,26 +141,6 @@ class TestAuthorize(OAuthTestCase):
|
||||
OAuthAuthorizationParams.from_request(request)
|
||||
self.assertEqual(cm.exception.cause, "redirect_uri_forbidden_scheme")
|
||||
|
||||
def test_invalid_redirect_uri_empty(self):
|
||||
"""test missing/invalid redirect URI"""
|
||||
provider = OAuth2Provider.objects.create(
|
||||
name=generate_id(),
|
||||
client_id="test",
|
||||
authorization_flow=create_test_flow(),
|
||||
redirect_uris=[],
|
||||
)
|
||||
request = self.factory.get(
|
||||
"/",
|
||||
data={
|
||||
"response_type": "code",
|
||||
"client_id": "test",
|
||||
"redirect_uri": "+",
|
||||
},
|
||||
)
|
||||
OAuthAuthorizationParams.from_request(request)
|
||||
provider.refresh_from_db()
|
||||
self.assertEqual(provider.redirect_uris, [RedirectURI(RedirectURIMatchingMode.STRICT, "+")])
|
||||
|
||||
def test_invalid_redirect_uri_regex(self):
|
||||
"""test missing/invalid redirect URI"""
|
||||
OAuth2Provider.objects.create(
|
||||
@@ -393,7 +374,7 @@ class TestAuthorize(OAuthTestCase):
|
||||
"nonce": generate_id(),
|
||||
},
|
||||
)
|
||||
token: AccessToken = AccessToken.objects.filter(user=user).first()
|
||||
token = AccessToken.objects.filter(user=user).first()
|
||||
expires = timedelta_from_string(provider.access_token_validity).total_seconds()
|
||||
self.assertEqual(
|
||||
response.url,
|
||||
@@ -465,7 +446,7 @@ class TestAuthorize(OAuthTestCase):
|
||||
},
|
||||
)
|
||||
self.assertEqual(response.status_code, 302)
|
||||
token: AccessToken = AccessToken.objects.filter(user=user).first()
|
||||
token = AccessToken.objects.filter(user=user).first()
|
||||
expires = timedelta_from_string(provider.access_token_validity).total_seconds()
|
||||
jwt = self.validate_jwe(token, provider)
|
||||
self.assertEqual(jwt["amr"], ["pwd"])
|
||||
@@ -564,7 +545,7 @@ class TestAuthorize(OAuthTestCase):
|
||||
response = self.client.get(
|
||||
reverse("authentik_api:flow-executor", kwargs={"flow_slug": flow.slug}),
|
||||
)
|
||||
token: AccessToken = AccessToken.objects.filter(user=user).first()
|
||||
token = AccessToken.objects.filter(user=user).first()
|
||||
self.assertIsNotNone(token)
|
||||
self.assertJSONEqual(
|
||||
response.content.decode(),
|
||||
@@ -690,18 +671,21 @@ class TestAuthorize(OAuthTestCase):
|
||||
Application.objects.create(name="app", slug="app", provider=provider)
|
||||
state = generate_id()
|
||||
self.client.logout()
|
||||
response = self.client.get(
|
||||
reverse("authentik_providers_oauth2:authorize"),
|
||||
data={
|
||||
"response_type": "code",
|
||||
"client_id": "test",
|
||||
"state": state,
|
||||
"redirect_uri": "foo://localhost",
|
||||
"ui_locales": "invalid fr",
|
||||
},
|
||||
)
|
||||
parsed = parse_qs(urlparse(response.url).query)
|
||||
self.assertEqual(parsed["locale"], ["fr"])
|
||||
try:
|
||||
response = self.client.get(
|
||||
reverse("authentik_providers_oauth2:authorize"),
|
||||
data={
|
||||
"response_type": "code",
|
||||
"client_id": "test",
|
||||
"state": state,
|
||||
"redirect_uri": "foo://localhost",
|
||||
"ui_locales": "invalid fr",
|
||||
},
|
||||
)
|
||||
parsed = parse_qs(urlparse(response.url).query)
|
||||
self.assertEqual(parsed["locale"], ["fr"])
|
||||
finally:
|
||||
translation.deactivate()
|
||||
|
||||
@apply_blueprint("default/flow-default-authentication-flow.yaml")
|
||||
def test_ui_locales_invalid(self):
|
||||
|
||||
@@ -4,22 +4,19 @@ from unittest.mock import Mock, patch
|
||||
|
||||
import jwt
|
||||
from django.test import RequestFactory
|
||||
from django.utils import timezone
|
||||
from dramatiq.results.errors import ResultFailure
|
||||
from requests import Response
|
||||
from requests.exceptions import HTTPError, Timeout
|
||||
|
||||
from authentik.core.models import Application, AuthenticatedSession, Session
|
||||
from authentik.core.models import Application
|
||||
from authentik.core.tests.utils import create_test_admin_user, create_test_flow
|
||||
from authentik.lib.generators import generate_id
|
||||
from authentik.providers.oauth2.id_token import hash_session_key
|
||||
from authentik.providers.oauth2.models import (
|
||||
AccessToken,
|
||||
OAuth2LogoutMethod,
|
||||
OAuth2Provider,
|
||||
RedirectURI,
|
||||
RedirectURIMatchingMode,
|
||||
RefreshToken,
|
||||
)
|
||||
from authentik.providers.oauth2.tasks import send_backchannel_logout_request
|
||||
from authentik.providers.oauth2.tests.utils import OAuthTestCase
|
||||
@@ -45,52 +42,6 @@ class TestBackChannelLogout(OAuthTestCase):
|
||||
self.app.provider = self.provider
|
||||
self.app.save()
|
||||
|
||||
def _create_session(self, session_key=None):
|
||||
"""Create a session with the given key or a generated one"""
|
||||
session_key = session_key or f"session-{generate_id()}"
|
||||
session = Session.objects.create(
|
||||
session_key=session_key,
|
||||
expires=timezone.now() + timezone.timedelta(hours=1),
|
||||
last_ip="255.255.255.255",
|
||||
)
|
||||
auth_session = AuthenticatedSession.objects.create(
|
||||
session=session,
|
||||
user=self.user,
|
||||
)
|
||||
return auth_session
|
||||
|
||||
def _create_token(
|
||||
self, provider, user, session=None, token_type="access", token_id=None
|
||||
): # nosec
|
||||
"""Create a token of the specified type"""
|
||||
token_id = token_id or f"{token_type}-token-{generate_id()}"
|
||||
kwargs = {
|
||||
"provider": provider,
|
||||
"user": user,
|
||||
"session": session,
|
||||
"token": token_id,
|
||||
"_id_token": "{}",
|
||||
"auth_time": timezone.now(),
|
||||
}
|
||||
|
||||
if token_type == "access": # nosec
|
||||
return AccessToken.objects.create(**kwargs)
|
||||
else: # refresh
|
||||
return RefreshToken.objects.create(**kwargs)
|
||||
|
||||
def _create_provider(self, name=None):
|
||||
"""Create an OAuth2 provider"""
|
||||
name = name or f"provider-{generate_id()}"
|
||||
provider = OAuth2Provider.objects.create(
|
||||
name=name,
|
||||
authorization_flow=create_test_flow(),
|
||||
redirect_uris=[
|
||||
RedirectURI(RedirectURIMatchingMode.STRICT, f"http://{name}/callback"),
|
||||
],
|
||||
signing_key=self.keypair,
|
||||
)
|
||||
return provider
|
||||
|
||||
def _create_logout_token(
|
||||
self,
|
||||
provider: OAuth2Provider | None = None,
|
||||
|
||||
@@ -2,13 +2,15 @@
|
||||
|
||||
from base64 import b64encode
|
||||
from json import loads
|
||||
from urllib.parse import quote
|
||||
|
||||
from django.urls import reverse
|
||||
|
||||
from authentik.blueprints.tests import apply_blueprint
|
||||
from authentik.core.models import Application
|
||||
from authentik.core.tests.utils import create_test_flow
|
||||
from authentik.lib.generators import generate_id
|
||||
from authentik.providers.oauth2.models import OAuth2Provider
|
||||
from authentik.providers.oauth2.models import DeviceToken, OAuth2Provider, ScopeMapping
|
||||
from authentik.providers.oauth2.tests.utils import OAuthTestCase
|
||||
|
||||
|
||||
@@ -96,3 +98,70 @@ class TesOAuth2DeviceBackchannel(OAuthTestCase):
|
||||
self.assertEqual(res.status_code, 200)
|
||||
body = loads(res.content.decode())
|
||||
self.assertEqual(body["expires_in"], 60)
|
||||
|
||||
def test_backchannel_client_id_via_auth_header_urlencoded(self):
|
||||
"""Test URL-encoded client IDs in Basic auth"""
|
||||
self.provider.client_id = "test/client+id"
|
||||
self.provider.save()
|
||||
creds = b64encode(f"{quote(self.provider.client_id, safe='')}:".encode()).decode()
|
||||
res = self.client.post(
|
||||
reverse("authentik_providers_oauth2:device"),
|
||||
HTTP_AUTHORIZATION=f"Basic {creds}",
|
||||
)
|
||||
self.assertEqual(res.status_code, 200)
|
||||
body = loads(res.content.decode())
|
||||
self.assertEqual(body["expires_in"], 60)
|
||||
|
||||
@apply_blueprint("system/providers-oauth2.yaml")
|
||||
def test_backchannel_scopes(self):
|
||||
"""Test backchannel"""
|
||||
self.provider.property_mappings.set(
|
||||
ScopeMapping.objects.filter(
|
||||
managed__in=[
|
||||
"goauthentik.io/providers/oauth2/scope-openid",
|
||||
"goauthentik.io/providers/oauth2/scope-email",
|
||||
"goauthentik.io/providers/oauth2/scope-profile",
|
||||
]
|
||||
)
|
||||
)
|
||||
creds = b64encode(f"{self.provider.client_id}:".encode()).decode()
|
||||
res = self.client.post(
|
||||
reverse("authentik_providers_oauth2:device"),
|
||||
HTTP_AUTHORIZATION=f"Basic {creds}",
|
||||
data={"scope": "openid email"},
|
||||
)
|
||||
self.assertEqual(res.status_code, 200)
|
||||
body = loads(res.content.decode())
|
||||
self.assertEqual(body["expires_in"], 60)
|
||||
token = DeviceToken.objects.filter(device_code=body["device_code"]).first()
|
||||
self.assertIsNotNone(token)
|
||||
self.assertEqual(len(token.scope), 2)
|
||||
self.assertIn("openid", token.scope)
|
||||
self.assertIn("email", token.scope)
|
||||
|
||||
@apply_blueprint("system/providers-oauth2.yaml")
|
||||
def test_backchannel_scopes_extra(self):
|
||||
"""Test backchannel"""
|
||||
self.provider.property_mappings.set(
|
||||
ScopeMapping.objects.filter(
|
||||
managed__in=[
|
||||
"goauthentik.io/providers/oauth2/scope-openid",
|
||||
"goauthentik.io/providers/oauth2/scope-email",
|
||||
"goauthentik.io/providers/oauth2/scope-profile",
|
||||
]
|
||||
)
|
||||
)
|
||||
creds = b64encode(f"{self.provider.client_id}:".encode()).decode()
|
||||
res = self.client.post(
|
||||
reverse("authentik_providers_oauth2:device"),
|
||||
HTTP_AUTHORIZATION=f"Basic {creds}",
|
||||
data={"scope": "openid email foo"},
|
||||
)
|
||||
self.assertEqual(res.status_code, 200)
|
||||
body = loads(res.content.decode())
|
||||
self.assertEqual(body["expires_in"], 60)
|
||||
token = DeviceToken.objects.filter(device_code=body["device_code"]).first()
|
||||
self.assertIsNotNone(token)
|
||||
self.assertEqual(len(token.scope), 2)
|
||||
self.assertIn("openid", token.scope)
|
||||
self.assertIn("email", token.scope)
|
||||
|
||||
@@ -14,6 +14,7 @@ from authentik.lib.generators import generate_id
|
||||
from authentik.providers.oauth2.id_token import IDToken
|
||||
from authentik.providers.oauth2.models import (
|
||||
AccessToken,
|
||||
ClientTypes,
|
||||
OAuth2Provider,
|
||||
RedirectURI,
|
||||
RedirectURIMatchingMode,
|
||||
@@ -43,7 +44,7 @@ class TesOAuth2Introspection(OAuthTestCase):
|
||||
|
||||
def test_introspect_refresh(self):
|
||||
"""Test introspect"""
|
||||
token: RefreshToken = RefreshToken.objects.create(
|
||||
token = RefreshToken.objects.create(
|
||||
provider=self.provider,
|
||||
user=self.user,
|
||||
token=generate_id(),
|
||||
@@ -75,7 +76,7 @@ class TesOAuth2Introspection(OAuthTestCase):
|
||||
|
||||
def test_introspect_access(self):
|
||||
"""Test introspect"""
|
||||
token: AccessToken = AccessToken.objects.create(
|
||||
token = AccessToken.objects.create(
|
||||
provider=self.provider,
|
||||
user=self.user,
|
||||
token=generate_id(),
|
||||
@@ -130,7 +131,7 @@ class TesOAuth2Introspection(OAuthTestCase):
|
||||
)
|
||||
auth = b64encode(f"{provider.client_id}:{provider.client_secret}".encode()).decode()
|
||||
|
||||
token: AccessToken = AccessToken.objects.create(
|
||||
token = AccessToken.objects.create(
|
||||
provider=self.provider,
|
||||
user=self.user,
|
||||
token=generate_id(),
|
||||
@@ -169,3 +170,76 @@ class TesOAuth2Introspection(OAuthTestCase):
|
||||
"active": False,
|
||||
},
|
||||
)
|
||||
|
||||
def test_introspect_provider_public(self):
|
||||
"""Test introspect"""
|
||||
self.provider.client_type = ClientTypes.PUBLIC
|
||||
self.provider.save()
|
||||
token = AccessToken.objects.create(
|
||||
provider=self.provider,
|
||||
user=self.user,
|
||||
token=generate_id(),
|
||||
auth_time=timezone.now(),
|
||||
_scope="openid user profile",
|
||||
_id_token=json.dumps(
|
||||
asdict(
|
||||
IDToken("foo", "bar"),
|
||||
)
|
||||
),
|
||||
)
|
||||
res = self.client.post(
|
||||
reverse("authentik_providers_oauth2:token-introspection"),
|
||||
HTTP_AUTHORIZATION=f"Basic {self.auth}",
|
||||
data={"token": token.token},
|
||||
)
|
||||
self.assertEqual(res.status_code, 200)
|
||||
self.assertJSONEqual(
|
||||
res.content.decode(),
|
||||
{
|
||||
"active": False,
|
||||
},
|
||||
)
|
||||
|
||||
def test_introspect_provider_fed(self):
|
||||
"""Test introspect with federation. self.provider is a confidential
|
||||
client and other_provider is a public client."""
|
||||
other_provider = OAuth2Provider.objects.create(
|
||||
name=generate_id(),
|
||||
authorization_flow=create_test_flow(),
|
||||
redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "")],
|
||||
signing_key=create_test_cert(),
|
||||
client_type=ClientTypes.PUBLIC,
|
||||
)
|
||||
Application.objects.create(name=generate_id(), slug=generate_id(), provider=other_provider)
|
||||
|
||||
other_provider.jwt_federation_providers.add(self.provider)
|
||||
|
||||
token = AccessToken.objects.create(
|
||||
provider=other_provider,
|
||||
user=self.user,
|
||||
token=generate_id(),
|
||||
auth_time=timezone.now(),
|
||||
_scope="openid user profile",
|
||||
_id_token=json.dumps(
|
||||
asdict(
|
||||
IDToken("foo", "bar"),
|
||||
)
|
||||
),
|
||||
)
|
||||
res = self.client.post(
|
||||
reverse("authentik_providers_oauth2:token-introspection"),
|
||||
HTTP_AUTHORIZATION=f"Basic {self.auth}",
|
||||
data={"token": token.token},
|
||||
)
|
||||
self.assertEqual(res.status_code, 200)
|
||||
self.assertJSONEqual(
|
||||
res.content.decode(),
|
||||
{
|
||||
"acr": ACR_AUTHENTIK_DEFAULT,
|
||||
"sub": "bar",
|
||||
"iss": "foo",
|
||||
"active": True,
|
||||
"client_id": other_provider.client_id,
|
||||
"scope": " ".join(token.scope),
|
||||
},
|
||||
)
|
||||
|
||||
@@ -46,7 +46,7 @@ class TesOAuth2Revoke(OAuthTestCase):
|
||||
|
||||
def test_revoke_refresh(self):
|
||||
"""Test revoke"""
|
||||
token: RefreshToken = RefreshToken.objects.create(
|
||||
token = RefreshToken.objects.create(
|
||||
provider=self.provider,
|
||||
user=self.user,
|
||||
token=generate_id(),
|
||||
@@ -69,7 +69,7 @@ class TesOAuth2Revoke(OAuthTestCase):
|
||||
|
||||
def test_revoke_access(self):
|
||||
"""Test revoke"""
|
||||
token: AccessToken = AccessToken.objects.create(
|
||||
token = AccessToken.objects.create(
|
||||
provider=self.provider,
|
||||
user=self.user,
|
||||
token=generate_id(),
|
||||
@@ -105,7 +105,19 @@ class TesOAuth2Revoke(OAuthTestCase):
|
||||
"""Test revoke (invalid auth)"""
|
||||
res = self.client.post(
|
||||
reverse("authentik_providers_oauth2:token-revoke"),
|
||||
HTTP_AUTHORIZATION="Basic fqewr",
|
||||
HTTP_AUTHORIZATION="Basic aaa",
|
||||
data={
|
||||
"token": generate_id(),
|
||||
},
|
||||
)
|
||||
self.assertEqual(res.status_code, 401)
|
||||
|
||||
def test_revoke_invalid_auth_secret(self):
|
||||
"""Test revoke (invalid secret)"""
|
||||
invalid_auth = b64encode(f"{self.provider.client_id}:aaa".encode()).decode()
|
||||
res = self.client.post(
|
||||
reverse("authentik_providers_oauth2:token-revoke"),
|
||||
HTTP_AUTHORIZATION=f"Basic {invalid_auth}",
|
||||
data={
|
||||
"token": generate_id(),
|
||||
},
|
||||
@@ -116,7 +128,7 @@ class TesOAuth2Revoke(OAuthTestCase):
|
||||
"""Test revoke public client"""
|
||||
self.provider.client_type = ClientTypes.PUBLIC
|
||||
self.provider.save()
|
||||
token: AccessToken = AccessToken.objects.create(
|
||||
token = AccessToken.objects.create(
|
||||
provider=self.provider,
|
||||
user=self.user,
|
||||
token=generate_id(),
|
||||
@@ -220,3 +232,74 @@ class TesOAuth2Revoke(OAuthTestCase):
|
||||
self.assertEqual(AccessToken.objects.all().count(), 0)
|
||||
self.assertEqual(RefreshToken.objects.all().count(), 0)
|
||||
self.assertEqual(DeviceToken.objects.all().count(), 0)
|
||||
|
||||
def test_revoke_provider_fed(self):
|
||||
"""Test revoke with federation. self.provider is a confidential
|
||||
client and other_provider is a public client."""
|
||||
other_provider = OAuth2Provider.objects.create(
|
||||
name=generate_id(),
|
||||
authorization_flow=create_test_flow(),
|
||||
redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "")],
|
||||
signing_key=create_test_cert(),
|
||||
client_type=ClientTypes.PUBLIC,
|
||||
)
|
||||
Application.objects.create(name=generate_id(), slug=generate_id(), provider=other_provider)
|
||||
|
||||
other_provider.jwt_federation_providers.add(self.provider)
|
||||
|
||||
token = AccessToken.objects.create(
|
||||
provider=other_provider,
|
||||
user=self.user,
|
||||
token=generate_id(),
|
||||
auth_time=timezone.now(),
|
||||
_scope="openid user profile",
|
||||
_id_token=json.dumps(
|
||||
asdict(
|
||||
IDToken("foo", "bar"),
|
||||
)
|
||||
),
|
||||
)
|
||||
res = self.client.post(
|
||||
reverse("authentik_providers_oauth2:token-revoke"),
|
||||
HTTP_AUTHORIZATION=f"Basic {self.auth}",
|
||||
data={"token": token.token},
|
||||
)
|
||||
self.assertEqual(res.status_code, 200)
|
||||
self.assertJSONEqual(res.content.decode(), {})
|
||||
|
||||
def test_revoke_provider_fed_public(self):
|
||||
"""Test revoke with federation. self.provider is a public
|
||||
client and other_provider is a public client."""
|
||||
self.provider.client_type = ClientTypes.PUBLIC
|
||||
self.provider.save()
|
||||
other_provider = OAuth2Provider.objects.create(
|
||||
name=generate_id(),
|
||||
authorization_flow=create_test_flow(),
|
||||
redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "")],
|
||||
signing_key=create_test_cert(),
|
||||
client_type=ClientTypes.PUBLIC,
|
||||
)
|
||||
Application.objects.create(name=generate_id(), slug=generate_id(), provider=other_provider)
|
||||
|
||||
other_provider.jwt_federation_providers.add(self.provider)
|
||||
|
||||
token = AccessToken.objects.create(
|
||||
provider=other_provider,
|
||||
user=self.user,
|
||||
token=generate_id(),
|
||||
auth_time=timezone.now(),
|
||||
_scope="openid user profile",
|
||||
_id_token=json.dumps(
|
||||
asdict(
|
||||
IDToken("foo", "bar"),
|
||||
)
|
||||
),
|
||||
)
|
||||
auth_public = b64encode(f"{self.provider.client_id}:{generate_id()}".encode()).decode()
|
||||
res = self.client.post(
|
||||
reverse("authentik_providers_oauth2:token-revoke"),
|
||||
HTTP_AUTHORIZATION=f"Basic {auth_public}",
|
||||
data={"token": token.token},
|
||||
)
|
||||
self.assertEqual(res.status_code, 200)
|
||||
self.assertTrue(AccessToken.objects.filter(token=token.token).exists())
|
||||
|
||||
@@ -1,11 +1,15 @@
|
||||
"""Test token view"""
|
||||
|
||||
from base64 import b64encode
|
||||
from datetime import timedelta
|
||||
from json import dumps
|
||||
from urllib.parse import quote
|
||||
|
||||
from django.test import RequestFactory
|
||||
from django.urls import reverse
|
||||
from django.utils import timezone
|
||||
from django.utils.timezone import now
|
||||
from freezegun import freeze_time
|
||||
|
||||
from authentik.blueprints.tests import apply_blueprint
|
||||
from authentik.common.oauth.constants import (
|
||||
@@ -28,6 +32,7 @@ from authentik.providers.oauth2.models import (
|
||||
ScopeMapping,
|
||||
)
|
||||
from authentik.providers.oauth2.tests.utils import OAuthTestCase
|
||||
from authentik.providers.oauth2.utils import extract_client_auth
|
||||
from authentik.providers.oauth2.views.token import TokenParams
|
||||
|
||||
|
||||
@@ -97,7 +102,7 @@ class TestToken(OAuthTestCase):
|
||||
)
|
||||
header = b64encode(f"{provider.client_id}:{provider.client_secret}".encode()).decode()
|
||||
user = create_test_admin_user()
|
||||
token: RefreshToken = RefreshToken.objects.create(
|
||||
token = RefreshToken.objects.create(
|
||||
provider=provider,
|
||||
user=user,
|
||||
token=generate_id(),
|
||||
@@ -115,6 +120,20 @@ class TestToken(OAuthTestCase):
|
||||
params = TokenParams.parse(request, provider, provider.client_id, provider.client_secret)
|
||||
self.assertEqual(params.provider, provider)
|
||||
|
||||
def test_extract_client_auth_basic_auth_percent_decodes(self):
|
||||
"""test percent-decoding of client credentials in Basic auth"""
|
||||
header = b64encode(
|
||||
f"{quote('client/id', safe='')}:{quote('secret+/==', safe='')}".encode()
|
||||
).decode()
|
||||
request = self.factory.post("/", HTTP_AUTHORIZATION=f"Basic {header}")
|
||||
self.assertEqual(extract_client_auth(request), ("client/id", "secret+/=="))
|
||||
|
||||
def test_extract_client_auth_basic_auth_preserves_raw_plus(self):
|
||||
"""test compatibility with clients that still send raw plus characters"""
|
||||
header = b64encode(b"client:secret+plus").decode()
|
||||
request = self.factory.post("/", HTTP_AUTHORIZATION=f"Basic {header}")
|
||||
self.assertEqual(extract_client_auth(request), ("client", "secret+plus"))
|
||||
|
||||
def test_auth_code_view(self):
|
||||
"""test request param"""
|
||||
provider = OAuth2Provider.objects.create(
|
||||
@@ -140,7 +159,7 @@ class TestToken(OAuthTestCase):
|
||||
},
|
||||
HTTP_AUTHORIZATION=f"Basic {header}",
|
||||
)
|
||||
access: AccessToken = AccessToken.objects.filter(user=user, provider=provider).first()
|
||||
access = AccessToken.objects.filter(user=user, provider=provider).first()
|
||||
self.assertJSONEqual(
|
||||
response.content.decode(),
|
||||
{
|
||||
@@ -182,7 +201,7 @@ class TestToken(OAuthTestCase):
|
||||
HTTP_AUTHORIZATION=f"Basic {header}",
|
||||
)
|
||||
self.assertEqual(response.status_code, 200)
|
||||
access: AccessToken = AccessToken.objects.filter(user=user, provider=provider).first()
|
||||
access = AccessToken.objects.filter(user=user, provider=provider).first()
|
||||
self.validate_jwe(access, provider)
|
||||
|
||||
@apply_blueprint("system/providers-oauth2.yaml")
|
||||
@@ -209,7 +228,7 @@ class TestToken(OAuthTestCase):
|
||||
self.app.save()
|
||||
header = b64encode(f"{provider.client_id}:{provider.client_secret}".encode()).decode()
|
||||
user = create_test_admin_user()
|
||||
token: RefreshToken = RefreshToken.objects.create(
|
||||
token = RefreshToken.objects.create(
|
||||
provider=provider,
|
||||
user=user,
|
||||
token=generate_id(),
|
||||
@@ -229,10 +248,8 @@ class TestToken(OAuthTestCase):
|
||||
)
|
||||
self.assertEqual(response["Access-Control-Allow-Credentials"], "true")
|
||||
self.assertEqual(response["Access-Control-Allow-Origin"], "http://local.invalid")
|
||||
access: AccessToken = AccessToken.objects.filter(user=user, provider=provider).first()
|
||||
refresh: RefreshToken = RefreshToken.objects.filter(
|
||||
user=user, provider=provider, revoked=False
|
||||
).first()
|
||||
access = AccessToken.objects.filter(user=user, provider=provider).first()
|
||||
refresh = RefreshToken.objects.filter(user=user, provider=provider, revoked=False).first()
|
||||
self.assertJSONEqual(
|
||||
response.content.decode(),
|
||||
{
|
||||
@@ -269,7 +286,7 @@ class TestToken(OAuthTestCase):
|
||||
)
|
||||
header = b64encode(f"{provider.client_id}:{provider.client_secret}".encode()).decode()
|
||||
user = create_test_admin_user()
|
||||
token: RefreshToken = RefreshToken.objects.create(
|
||||
token = RefreshToken.objects.create(
|
||||
provider=provider,
|
||||
user=user,
|
||||
token=generate_id(),
|
||||
@@ -287,10 +304,8 @@ class TestToken(OAuthTestCase):
|
||||
HTTP_AUTHORIZATION=f"Basic {header}",
|
||||
HTTP_ORIGIN="http://another.invalid",
|
||||
)
|
||||
access: AccessToken = AccessToken.objects.filter(user=user, provider=provider).first()
|
||||
refresh: RefreshToken = RefreshToken.objects.filter(
|
||||
user=user, provider=provider, revoked=False
|
||||
).first()
|
||||
access = AccessToken.objects.filter(user=user, provider=provider).first()
|
||||
refresh = RefreshToken.objects.filter(user=user, provider=provider, revoked=False).first()
|
||||
self.assertNotIn("Access-Control-Allow-Credentials", response)
|
||||
self.assertNotIn("Access-Control-Allow-Origin", response)
|
||||
self.assertJSONEqual(
|
||||
@@ -331,7 +346,7 @@ class TestToken(OAuthTestCase):
|
||||
self.app.save()
|
||||
header = b64encode(f"{provider.client_id}:{provider.client_secret}".encode()).decode()
|
||||
user = create_test_admin_user()
|
||||
token: RefreshToken = RefreshToken.objects.create(
|
||||
token = RefreshToken.objects.create(
|
||||
provider=provider,
|
||||
user=user,
|
||||
token=generate_id(),
|
||||
@@ -349,9 +364,7 @@ class TestToken(OAuthTestCase):
|
||||
},
|
||||
HTTP_AUTHORIZATION=f"Basic {header}",
|
||||
)
|
||||
new_token: RefreshToken = (
|
||||
RefreshToken.objects.filter(user=user).exclude(pk=token.pk).first()
|
||||
)
|
||||
new_token = RefreshToken.objects.filter(user=user).exclude(pk=token.pk).first()
|
||||
# Post again with initial token -> get new refresh token
|
||||
# and revoke old one
|
||||
response = self.client.post(
|
||||
@@ -379,7 +392,11 @@ class TestToken(OAuthTestCase):
|
||||
|
||||
@apply_blueprint("system/providers-oauth2.yaml")
|
||||
def test_refresh_token_view_threshold(self):
|
||||
"""test request param"""
|
||||
"""refresh token threshold
|
||||
|
||||
threshold set to 1 hour, refresh token expires in 2 hours.
|
||||
First request should not return a new refresh token, second request
|
||||
has a fake time 1 hours in the future which should return a new access token"""
|
||||
provider = OAuth2Provider.objects.create(
|
||||
name=generate_id(),
|
||||
authorization_flow=create_test_flow(),
|
||||
@@ -402,13 +419,14 @@ class TestToken(OAuthTestCase):
|
||||
self.app.save()
|
||||
header = b64encode(f"{provider.client_id}:{provider.client_secret}".encode()).decode()
|
||||
user = create_test_admin_user()
|
||||
token: RefreshToken = RefreshToken.objects.create(
|
||||
token = RefreshToken.objects.create(
|
||||
provider=provider,
|
||||
user=user,
|
||||
token=generate_id(),
|
||||
_id_token=dumps({}),
|
||||
auth_time=timezone.now(),
|
||||
_scope="offline_access",
|
||||
expires=now() + timedelta(hours=2),
|
||||
)
|
||||
response = self.client.post(
|
||||
reverse("authentik_providers_oauth2:token"),
|
||||
@@ -420,9 +438,7 @@ class TestToken(OAuthTestCase):
|
||||
HTTP_AUTHORIZATION=f"Basic {header}",
|
||||
HTTP_ORIGIN="http://local.invalid",
|
||||
)
|
||||
self.assertEqual(response["Access-Control-Allow-Credentials"], "true")
|
||||
self.assertEqual(response["Access-Control-Allow-Origin"], "http://local.invalid")
|
||||
access: AccessToken = AccessToken.objects.filter(user=user, provider=provider).first()
|
||||
access = AccessToken.objects.filter(user=user, provider=provider).first()
|
||||
self.assertJSONEqual(
|
||||
response.content.decode(),
|
||||
{
|
||||
@@ -437,6 +453,42 @@ class TestToken(OAuthTestCase):
|
||||
)
|
||||
self.validate_jwt(access, provider)
|
||||
|
||||
with freeze_time(now() + timedelta(hours=1, minutes=10)):
|
||||
response = self.client.post(
|
||||
reverse("authentik_providers_oauth2:token"),
|
||||
data={
|
||||
"grant_type": GRANT_TYPE_REFRESH_TOKEN,
|
||||
"refresh_token": token.token,
|
||||
"redirect_uri": "http://local.invalid",
|
||||
},
|
||||
HTTP_AUTHORIZATION=f"Basic {header}",
|
||||
HTTP_ORIGIN="http://local.invalid",
|
||||
)
|
||||
access = (
|
||||
AccessToken.objects.filter(user=user, provider=provider)
|
||||
.exclude(pk=access.pk)
|
||||
.first()
|
||||
)
|
||||
refresh = (
|
||||
RefreshToken.objects.filter(user=user, provider=provider)
|
||||
.exclude(pk=token.pk)
|
||||
.first()
|
||||
)
|
||||
self.assertJSONEqual(
|
||||
response.content.decode(),
|
||||
{
|
||||
"access_token": access.token,
|
||||
"token_type": TOKEN_TYPE,
|
||||
"expires_in": 3600,
|
||||
"id_token": provider.encode(
|
||||
access.id_token.to_dict(),
|
||||
),
|
||||
"scope": "offline_access",
|
||||
"refresh_token": refresh.token,
|
||||
},
|
||||
)
|
||||
self.validate_jwt(access, provider)
|
||||
|
||||
@apply_blueprint("system/providers-oauth2.yaml")
|
||||
def test_scope_claim_override_via_property_mapping(self):
|
||||
"""Test that property mappings can override the scope claim in access tokens.
|
||||
@@ -484,7 +536,7 @@ class TestToken(OAuthTestCase):
|
||||
)
|
||||
self.assertEqual(response.status_code, 200)
|
||||
|
||||
access: AccessToken = AccessToken.objects.filter(user=user, provider=provider).first()
|
||||
access = AccessToken.objects.filter(user=user, provider=provider).first()
|
||||
jwt_data = self.validate_jwt(access, provider)
|
||||
|
||||
# The scope should be the custom value from the property mapping,
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
|
||||
from base64 import b64encode
|
||||
from json import loads
|
||||
from urllib.parse import quote
|
||||
|
||||
from django.test import RequestFactory
|
||||
from django.urls import reverse
|
||||
@@ -178,6 +179,41 @@ class TestTokenClientCredentialsStandardCompat(OAuthTestCase):
|
||||
self.assertEqual(jwt["given_name"], self.user.name)
|
||||
self.assertEqual(jwt["preferred_username"], self.user.username)
|
||||
|
||||
def test_successful_basic_auth_urlencoded_client_secret(self):
|
||||
"""test successful with URL-encoded Basic auth credentials"""
|
||||
client_secret = b64encode(f"sa:{self.token.key}".encode()).decode()
|
||||
header = b64encode(
|
||||
f"{quote(self.provider.client_id, safe='')}:{quote(client_secret, safe='')}".encode()
|
||||
).decode()
|
||||
response = self.client.post(
|
||||
reverse("authentik_providers_oauth2:token"),
|
||||
{
|
||||
"grant_type": GRANT_TYPE_CLIENT_CREDENTIALS,
|
||||
"scope": f"{SCOPE_OPENID} {SCOPE_OPENID_EMAIL} {SCOPE_OPENID_PROFILE}",
|
||||
},
|
||||
HTTP_AUTHORIZATION=f"Basic {header}",
|
||||
)
|
||||
self.assertEqual(response.status_code, 200)
|
||||
body = loads(response.content.decode())
|
||||
self.assertEqual(body["token_type"], TOKEN_TYPE)
|
||||
_, alg = self.provider.jwt_key
|
||||
jwt = decode(
|
||||
body["access_token"],
|
||||
key=self.provider.signing_key.public_key,
|
||||
algorithms=[alg],
|
||||
audience=self.provider.client_id,
|
||||
)
|
||||
self.assertEqual(jwt["given_name"], self.user.name)
|
||||
self.assertEqual(jwt["preferred_username"], self.user.username)
|
||||
jwt = decode(
|
||||
body["id_token"],
|
||||
key=self.provider.signing_key.public_key,
|
||||
algorithms=[alg],
|
||||
audience=self.provider.client_id,
|
||||
)
|
||||
self.assertEqual(jwt["given_name"], self.user.name)
|
||||
self.assertEqual(jwt["preferred_username"], self.user.username)
|
||||
|
||||
def test_successful_password(self):
|
||||
"""test successful (password grant)"""
|
||||
response = self.client.post(
|
||||
|
||||
@@ -40,7 +40,7 @@ class TestUserinfo(OAuthTestCase):
|
||||
self.app.provider = self.provider
|
||||
self.app.save()
|
||||
self.user = create_test_admin_user()
|
||||
self.token: AccessToken = AccessToken.objects.create(
|
||||
self.token = AccessToken.objects.create(
|
||||
provider=self.provider,
|
||||
user=self.user,
|
||||
token=generate_id(),
|
||||
|
||||
@@ -7,7 +7,7 @@ from binascii import Error
|
||||
from hashlib import sha256
|
||||
from hmac import compare_digest
|
||||
from typing import Any
|
||||
from urllib.parse import urlparse
|
||||
from urllib.parse import unquote, urlparse
|
||||
|
||||
from django.http import HttpRequest, HttpResponse, JsonResponse
|
||||
from django.http.response import HttpResponseRedirect
|
||||
@@ -122,6 +122,10 @@ def extract_client_auth(request: HttpRequest) -> tuple[str, str]:
|
||||
try:
|
||||
user_pass = b64decode(b64_user_pass).decode("utf-8").partition(":")
|
||||
client_id, _, client_secret = user_pass
|
||||
# RFC 6749 requires client credentials in Basic auth to be form-encoded first.
|
||||
# We only percent-decode here so raw `+` characters keep their previous meaning.
|
||||
client_id = unquote(client_id)
|
||||
client_secret = unquote(client_secret)
|
||||
except ValueError, Error:
|
||||
client_id = client_secret = "" # nosec
|
||||
else:
|
||||
|
||||
@@ -58,7 +58,6 @@ from authentik.providers.oauth2.models import (
|
||||
AuthorizationCode,
|
||||
GrantTypes,
|
||||
OAuth2Provider,
|
||||
RedirectURI,
|
||||
RedirectURIMatchingMode,
|
||||
ResponseMode,
|
||||
ResponseTypes,
|
||||
@@ -196,14 +195,6 @@ class OAuthAuthorizationParams:
|
||||
LOGGER.warning("Missing redirect uri.")
|
||||
raise RedirectUriError("", allowed_redirect_urls).with_cause("redirect_uri_missing")
|
||||
|
||||
if len(allowed_redirect_urls) < 1:
|
||||
LOGGER.info("Setting redirect for blank redirect_uris", redirect=self.redirect_uri)
|
||||
self.provider.redirect_uris = [
|
||||
RedirectURI(RedirectURIMatchingMode.STRICT, self.redirect_uri)
|
||||
]
|
||||
self.provider.save()
|
||||
allowed_redirect_urls = self.provider.redirect_uris
|
||||
|
||||
match_found = False
|
||||
for allowed in allowed_redirect_urls:
|
||||
if allowed.matching_mode == RedirectURIMatchingMode.STRICT:
|
||||
|
||||
@@ -15,7 +15,7 @@ from authentik.core.models import Application
|
||||
from authentik.lib.config import CONFIG
|
||||
from authentik.lib.utils.time import timedelta_from_string
|
||||
from authentik.providers.oauth2.errors import DeviceCodeError
|
||||
from authentik.providers.oauth2.models import DeviceToken, OAuth2Provider
|
||||
from authentik.providers.oauth2.models import DeviceToken, OAuth2Provider, ScopeMapping
|
||||
from authentik.providers.oauth2.utils import TokenResponse, extract_client_auth
|
||||
from authentik.providers.oauth2.views.device_init import QS_KEY_CODE
|
||||
|
||||
@@ -28,7 +28,7 @@ class DeviceView(View):
|
||||
|
||||
client_id: str
|
||||
provider: OAuth2Provider
|
||||
scopes: list[str] = []
|
||||
scopes: set[str] = []
|
||||
|
||||
def parse_request(self):
|
||||
"""Parse incoming request"""
|
||||
@@ -44,7 +44,21 @@ class DeviceView(View):
|
||||
raise DeviceCodeError("invalid_client") from None
|
||||
self.provider = provider
|
||||
self.client_id = client_id
|
||||
self.scopes = self.request.POST.get("scope", "").split(" ")
|
||||
|
||||
scopes_to_check = set(self.request.POST.get("scope", "").split())
|
||||
default_scope_names = set(
|
||||
ScopeMapping.objects.filter(provider__in=[self.provider]).values_list(
|
||||
"scope_name", flat=True
|
||||
)
|
||||
)
|
||||
self.scopes = scopes_to_check
|
||||
if not scopes_to_check.issubset(default_scope_names):
|
||||
LOGGER.info(
|
||||
"Application requested scopes not configured, setting to overlap",
|
||||
scope_allowed=default_scope_names,
|
||||
scope_given=self.scopes,
|
||||
)
|
||||
self.scopes = self.scopes.intersection(default_scope_names)
|
||||
|
||||
def dispatch(self, request: HttpRequest, *args, **kwargs) -> HttpResponse:
|
||||
throttle = AnonRateThrottle()
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from django.db.models import Q
|
||||
from django.http import HttpRequest, HttpResponse
|
||||
from django.utils.decorators import method_decorator
|
||||
from django.views import View
|
||||
@@ -10,7 +11,7 @@ from structlog.stdlib import get_logger
|
||||
|
||||
from authentik.providers.oauth2.errors import TokenIntrospectionError
|
||||
from authentik.providers.oauth2.id_token import IDToken
|
||||
from authentik.providers.oauth2.models import AccessToken, OAuth2Provider, RefreshToken
|
||||
from authentik.providers.oauth2.models import AccessToken, ClientTypes, OAuth2Provider, RefreshToken
|
||||
from authentik.providers.oauth2.utils import TokenResponse, authenticate_provider
|
||||
|
||||
LOGGER = get_logger()
|
||||
@@ -33,10 +34,7 @@ class TokenIntrospectionParams:
|
||||
self.id_token = self.token.id_token
|
||||
|
||||
if not self.token.id_token:
|
||||
LOGGER.debug(
|
||||
"token not an authentication token",
|
||||
token=self.token,
|
||||
)
|
||||
LOGGER.debug("token not an authentication token", token=self.token)
|
||||
raise TokenIntrospectionError()
|
||||
|
||||
@staticmethod
|
||||
@@ -45,14 +43,23 @@ class TokenIntrospectionParams:
|
||||
raw_token = request.POST.get("token")
|
||||
provider = authenticate_provider(request)
|
||||
if not provider:
|
||||
LOGGER.info("Failed to authenticate introspection request")
|
||||
raise TokenIntrospectionError
|
||||
if provider.client_type != ClientTypes.CONFIDENTIAL:
|
||||
LOGGER.info("Introspection request from public provider, denying.")
|
||||
raise TokenIntrospectionError
|
||||
|
||||
access_token = AccessToken.objects.filter(token=raw_token, provider=provider).first()
|
||||
query = Q(
|
||||
Q(provider=provider) | Q(provider__jwt_federation_providers__in=[provider]),
|
||||
token=raw_token,
|
||||
)
|
||||
|
||||
access_token = AccessToken.objects.filter(query).first()
|
||||
if access_token:
|
||||
return TokenIntrospectionParams(access_token, provider)
|
||||
refresh_token = RefreshToken.objects.filter(token=raw_token, provider=provider).first()
|
||||
return TokenIntrospectionParams(access_token, access_token.provider)
|
||||
refresh_token = RefreshToken.objects.filter(query).first()
|
||||
if refresh_token:
|
||||
return TokenIntrospectionParams(refresh_token, provider)
|
||||
return TokenIntrospectionParams(refresh_token, refresh_token.provider)
|
||||
LOGGER.debug("Token does not exist", token=raw_token)
|
||||
raise TokenIntrospectionError()
|
||||
|
||||
|
||||
@@ -704,7 +704,7 @@ class TokenView(View):
|
||||
refresh_token_threshold = timedelta_from_string(self.provider.refresh_token_threshold)
|
||||
if (
|
||||
refresh_token_threshold.total_seconds() == 0
|
||||
or (now - self.params.refresh_token.expires) > refresh_token_threshold
|
||||
or (self.params.refresh_token.expires - now) < refresh_token_threshold
|
||||
):
|
||||
refresh_token_expiry = now + timedelta_from_string(self.provider.refresh_token_validity)
|
||||
refresh_token = RefreshToken(
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
from django.db.models import Q
|
||||
from django.http import Http404, HttpRequest, HttpResponse
|
||||
from django.utils.decorators import method_decorator
|
||||
from django.views import View
|
||||
@@ -32,15 +33,25 @@ class TokenRevocationParams:
|
||||
raw_token = request.POST.get("token")
|
||||
|
||||
provider, _, _ = provider_from_request(request)
|
||||
if provider and provider.client_type == ClientTypes.CONFIDENTIAL:
|
||||
provider = authenticate_provider(request)
|
||||
if not provider:
|
||||
raise TokenRevocationError("invalid_client")
|
||||
# By default clients can only revoke their own tokens
|
||||
query = Q(provider=provider, token=raw_token)
|
||||
if provider.client_type == ClientTypes.CONFIDENTIAL:
|
||||
provider = authenticate_provider(request)
|
||||
if not provider:
|
||||
raise TokenRevocationError("invalid_client")
|
||||
# If the request is authenticated by a confidential provider, it can also
|
||||
# revoke federated tokens
|
||||
query = Q(
|
||||
Q(provider=provider) | Q(provider__jwt_federation_providers__in=[provider]),
|
||||
token=raw_token,
|
||||
)
|
||||
|
||||
access_token = AccessToken.objects.filter(token=raw_token).first()
|
||||
access_token = AccessToken.objects.filter(query).first()
|
||||
if access_token:
|
||||
return TokenRevocationParams(access_token, provider)
|
||||
refresh_token = RefreshToken.objects.filter(token=raw_token).first()
|
||||
refresh_token = RefreshToken.objects.filter(query).first()
|
||||
if refresh_token:
|
||||
return TokenRevocationParams(refresh_token, provider)
|
||||
LOGGER.debug("Token does not exist", token=raw_token)
|
||||
|
||||
@@ -27,6 +27,8 @@ class TraefikMiddlewareSpecForwardAuth:
|
||||
|
||||
trustForwardHeader: bool = field(default=True)
|
||||
|
||||
maxResponseBodySize: int = field(default=1024 * 1024 * 4)
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class TraefikMiddlewareSpec:
|
||||
@@ -140,6 +142,7 @@ class Traefik3MiddlewareReconciler(KubernetesObjectReconciler[TraefikMiddleware]
|
||||
],
|
||||
authResponseHeadersRegex="",
|
||||
trustForwardHeader=True,
|
||||
maxResponseBodySize=1024 * 1024 * 4,
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
@@ -1,13 +0,0 @@
|
||||
"""Proxy provider signals"""
|
||||
|
||||
from django.db.models.signals import pre_delete
|
||||
from django.dispatch import receiver
|
||||
|
||||
from authentik.core.models import AuthenticatedSession
|
||||
from authentik.providers.proxy.tasks import proxy_on_logout
|
||||
|
||||
|
||||
@receiver(pre_delete, sender=AuthenticatedSession)
|
||||
def logout_proxy_revoke(sender: type[AuthenticatedSession], instance: AuthenticatedSession, **_):
|
||||
"""Catch logout by expiring sessions being deleted"""
|
||||
proxy_on_logout.send(instance.session.session_key)
|
||||
@@ -1,26 +0,0 @@
|
||||
"""proxy provider tasks"""
|
||||
|
||||
from asgiref.sync import async_to_sync
|
||||
from channels.layers import get_channel_layer
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
from dramatiq.actor import actor
|
||||
|
||||
from authentik.outposts.consumer import build_outpost_group
|
||||
from authentik.outposts.models import Outpost, OutpostType
|
||||
from authentik.providers.oauth2.id_token import hash_session_key
|
||||
|
||||
|
||||
@actor(description=_("Terminate session on Proxy outpost."))
|
||||
def proxy_on_logout(session_id: str):
|
||||
layer = get_channel_layer()
|
||||
hashed_session_id = hash_session_key(session_id)
|
||||
for outpost in Outpost.objects.filter(type=OutpostType.PROXY):
|
||||
group = build_outpost_group(outpost.pk)
|
||||
async_to_sync(layer.group_send)(
|
||||
group,
|
||||
{
|
||||
"type": "event.provider.specific",
|
||||
"sub_type": "logout",
|
||||
"session_id": hashed_session_id,
|
||||
},
|
||||
)
|
||||
@@ -231,8 +231,8 @@ class SAMLProviderSerializer(ProviderSerializer):
|
||||
class SAMLMetadataSerializer(PassiveSerializer):
|
||||
"""SAML Provider Metadata serializer"""
|
||||
|
||||
metadata = CharField(read_only=True)
|
||||
download_url = CharField(read_only=True, required=False)
|
||||
metadata = CharField()
|
||||
download_url = CharField(required=False, allow_null=True)
|
||||
|
||||
|
||||
class SAMLProviderImportSerializer(PassiveSerializer):
|
||||
@@ -314,7 +314,7 @@ class SAMLProviderViewSet(UsedByMixin, ModelViewSet):
|
||||
return response
|
||||
return Response({"metadata": metadata}, content_type="application/json")
|
||||
except Provider.application.RelatedObjectDoesNotExist:
|
||||
return Response({"metadata": ""}, content_type="application/json")
|
||||
raise Http404 from None
|
||||
|
||||
@permission_required(
|
||||
None,
|
||||
|
||||
@@ -157,7 +157,7 @@ class TestSAMLProviderAPI(APITestCase):
|
||||
response = self.client.get(
|
||||
reverse("authentik_api:samlprovider-metadata", kwargs={"pk": provider.pk}),
|
||||
)
|
||||
self.assertEqual(200, response.status_code)
|
||||
self.assertEqual(404, response.status_code)
|
||||
response = self.client.get(
|
||||
reverse("authentik_api:samlprovider-metadata", kwargs={"pk": "abc"}),
|
||||
)
|
||||
|
||||
@@ -8,11 +8,11 @@ from django.urls import reverse
|
||||
|
||||
from authentik.common.saml.constants import SAML_NAME_ID_FORMAT_EMAIL
|
||||
from authentik.core.models import Application
|
||||
from authentik.core.tests.utils import create_test_brand, create_test_flow
|
||||
from authentik.core.tests.utils import create_test_brand, create_test_cert, create_test_flow
|
||||
from authentik.flows.planner import FlowPlan
|
||||
from authentik.flows.views.executor import SESSION_KEY_PLAN
|
||||
from authentik.providers.saml.exceptions import CannotHandleAssertion
|
||||
from authentik.providers.saml.models import SAMLProvider
|
||||
from authentik.providers.saml.models import SAMLBindings, SAMLLogoutMethods, SAMLProvider
|
||||
from authentik.providers.saml.processors.logout_request import LogoutRequestProcessor
|
||||
from authentik.providers.saml.views.flows import PLAN_CONTEXT_SAML_RELAY_STATE
|
||||
from authentik.providers.saml.views.sp_slo import (
|
||||
@@ -91,32 +91,33 @@ class TestSPInitiatedSLOViews(TestCase):
|
||||
self.assertEqual(logout_request.issuer, self.provider.issuer)
|
||||
self.assertEqual(logout_request.session_index, "test-session-123")
|
||||
|
||||
def test_redirect_view_handles_logout_response_with_relay_state(self):
|
||||
"""Test that redirect view handles logout response with RelayState"""
|
||||
# Use raw URL (no encoding needed)
|
||||
relay_state = "https://idp.example.com/flow/return"
|
||||
def test_redirect_view_handles_logout_response_with_plan_context(self):
|
||||
"""Test that redirect view always redirects to plan context URL, ignoring RelayState"""
|
||||
plan_relay_state = "https://idp.example.com/flow/return"
|
||||
|
||||
# Create request with SAML logout response
|
||||
request = self.factory.get(
|
||||
f"/slo/redirect/{self.application.slug}/",
|
||||
{
|
||||
"SAMLResponse": "dummy-response",
|
||||
"RelayState": relay_state,
|
||||
"RelayState": "https://somewhere-else.example.com/return",
|
||||
},
|
||||
)
|
||||
request.session = {}
|
||||
plan = FlowPlan(flow_pk="test-flow")
|
||||
plan.context[PLAN_CONTEXT_SAML_RELAY_STATE] = plan_relay_state
|
||||
request.session = {SESSION_KEY_PLAN: plan}
|
||||
request.brand = self.brand
|
||||
|
||||
view = SPInitiatedSLOBindingRedirectView()
|
||||
view.setup(request, application_slug=self.application.slug)
|
||||
response = view.dispatch(request, application_slug=self.application.slug)
|
||||
|
||||
# Should redirect to relay state URL
|
||||
# Should redirect to plan context URL, not the request's RelayState
|
||||
self.assertEqual(response.status_code, 302)
|
||||
self.assertEqual(response.url, relay_state)
|
||||
self.assertEqual(response.url, plan_relay_state)
|
||||
|
||||
def test_redirect_view_handles_logout_response_plain_relay_state(self):
|
||||
"""Test that redirect view handles logout response with plain RelayState"""
|
||||
def test_redirect_view_ignores_relay_state_without_plan(self):
|
||||
"""Test that redirect view ignores RelayState and falls back to root when no plan context"""
|
||||
relay_state = "https://sp.example.com/plain"
|
||||
|
||||
# Create request with SAML logout response
|
||||
@@ -134,9 +135,9 @@ class TestSPInitiatedSLOViews(TestCase):
|
||||
view.setup(request, application_slug=self.application.slug)
|
||||
response = view.dispatch(request, application_slug=self.application.slug)
|
||||
|
||||
# Should redirect to plain relay state
|
||||
# Should ignore relay_state and redirect to root (no plan context)
|
||||
self.assertEqual(response.status_code, 302)
|
||||
self.assertEqual(response.url, relay_state)
|
||||
self.assertEqual(response.url, reverse("authentik_core:root-redirect"))
|
||||
|
||||
def test_redirect_view_handles_logout_response_no_relay_state_with_plan_context(self):
|
||||
"""Test that redirect view uses plan context fallback when no RelayState"""
|
||||
@@ -228,29 +229,30 @@ class TestSPInitiatedSLOViews(TestCase):
|
||||
self.assertEqual(logout_request.issuer, self.provider.issuer)
|
||||
self.assertEqual(logout_request.session_index, "test-session-123")
|
||||
|
||||
def test_post_view_handles_logout_response_with_relay_state(self):
|
||||
"""Test that POST view handles logout response with RelayState"""
|
||||
# Use raw URL (no encoding needed)
|
||||
relay_state = "https://idp.example.com/flow/return"
|
||||
def test_post_view_handles_logout_response_with_plan_context(self):
|
||||
"""Test that POST view always redirects to plan context URL, ignoring RelayState"""
|
||||
plan_relay_state = "https://idp.example.com/flow/return"
|
||||
|
||||
# Create POST request with SAML logout response
|
||||
request = self.factory.post(
|
||||
f"/slo/post/{self.application.slug}/",
|
||||
{
|
||||
"SAMLResponse": "dummy-response",
|
||||
"RelayState": relay_state,
|
||||
"RelayState": "https://somewhere-else.example.com/return",
|
||||
},
|
||||
)
|
||||
request.session = {}
|
||||
plan = FlowPlan(flow_pk="test-flow")
|
||||
plan.context[PLAN_CONTEXT_SAML_RELAY_STATE] = plan_relay_state
|
||||
request.session = {SESSION_KEY_PLAN: plan}
|
||||
request.brand = self.brand
|
||||
|
||||
view = SPInitiatedSLOBindingPOSTView()
|
||||
view.setup(request, application_slug=self.application.slug)
|
||||
response = view.dispatch(request, application_slug=self.application.slug)
|
||||
|
||||
# Should redirect to relay state URL
|
||||
# Should redirect to plan context URL, not the request's RelayState
|
||||
self.assertEqual(response.status_code, 302)
|
||||
self.assertEqual(response.url, relay_state)
|
||||
self.assertEqual(response.url, plan_relay_state)
|
||||
|
||||
def test_post_view_handles_logout_response_no_relay_state_with_plan_context(self):
|
||||
"""Test that POST view uses plan context fallback when no RelayState"""
|
||||
@@ -417,7 +419,7 @@ class TestSPInitiatedSLOViews(TestCase):
|
||||
view.resolve_provider_application()
|
||||
|
||||
def test_relay_state_decoding_failure(self):
|
||||
"""Test handling of RelayState that's a path"""
|
||||
"""Test that arbitrary path RelayState is ignored and redirects to root"""
|
||||
# Create request with relay state that is a path
|
||||
request = self.factory.get(
|
||||
f"/slo/redirect/{self.application.slug}/",
|
||||
@@ -433,6 +435,357 @@ class TestSPInitiatedSLOViews(TestCase):
|
||||
view.setup(request, application_slug=self.application.slug)
|
||||
response = view.dispatch(request, application_slug=self.application.slug)
|
||||
|
||||
# Should treat it as plain URL and redirect to it
|
||||
# Should ignore relay_state and redirect to root (no plan context)
|
||||
self.assertEqual(response.status_code, 302)
|
||||
self.assertEqual(response.url, "/some/invalid/path")
|
||||
self.assertEqual(response.url, reverse("authentik_core:root-redirect"))
|
||||
|
||||
def test_redirect_view_blocks_external_relay_state(self):
|
||||
"""Test that redirect view ignores external malicious URL and redirects to root"""
|
||||
request = self.factory.get(
|
||||
f"/slo/redirect/{self.application.slug}/",
|
||||
{
|
||||
"SAMLResponse": "dummy-response",
|
||||
"RelayState": "https://evil.com/phishing",
|
||||
},
|
||||
)
|
||||
request.session = {}
|
||||
request.brand = self.brand
|
||||
|
||||
view = SPInitiatedSLOBindingRedirectView()
|
||||
view.setup(request, application_slug=self.application.slug)
|
||||
response = view.dispatch(request, application_slug=self.application.slug)
|
||||
|
||||
# Should ignore relay_state and redirect to root (no plan context)
|
||||
self.assertEqual(response.status_code, 302)
|
||||
self.assertEqual(response.url, reverse("authentik_core:root-redirect"))
|
||||
|
||||
def test_redirect_view_ignores_relay_state_uses_plan_context(self):
|
||||
"""Test that redirect view always uses plan context URL regardless of RelayState"""
|
||||
plan_relay_state = "https://authentik.example.com/if/flow/logout/"
|
||||
|
||||
request = self.factory.get(
|
||||
f"/slo/redirect/{self.application.slug}/",
|
||||
{
|
||||
"SAMLResponse": "dummy-response",
|
||||
"RelayState": "https://evil.com/phishing",
|
||||
},
|
||||
)
|
||||
plan = FlowPlan(flow_pk="test-flow")
|
||||
plan.context[PLAN_CONTEXT_SAML_RELAY_STATE] = plan_relay_state
|
||||
request.session = {SESSION_KEY_PLAN: plan}
|
||||
request.brand = self.brand
|
||||
|
||||
view = SPInitiatedSLOBindingRedirectView()
|
||||
view.setup(request, application_slug=self.application.slug)
|
||||
response = view.dispatch(request, application_slug=self.application.slug)
|
||||
|
||||
# Should always use plan context value, ignoring malicious RelayState
|
||||
self.assertEqual(response.status_code, 302)
|
||||
self.assertEqual(response.url, plan_relay_state)
|
||||
|
||||
def test_post_view_ignores_external_relay_state(self):
|
||||
"""Test that POST view ignores external RelayState and redirects to root"""
|
||||
request = self.factory.post(
|
||||
f"/slo/post/{self.application.slug}/",
|
||||
{
|
||||
"SAMLResponse": "dummy-response",
|
||||
"RelayState": "https://evil.com/phishing",
|
||||
},
|
||||
)
|
||||
request.session = {}
|
||||
request.brand = self.brand
|
||||
|
||||
view = SPInitiatedSLOBindingPOSTView()
|
||||
view.setup(request, application_slug=self.application.slug)
|
||||
response = view.dispatch(request, application_slug=self.application.slug)
|
||||
|
||||
# Should ignore relay_state and redirect to root (no plan context)
|
||||
self.assertEqual(response.status_code, 302)
|
||||
self.assertEqual(response.url, reverse("authentik_core:root-redirect"))
|
||||
|
||||
|
||||
class TestSPInitiatedSLOLogoutMethods(TestCase):
|
||||
"""Test SP-initiated SAML SLO logout method branching"""
|
||||
|
||||
def setUp(self):
|
||||
"""Set up test fixtures"""
|
||||
self.factory = RequestFactory()
|
||||
self.brand = create_test_brand()
|
||||
self.flow = create_test_flow()
|
||||
self.invalidation_flow = create_test_flow()
|
||||
self.cert = create_test_cert()
|
||||
|
||||
# Create provider with sls_url
|
||||
self.provider = SAMLProvider.objects.create(
|
||||
name="test-provider",
|
||||
authorization_flow=self.flow,
|
||||
invalidation_flow=self.invalidation_flow,
|
||||
acs_url="https://sp.example.com/acs",
|
||||
sls_url="https://sp.example.com/sls",
|
||||
issuer="https://idp.example.com",
|
||||
sp_binding="redirect",
|
||||
sls_binding="redirect",
|
||||
signing_kp=self.cert,
|
||||
)
|
||||
|
||||
# Create application
|
||||
self.application = Application.objects.create(
|
||||
name="test-app",
|
||||
slug="test-app-logout-methods",
|
||||
provider=self.provider,
|
||||
)
|
||||
|
||||
# Create logout request processor for generating test requests
|
||||
self.processor = LogoutRequestProcessor(
|
||||
provider=self.provider,
|
||||
user=None,
|
||||
destination="https://idp.example.com/sls",
|
||||
name_id="test@example.com",
|
||||
name_id_format=SAML_NAME_ID_FORMAT_EMAIL,
|
||||
session_index="test-session-123",
|
||||
relay_state="https://sp.example.com/return",
|
||||
)
|
||||
|
||||
@patch("authentik.providers.saml.views.sp_slo.AuthenticatedSession")
|
||||
def test_frontchannel_native_post_binding(self, mock_auth_session):
|
||||
"""Test FRONTCHANNEL_NATIVE with POST binding parses request correctly"""
|
||||
mock_auth_session.from_request.return_value = None
|
||||
|
||||
self.provider.logout_method = SAMLLogoutMethods.FRONTCHANNEL_NATIVE
|
||||
self.provider.sls_binding = SAMLBindings.POST
|
||||
self.provider.save()
|
||||
|
||||
encoded_request = self.processor.encode_redirect()
|
||||
|
||||
request = self.factory.get(
|
||||
f"/slo/redirect/{self.application.slug}/",
|
||||
{
|
||||
"SAMLRequest": encoded_request,
|
||||
"RelayState": "https://sp.example.com/return",
|
||||
},
|
||||
)
|
||||
request.session = {}
|
||||
request.brand = self.brand
|
||||
request.user = MagicMock()
|
||||
|
||||
view = SPInitiatedSLOBindingRedirectView()
|
||||
view.setup(request, application_slug=self.application.slug)
|
||||
view.resolve_provider_application()
|
||||
view.check_saml_request()
|
||||
|
||||
# Verify the logout request was parsed and provider is configured correctly
|
||||
self.assertIn("authentik/providers/saml/logout_request", view.plan_context)
|
||||
self.assertEqual(view.provider.logout_method, SAMLLogoutMethods.FRONTCHANNEL_NATIVE)
|
||||
self.assertEqual(view.provider.sls_binding, SAMLBindings.POST)
|
||||
|
||||
@patch("authentik.providers.saml.views.sp_slo.AuthenticatedSession")
|
||||
def test_frontchannel_native_redirect_binding(self, mock_auth_session):
|
||||
"""Test FRONTCHANNEL_NATIVE with REDIRECT binding creates redirect URL"""
|
||||
mock_auth_session.from_request.return_value = None
|
||||
|
||||
self.provider.logout_method = SAMLLogoutMethods.FRONTCHANNEL_NATIVE
|
||||
self.provider.sls_binding = SAMLBindings.REDIRECT
|
||||
self.provider.save()
|
||||
|
||||
encoded_request = self.processor.encode_redirect()
|
||||
|
||||
request = self.factory.get(
|
||||
f"/slo/redirect/{self.application.slug}/",
|
||||
{
|
||||
"SAMLRequest": encoded_request,
|
||||
"RelayState": "https://sp.example.com/return",
|
||||
},
|
||||
)
|
||||
request.session = {}
|
||||
request.brand = self.brand
|
||||
request.user = MagicMock()
|
||||
|
||||
view = SPInitiatedSLOBindingRedirectView()
|
||||
view.setup(request, application_slug=self.application.slug)
|
||||
view.resolve_provider_application()
|
||||
view.check_saml_request()
|
||||
|
||||
# Verify the logout request was parsed
|
||||
self.assertIn("authentik/providers/saml/logout_request", view.plan_context)
|
||||
|
||||
@patch("authentik.providers.saml.views.sp_slo.AuthenticatedSession")
|
||||
def test_frontchannel_iframe_post_binding(self, mock_auth_session):
|
||||
"""Test FRONTCHANNEL_IFRAME with POST binding creates IframeLogoutStageView"""
|
||||
mock_auth_session.from_request.return_value = None
|
||||
|
||||
self.provider.logout_method = SAMLLogoutMethods.FRONTCHANNEL_IFRAME
|
||||
self.provider.sls_binding = SAMLBindings.POST
|
||||
self.provider.save()
|
||||
|
||||
encoded_request = self.processor.encode_redirect()
|
||||
|
||||
request = self.factory.get(
|
||||
f"/slo/redirect/{self.application.slug}/",
|
||||
{
|
||||
"SAMLRequest": encoded_request,
|
||||
"RelayState": "https://sp.example.com/return",
|
||||
},
|
||||
)
|
||||
request.session = {}
|
||||
request.brand = self.brand
|
||||
request.user = MagicMock()
|
||||
|
||||
view = SPInitiatedSLOBindingRedirectView()
|
||||
view.setup(request, application_slug=self.application.slug)
|
||||
view.resolve_provider_application()
|
||||
view.check_saml_request()
|
||||
|
||||
# Verify the logout request was parsed
|
||||
self.assertIn("authentik/providers/saml/logout_request", view.plan_context)
|
||||
|
||||
@patch("authentik.providers.saml.views.sp_slo.AuthenticatedSession")
|
||||
def test_frontchannel_iframe_redirect_binding(self, mock_auth_session):
|
||||
"""Test FRONTCHANNEL_IFRAME with REDIRECT binding"""
|
||||
mock_auth_session.from_request.return_value = None
|
||||
|
||||
self.provider.logout_method = SAMLLogoutMethods.FRONTCHANNEL_IFRAME
|
||||
self.provider.sls_binding = SAMLBindings.REDIRECT
|
||||
self.provider.save()
|
||||
|
||||
encoded_request = self.processor.encode_redirect()
|
||||
|
||||
request = self.factory.get(
|
||||
f"/slo/redirect/{self.application.slug}/",
|
||||
{
|
||||
"SAMLRequest": encoded_request,
|
||||
"RelayState": "https://sp.example.com/return",
|
||||
},
|
||||
)
|
||||
request.session = {}
|
||||
request.brand = self.brand
|
||||
request.user = MagicMock()
|
||||
|
||||
view = SPInitiatedSLOBindingRedirectView()
|
||||
view.setup(request, application_slug=self.application.slug)
|
||||
view.resolve_provider_application()
|
||||
view.check_saml_request()
|
||||
|
||||
# Verify the logout request was parsed
|
||||
self.assertIn("authentik/providers/saml/logout_request", view.plan_context)
|
||||
|
||||
@patch("authentik.providers.saml.views.sp_slo.AuthenticatedSession")
|
||||
def test_backchannel_parses_request(self, mock_auth_session):
|
||||
"""Test BACKCHANNEL mode parses request correctly"""
|
||||
mock_auth_session.from_request.return_value = None
|
||||
|
||||
self.provider.logout_method = SAMLLogoutMethods.BACKCHANNEL
|
||||
self.provider.sls_binding = SAMLBindings.POST
|
||||
self.provider.save()
|
||||
|
||||
encoded_request = self.processor.encode_redirect()
|
||||
|
||||
request = self.factory.get(
|
||||
f"/slo/redirect/{self.application.slug}/",
|
||||
{
|
||||
"SAMLRequest": encoded_request,
|
||||
"RelayState": "https://sp.example.com/return",
|
||||
},
|
||||
)
|
||||
request.session = {}
|
||||
request.brand = self.brand
|
||||
request.user = MagicMock()
|
||||
|
||||
view = SPInitiatedSLOBindingRedirectView()
|
||||
view.setup(request, application_slug=self.application.slug)
|
||||
view.resolve_provider_application()
|
||||
view.check_saml_request()
|
||||
|
||||
# Verify the logout request was parsed and provider is configured correctly
|
||||
self.assertIn("authentik/providers/saml/logout_request", view.plan_context)
|
||||
self.assertEqual(view.provider.logout_method, SAMLLogoutMethods.BACKCHANNEL)
|
||||
self.assertEqual(view.provider.sls_binding, SAMLBindings.POST)
|
||||
|
||||
@patch("authentik.providers.saml.views.sp_slo.AuthenticatedSession")
|
||||
def test_no_sls_url_only_session_end(self, mock_auth_session):
|
||||
"""Test that only SessionEndStage is appended when sls_url is empty"""
|
||||
mock_auth_session.from_request.return_value = None
|
||||
|
||||
# Create provider without sls_url
|
||||
provider_no_sls = SAMLProvider.objects.create(
|
||||
name="no-sls-provider",
|
||||
authorization_flow=self.flow,
|
||||
invalidation_flow=self.invalidation_flow,
|
||||
acs_url="https://sp.example.com/acs",
|
||||
sls_url="", # No SLS URL
|
||||
issuer="https://idp.example.com",
|
||||
)
|
||||
|
||||
app_no_sls = Application.objects.create(
|
||||
name="no-sls-app",
|
||||
slug="no-sls-app",
|
||||
provider=provider_no_sls,
|
||||
)
|
||||
|
||||
processor = LogoutRequestProcessor(
|
||||
provider=provider_no_sls,
|
||||
user=None,
|
||||
destination="https://idp.example.com/sls",
|
||||
name_id="test@example.com",
|
||||
name_id_format=SAML_NAME_ID_FORMAT_EMAIL,
|
||||
session_index="test-session-123",
|
||||
)
|
||||
encoded_request = processor.encode_redirect()
|
||||
|
||||
request = self.factory.get(
|
||||
f"/slo/redirect/{app_no_sls.slug}/",
|
||||
{
|
||||
"SAMLRequest": encoded_request,
|
||||
},
|
||||
)
|
||||
request.session = {}
|
||||
request.brand = self.brand
|
||||
request.user = MagicMock()
|
||||
|
||||
view = SPInitiatedSLOBindingRedirectView()
|
||||
view.setup(request, application_slug=app_no_sls.slug)
|
||||
view.resolve_provider_application()
|
||||
view.check_saml_request()
|
||||
|
||||
# Verify the provider has no sls_url
|
||||
self.assertEqual(view.provider.sls_url, "")
|
||||
|
||||
@patch("authentik.providers.saml.views.sp_slo.AuthenticatedSession")
|
||||
def test_relay_state_propagation(self, mock_auth_session):
|
||||
"""Test that relay state from logout request is passed through to response"""
|
||||
mock_auth_session.from_request.return_value = None
|
||||
|
||||
self.provider.logout_method = SAMLLogoutMethods.FRONTCHANNEL_IFRAME
|
||||
self.provider.save()
|
||||
|
||||
expected_relay_state = "https://sp.example.com/custom-return"
|
||||
|
||||
processor = LogoutRequestProcessor(
|
||||
provider=self.provider,
|
||||
user=None,
|
||||
destination="https://idp.example.com/sls",
|
||||
name_id="test@example.com",
|
||||
name_id_format=SAML_NAME_ID_FORMAT_EMAIL,
|
||||
session_index="test-session-123",
|
||||
relay_state=expected_relay_state,
|
||||
)
|
||||
encoded_request = processor.encode_redirect()
|
||||
|
||||
request = self.factory.get(
|
||||
f"/slo/redirect/{self.application.slug}/",
|
||||
{
|
||||
"SAMLRequest": encoded_request,
|
||||
"RelayState": expected_relay_state,
|
||||
},
|
||||
)
|
||||
request.session = {}
|
||||
request.brand = self.brand
|
||||
request.user = MagicMock()
|
||||
|
||||
view = SPInitiatedSLOBindingRedirectView()
|
||||
view.setup(request, application_slug=self.application.slug)
|
||||
view.resolve_provider_application()
|
||||
view.check_saml_request()
|
||||
|
||||
# Verify relay state was captured
|
||||
logout_request = view.plan_context.get("authentik/providers/saml/logout_request")
|
||||
self.assertEqual(logout_request.relay_state, expected_relay_state)
|
||||
|
||||
@@ -29,6 +29,24 @@ from authentik.providers.saml.views.flows import (
|
||||
LOGGER = get_logger()
|
||||
|
||||
|
||||
def _get_redirect_url(request: HttpRequest, relay_state: str = "") -> str:
|
||||
"""Get the safe redirect URL from the plan context, logging a warning if the
|
||||
incoming relay_state doesn't match the stored value."""
|
||||
stored_relay_state = ""
|
||||
if SESSION_KEY_PLAN in request.session:
|
||||
plan: FlowPlan = request.session[SESSION_KEY_PLAN]
|
||||
stored_relay_state = plan.context.get(PLAN_CONTEXT_SAML_RELAY_STATE, "")
|
||||
|
||||
if relay_state and relay_state != stored_relay_state:
|
||||
LOGGER.warning(
|
||||
"SAML logout relay_state mismatch, possible open redirect attempt",
|
||||
received_relay_state=relay_state,
|
||||
stored_relay_state=stored_relay_state,
|
||||
)
|
||||
|
||||
return stored_relay_state
|
||||
|
||||
|
||||
class SPInitiatedSLOView(PolicyAccessView):
|
||||
"""Handle SP-initiated SAML Single Logout requests"""
|
||||
|
||||
@@ -96,17 +114,9 @@ class SPInitiatedSLOBindingRedirectView(SPInitiatedSLOView):
|
||||
# IDP SLO, so we want to redirect to our next provider
|
||||
if REQUEST_KEY_SAML_RESPONSE in request.GET:
|
||||
relay_state = request.GET.get(REQUEST_KEY_RELAY_STATE, "")
|
||||
if relay_state:
|
||||
return redirect(relay_state)
|
||||
|
||||
# No RelayState provided, try to get return URL from plan context
|
||||
if SESSION_KEY_PLAN in request.session:
|
||||
plan: FlowPlan = request.session[SESSION_KEY_PLAN]
|
||||
relay_state = plan.context.get(PLAN_CONTEXT_SAML_RELAY_STATE)
|
||||
if relay_state:
|
||||
return redirect(relay_state)
|
||||
|
||||
# No relay state and no plan context - redirect to root
|
||||
redirect_url = _get_redirect_url(request, relay_state)
|
||||
if redirect_url:
|
||||
return redirect(redirect_url)
|
||||
return redirect("authentik_core:root-redirect")
|
||||
|
||||
# For SAML logout requests, use the parent dispatch with auth checks
|
||||
@@ -147,17 +157,9 @@ class SPInitiatedSLOBindingPOSTView(SPInitiatedSLOView):
|
||||
# IDP SLO, so we want to redirect to our next provider
|
||||
if REQUEST_KEY_SAML_RESPONSE in request.POST:
|
||||
relay_state = request.POST.get(REQUEST_KEY_RELAY_STATE, "")
|
||||
if relay_state:
|
||||
return redirect(relay_state)
|
||||
|
||||
# No RelayState provided, try to get return URL from plan context
|
||||
if SESSION_KEY_PLAN in request.session:
|
||||
plan: FlowPlan = request.session[SESSION_KEY_PLAN]
|
||||
relay_state = plan.context.get(PLAN_CONTEXT_SAML_RELAY_STATE)
|
||||
if relay_state:
|
||||
return redirect(relay_state)
|
||||
|
||||
# No relay state and no plan context - redirect to root
|
||||
redirect_url = _get_redirect_url(request, relay_state)
|
||||
if redirect_url:
|
||||
return redirect(redirect_url)
|
||||
return redirect("authentik_core:root-redirect")
|
||||
|
||||
# For SAML logout requests, use the parent dispatch with auth checks
|
||||
|
||||
@@ -10,6 +10,7 @@ from authentik.blueprints.tests import apply_blueprint
|
||||
from authentik.core.models import Application, Group, User
|
||||
from authentik.lib.generators import generate_id
|
||||
from authentik.providers.scim.models import SCIMMapping, SCIMProvider, SCIMProviderGroup
|
||||
from authentik.providers.scim.tasks import scim_sync
|
||||
|
||||
|
||||
class SCIMGroupTests(TestCase):
|
||||
@@ -205,3 +206,80 @@ class SCIMGroupTests(TestCase):
|
||||
self.assertEqual(mock.request_history[1].method, "POST")
|
||||
self.assertEqual(mock.request_history[2].method, "GET")
|
||||
self.assertNotIn("PUT", [req.method for req in mock.request_history])
|
||||
|
||||
def _create_stale_provider_group(self, scim_id: str) -> Group:
|
||||
"""Create a group that is outside the provider's scope (via group_filters) with an
|
||||
existing SCIMProviderGroup, simulating a previously synced group now out of scope."""
|
||||
self.app.backchannel_providers.remove(self.provider)
|
||||
anchor = Group.objects.create(name=generate_id())
|
||||
stale = Group.objects.create(name=generate_id())
|
||||
self.app.backchannel_providers.add(self.provider)
|
||||
|
||||
self.provider.group_filters.set([anchor])
|
||||
SCIMProviderGroup.objects.create(provider=self.provider, group=stale, scim_id=scim_id)
|
||||
return stale
|
||||
|
||||
@Mocker()
|
||||
def test_sync_cleanup_stale_group_delete(self, mock: Mocker):
|
||||
"""Stale (out-of-scope) groups are deleted during full sync cleanup"""
|
||||
scim_id = generate_id()
|
||||
mock.get("https://localhost/ServiceProviderConfig", json={})
|
||||
|
||||
mock.post("https://localhost/Groups", json={"id": generate_id()})
|
||||
mock.delete(f"https://localhost/Groups/{scim_id}", status_code=204)
|
||||
self._create_stale_provider_group(scim_id)
|
||||
|
||||
scim_sync.send(self.provider.pk).get_result()
|
||||
|
||||
delete_reqs = [r for r in mock.request_history if r.method == "DELETE"]
|
||||
self.assertEqual(len(delete_reqs), 1)
|
||||
self.assertEqual(delete_reqs[0].url, f"https://localhost/Groups/{scim_id}")
|
||||
self.assertFalse(
|
||||
SCIMProviderGroup.objects.filter(provider=self.provider, scim_id=scim_id).exists()
|
||||
)
|
||||
|
||||
@Mocker()
|
||||
def test_sync_cleanup_stale_group_not_found(self, mock: Mocker):
|
||||
"""Stale group cleanup handles 404 from the remote gracefully"""
|
||||
scim_id = generate_id()
|
||||
mock.get("https://localhost/ServiceProviderConfig", json={})
|
||||
mock.post("https://localhost/Groups", json={"id": generate_id()})
|
||||
mock.delete(f"https://localhost/Groups/{scim_id}", status_code=404)
|
||||
self._create_stale_provider_group(scim_id)
|
||||
|
||||
scim_sync.send(self.provider.pk).get_result()
|
||||
|
||||
delete_reqs = [r for r in mock.request_history if r.method == "DELETE"]
|
||||
self.assertEqual(len(delete_reqs), 1)
|
||||
|
||||
self.assertFalse(
|
||||
SCIMProviderGroup.objects.filter(provider=self.provider, scim_id=scim_id).exists()
|
||||
)
|
||||
|
||||
@Mocker()
|
||||
def test_sync_cleanup_stale_group_transient_error(self, mock: Mocker):
|
||||
"""Stale group cleanup logs and retries on transient HTTP errors"""
|
||||
scim_id = generate_id()
|
||||
mock.get("https://localhost/ServiceProviderConfig", json={})
|
||||
mock.post("https://localhost/Groups", json={"id": generate_id()})
|
||||
mock.delete(f"https://localhost/Groups/{scim_id}", status_code=429)
|
||||
self._create_stale_provider_group(scim_id)
|
||||
|
||||
scim_sync.send(self.provider.pk)
|
||||
|
||||
delete_reqs = [r for r in mock.request_history if r.method == "DELETE"]
|
||||
self.assertEqual(len(delete_reqs), 1)
|
||||
|
||||
@Mocker()
|
||||
def test_sync_cleanup_stale_group_dry_run(self, mock: Mocker):
|
||||
"""Stale group cleanup skips HTTP DELETE in dry_run mode"""
|
||||
self.provider.dry_run = True
|
||||
self.provider.save()
|
||||
scim_id = generate_id()
|
||||
mock.get("https://localhost/ServiceProviderConfig", json={})
|
||||
self._create_stale_provider_group(scim_id)
|
||||
|
||||
scim_sync.send(self.provider.pk)
|
||||
|
||||
delete_reqs = [r for r in mock.request_history if r.method == "DELETE"]
|
||||
self.assertEqual(len(delete_reqs), 0)
|
||||
|
||||
@@ -1,17 +1,19 @@
|
||||
"""SCIM User tests"""
|
||||
|
||||
from json import loads
|
||||
from unittest.mock import patch
|
||||
|
||||
from django.test import TestCase
|
||||
from jsonschema import validate
|
||||
from requests_mock import Mocker
|
||||
|
||||
from authentik.blueprints.tests import apply_blueprint
|
||||
from authentik.core.models import Application, Group, User
|
||||
from authentik.core.models import Application, Group, User, UserTypes
|
||||
from authentik.lib.generators import generate_id
|
||||
from authentik.lib.sync.outgoing.base import SAFE_METHODS
|
||||
from authentik.lib.sync.outgoing.exceptions import TransientSyncException
|
||||
from authentik.providers.scim.models import SCIMMapping, SCIMProvider, SCIMProviderUser
|
||||
from authentik.providers.scim.tasks import scim_sync, scim_sync_objects
|
||||
from authentik.providers.scim.tasks import scim_sync, scim_sync_objects, sync_tasks
|
||||
from authentik.tasks.models import Task
|
||||
from authentik.tenants.models import Tenant
|
||||
|
||||
@@ -537,3 +539,104 @@ class SCIMUserTests(TestCase):
|
||||
self.assertEqual(mock.call_count, 2)
|
||||
self.assertEqual(mock.request_history[0].method, "GET")
|
||||
self.assertEqual(mock.request_history[1].method, "POST")
|
||||
|
||||
def _create_stale_provider_user(self, scim_id: str, uid: str) -> User:
|
||||
"""Create a service-account user (excluded from provider scope) with an existing
|
||||
SCIMProviderUser, simulating a previously synced user that is now out of scope."""
|
||||
user = User.objects.create(
|
||||
username=uid,
|
||||
name=f"{uid} {uid}",
|
||||
email=f"{uid}@goauthentik.io",
|
||||
type=UserTypes.SERVICE_ACCOUNT,
|
||||
)
|
||||
SCIMProviderUser.objects.create(provider=self.provider, user=user, scim_id=scim_id)
|
||||
return user
|
||||
|
||||
@Mocker()
|
||||
def test_sync_cleanup_stale_user_delete(self, mock: Mocker):
|
||||
"""Stale (out-of-scope) users are deleted during full sync cleanup"""
|
||||
scim_id = generate_id()
|
||||
uid = generate_id()
|
||||
mock.get("https://localhost/ServiceProviderConfig", json={})
|
||||
mock.delete(f"https://localhost/Users/{scim_id}", status_code=204)
|
||||
self._create_stale_provider_user(scim_id, uid)
|
||||
|
||||
scim_sync.send(self.provider.pk).get_result()
|
||||
|
||||
delete_reqs = [r for r in mock.request_history if r.method == "DELETE"]
|
||||
self.assertEqual(len(delete_reqs), 1)
|
||||
self.assertEqual(delete_reqs[0].url, f"https://localhost/Users/{scim_id}")
|
||||
self.assertFalse(
|
||||
SCIMProviderUser.objects.filter(provider=self.provider, scim_id=scim_id).exists()
|
||||
)
|
||||
|
||||
@Mocker()
|
||||
def test_sync_cleanup_stale_user_not_found(self, mock: Mocker):
|
||||
"""Stale user cleanup handles 404 from the remote gracefully"""
|
||||
scim_id = generate_id()
|
||||
uid = generate_id()
|
||||
mock.get("https://localhost/ServiceProviderConfig", json={})
|
||||
mock.delete(f"https://localhost/Users/{scim_id}", status_code=404)
|
||||
self._create_stale_provider_user(scim_id, uid)
|
||||
|
||||
scim_sync.send(self.provider.pk).get_result()
|
||||
|
||||
delete_reqs = [r for r in mock.request_history if r.method == "DELETE"]
|
||||
self.assertEqual(len(delete_reqs), 1)
|
||||
|
||||
self.assertFalse(
|
||||
SCIMProviderUser.objects.filter(provider=self.provider, scim_id=scim_id).exists()
|
||||
)
|
||||
|
||||
@Mocker()
|
||||
def test_sync_cleanup_stale_user_transient_error(self, mock: Mocker):
|
||||
"""Stale user cleanup logs and retries on transient HTTP errors"""
|
||||
scim_id = generate_id()
|
||||
uid = generate_id()
|
||||
mock.get("https://localhost/ServiceProviderConfig", json={})
|
||||
mock.delete(f"https://localhost/Users/{scim_id}", status_code=429)
|
||||
self._create_stale_provider_user(scim_id, uid)
|
||||
|
||||
scim_sync.send(self.provider.pk)
|
||||
|
||||
delete_reqs = [r for r in mock.request_history if r.method == "DELETE"]
|
||||
self.assertEqual(len(delete_reqs), 1)
|
||||
|
||||
@Mocker()
|
||||
def test_sync_cleanup_stale_user_dry_run(self, mock: Mocker):
|
||||
"""Stale user cleanup skips HTTP DELETE in dry_run mode"""
|
||||
self.provider.dry_run = True
|
||||
self.provider.save()
|
||||
scim_id = generate_id()
|
||||
uid = generate_id()
|
||||
mock.get("https://localhost/ServiceProviderConfig", json={})
|
||||
self._create_stale_provider_user(scim_id, uid)
|
||||
|
||||
scim_sync.send(self.provider.pk)
|
||||
|
||||
delete_reqs = [r for r in mock.request_history if r.method == "DELETE"]
|
||||
self.assertEqual(len(delete_reqs), 0)
|
||||
|
||||
def test_sync_cleanup_client_for_model_transient(self):
|
||||
"""Cleanup silently skips an object type when client_for_model raises
|
||||
TransientSyncException"""
|
||||
with Mocker() as mock:
|
||||
mock.get("https://localhost/ServiceProviderConfig", json={})
|
||||
with patch.object(
|
||||
SCIMProvider,
|
||||
"client_for_model",
|
||||
side_effect=TransientSyncException("connection failed"),
|
||||
):
|
||||
scim_sync.send(self.provider.pk).get_result()
|
||||
|
||||
def test_sync_transient_exception(self):
|
||||
"""TransientSyncException in _sync_cleanup is caught by sync() which then
|
||||
schedules a retry"""
|
||||
with Mocker() as mock:
|
||||
mock.get("https://localhost/ServiceProviderConfig", json={})
|
||||
with patch.object(
|
||||
sync_tasks,
|
||||
"_sync_cleanup",
|
||||
side_effect=TransientSyncException("connection failed"),
|
||||
):
|
||||
scim_sync.send(self.provider.pk)
|
||||
|
||||
@@ -6,6 +6,7 @@ from django.db import migrations, models
|
||||
class Migration(migrations.Migration):
|
||||
|
||||
dependencies = [
|
||||
("authentik_core", "0056_user_roles"), # must run before group field is removed
|
||||
("authentik_rbac", "0009_remove_initialpermissions_mode"),
|
||||
]
|
||||
|
||||
|
||||
@@ -60,11 +60,7 @@ class LDAPSourceSerializer(SourceSerializer):
|
||||
sources = sources.exclude(pk=self.instance.pk)
|
||||
if sources.exists():
|
||||
raise ValidationError(
|
||||
{
|
||||
"sync_users_password": _(
|
||||
"Only a single LDAP Source with password synchronization is allowed"
|
||||
)
|
||||
}
|
||||
_("Only a single LDAP Source with password synchronization is allowed")
|
||||
)
|
||||
return sync_users_password
|
||||
|
||||
@@ -221,7 +217,7 @@ class LDAPSourceViewSet(UsedByMixin, ModelViewSet):
|
||||
for sync_class in SYNC_CLASSES:
|
||||
class_name = sync_class.name()
|
||||
all_objects.setdefault(class_name, [])
|
||||
for page in sync_class(source).get_objects(size_limit=10):
|
||||
for page in sync_class(source, Task()).get_objects(size_limit=10):
|
||||
for obj in page:
|
||||
obj: dict
|
||||
obj.pop("raw_attributes", None)
|
||||
|
||||
@@ -14,6 +14,7 @@ from django.utils.translation import gettext_lazy as _
|
||||
from ldap3 import ALL, NONE, RANDOM, Connection, Server, ServerPool, Tls
|
||||
from ldap3.core.exceptions import LDAPException, LDAPInsufficientAccessRightsResult, LDAPSchemaError
|
||||
from rest_framework.serializers import Serializer
|
||||
from structlog.stdlib import get_logger
|
||||
|
||||
from authentik.core.models import (
|
||||
Group,
|
||||
@@ -31,6 +32,7 @@ from authentik.tasks.schedules.common import ScheduleSpec
|
||||
LDAP_TIMEOUT = 15
|
||||
LDAP_UNIQUENESS = "ldap_uniq"
|
||||
LDAP_DISTINGUISHED_NAME = "distinguishedName"
|
||||
LOGGER = get_logger()
|
||||
|
||||
|
||||
def flatten(value: Any) -> Any:
|
||||
@@ -268,6 +270,7 @@ class LDAPSource(IncomingSyncSource):
|
||||
)
|
||||
|
||||
if self.start_tls:
|
||||
LOGGER.debug("Connection StartTLS", source=self)
|
||||
conn.start_tls(read_server_info=False)
|
||||
try:
|
||||
successful = conn.bind()
|
||||
@@ -278,7 +281,9 @@ class LDAPSource(IncomingSyncSource):
|
||||
# See https://github.com/goauthentik/authentik/issues/4590
|
||||
# See also https://github.com/goauthentik/authentik/issues/3399
|
||||
if server_kwargs.get("get_info", ALL) == NONE:
|
||||
LOGGER.warning("Failed to connect after schema downgrade", source=self, exc=exc)
|
||||
raise exc
|
||||
LOGGER.warning("Downgrading connection to no schema info", source=self, exc=exc)
|
||||
server_kwargs["get_info"] = NONE
|
||||
return self.connection(server, server_kwargs, connection_kwargs)
|
||||
finally:
|
||||
|
||||
@@ -27,7 +27,7 @@ from authentik.tasks.middleware import CurrentTask
|
||||
from authentik.tasks.models import Task
|
||||
|
||||
LOGGER = get_logger()
|
||||
SYNC_CLASSES = [
|
||||
SYNC_CLASSES: list[type[BaseLDAPSynchronizer]] = [
|
||||
UserLDAPSynchronizer,
|
||||
GroupLDAPSynchronizer,
|
||||
MembershipLDAPSynchronizer,
|
||||
|
||||
@@ -1,10 +1,19 @@
|
||||
"""LDAP Source API tests"""
|
||||
|
||||
from json import loads
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from django.db.models import Q
|
||||
from django.urls import reverse
|
||||
from rest_framework.exceptions import ErrorDetail
|
||||
from rest_framework.test import APITestCase
|
||||
|
||||
from authentik.lib.generators import generate_key
|
||||
from authentik.blueprints.tests import apply_blueprint
|
||||
from authentik.core.tests.utils import create_test_admin_user
|
||||
from authentik.lib.generators import generate_id, generate_key
|
||||
from authentik.sources.ldap.api import LDAPSourceSerializer
|
||||
from authentik.sources.ldap.models import LDAPSource
|
||||
from authentik.sources.ldap.models import LDAPSource, LDAPSourcePropertyMapping
|
||||
from authentik.sources.ldap.tests.mock_ad import mock_ad_connection
|
||||
|
||||
LDAP_PASSWORD = generate_key()
|
||||
|
||||
@@ -26,12 +35,13 @@ class LDAPAPITests(APITestCase):
|
||||
}
|
||||
)
|
||||
self.assertTrue(serializer.is_valid())
|
||||
self.assertEqual(serializer.errors, {})
|
||||
|
||||
def test_sync_users_password_invalid(self):
|
||||
"""Ensure only a single source with password sync can be created"""
|
||||
LDAPSource.objects.create(
|
||||
name="foo",
|
||||
slug="foo",
|
||||
slug=generate_id(),
|
||||
server_uri="ldaps://1.2.3.4",
|
||||
bind_cn="",
|
||||
bind_password=LDAP_PASSWORD,
|
||||
@@ -41,15 +51,26 @@ class LDAPAPITests(APITestCase):
|
||||
serializer = LDAPSourceSerializer(
|
||||
data={
|
||||
"name": "foo",
|
||||
"slug": " foo",
|
||||
"slug": generate_id(),
|
||||
"server_uri": "ldaps://1.2.3.4",
|
||||
"bind_cn": "",
|
||||
"bind_password": LDAP_PASSWORD,
|
||||
"base_dn": "dc=foo",
|
||||
"sync_users_password": False,
|
||||
"sync_users_password": True,
|
||||
}
|
||||
)
|
||||
self.assertFalse(serializer.is_valid())
|
||||
self.assertEqual(
|
||||
serializer.errors,
|
||||
{
|
||||
"sync_users_password": [
|
||||
ErrorDetail(
|
||||
string="Only a single LDAP Source with password synchronization is allowed",
|
||||
code="invalid",
|
||||
)
|
||||
]
|
||||
},
|
||||
)
|
||||
|
||||
def test_sync_users_mapping_empty(self):
|
||||
"""Check that when sync_users is enabled, property mappings must be set"""
|
||||
@@ -82,3 +103,32 @@ class LDAPAPITests(APITestCase):
|
||||
}
|
||||
)
|
||||
self.assertFalse(serializer.is_valid())
|
||||
|
||||
@apply_blueprint("system/sources-ldap.yaml")
|
||||
def test_sync_debug(self):
|
||||
user = create_test_admin_user()
|
||||
self.client.force_login(user)
|
||||
|
||||
source: LDAPSource = LDAPSource.objects.create(
|
||||
name=generate_id(),
|
||||
slug=generate_id(),
|
||||
base_dn="dc=goauthentik,dc=io",
|
||||
additional_user_dn="ou=users",
|
||||
additional_group_dn="ou=groups",
|
||||
)
|
||||
source.user_property_mappings.set(
|
||||
LDAPSourcePropertyMapping.objects.filter(
|
||||
Q(managed__startswith="goauthentik.io/sources/ldap/default")
|
||||
| Q(managed__startswith="goauthentik.io/sources/ldap/ms")
|
||||
)
|
||||
)
|
||||
connection = MagicMock(return_value=mock_ad_connection(LDAP_PASSWORD))
|
||||
with patch("authentik.sources.ldap.models.LDAPSource.connection", connection):
|
||||
res = self.client.get(
|
||||
reverse("authentik_api:ldapsource-debug", kwargs={"slug": source.slug})
|
||||
)
|
||||
self.assertEqual(res.status_code, 200)
|
||||
body = loads(res.content.decode())
|
||||
self.assertIn("users", body)
|
||||
self.assertIn("groups", body)
|
||||
self.assertIn("membership", body)
|
||||
|
||||
@@ -59,7 +59,11 @@ class OAuthSourceSerializer(SourceSerializer):
|
||||
|
||||
def validate(self, attrs: dict) -> dict:
|
||||
session = get_http_session()
|
||||
source_type = registry.find_type(attrs["provider_type"])
|
||||
provider_type_name = attrs.get(
|
||||
"provider_type",
|
||||
self.instance.provider_type if self.instance else None,
|
||||
)
|
||||
source_type = registry.find_type(provider_type_name)
|
||||
|
||||
well_known = attrs.get("oidc_well_known_url") or source_type.oidc_well_known_url
|
||||
inferred_oidc_jwks_url = None
|
||||
@@ -101,16 +105,15 @@ class OAuthSourceSerializer(SourceSerializer):
|
||||
config = jwks_config.json()
|
||||
attrs["oidc_jwks"] = config
|
||||
|
||||
provider_type = registry.find_type(attrs.get("provider_type", ""))
|
||||
for url in [
|
||||
"authorization_url",
|
||||
"access_token_url",
|
||||
"profile_url",
|
||||
]:
|
||||
if getattr(provider_type, url, None) is None:
|
||||
if getattr(source_type, url, None) is None:
|
||||
if url not in attrs:
|
||||
raise ValidationError(
|
||||
f"{url} is required for provider {provider_type.verbose_name}"
|
||||
f"{url} is required for provider {source_type.verbose_name}"
|
||||
)
|
||||
return attrs
|
||||
|
||||
|
||||
31
authentik/sources/oauth/tests/test_api.py
Normal file
31
authentik/sources/oauth/tests/test_api.py
Normal file
@@ -0,0 +1,31 @@
|
||||
from django.urls import reverse
|
||||
from rest_framework.test import APITestCase
|
||||
|
||||
from authentik.core.tests.utils import create_test_admin_user
|
||||
from authentik.lib.generators import generate_id
|
||||
from authentik.sources.oauth.models import OAuthSource
|
||||
|
||||
|
||||
class TestOAuthSourceAPI(APITestCase):
|
||||
def setUp(self):
|
||||
self.source = OAuthSource.objects.create(
|
||||
name=generate_id(),
|
||||
slug=generate_id(),
|
||||
provider_type="openidconnect",
|
||||
authorization_url="",
|
||||
profile_url="",
|
||||
consumer_key=generate_id(),
|
||||
)
|
||||
self.user = create_test_admin_user()
|
||||
|
||||
def test_patch_no_type(self):
|
||||
self.client.force_login(self.user)
|
||||
res = self.client.patch(
|
||||
reverse("authentik_api:oauthsource-detail", kwargs={"slug": self.source.slug}),
|
||||
{
|
||||
"authorization_url": f"https://{generate_id()}",
|
||||
"profile_url": f"https://{generate_id()}",
|
||||
"access_token_url": f"https://{generate_id()}",
|
||||
},
|
||||
)
|
||||
self.assertEqual(res.status_code, 200)
|
||||
@@ -38,6 +38,7 @@ from authentik.stages.authenticator_validate.models import AuthenticatorValidate
|
||||
from authentik.stages.authenticator_webauthn.models import UserVerification, WebAuthnDevice
|
||||
from authentik.stages.authenticator_webauthn.stage import PLAN_CONTEXT_WEBAUTHN_CHALLENGE
|
||||
from authentik.stages.authenticator_webauthn.utils import get_origin, get_rp_id
|
||||
from authentik.stages.password.stage import PLAN_CONTEXT_METHOD_ARGS
|
||||
|
||||
LOGGER = get_logger()
|
||||
if TYPE_CHECKING:
|
||||
@@ -143,7 +144,11 @@ def validate_challenge_code(code: str, stage_view: StageView, user: User) -> Dev
|
||||
credentials={"username": user.username},
|
||||
request=stage_view.request,
|
||||
stage=stage_view.executor.current_stage,
|
||||
device_class=DeviceClasses.TOTP.value,
|
||||
context={
|
||||
PLAN_CONTEXT_METHOD_ARGS: {
|
||||
"device_class": DeviceClasses.TOTP.value,
|
||||
}
|
||||
},
|
||||
)
|
||||
raise ValidationError(
|
||||
_("Invalid Token. Please ensure the time on your device is accurate and try again.")
|
||||
@@ -215,9 +220,13 @@ def validate_challenge_webauthn(
|
||||
credentials={"username": user.username},
|
||||
request=stage_view.request,
|
||||
stage=stage_view.executor.current_stage,
|
||||
device=device,
|
||||
device_class=DeviceClasses.WEBAUTHN.value,
|
||||
device_type=device.device_type,
|
||||
context={
|
||||
PLAN_CONTEXT_METHOD_ARGS: {
|
||||
"device": device,
|
||||
"device_class": DeviceClasses.WEBAUTHN.value,
|
||||
"device_type": device.device_type,
|
||||
},
|
||||
},
|
||||
)
|
||||
raise ValidationError("Assertion failed") from exc
|
||||
|
||||
@@ -267,8 +276,12 @@ def validate_challenge_duo(device_pk: int, stage_view: StageView, user: User) ->
|
||||
credentials={"username": user.username},
|
||||
request=stage_view.request,
|
||||
stage=stage_view.executor.current_stage,
|
||||
device_class=DeviceClasses.DUO.value,
|
||||
duo_response=response,
|
||||
context={
|
||||
PLAN_CONTEXT_METHOD_ARGS: {
|
||||
"device_class": DeviceClasses.DUO.value,
|
||||
"duo_response": response,
|
||||
}
|
||||
},
|
||||
)
|
||||
raise ValidationError("Duo denied access", code="denied")
|
||||
return device
|
||||
|
||||
@@ -99,6 +99,7 @@ class IdentificationChallenge(Challenge):
|
||||
password_fields = BooleanField()
|
||||
allow_show_password = BooleanField(default=False)
|
||||
application_pre = CharField(required=False)
|
||||
application_pre_launch = CharField(required=False)
|
||||
flow_designation = ChoiceField(FlowDesignation.choices)
|
||||
captcha_stage = CaptchaChallenge(required=False, allow_null=True)
|
||||
|
||||
@@ -348,9 +349,12 @@ class IdentificationStageView(ChallengeStageView):
|
||||
# If the user has been redirected to us whilst trying to access an
|
||||
# application, PLAN_CONTEXT_APPLICATION is set in the flow plan
|
||||
if PLAN_CONTEXT_APPLICATION in self.executor.plan.context:
|
||||
challenge.initial_data["application_pre"] = self.executor.plan.context.get(
|
||||
app: Application = self.executor.plan.context.get(
|
||||
PLAN_CONTEXT_APPLICATION, Application()
|
||||
).name
|
||||
)
|
||||
challenge.initial_data["application_pre"] = app.name
|
||||
if launch_url := app.get_launch_url():
|
||||
challenge.initial_data["application_pre_launch"] = launch_url
|
||||
if (
|
||||
PLAN_CONTEXT_DEVICE in self.executor.plan.context
|
||||
and PLAN_CONTEXT_DEVICE_AUTH_TOKEN in self.executor.plan.context
|
||||
|
||||
@@ -157,7 +157,7 @@ class TaskViewSet(
|
||||
def retry(self, request: Request, pk=None) -> Response:
|
||||
"""Retry task"""
|
||||
task: Task = self.get_object()
|
||||
if task.state not in (TaskState.REJECTED, TaskState.DONE):
|
||||
if task.state != TaskState.REJECTED:
|
||||
return Response(status=400)
|
||||
broker = get_broker()
|
||||
broker.enqueue(Message.decode(task.message))
|
||||
|
||||
178
blueprints/example/flows-recovery-email-mfa-verification.yaml
Normal file
178
blueprints/example/flows-recovery-email-mfa-verification.yaml
Normal file
@@ -0,0 +1,178 @@
|
||||
version: 1
|
||||
metadata:
|
||||
labels:
|
||||
blueprints.goauthentik.io/instantiate: "false"
|
||||
name: Example - Recovery with email and MFA verification
|
||||
entries:
|
||||
- identifiers:
|
||||
slug: default-recovery-flow
|
||||
id: flow
|
||||
model: authentik_flows.flow
|
||||
attrs:
|
||||
name: Default recovery flow
|
||||
title: Reset your password
|
||||
designation: recovery
|
||||
authentication: require_unauthenticated
|
||||
- identifiers:
|
||||
name: default-recovery-field-password
|
||||
id: prompt-field-password
|
||||
model: authentik_stages_prompt.prompt
|
||||
attrs:
|
||||
field_key: password
|
||||
label: Password
|
||||
type: password
|
||||
required: true
|
||||
placeholder: Password
|
||||
order: 0
|
||||
placeholder_expression: false
|
||||
- identifiers:
|
||||
name: default-recovery-field-password-repeat
|
||||
id: prompt-field-password-repeat
|
||||
model: authentik_stages_prompt.prompt
|
||||
attrs:
|
||||
field_key: password_repeat
|
||||
label: Password (repeat)
|
||||
type: password
|
||||
required: true
|
||||
placeholder: Password (repeat)
|
||||
order: 1
|
||||
placeholder_expression: false
|
||||
- identifiers:
|
||||
name: default-recovery-skip-if-restored
|
||||
id: default-recovery-skip-if-restored
|
||||
model: authentik_policies_expression.expressionpolicy
|
||||
attrs:
|
||||
expression: |
|
||||
return bool(request.context.get('is_restored', True))
|
||||
- identifiers:
|
||||
name: default-recovery-email
|
||||
id: default-recovery-email
|
||||
model: authentik_stages_email.emailstage
|
||||
attrs:
|
||||
use_global_settings: true
|
||||
host: localhost
|
||||
port: 25
|
||||
username: ""
|
||||
use_tls: false
|
||||
use_ssl: false
|
||||
timeout: 10
|
||||
from_address: system@authentik.local
|
||||
token_expiry: minutes=30
|
||||
subject: authentik
|
||||
template: email/password_reset.html
|
||||
activate_user_on_success: true
|
||||
recovery_max_attempts: 5
|
||||
recovery_cache_timeout: minutes=5
|
||||
- identifiers:
|
||||
name: default-recovery-mfa
|
||||
id: default-recovery-mfa
|
||||
model: authentik_stages_authenticator_validate.authenticatorvalidatestage
|
||||
- identifiers:
|
||||
name: default-recovery-user-write
|
||||
id: default-recovery-user-write
|
||||
model: authentik_stages_user_write.userwritestage
|
||||
attrs:
|
||||
user_creation_mode: never_create
|
||||
- identifiers:
|
||||
name: default-recovery-identification
|
||||
id: default-recovery-identification
|
||||
model: authentik_stages_identification.identificationstage
|
||||
attrs:
|
||||
user_fields:
|
||||
- email
|
||||
- username
|
||||
- identifiers:
|
||||
name: default-recovery-user-login
|
||||
id: default-recovery-user-login
|
||||
model: authentik_stages_user_login.userloginstage
|
||||
- identifiers:
|
||||
name: Change your password
|
||||
id: stages-prompt-password
|
||||
model: authentik_stages_prompt.promptstage
|
||||
attrs:
|
||||
fields:
|
||||
- !KeyOf prompt-field-password
|
||||
- !KeyOf prompt-field-password-repeat
|
||||
validation_policies: []
|
||||
- identifiers:
|
||||
target: !KeyOf flow
|
||||
stage: !KeyOf default-recovery-identification
|
||||
order: 10
|
||||
model: authentik_flows.flowstagebinding
|
||||
id: flow-binding-identification
|
||||
attrs:
|
||||
evaluate_on_plan: true
|
||||
re_evaluate_policies: true
|
||||
policy_engine_mode: any
|
||||
invalid_response_action: retry
|
||||
- identifiers:
|
||||
target: !KeyOf flow
|
||||
stage: !KeyOf default-recovery-email
|
||||
order: 20
|
||||
model: authentik_flows.flowstagebinding
|
||||
id: flow-binding-email
|
||||
attrs:
|
||||
evaluate_on_plan: true
|
||||
re_evaluate_policies: true
|
||||
policy_engine_mode: any
|
||||
invalid_response_action: retry
|
||||
- identifiers:
|
||||
target: !KeyOf flow
|
||||
stage: !KeyOf default-recovery-mfa
|
||||
order: 21
|
||||
model: authentik_flows.flowstagebinding
|
||||
id: flow-binding-mfa
|
||||
attrs:
|
||||
evaluate_on_plan: true
|
||||
re_evaluate_policies: true
|
||||
policy_engine_mode: any
|
||||
invalid_response_action: retry
|
||||
- identifiers:
|
||||
target: !KeyOf flow
|
||||
stage: !KeyOf stages-prompt-password
|
||||
order: 30
|
||||
model: authentik_flows.flowstagebinding
|
||||
attrs:
|
||||
evaluate_on_plan: true
|
||||
re_evaluate_policies: false
|
||||
policy_engine_mode: any
|
||||
invalid_response_action: retry
|
||||
- identifiers:
|
||||
target: !KeyOf flow
|
||||
stage: !KeyOf default-recovery-user-write
|
||||
order: 40
|
||||
model: authentik_flows.flowstagebinding
|
||||
attrs:
|
||||
evaluate_on_plan: true
|
||||
re_evaluate_policies: false
|
||||
policy_engine_mode: any
|
||||
invalid_response_action: retry
|
||||
- identifiers:
|
||||
target: !KeyOf flow
|
||||
stage: !KeyOf default-recovery-user-login
|
||||
order: 100
|
||||
model: authentik_flows.flowstagebinding
|
||||
attrs:
|
||||
evaluate_on_plan: true
|
||||
re_evaluate_policies: false
|
||||
policy_engine_mode: any
|
||||
invalid_response_action: retry
|
||||
- identifiers:
|
||||
policy: !KeyOf default-recovery-skip-if-restored
|
||||
target: !KeyOf flow-binding-identification
|
||||
order: 0
|
||||
model: authentik_policies.policybinding
|
||||
attrs:
|
||||
negate: false
|
||||
enabled: true
|
||||
timeout: 30
|
||||
- identifiers:
|
||||
policy: !KeyOf default-recovery-skip-if-restored
|
||||
target: !KeyOf flow-binding-email
|
||||
order: 0
|
||||
state: absent
|
||||
model: authentik_policies.policybinding
|
||||
attrs:
|
||||
negate: false
|
||||
enabled: true
|
||||
timeout: 30
|
||||
@@ -2,7 +2,7 @@
|
||||
"$schema": "http://json-schema.org/draft-07/schema",
|
||||
"$id": "https://goauthentik.io/blueprints/schema.json",
|
||||
"type": "object",
|
||||
"title": "authentik 2026.2.0-rc5 Blueprint schema",
|
||||
"title": "authentik 2026.2.3-rc1 Blueprint schema",
|
||||
"required": [
|
||||
"version",
|
||||
"entries"
|
||||
|
||||
@@ -16,6 +16,7 @@ type Config struct {
|
||||
Listen ListenConfig `yaml:"listen" env:", prefix=AUTHENTIK_LISTEN__"`
|
||||
Web WebConfig `yaml:"web" env:", prefix=AUTHENTIK_WEB__"`
|
||||
Log LogConfig `yaml:"log" env:", prefix=AUTHENTIK_LOG__"`
|
||||
LDAP LDAPConfig `yaml:"ldap" env:", prefix=AUTHENTIK_LDAP__"`
|
||||
|
||||
// Outpost specific config
|
||||
// These are only relevant for proxy/ldap outposts, and cannot be set via YAML
|
||||
@@ -104,9 +105,17 @@ type OutpostConfig struct {
|
||||
}
|
||||
|
||||
type WebConfig struct {
|
||||
Path string `yaml:"path" env:"PATH, overwrite"`
|
||||
Path string `yaml:"path" env:"PATH, overwrite"`
|
||||
TimeoutHttpReadHeader string `yaml:"timeout_http_read_header" env:"TIMEOUT_HTTP_READ_HEADER, overwrite"`
|
||||
TimeoutHttpRead string `yaml:"timeout_http_read" env:"TIMEOUT_HTTP_READ, overwrite"`
|
||||
TimeoutHttpWrite string `yaml:"timeout_http_write" env:"TIMEOUT_HTTP_WRITE, overwrite"`
|
||||
TimeoutHttpIdle string `yaml:"timeout_http_idle" env:"TIMEOUT_HTTP_IDLE, overwrite"`
|
||||
}
|
||||
|
||||
type LogConfig struct {
|
||||
HttpHeaders []string `yaml:"http_headers" env:"HTTP_HEADERS, overwrite"`
|
||||
}
|
||||
|
||||
type LDAPConfig struct {
|
||||
PageSize int `yaml:"page_size" env:"PAGE_SIZE, overwrite"`
|
||||
}
|
||||
|
||||
@@ -1 +1 @@
|
||||
2026.2.0-rc5
|
||||
2026.2.3-rc1
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
"github.com/getsentry/sentry-go"
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
"goauthentik.io/api/v3"
|
||||
"goauthentik.io/internal/config"
|
||||
"goauthentik.io/internal/outpost/ak"
|
||||
"goauthentik.io/internal/outpost/ldap/constants"
|
||||
"goauthentik.io/internal/outpost/ldap/group"
|
||||
@@ -83,10 +84,6 @@ func (ds *DirectSearcher) Search(req *search.Request) (ldap.ServerSearchResult,
|
||||
|
||||
entries := make([]*ldap.Entry, 0)
|
||||
|
||||
// Create a custom client to set additional headers
|
||||
c := api.NewAPIClient(ds.si.GetAPIClient().GetConfig())
|
||||
c.GetConfig().AddDefaultHeader("X-authentik-outpost-ldap-query", req.Filter)
|
||||
|
||||
scope := req.Scope
|
||||
needUsers, needGroups := ds.si.GetNeededObjects(scope, req.BaseDN, req.FilterObjectClass)
|
||||
|
||||
@@ -113,7 +110,7 @@ func (ds *DirectSearcher) Search(req *search.Request) (ldap.ServerSearchResult,
|
||||
errs.Go(func() error {
|
||||
if flags.CanSearch {
|
||||
uapisp := sentry.StartSpan(errCtx, "authentik.providers.ldap.search.api_user")
|
||||
searchReq, skip := utils.ParseFilterForUser(c.CoreAPI.CoreUsersList(uapisp.Context()).IncludeGroups(true), parsedFilter, false)
|
||||
searchReq, skip := utils.ParseFilterForUser(ds.si.GetAPIClient().CoreAPI.CoreUsersList(uapisp.Context()).IncludeGroups(true), parsedFilter, false)
|
||||
|
||||
if skip {
|
||||
req.Log().Trace("Skip backend request")
|
||||
@@ -121,7 +118,7 @@ func (ds *DirectSearcher) Search(req *search.Request) (ldap.ServerSearchResult,
|
||||
}
|
||||
|
||||
u, err := ak.Paginator(searchReq, ak.PaginatorOptions{
|
||||
PageSize: 100,
|
||||
PageSize: config.Get().LDAP.PageSize,
|
||||
Logger: ds.log,
|
||||
})
|
||||
uapisp.Finish()
|
||||
@@ -132,7 +129,7 @@ func (ds *DirectSearcher) Search(req *search.Request) (ldap.ServerSearchResult,
|
||||
} else {
|
||||
if flags.UserInfo == nil {
|
||||
uapisp := sentry.StartSpan(errCtx, "authentik.providers.ldap.search.api_user")
|
||||
u, _, err := c.CoreAPI.CoreUsersRetrieve(uapisp.Context(), flags.UserPk).Execute()
|
||||
u, _, err := ds.si.GetAPIClient().CoreAPI.CoreUsersRetrieve(uapisp.Context(), flags.UserPk).Execute()
|
||||
uapisp.Finish()
|
||||
|
||||
if err != nil {
|
||||
@@ -155,7 +152,7 @@ func (ds *DirectSearcher) Search(req *search.Request) (ldap.ServerSearchResult,
|
||||
if needGroups {
|
||||
errs.Go(func() error {
|
||||
gapisp := sentry.StartSpan(errCtx, "authentik.providers.ldap.search.api_group")
|
||||
searchReq, skip := utils.ParseFilterForGroup(c.CoreAPI.CoreGroupsList(gapisp.Context()).IncludeUsers(true).IncludeChildren(true).IncludeParents(true), parsedFilter, false)
|
||||
searchReq, skip := utils.ParseFilterForGroup(ds.si.GetAPIClient().CoreAPI.CoreGroupsList(gapisp.Context()).IncludeUsers(true).IncludeChildren(true).IncludeParents(true), parsedFilter, false)
|
||||
if skip {
|
||||
req.Log().Trace("Skip backend request")
|
||||
return nil
|
||||
@@ -167,7 +164,7 @@ func (ds *DirectSearcher) Search(req *search.Request) (ldap.ServerSearchResult,
|
||||
}
|
||||
|
||||
g, err := ak.Paginator(searchReq, ak.PaginatorOptions{
|
||||
PageSize: 100,
|
||||
PageSize: config.Get().LDAP.PageSize,
|
||||
Logger: ds.log,
|
||||
})
|
||||
gapisp.Finish()
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"goauthentik.io/api/v3"
|
||||
"goauthentik.io/internal/config"
|
||||
"goauthentik.io/internal/outpost/ak"
|
||||
"goauthentik.io/internal/outpost/ldap/constants"
|
||||
"goauthentik.io/internal/outpost/ldap/flags"
|
||||
@@ -53,12 +54,12 @@ func NewMemorySearcher(si server.LDAPServerInstance, existing search.Searcher) *
|
||||
func (ms *MemorySearcher) fetch() {
|
||||
// Error is not handled here, we get an empty/truncated list and the error is logged
|
||||
users, _ := ak.Paginator(ms.si.GetAPIClient().CoreAPI.CoreUsersList(context.TODO()).IncludeGroups(true), ak.PaginatorOptions{
|
||||
PageSize: 100,
|
||||
PageSize: config.Get().LDAP.PageSize,
|
||||
Logger: ms.log,
|
||||
})
|
||||
ms.users = users
|
||||
groups, _ := ak.Paginator(ms.si.GetAPIClient().CoreAPI.CoreGroupsList(context.TODO()).IncludeUsers(true).IncludeChildren(true).IncludeParents(true), ak.PaginatorOptions{
|
||||
PageSize: 100,
|
||||
PageSize: config.Get().LDAP.PageSize,
|
||||
Logger: ms.log,
|
||||
})
|
||||
ms.groups = groups
|
||||
|
||||
@@ -26,7 +26,6 @@ import (
|
||||
"goauthentik.io/api/v3"
|
||||
"goauthentik.io/internal/config"
|
||||
"goauthentik.io/internal/outpost/ak"
|
||||
"goauthentik.io/internal/outpost/proxyv2/constants"
|
||||
"goauthentik.io/internal/outpost/proxyv2/hs256"
|
||||
"goauthentik.io/internal/outpost/proxyv2/metrics"
|
||||
"goauthentik.io/internal/outpost/proxyv2/templates"
|
||||
@@ -294,22 +293,16 @@ func (a *Application) Stop() {
|
||||
|
||||
func (a *Application) handleSignOut(rw http.ResponseWriter, r *http.Request) {
|
||||
redirect := a.endpoint.EndSessionEndpoint
|
||||
s, err := a.sessions.Get(r, a.SessionName())
|
||||
if err != nil {
|
||||
cc := a.getClaimsFromSession(rw, r)
|
||||
if cc == nil {
|
||||
a.redirectToStart(rw, r)
|
||||
return
|
||||
}
|
||||
c, exists := s.Values[constants.SessionClaims]
|
||||
if c == nil && !exists {
|
||||
a.redirectToStart(rw, r)
|
||||
return
|
||||
}
|
||||
cc := c.(types.Claims)
|
||||
uv := url.Values{
|
||||
"id_token_hint": []string{cc.RawToken},
|
||||
}
|
||||
redirect += "?" + uv.Encode()
|
||||
err = a.Logout(r.Context(), func(c types.Claims) bool {
|
||||
err := a.Logout(r.Context(), func(c types.Claims) bool {
|
||||
return c.Sub == cc.Sub
|
||||
})
|
||||
if err != nil {
|
||||
|
||||
@@ -187,10 +187,7 @@ func BuildConnConfig(cfg config.PostgreSQLConfig) (*pgx.ConnConfig, error) {
|
||||
if connConfig.RuntimeParams == nil {
|
||||
connConfig.RuntimeParams = make(map[string]string)
|
||||
}
|
||||
|
||||
if cfg.DefaultSchema != "" {
|
||||
connConfig.RuntimeParams["search_path"] = cfg.DefaultSchema
|
||||
}
|
||||
effectiveSearchPath := cfg.DefaultSchema
|
||||
|
||||
// Parse and apply connection options if specified
|
||||
if cfg.ConnOptions != "" {
|
||||
@@ -198,12 +195,39 @@ func BuildConnConfig(cfg config.PostgreSQLConfig) (*pgx.ConnConfig, error) {
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse connection options: %w", err)
|
||||
}
|
||||
// search_path from ConnOptions is not supported here; Django controls schema selection.
|
||||
// Always remove it so it cannot end up in startup RuntimeParams via applyConnOptions.
|
||||
delete(connOpts, "search_path")
|
||||
|
||||
if err := applyConnOptions(connConfig, connOpts); err != nil {
|
||||
return nil, fmt.Errorf("failed to apply connection options: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// search_path may already be present via pgx/libpq inherited defaults (e.g. service files).
|
||||
// Always remove it from startup RuntimeParams; apply it via AfterConnect instead.
|
||||
if inheritedSearchPath, hasInheritedSearchPath := connConfig.RuntimeParams["search_path"]; hasInheritedSearchPath {
|
||||
if effectiveSearchPath == "" {
|
||||
effectiveSearchPath = inheritedSearchPath
|
||||
}
|
||||
delete(connConfig.RuntimeParams, "search_path")
|
||||
}
|
||||
|
||||
// Set search_path after connection startup to avoid startup-parameter issues with PgBouncer.
|
||||
if effectiveSearchPath != "" {
|
||||
connConfig.AfterConnect = func(ctx context.Context, pgConn *pgconn.PgConn) error {
|
||||
result := pgConn.ExecParams(
|
||||
ctx,
|
||||
"select pg_catalog.set_config('search_path', $1, false)",
|
||||
[][]byte{[]byte(effectiveSearchPath)},
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
).Read()
|
||||
return result.Err
|
||||
}
|
||||
}
|
||||
|
||||
return connConfig, nil
|
||||
}
|
||||
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user