mirror of
https://github.com/goauthentik/authentik
synced 2026-05-05 22:52:42 +02:00
Compare commits
138 Commits
modal-revi
...
version/20
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
319008dec8 | ||
|
|
8beb2fac18 | ||
|
|
ac7b28d0b0 | ||
|
|
073acf92c2 | ||
|
|
ad107c19af | ||
|
|
d285fcd8a7 | ||
|
|
84066cab48 | ||
|
|
e623d93ff5 | ||
|
|
1d0628dfbe | ||
|
|
996645105c | ||
|
|
63d7ca6ef0 | ||
|
|
5b24f4ad80 | ||
|
|
ed2e6cfb9c | ||
|
|
a1431ea48e | ||
|
|
b30e77b363 | ||
|
|
2f50cdd9fe | ||
|
|
494bdcaa09 | ||
|
|
e36ce1789e | ||
|
|
5a72ed83e0 | ||
|
|
f72d257e43 | ||
|
|
cbedb16cc4 | ||
|
|
6fc1b5ce90 | ||
|
|
57b0fa48c1 | ||
|
|
84a344ed87 | ||
|
|
f864cb56ab | ||
|
|
692735f9e1 | ||
|
|
e24fb300b1 | ||
|
|
f0e90d6873 | ||
|
|
0cf45835a0 | ||
|
|
69d35c1d26 | ||
|
|
ac803b210d | ||
|
|
c9728b4607 | ||
|
|
6e45584563 | ||
|
|
59a2e84b35 | ||
|
|
6025dbb9c9 | ||
|
|
d07bcd5025 | ||
|
|
e80655d285 | ||
|
|
e0d3d4d38c | ||
|
|
62112404ee | ||
|
|
1c9e12fcd9 | ||
|
|
42c6c257ec | ||
|
|
41bd9d7913 | ||
|
|
2c84935732 | ||
|
|
819c13a9bc | ||
|
|
0d8f366af8 | ||
|
|
093e60c753 | ||
|
|
af646f32d2 | ||
|
|
de4afc7322 | ||
|
|
bc1983106f | ||
|
|
8c2c1474f1 | ||
|
|
0dccbd4193 | ||
|
|
6a70894e01 | ||
|
|
2f5eb9b2e4 | ||
|
|
12aedb3a9e | ||
|
|
303dc93514 | ||
|
|
fbb217db57 | ||
|
|
4de253653f | ||
|
|
4154c06831 | ||
|
|
4750ed5e2a | ||
|
|
361017127d | ||
|
|
0ca5a54307 | ||
|
|
ef1aad5dbb | ||
|
|
29d880920e | ||
|
|
fc6f8374e6 | ||
|
|
a8668bbac4 | ||
|
|
d686932166 | ||
|
|
feceb220b1 | ||
|
|
937df6e07f | ||
|
|
48e6b968a6 | ||
|
|
cd89c45e75 | ||
|
|
e53995e2c1 | ||
|
|
33d5f11f0e | ||
|
|
565e16eca7 | ||
|
|
9a0164b722 | ||
|
|
8af491630b | ||
|
|
8e25e7a213 | ||
|
|
4d183657da | ||
|
|
be89b6052d | ||
|
|
ad5d2bb611 | ||
|
|
8d30fb3d25 | ||
|
|
cea3fbfa9b | ||
|
|
151d889ff4 | ||
|
|
58ca3ecbd5 | ||
|
|
1a6c7082a3 | ||
|
|
1dc60276f9 | ||
|
|
de045c6d7b | ||
|
|
850728e9bb | ||
|
|
84a605a4ba | ||
|
|
1780bb0cf0 | ||
|
|
cd75fe235d | ||
|
|
e6e62e9de1 | ||
|
|
ac7a4f8a22 | ||
|
|
0290ed3342 | ||
|
|
e367525794 | ||
|
|
93c319baee | ||
|
|
1d02ee7d74 | ||
|
|
93439b5742 | ||
|
|
6682a6664e | ||
|
|
0b5bac74e9 | ||
|
|
062823f1b2 | ||
|
|
a17fe58971 | ||
|
|
422ea893b1 | ||
|
|
15c9f93851 | ||
|
|
e2202d498b | ||
|
|
9ea9a86ad3 | ||
|
|
4bac1edd61 | ||
|
|
24726be3c9 | ||
|
|
411f06756f | ||
|
|
4bdcab48c3 | ||
|
|
00dbd377a7 | ||
|
|
a01c0575db | ||
|
|
6e51d044bb | ||
|
|
6d1b168dc4 | ||
|
|
43675c2b22 | ||
|
|
8645273eaf | ||
|
|
eb6f4712fe | ||
|
|
7b9505242e | ||
|
|
3dda20ebc7 | ||
|
|
dfd2bc5c3c | ||
|
|
06a270913c | ||
|
|
430507fc72 | ||
|
|
847af7f9ea | ||
|
|
8f1cb636e8 | ||
|
|
e61c876002 | ||
|
|
33c0d3df0a | ||
|
|
3a03e1ebfd | ||
|
|
1e41b77761 | ||
|
|
6c1662f99f | ||
|
|
bb5bc5c8da | ||
|
|
30670c9070 | ||
|
|
fdbf9ffedc | ||
|
|
2ec433d724 | ||
|
|
55297b9e6a | ||
|
|
f9dda6582c | ||
|
|
3394c17bfd | ||
|
|
a37d101b10 | ||
|
|
4774b4db87 | ||
|
|
fdb52c9394 |
23
.github/actions/cherry-pick/action.yml
vendored
23
.github/actions/cherry-pick/action.yml
vendored
@@ -115,20 +115,13 @@ runs:
|
||||
shell: bash
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ inputs.token }}
|
||||
PR_NUMBER: ${{ steps.should_run.outputs.pr_number }}
|
||||
REASON: ${{ steps.should_run.outputs.reason }}
|
||||
run: |
|
||||
set -e -o pipefail
|
||||
PR_NUMBER="${{ steps.should_run.outputs.pr_number }}"
|
||||
|
||||
# Get PR details
|
||||
PR_DATA=$(gh api repos/${{ github.repository }}/pulls/$PR_NUMBER)
|
||||
PR_TITLE=$(echo "$PR_DATA" | jq -r '.title')
|
||||
PR_AUTHOR=$(echo "$PR_DATA" | jq -r '.user.login')
|
||||
|
||||
echo "pr_title=$PR_TITLE" >> $GITHUB_OUTPUT
|
||||
echo "pr_author=$PR_AUTHOR" >> $GITHUB_OUTPUT
|
||||
|
||||
# Determine which labels to process
|
||||
if [ "${{ steps.should_run.outputs.reason }}" = "label_added_to_merged_pr" ]; then
|
||||
if [ "${REASON}" = "label_added_to_merged_pr" ]; then
|
||||
# Only process the specific label that was just added
|
||||
if [ "${{ github.event_name }}" = "issues" ]; then
|
||||
LABEL_NAME="${{ github.event.label.name }}"
|
||||
@@ -152,13 +145,13 @@ runs:
|
||||
shell: bash
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ inputs.token }}
|
||||
PR_NUMBER: '${{ steps.should_run.outputs.pr_number }}'
|
||||
COMMIT_SHA: '${{ steps.should_run.outputs.merge_commit_sha }}'
|
||||
PR_TITLE: ${{ github.event.pull_request.title }}
|
||||
PR_AUTHOR: ${{ github.event.pull_request.user.login }}
|
||||
LABELS: '${{ steps.pr_details.outputs.labels }}'
|
||||
run: |
|
||||
set -e -o pipefail
|
||||
PR_NUMBER='${{ steps.should_run.outputs.pr_number }}'
|
||||
COMMIT_SHA='${{ steps.should_run.outputs.merge_commit_sha }}'
|
||||
PR_TITLE='${{ steps.pr_details.outputs.pr_title }}'
|
||||
PR_AUTHOR='${{ steps.pr_details.outputs.pr_author }}'
|
||||
LABELS='${{ steps.pr_details.outputs.labels }}'
|
||||
|
||||
echo "Processing PR #$PR_NUMBER (reason: ${{ steps.should_run.outputs.reason }})"
|
||||
echo "Found backport labels: $LABELS"
|
||||
|
||||
2
.github/actions/setup/action.yml
vendored
2
.github/actions/setup/action.yml
vendored
@@ -58,7 +58,7 @@ runs:
|
||||
run: |
|
||||
export PSQL_TAG=${{ inputs.postgresql_version }}
|
||||
docker compose -f .github/actions/setup/compose.yml up -d
|
||||
cd web && npm i
|
||||
cd web && npm ci
|
||||
- name: Generate config
|
||||
if: ${{ contains(inputs.dependencies, 'python') }}
|
||||
shell: uv run python {0}
|
||||
|
||||
2
.github/workflows/api-ts-publish.yml
vendored
2
.github/workflows/api-ts-publish.yml
vendored
@@ -21,7 +21,7 @@ jobs:
|
||||
uses: actions/create-github-app-token@29824e69f54612133e76f7eaac726eef6c875baf # v2
|
||||
with:
|
||||
app-id: ${{ secrets.GH_APP_ID }}
|
||||
private-key: ${{ secrets.GH_APP_PRIVATE_KEY }}
|
||||
private-key: ${{ secrets.GH_APP_PRIV_KEY }}
|
||||
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v5
|
||||
with:
|
||||
token: ${{ steps.generate_token.outputs.token }}
|
||||
|
||||
2
.github/workflows/gen-image-compress.yml
vendored
2
.github/workflows/gen-image-compress.yml
vendored
@@ -32,7 +32,7 @@ jobs:
|
||||
uses: actions/create-github-app-token@29824e69f54612133e76f7eaac726eef6c875baf # v2
|
||||
with:
|
||||
app-id: ${{ secrets.GH_APP_ID }}
|
||||
private-key: ${{ secrets.GH_APP_PRIVATE_KEY }}
|
||||
private-key: ${{ secrets.GH_APP_PRIV_KEY }}
|
||||
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v5
|
||||
with:
|
||||
token: ${{ steps.generate_token.outputs.token }}
|
||||
|
||||
@@ -19,7 +19,7 @@ jobs:
|
||||
uses: actions/create-github-app-token@29824e69f54612133e76f7eaac726eef6c875baf # v2
|
||||
with:
|
||||
app-id: ${{ secrets.GH_APP_ID }}
|
||||
private-key: ${{ secrets.GH_APP_PRIVATE_KEY }}
|
||||
private-key: ${{ secrets.GH_APP_PRIV_KEY }}
|
||||
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v5
|
||||
with:
|
||||
token: ${{ steps.generate_token.outputs.token }}
|
||||
|
||||
2
.github/workflows/gh-cherry-pick.yml
vendored
2
.github/workflows/gh-cherry-pick.yml
vendored
@@ -14,7 +14,7 @@ jobs:
|
||||
if: ${{ env.GH_APP_ID != '' }}
|
||||
with:
|
||||
app-id: ${{ secrets.GH_APP_ID }}
|
||||
private-key: ${{ secrets.GH_APP_PRIVATE_KEY }}
|
||||
private-key: ${{ secrets.GH_APP_PRIV_KEY }}
|
||||
env:
|
||||
GH_APP_ID: ${{ secrets.GH_APP_ID }}
|
||||
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v5
|
||||
|
||||
2
.github/workflows/gh-ghcr-retention.yml
vendored
2
.github/workflows/gh-ghcr-retention.yml
vendored
@@ -19,7 +19,7 @@ jobs:
|
||||
uses: actions/create-github-app-token@29824e69f54612133e76f7eaac726eef6c875baf # v2
|
||||
with:
|
||||
app-id: ${{ secrets.GH_APP_ID }}
|
||||
private-key: ${{ secrets.GH_APP_PRIVATE_KEY }}
|
||||
private-key: ${{ secrets.GH_APP_PRIV_KEY }}
|
||||
- name: Delete 'dev' containers older than a week
|
||||
uses: snok/container-retention-policy@3b0972b2276b171b212f8c4efbca59ebba26eceb # v3.0.1
|
||||
with:
|
||||
|
||||
4
.github/workflows/release-branch-off.yml
vendored
4
.github/workflows/release-branch-off.yml
vendored
@@ -32,7 +32,7 @@ jobs:
|
||||
uses: actions/create-github-app-token@29824e69f54612133e76f7eaac726eef6c875baf # v2
|
||||
with:
|
||||
app-id: ${{ secrets.GH_APP_ID }}
|
||||
private-key: ${{ secrets.GH_APP_PRIVATE_KEY }}
|
||||
private-key: ${{ secrets.GH_APP_PRIV_KEY }}
|
||||
- name: Checkout main
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v5
|
||||
with:
|
||||
@@ -60,7 +60,7 @@ jobs:
|
||||
uses: actions/create-github-app-token@29824e69f54612133e76f7eaac726eef6c875baf # v2
|
||||
with:
|
||||
app-id: ${{ secrets.GH_APP_ID }}
|
||||
private-key: ${{ secrets.GH_APP_PRIVATE_KEY }}
|
||||
private-key: ${{ secrets.GH_APP_PRIV_KEY }}
|
||||
- name: Checkout main
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v5
|
||||
with:
|
||||
|
||||
21
.github/workflows/release-publish.yml
vendored
21
.github/workflows/release-publish.yml
vendored
@@ -160,10 +160,17 @@ jobs:
|
||||
node-version-file: web/package.json
|
||||
cache: "npm"
|
||||
cache-dependency-path: web/package-lock.json
|
||||
- name: Build web
|
||||
- name: Install web dependencies
|
||||
working-directory: web/
|
||||
run: |
|
||||
npm ci
|
||||
- name: Generate API Clients
|
||||
run: |
|
||||
make gen-client-ts
|
||||
make gen-client-go
|
||||
- name: Build web
|
||||
working-directory: web/
|
||||
run: |
|
||||
npm run build-proxy
|
||||
- name: Build outpost
|
||||
run: |
|
||||
@@ -210,12 +217,12 @@ jobs:
|
||||
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v5
|
||||
- name: Run test suite in final docker images
|
||||
run: |
|
||||
echo "PG_PASS=$(openssl rand 32 | base64 -w 0)" >> .env
|
||||
echo "AUTHENTIK_SECRET_KEY=$(openssl rand 32 | base64 -w 0)" >> .env
|
||||
docker compose pull -q
|
||||
docker compose up --no-start
|
||||
docker compose start postgresql
|
||||
docker compose run -u root server test-all
|
||||
echo "PG_PASS=$(openssl rand 32 | base64 -w 0)" >> lifecycle/container/.env
|
||||
echo "AUTHENTIK_SECRET_KEY=$(openssl rand 32 | base64 -w 0)" >> lifecycle/container/.env
|
||||
docker compose -f lifecycle/container/compose.yml pull -q
|
||||
docker compose -f lifecycle/container/compose.yml up --no-start
|
||||
docker compose -f lifecycle/container/compose.yml start postgresql
|
||||
docker compose -f lifecycle/container/compose.yml run -u root server test-all
|
||||
sentry-release:
|
||||
needs:
|
||||
- build-server
|
||||
|
||||
15
.github/workflows/release-tag.yml
vendored
15
.github/workflows/release-tag.yml
vendored
@@ -70,7 +70,7 @@ jobs:
|
||||
uses: actions/create-github-app-token@29824e69f54612133e76f7eaac726eef6c875baf # v2
|
||||
with:
|
||||
app-id: ${{ secrets.GH_APP_ID }}
|
||||
private-key: ${{ secrets.GH_APP_PRIVATE_KEY }}
|
||||
private-key: ${{ secrets.GH_APP_PRIV_KEY }}
|
||||
- id: get-user-id
|
||||
name: Get GitHub app user ID
|
||||
run: echo "user-id=$(gh api "/users/${{ steps.app-token.outputs.app-slug }}[bot]" --jq .id)" >> "$GITHUB_OUTPUT"
|
||||
@@ -91,6 +91,7 @@ jobs:
|
||||
# ID from https://api.github.com/users/authentik-automation[bot]
|
||||
git config --global user.name '${{ steps.app-token.outputs.app-slug }}[bot]'
|
||||
git config --global user.email '${{ steps.get-user-id.outputs.user-id }}+${{ steps.app-token.outputs.app-slug }}[bot]@users.noreply.github.com'
|
||||
git pull
|
||||
git commit -a -m "release: ${{ inputs.version }}" --allow-empty
|
||||
git tag "version/${{ inputs.version }}" HEAD -m "version/${{ inputs.version }}"
|
||||
git push --follow-tags
|
||||
@@ -117,7 +118,7 @@ jobs:
|
||||
uses: actions/create-github-app-token@29824e69f54612133e76f7eaac726eef6c875baf # v2
|
||||
with:
|
||||
app-id: ${{ secrets.GH_APP_ID }}
|
||||
private-key: ${{ secrets.GH_APP_PRIVATE_KEY }}
|
||||
private-key: ${{ secrets.GH_APP_PRIV_KEY }}
|
||||
repositories: helm
|
||||
- id: get-user-id
|
||||
name: Get GitHub app user ID
|
||||
@@ -159,7 +160,7 @@ jobs:
|
||||
uses: actions/create-github-app-token@29824e69f54612133e76f7eaac726eef6c875baf # v2
|
||||
with:
|
||||
app-id: ${{ secrets.GH_APP_ID }}
|
||||
private-key: ${{ secrets.GH_APP_PRIVATE_KEY }}
|
||||
private-key: ${{ secrets.GH_APP_PRIV_KEY }}
|
||||
repositories: version
|
||||
- id: get-user-id
|
||||
name: Get GitHub app user ID
|
||||
@@ -174,21 +175,25 @@ jobs:
|
||||
if: "${{ inputs.release_reason == 'feature' }}"
|
||||
run: |
|
||||
changelog_url="https://docs.goauthentik.io/docs/releases/${{ needs.check-inputs.outputs.major_version }}"
|
||||
reason="${{ inputs.release_reason }}"
|
||||
jq \
|
||||
--arg version "${{ inputs.version }}" \
|
||||
--arg changelog "See ${changelog_url}" \
|
||||
--arg changelog_url "${changelog_url}" \
|
||||
'.stable.version = $version | .stable.changelog = $changelog | .stable.changelog_url = $changelog_url' version.json > version.new.json
|
||||
--arg reason "${reason}" \
|
||||
'.stable.version = $version | .stable.changelog = $changelog | .stable.changelog_url = $changelog_url | .stable.reason = $reason' version.json > version.new.json
|
||||
mv version.new.json version.json
|
||||
- name: Bump version
|
||||
if: "${{ inputs.release_reason != 'feature' }}"
|
||||
run: |
|
||||
changelog_url="https://docs.goauthentik.io/docs/releases/${{ needs.check-inputs.outputs.major_version }}#fixed-in-$(echo -n ${{ inputs.version}} | sed 's/\.//g')"
|
||||
reason="${{ inputs.release_reason }}"
|
||||
jq \
|
||||
--arg version "${{ inputs.version }}" \
|
||||
--arg changelog "See ${changelog_url}" \
|
||||
--arg changelog_url "${changelog_url}" \
|
||||
'.stable.version = $version | .stable.changelog = $changelog | .stable.changelog_url = $changelog_url' version.json > version.new.json
|
||||
--arg reason "${reason}" \
|
||||
'.stable.version = $version | .stable.changelog = $changelog | .stable.changelog_url = $changelog_url | .stable.reason = $reason' version.json > version.new.json
|
||||
mv version.new.json version.json
|
||||
- name: Create pull request
|
||||
uses: peter-evans/create-pull-request@c0f553fe549906ede9cf27b5156039d195d2ece0 # v7
|
||||
|
||||
2
.github/workflows/repo-stale.yml
vendored
2
.github/workflows/repo-stale.yml
vendored
@@ -18,7 +18,7 @@ jobs:
|
||||
uses: actions/create-github-app-token@29824e69f54612133e76f7eaac726eef6c875baf # v2
|
||||
with:
|
||||
app-id: ${{ secrets.GH_APP_ID }}
|
||||
private-key: ${{ secrets.GH_APP_PRIVATE_KEY }}
|
||||
private-key: ${{ secrets.GH_APP_PRIV_KEY }}
|
||||
- uses: actions/stale@997185467fa4f803885201cee163a9f38240193d # v10
|
||||
with:
|
||||
repo-token: ${{ steps.generate_token.outputs.token }}
|
||||
|
||||
@@ -24,7 +24,7 @@ jobs:
|
||||
uses: actions/create-github-app-token@29824e69f54612133e76f7eaac726eef6c875baf # v2
|
||||
with:
|
||||
app-id: ${{ secrets.GH_APP_ID }}
|
||||
private-key: ${{ secrets.GH_APP_PRIVATE_KEY }}
|
||||
private-key: ${{ secrets.GH_APP_PRIV_KEY }}
|
||||
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v5
|
||||
if: ${{ github.event_name != 'pull_request' }}
|
||||
with:
|
||||
|
||||
8
Makefile
8
Makefile
@@ -148,11 +148,11 @@ bump: ## Bump authentik version. Usage: make bump version=20xx.xx.xx
|
||||
ifndef version
|
||||
$(error Usage: make bump version=20xx.xx.xx )
|
||||
endif
|
||||
$(SED_INPLACE) 's/^version = ".*"/version = "$(version)"/' pyproject.toml
|
||||
$(SED_INPLACE) 's/^VERSION = ".*"/VERSION = "$(version)"/' authentik/__init__.py
|
||||
$(eval current_version := $(shell cat ${PWD}/internal/constants/VERSION))
|
||||
$(SED_INPLACE) 's/^version = ".*"/version = "$(version)"/' ${PWD}/pyproject.toml
|
||||
$(SED_INPLACE) 's/^VERSION = ".*"/VERSION = "$(version)"/' ${PWD}/authentik/__init__.py
|
||||
$(MAKE) gen-build gen-compose aws-cfn
|
||||
npm version --no-git-tag-version --allow-same-version $(version)
|
||||
cd ${PWD}/web && npm version --no-git-tag-version --allow-same-version $(version)
|
||||
$(SED_INPLACE) "s/\"${current_version}\"/\"$(version)\"/" ${PWD}/package.json ${PWD}/package-lock.json ${PWD}/web/package.json ${PWD}/web/package-lock.json
|
||||
echo -n $(version) > ${PWD}/internal/constants/VERSION
|
||||
|
||||
#########################
|
||||
|
||||
@@ -20,8 +20,8 @@ Even if the issue is not a CVE, we still greatly appreciate your help in hardeni
|
||||
|
||||
| Version | Supported |
|
||||
| ---------- | ---------- |
|
||||
| 2025.10.x | ✅ |
|
||||
| 2025.12.x | ✅ |
|
||||
| 2026.2.x | ✅ |
|
||||
|
||||
## Reporting a Vulnerability
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
from functools import lru_cache
|
||||
from os import environ
|
||||
|
||||
VERSION = "2026.2.0-rc1"
|
||||
VERSION = "2026.2.2-rc1"
|
||||
ENV_GIT_HASH_KEY = "GIT_BUILD_HASH"
|
||||
|
||||
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
"""Schema generation tests"""
|
||||
|
||||
from pathlib import Path
|
||||
from tempfile import gettempdir
|
||||
from uuid import uuid4
|
||||
|
||||
from django.core.management import call_command
|
||||
from django.urls import reverse
|
||||
@@ -29,15 +31,14 @@ class TestSchemaGeneration(APITestCase):
|
||||
|
||||
def test_build_schema(self):
|
||||
"""Test schema build command"""
|
||||
blueprint_file = Path("blueprints/schema.json")
|
||||
api_file = Path("schema.yml")
|
||||
blueprint_file.unlink()
|
||||
api_file.unlink()
|
||||
tmp = Path(gettempdir())
|
||||
blueprint_file = tmp / f"{str(uuid4())}.json"
|
||||
api_file = tmp / f"{str(uuid4())}.yml"
|
||||
with (
|
||||
CONFIG.patch("debug", True),
|
||||
CONFIG.patch("tenants.enabled", True),
|
||||
CONFIG.patch("outposts.disable_embedded_outpost", True),
|
||||
):
|
||||
call_command("build_schema")
|
||||
call_command("build_schema", blueprint_file=blueprint_file, api_file=api_file)
|
||||
self.assertTrue(blueprint_file.exists())
|
||||
self.assertTrue(api_file.exists())
|
||||
|
||||
@@ -47,7 +47,12 @@ class ApplicationSerializer(ModelSerializer):
|
||||
"""Application Serializer"""
|
||||
|
||||
launch_url = SerializerMethodField()
|
||||
provider_obj = ProviderSerializer(source="get_provider", required=False, read_only=True)
|
||||
provider_obj = ProviderSerializer(
|
||||
source="get_provider",
|
||||
required=False,
|
||||
read_only=True,
|
||||
allow_null=True,
|
||||
)
|
||||
backchannel_providers_obj = ProviderSerializer(
|
||||
source="backchannel_providers", required=False, read_only=True, many=True
|
||||
)
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
"""authentik core models"""
|
||||
|
||||
import re
|
||||
import traceback
|
||||
from datetime import datetime, timedelta
|
||||
from enum import StrEnum
|
||||
from hashlib import sha256
|
||||
@@ -15,7 +17,6 @@ from django.contrib.sessions.base_session import AbstractBaseSession
|
||||
from django.core.validators import validate_slug
|
||||
from django.db import models
|
||||
from django.db.models import Q, QuerySet, options
|
||||
from django.db.models.constants import LOOKUP_SEP
|
||||
from django.http import HttpRequest
|
||||
from django.utils.functional import cached_property
|
||||
from django.utils.timezone import now
|
||||
@@ -43,6 +44,7 @@ from authentik.lib.models import (
|
||||
DomainlessFormattedURLValidator,
|
||||
SerializerModel,
|
||||
)
|
||||
from authentik.lib.utils.inheritance import get_deepest_child
|
||||
from authentik.lib.utils.time import timedelta_from_string
|
||||
from authentik.policies.models import PolicyBindingModel
|
||||
from authentik.rbac.models import Role
|
||||
@@ -528,23 +530,35 @@ class User(SerializerModel, AttributesMixin, AbstractUser):
|
||||
"default: in 30 days). See authentik logs for every will invocation of this "
|
||||
"deprecation."
|
||||
)
|
||||
stacktrace = traceback.format_stack()
|
||||
# The last line is this function, the next-to-last line is its caller
|
||||
cause = stacktrace[-2] if len(stacktrace) > 1 else "Unknown, see stacktrace in logs"
|
||||
if search := re.search(r'"(.*?)"', cause):
|
||||
cause = f"Property mapping or Expression policy named {search.group(1)}"
|
||||
|
||||
LOGGER.warning(
|
||||
"deprecation used",
|
||||
message=message_logger,
|
||||
deprecation=deprecation,
|
||||
replacement=replacement,
|
||||
cause=cause,
|
||||
stacktrace=stacktrace,
|
||||
)
|
||||
if not Event.filter_not_expired(
|
||||
action=EventAction.CONFIGURATION_WARNING, context__deprecation=deprecation
|
||||
action=EventAction.CONFIGURATION_WARNING,
|
||||
context__deprecation=deprecation,
|
||||
context__cause=cause,
|
||||
).exists():
|
||||
event = Event.new(
|
||||
EventAction.CONFIGURATION_WARNING,
|
||||
deprecation=deprecation,
|
||||
replacement=replacement,
|
||||
message=message_event,
|
||||
cause=cause,
|
||||
)
|
||||
event.expires = datetime.now() + timedelta(days=30)
|
||||
event.save()
|
||||
|
||||
return self.groups
|
||||
|
||||
def set_password(self, raw_password, signal=True, sender=None, request=None):
|
||||
@@ -789,25 +803,7 @@ class Application(SerializerModel, PolicyBindingModel):
|
||||
"""Get casted provider instance. Needs Application queryset with_provider"""
|
||||
if not self.provider:
|
||||
return None
|
||||
|
||||
candidates = []
|
||||
base_class = Provider
|
||||
for subclass in base_class.objects.get_queryset()._get_subclasses_recurse(base_class):
|
||||
parent = self.provider
|
||||
for level in subclass.split(LOOKUP_SEP):
|
||||
try:
|
||||
parent = getattr(parent, level)
|
||||
except AttributeError:
|
||||
break
|
||||
if parent in candidates:
|
||||
continue
|
||||
idx = subclass.count(LOOKUP_SEP)
|
||||
if type(parent) is not base_class:
|
||||
idx += 1
|
||||
candidates.insert(idx, parent)
|
||||
if not candidates:
|
||||
return None
|
||||
return candidates[-1]
|
||||
return get_deepest_child(self.provider)
|
||||
|
||||
def backchannel_provider_for[T: Provider](self, provider_type: type[T], **kwargs) -> T | None:
|
||||
"""Get Backchannel provider for a specific type"""
|
||||
@@ -1119,7 +1115,11 @@ class ExpiringModel(models.Model):
|
||||
default the object is deleted. This is less efficient compared
|
||||
to bulk deleting objects, but classes like Token() need to change
|
||||
values instead of being deleted."""
|
||||
return self.delete(*args, **kwargs)
|
||||
try:
|
||||
return self.delete(*args, **kwargs)
|
||||
except self.DoesNotExist:
|
||||
# Object has already been deleted, so this should be fine
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def filter_not_expired(cls, **kwargs) -> QuerySet[Self]:
|
||||
|
||||
@@ -24,7 +24,8 @@ from authentik.root.ws.consumer import build_device_group
|
||||
|
||||
# Arguments: user: User, password: str
|
||||
password_changed = Signal()
|
||||
# Arguments: credentials: dict[str, any], request: HttpRequest, stage: Stage
|
||||
# Arguments: credentials: dict[str, any], request: HttpRequest,
|
||||
# stage: Stage, context: dict[str, any]
|
||||
login_failed = Signal()
|
||||
|
||||
LOGGER = get_logger()
|
||||
|
||||
@@ -44,19 +44,24 @@
|
||||
{% endblock %}
|
||||
</div>
|
||||
</main>
|
||||
<footer aria-label="Site footer" class="pf-c-login__footer pf-m-dark">
|
||||
<ul class="pf-c-list pf-m-inline">
|
||||
{% for link in footer_links %}
|
||||
<li>
|
||||
<a href="{{ link.href }}">{{ link.name }}</a>
|
||||
</li>
|
||||
{% endfor %}
|
||||
<li>
|
||||
<span>
|
||||
{% trans 'Powered by authentik' %}
|
||||
</span>
|
||||
</li>
|
||||
</ul>
|
||||
<footer
|
||||
name="site-footer"
|
||||
aria-label="{% trans 'Site footer' %}"
|
||||
class="pf-c-login__footer pf-m-dark">
|
||||
<div name="flow-links" aria-label="{% trans 'Flow links' %}">
|
||||
<ul class="pf-c-list pf-m-inline" part="list">
|
||||
{% for link in footer_links %}
|
||||
<li part="list-item">
|
||||
<a part="list-item-link" href="{{ link.href }}">{{ link.name }}</a>
|
||||
</li>
|
||||
{% endfor %}
|
||||
<li part="list-item">
|
||||
<span>
|
||||
{% trans 'Powered by authentik' %}
|
||||
</span>
|
||||
</li>
|
||||
</ul>
|
||||
</div>
|
||||
</footer>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
@@ -78,7 +78,7 @@ def generate_key_id_legacy(key_data: str) -> str:
|
||||
"""Generate Key ID using MD5 (legacy format for backwards compatibility)."""
|
||||
if not key_data:
|
||||
return ""
|
||||
return md5(key_data.encode("utf-8")).hexdigest() # nosec
|
||||
return md5(key_data.encode("utf-8"), usedforsecurity=False).hexdigest() # nosec
|
||||
|
||||
|
||||
class CertificateKeyPair(SerializerModel, ManagedModel, CreatedUpdatedModel):
|
||||
|
||||
@@ -1,8 +1,11 @@
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
from rest_framework.exceptions import ValidationError
|
||||
from rest_framework.viewsets import ModelViewSet
|
||||
|
||||
from authentik.core.api.used_by import UsedByMixin
|
||||
from authentik.endpoints.api.connectors import ConnectorSerializer
|
||||
from authentik.endpoints.models import EndpointStage
|
||||
from authentik.endpoints.controller import Capabilities
|
||||
from authentik.endpoints.models import Connector, EndpointStage
|
||||
from authentik.flows.api.stages import StageSerializer
|
||||
|
||||
|
||||
@@ -11,6 +14,13 @@ class EndpointStageSerializer(StageSerializer):
|
||||
|
||||
connector_obj = ConnectorSerializer(source="connector", read_only=True)
|
||||
|
||||
def validate_connector(self, connector: Connector) -> Connector:
|
||||
conn: Connector = Connector.objects.get_subclass(pk=connector.pk)
|
||||
controller = conn.controller(conn)
|
||||
if Capabilities.STAGE_ENDPOINTS not in controller.capabilities():
|
||||
raise ValidationError(_("Selected connector is not compatible with this stage."))
|
||||
return connector
|
||||
|
||||
class Meta:
|
||||
model = EndpointStage
|
||||
fields = StageSerializer.Meta.fields + [
|
||||
|
||||
@@ -18,7 +18,10 @@ from authentik.rbac.decorators import permission_required
|
||||
class EnrollmentTokenSerializer(ModelSerializer):
|
||||
|
||||
device_group_obj = DeviceAccessGroupSerializer(
|
||||
source="device_group", read_only=True, required=False
|
||||
source="device_group",
|
||||
read_only=True,
|
||||
required=False,
|
||||
allow_null=True,
|
||||
)
|
||||
|
||||
def __init__(self, *args, **kwargs) -> None:
|
||||
|
||||
@@ -37,6 +37,8 @@ class AgentEnrollmentAuth(BaseAuthentication):
|
||||
token = EnrollmentToken.filter_not_expired(key=key).first()
|
||||
if not token:
|
||||
raise PermissionDenied()
|
||||
if not token.connector.enabled:
|
||||
raise PermissionDenied()
|
||||
CTX_AUTH_VIA.set("endpoint_token_enrollment")
|
||||
return (DeviceUser(), token)
|
||||
|
||||
@@ -51,6 +53,8 @@ class AgentAuth(BaseAuthentication):
|
||||
device_token = DeviceToken.filter_not_expired(key=key).first()
|
||||
if not device_token:
|
||||
raise PermissionDenied()
|
||||
if not device_token.device.connector.enabled:
|
||||
raise PermissionDenied()
|
||||
if device_token.device.device.is_expired:
|
||||
raise PermissionDenied()
|
||||
CTX_AUTH_VIA.set("endpoint_token")
|
||||
|
||||
@@ -8,7 +8,7 @@ from rest_framework.fields import CharField
|
||||
|
||||
from authentik.core.api.utils import PassiveSerializer
|
||||
from authentik.endpoints.connectors.agent.models import AgentConnector, EnrollmentToken
|
||||
from authentik.endpoints.controller import BaseController
|
||||
from authentik.endpoints.controller import BaseController, Capabilities
|
||||
from authentik.endpoints.facts import OSFamily
|
||||
|
||||
|
||||
@@ -48,8 +48,8 @@ class AgentConnectorController(BaseController[AgentConnector]):
|
||||
def vendor_identifier() -> str:
|
||||
return "goauthentik.io/platform"
|
||||
|
||||
def supported_enrollment_methods(self):
|
||||
return []
|
||||
def capabilities(self) -> list[Capabilities]:
|
||||
return [Capabilities.STAGE_ENDPOINTS]
|
||||
|
||||
def generate_mdm_config(
|
||||
self, target_platform: OSFamily, request: HttpRequest, token: EnrollmentToken
|
||||
|
||||
@@ -58,6 +58,16 @@ class TestAgentAPI(APITestCase):
|
||||
)
|
||||
self.assertEqual(response.status_code, 200)
|
||||
|
||||
def test_enroll_disabled(self):
|
||||
self.connector.enabled = False
|
||||
self.connector.save()
|
||||
response = self.client.post(
|
||||
reverse("authentik_api:agentconnector-enroll"),
|
||||
data={"device_serial": generate_id(), "device_name": "bar"},
|
||||
HTTP_AUTHORIZATION=f"Bearer {self.token.key}",
|
||||
)
|
||||
self.assertEqual(response.status_code, 403)
|
||||
|
||||
def test_enroll_token_delete(self):
|
||||
response = self.client.post(
|
||||
reverse("authentik_api:agentconnector-enroll"),
|
||||
@@ -104,6 +114,16 @@ class TestAgentAPI(APITestCase):
|
||||
)
|
||||
self.assertEqual(response.status_code, 200)
|
||||
|
||||
@reconcile_app("authentik_crypto")
|
||||
def test_config_disabled(self):
|
||||
self.connector.enabled = False
|
||||
self.connector.save()
|
||||
response = self.client.get(
|
||||
reverse("authentik_api:agentconnector-agent-config"),
|
||||
HTTP_AUTHORIZATION=f"Bearer+agent {self.device_token.key}",
|
||||
)
|
||||
self.assertEqual(response.status_code, 403)
|
||||
|
||||
def test_check_in(self):
|
||||
response = self.client.post(
|
||||
reverse("authentik_api:agentconnector-check-in"),
|
||||
@@ -112,6 +132,16 @@ class TestAgentAPI(APITestCase):
|
||||
)
|
||||
self.assertEqual(response.status_code, 204)
|
||||
|
||||
def test_check_in_disabled(self):
|
||||
self.connector.enabled = False
|
||||
self.connector.save()
|
||||
response = self.client.post(
|
||||
reverse("authentik_api:agentconnector-check-in"),
|
||||
data=CHECK_IN_DATA_VALID,
|
||||
HTTP_AUTHORIZATION=f"Bearer+agent {self.device_token.key}",
|
||||
)
|
||||
self.assertEqual(response.status_code, 403)
|
||||
|
||||
def test_check_in_token_expired(self):
|
||||
self.device_token.expiring = True
|
||||
self.device_token.expires = now() - timedelta(hours=1)
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from hashlib import sha256
|
||||
from json import loads
|
||||
from unittest.mock import PropertyMock, patch
|
||||
|
||||
from django.urls import reverse
|
||||
from jwt import encode
|
||||
@@ -232,3 +233,43 @@ class TestEndpointStage(FlowTestCase):
|
||||
plan = plan()
|
||||
self.assertNotIn(PLAN_CONTEXT_AGENT_ENDPOINT_CHALLENGE, plan.context)
|
||||
self.assertEqual(plan.context[PLAN_CONTEXT_DEVICE], self.device)
|
||||
|
||||
def test_endpoint_stage_connector_no_stage_optional(self):
|
||||
flow = create_test_flow()
|
||||
stage = EndpointStage.objects.create(connector=self.connector, mode=StageMode.OPTIONAL)
|
||||
FlowStageBinding.objects.create(stage=stage, target=flow, order=0)
|
||||
|
||||
with patch(
|
||||
"authentik.endpoints.connectors.agent.models.AgentConnector.stage",
|
||||
PropertyMock(return_value=None),
|
||||
):
|
||||
with self.assertFlowFinishes() as plan:
|
||||
res = self.client.get(
|
||||
reverse("authentik_api:flow-executor", kwargs={"flow_slug": flow.slug}),
|
||||
)
|
||||
self.assertStageRedirects(res, reverse("authentik_core:root-redirect"))
|
||||
plan = plan()
|
||||
self.assertNotIn(PLAN_CONTEXT_AGENT_ENDPOINT_CHALLENGE, plan.context)
|
||||
self.assertNotIn(PLAN_CONTEXT_DEVICE, plan.context)
|
||||
|
||||
def test_endpoint_stage_connector_no_stage_required(self):
|
||||
flow = create_test_flow()
|
||||
stage = EndpointStage.objects.create(connector=self.connector, mode=StageMode.REQUIRED)
|
||||
FlowStageBinding.objects.create(stage=stage, target=flow, order=0)
|
||||
|
||||
with patch(
|
||||
"authentik.endpoints.connectors.agent.models.AgentConnector.stage",
|
||||
PropertyMock(return_value=None),
|
||||
):
|
||||
with self.assertFlowFinishes() as plan:
|
||||
res = self.client.get(
|
||||
reverse("authentik_api:flow-executor", kwargs={"flow_slug": flow.slug}),
|
||||
)
|
||||
self.assertStageResponse(
|
||||
res,
|
||||
component="ak-stage-access-denied",
|
||||
error_message="Invalid stage configuration",
|
||||
)
|
||||
plan = plan()
|
||||
self.assertNotIn(PLAN_CONTEXT_AGENT_ENDPOINT_CHALLENGE, plan.context)
|
||||
self.assertNotIn(PLAN_CONTEXT_DEVICE, plan.context)
|
||||
|
||||
@@ -8,13 +8,15 @@ from authentik.lib.sentry import SentryIgnoredException
|
||||
MERGED_VENDOR = "goauthentik.io/@merged"
|
||||
|
||||
|
||||
class EnrollmentMethods(models.TextChoices):
|
||||
class Capabilities(models.TextChoices):
|
||||
# Automatically enrolled through user action
|
||||
AUTOMATIC_USER = "automatic_user"
|
||||
ENROLL_AUTOMATIC_USER = "enroll_automatic_user"
|
||||
# Automatically enrolled through connector integration
|
||||
AUTOMATIC_API = "automatic_api"
|
||||
ENROLL_AUTOMATIC_API = "enroll_automatic_api"
|
||||
# Manually enrolled with user interaction (user scanning a QR code for example)
|
||||
MANUAL_USER = "manual_user"
|
||||
ENROLL_MANUAL_USER = "enroll_manual_user"
|
||||
# Supported for use with Endpoints stage
|
||||
STAGE_ENDPOINTS = "stage_endpoints"
|
||||
|
||||
|
||||
class ConnectorSyncException(SentryIgnoredException):
|
||||
@@ -34,7 +36,7 @@ class BaseController[T: "Connector"]:
|
||||
def vendor_identifier() -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
def supported_enrollment_methods(self) -> list[EnrollmentMethods]:
|
||||
def capabilities(self) -> list[Capabilities]:
|
||||
return []
|
||||
|
||||
def stage_view_enrollment(self) -> StageView | None:
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from authentik.endpoints.models import EndpointStage
|
||||
from authentik.endpoints.models import Connector, EndpointStage, StageMode
|
||||
from authentik.flows.stage import StageView
|
||||
|
||||
PLAN_CONTEXT_ENDPOINT_CONNECTOR = "endpoint_connector"
|
||||
@@ -6,15 +6,27 @@ PLAN_CONTEXT_ENDPOINT_CONNECTOR = "endpoint_connector"
|
||||
|
||||
class EndpointStageView(StageView):
|
||||
|
||||
def _get_inner(self):
|
||||
def _get_inner(self) -> StageView | None:
|
||||
stage: EndpointStage = self.executor.current_stage
|
||||
inner_stage: type[StageView] | None = stage.connector.stage
|
||||
connector: Connector = stage.connector
|
||||
if not connector.enabled:
|
||||
return None
|
||||
inner_stage: type[StageView] | None = connector.stage
|
||||
if not inner_stage:
|
||||
return self.executor.stage_ok()
|
||||
return None
|
||||
return inner_stage(self.executor, request=self.request)
|
||||
|
||||
def dispatch(self, request, *args, **kwargs):
|
||||
return self._get_inner().dispatch(request, *args, **kwargs)
|
||||
inner = self._get_inner()
|
||||
if inner is None:
|
||||
stage: EndpointStage = self.executor.current_stage
|
||||
if stage.mode == StageMode.OPTIONAL:
|
||||
return self.executor.stage_ok()
|
||||
else:
|
||||
return self.executor.stage_invalid("Invalid stage configuration")
|
||||
return inner.dispatch(request, *args, **kwargs)
|
||||
|
||||
def cleanup(self):
|
||||
return self._get_inner().cleanup()
|
||||
inner = self._get_inner()
|
||||
if inner is not None:
|
||||
return inner.cleanup()
|
||||
|
||||
@@ -6,7 +6,7 @@ from django.utils.translation import gettext_lazy as _
|
||||
from dramatiq.actor import actor
|
||||
from structlog.stdlib import get_logger
|
||||
|
||||
from authentik.endpoints.controller import EnrollmentMethods
|
||||
from authentik.endpoints.controller import Capabilities
|
||||
from authentik.endpoints.models import Connector
|
||||
|
||||
LOGGER = get_logger()
|
||||
@@ -17,11 +17,11 @@ def endpoints_sync(connector_pk: Any):
|
||||
connector: Connector | None = (
|
||||
Connector.objects.filter(pk=connector_pk).select_subclasses().first()
|
||||
)
|
||||
if not connector:
|
||||
if not connector or not connector.enabled:
|
||||
return
|
||||
controller = connector.controller
|
||||
ctrl = controller(connector)
|
||||
if EnrollmentMethods.AUTOMATIC_API not in ctrl.supported_enrollment_methods():
|
||||
if Capabilities.AUTOMATIC_API not in ctrl.capabilities():
|
||||
return
|
||||
LOGGER.info("Syncing connector", connector=connector.name)
|
||||
ctrl.sync_endpoints()
|
||||
|
||||
41
authentik/endpoints/tests/test_api.py
Normal file
41
authentik/endpoints/tests/test_api.py
Normal file
@@ -0,0 +1,41 @@
|
||||
from django.urls import reverse
|
||||
from rest_framework.test import APITestCase
|
||||
|
||||
from authentik.core.tests.utils import create_test_admin_user
|
||||
from authentik.endpoints.connectors.agent.models import AgentConnector
|
||||
from authentik.endpoints.models import StageMode
|
||||
from authentik.enterprise.endpoints.connectors.fleet.models import FleetConnector
|
||||
from authentik.lib.generators import generate_id
|
||||
|
||||
|
||||
class TestAPI(APITestCase):
|
||||
def setUp(self):
|
||||
self.user = create_test_admin_user()
|
||||
self.client.force_login(self.user)
|
||||
|
||||
def test_endpoint_stage_agent(self):
|
||||
connector = AgentConnector.objects.create(name=generate_id())
|
||||
res = self.client.post(
|
||||
reverse("authentik_api:stages-endpoint-list"),
|
||||
data={
|
||||
"name": generate_id(),
|
||||
"connector": str(connector.pk),
|
||||
"mode": StageMode.REQUIRED,
|
||||
},
|
||||
)
|
||||
self.assertEqual(res.status_code, 201)
|
||||
|
||||
def test_endpoint_stage_fleet(self):
|
||||
connector = FleetConnector.objects.create(name=generate_id())
|
||||
res = self.client.post(
|
||||
reverse("authentik_api:stages-endpoint-list"),
|
||||
data={
|
||||
"name": generate_id(),
|
||||
"connector": str(connector.pk),
|
||||
"mode": StageMode.REQUIRED,
|
||||
},
|
||||
)
|
||||
self.assertEqual(res.status_code, 400)
|
||||
self.assertJSONEqual(
|
||||
res.content, {"connector": ["Selected connector is not compatible with this stage."]}
|
||||
)
|
||||
@@ -3,6 +3,7 @@ from hmac import compare_digest
|
||||
|
||||
from django.http import Http404, HttpRequest, HttpResponse, HttpResponseBadRequest, QueryDict
|
||||
|
||||
from authentik.common.oauth.constants import QS_LOGIN_HINT
|
||||
from authentik.endpoints.connectors.agent.auth import (
|
||||
agent_auth_issue_token,
|
||||
check_device_policies,
|
||||
@@ -14,7 +15,7 @@ from authentik.enterprise.policy import EnterprisePolicyAccessView
|
||||
from authentik.flows.exceptions import FlowNonApplicableException
|
||||
from authentik.flows.models import in_memory_stage
|
||||
from authentik.flows.planner import PLAN_CONTEXT_DEVICE, FlowPlanner
|
||||
from authentik.flows.stage import StageView
|
||||
from authentik.flows.stage import PLAN_CONTEXT_PENDING_USER_IDENTIFIER, StageView
|
||||
from authentik.providers.oauth2.utils import HttpResponseRedirectScheme
|
||||
|
||||
QS_AGENT_IA_TOKEN = "ak-auth-ia-token" # nosec
|
||||
@@ -64,14 +65,14 @@ class AgentInteractiveAuth(EnterprisePolicyAccessView):
|
||||
|
||||
planner = FlowPlanner(self.connector.authorization_flow)
|
||||
planner.allow_empty_flows = True
|
||||
context = {
|
||||
PLAN_CONTEXT_DEVICE: self.device,
|
||||
PLAN_CONTEXT_DEVICE_AUTH_TOKEN: self.auth_token,
|
||||
}
|
||||
if QS_LOGIN_HINT in request.GET:
|
||||
context[PLAN_CONTEXT_PENDING_USER_IDENTIFIER] = request.GET[QS_LOGIN_HINT]
|
||||
try:
|
||||
plan = planner.plan(
|
||||
self.request,
|
||||
{
|
||||
PLAN_CONTEXT_DEVICE: self.device,
|
||||
PLAN_CONTEXT_DEVICE_AUTH_TOKEN: self.auth_token,
|
||||
},
|
||||
)
|
||||
plan = planner.plan(self.request, context)
|
||||
except FlowNonApplicableException:
|
||||
return self.handle_no_permission_authenticated()
|
||||
plan.append_stage(in_memory_stage(AgentAuthFulfillmentStage))
|
||||
@@ -84,7 +85,6 @@ class AgentInteractiveAuth(EnterprisePolicyAccessView):
|
||||
|
||||
|
||||
class AgentAuthFulfillmentStage(StageView):
|
||||
|
||||
def get(self, request: HttpRequest, *args, **kwargs) -> HttpResponse:
|
||||
device: Device = self.executor.plan.context.pop(PLAN_CONTEXT_DEVICE)
|
||||
auth_token: DeviceAuthenticationToken = self.executor.plan.context.pop(
|
||||
|
||||
@@ -6,7 +6,7 @@ from requests import RequestException
|
||||
from rest_framework.exceptions import ValidationError
|
||||
|
||||
from authentik.core.models import User
|
||||
from authentik.endpoints.controller import BaseController, ConnectorSyncException, EnrollmentMethods
|
||||
from authentik.endpoints.controller import BaseController, Capabilities, ConnectorSyncException
|
||||
from authentik.endpoints.facts import (
|
||||
DeviceFacts,
|
||||
OSFamily,
|
||||
@@ -43,8 +43,8 @@ class FleetController(BaseController[DBC]):
|
||||
def vendor_identifier() -> str:
|
||||
return "fleetdm.com"
|
||||
|
||||
def supported_enrollment_methods(self) -> list[EnrollmentMethods]:
|
||||
return [EnrollmentMethods.AUTOMATIC_API]
|
||||
def capabilities(self) -> list[Capabilities]:
|
||||
return [Capabilities.ENROLL_AUTOMATIC_API]
|
||||
|
||||
def _url(self, path: str) -> str:
|
||||
return f"{self.connector.url}{path}"
|
||||
|
||||
@@ -15,6 +15,7 @@ from django.core.cache import cache
|
||||
from django.db.models.query import QuerySet
|
||||
from django.utils.timezone import now
|
||||
from jwt import PyJWTError, decode, get_unverified_header
|
||||
from jwt.algorithms import ECAlgorithm
|
||||
from rest_framework.exceptions import ValidationError
|
||||
from rest_framework.fields import (
|
||||
ChoiceField,
|
||||
@@ -109,13 +110,20 @@ class LicenseKey:
|
||||
intermediate.verify_directly_issued_by(get_licensing_key())
|
||||
except InvalidSignature, TypeError, ValueError, Error:
|
||||
raise ValidationError("Unable to verify license") from None
|
||||
_validate_curve_original = ECAlgorithm._validate_curve
|
||||
try:
|
||||
# authentik's license are generated with `algorithm="ES512"` and signed with
|
||||
# a key of curve `secp384r1`. Starting with version 2.11.0, pyjwt enforces the spec, see
|
||||
# https://github.com/jpadilla/pyjwt/commit/5b8622773358e56d3d3c0a9acf404809ff34433a
|
||||
# authentik will change its license generation to `algorithm="ES384"` in 2026.
|
||||
# TODO: remove this when the last incompatible license runs out.
|
||||
ECAlgorithm._validate_curve = lambda *_: True
|
||||
body = from_dict(
|
||||
LicenseKey,
|
||||
decode(
|
||||
jwt,
|
||||
our_cert.public_key(),
|
||||
algorithms=["ES512"],
|
||||
algorithms=["ES384", "ES512"],
|
||||
audience=get_license_aud(),
|
||||
options={"verify_exp": check_expiry, "verify_signature": check_expiry},
|
||||
),
|
||||
@@ -125,6 +133,8 @@ class LicenseKey:
|
||||
if unverified["aud"] != get_license_aud():
|
||||
raise ValidationError("Invalid Install ID in license") from None
|
||||
raise ValidationError("Unable to verify license") from None
|
||||
finally:
|
||||
ECAlgorithm._validate_curve = _validate_curve_original
|
||||
return body
|
||||
|
||||
@staticmethod
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
from datetime import date
|
||||
from datetime import datetime
|
||||
|
||||
from django.db.models import BooleanField as ModelBooleanField
|
||||
from django.db.models import Case, Q, Value, When
|
||||
from django_filters.rest_framework import BooleanFilter, FilterSet
|
||||
from drf_spectacular.utils import extend_schema, extend_schema_field
|
||||
from drf_spectacular.utils import extend_schema
|
||||
from rest_framework.decorators import action
|
||||
from rest_framework.fields import DateField, IntegerField, SerializerMethodField
|
||||
from rest_framework.fields import IntegerField, SerializerMethodField
|
||||
from rest_framework.mixins import CreateModelMixin
|
||||
from rest_framework.request import Request
|
||||
from rest_framework.response import Response
|
||||
@@ -21,6 +21,7 @@ from authentik.enterprise.lifecycle.utils import (
|
||||
ReviewerUserSerializer,
|
||||
admin_link_for_model,
|
||||
parse_content_type,
|
||||
start_of_day,
|
||||
)
|
||||
from authentik.lib.utils.time import timedelta_from_string
|
||||
|
||||
@@ -67,13 +68,13 @@ class LifecycleIterationSerializer(EnterpriseRequiredMixin, ModelSerializer):
|
||||
def get_object_admin_url(self, iteration: LifecycleIteration) -> str:
|
||||
return admin_link_for_model(iteration.object)
|
||||
|
||||
@extend_schema_field(DateField())
|
||||
def get_grace_period_end(self, iteration: LifecycleIteration) -> date:
|
||||
return iteration.opened_on + timedelta_from_string(iteration.rule.grace_period)
|
||||
def get_grace_period_end(self, iteration: LifecycleIteration) -> datetime:
|
||||
return start_of_day(
|
||||
iteration.opened_on + timedelta_from_string(iteration.rule.grace_period)
|
||||
)
|
||||
|
||||
@extend_schema_field(DateField())
|
||||
def get_next_review_date(self, iteration: LifecycleIteration):
|
||||
return iteration.opened_on + timedelta_from_string(iteration.rule.interval)
|
||||
def get_next_review_date(self, iteration: LifecycleIteration) -> datetime:
|
||||
return start_of_day(iteration.opened_on + timedelta_from_string(iteration.rule.interval))
|
||||
|
||||
def get_user_can_review(self, iteration: LifecycleIteration) -> bool:
|
||||
return iteration.user_can_review(self.context["request"].user)
|
||||
@@ -102,7 +103,7 @@ class IterationViewSet(EnterpriseRequiredMixin, CreateModelMixin, GenericViewSet
|
||||
default=Value(False),
|
||||
output_field=ModelBooleanField(),
|
||||
)
|
||||
)
|
||||
).distinct()
|
||||
|
||||
@action(
|
||||
detail=False,
|
||||
|
||||
@@ -0,0 +1,18 @@
|
||||
# Generated by Django 5.2.11 on 2026-02-13 09:33
|
||||
|
||||
from django.db import migrations, models
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
|
||||
dependencies = [
|
||||
("authentik_lifecycle", "0001_initial"),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.AlterField(
|
||||
model_name="lifecycleiteration",
|
||||
name="opened_on",
|
||||
field=models.DateTimeField(auto_now_add=True),
|
||||
),
|
||||
]
|
||||
@@ -1,3 +1,4 @@
|
||||
from datetime import timedelta
|
||||
from uuid import uuid4
|
||||
|
||||
from django.contrib.contenttypes.fields import GenericForeignKey
|
||||
@@ -13,7 +14,7 @@ from rest_framework.serializers import BaseSerializer
|
||||
|
||||
from authentik.blueprints.models import ManagedModel
|
||||
from authentik.core.models import Group, User
|
||||
from authentik.enterprise.lifecycle.utils import link_for_model
|
||||
from authentik.enterprise.lifecycle.utils import link_for_model, start_of_day
|
||||
from authentik.events.models import Event, EventAction, NotificationSeverity, NotificationTransport
|
||||
from authentik.lib.models import SerializerModel
|
||||
from authentik.lib.utils.time import timedelta_from_string, timedelta_string_validator
|
||||
@@ -98,7 +99,9 @@ class LifecycleRule(SerializerModel):
|
||||
|
||||
def _get_newly_overdue_iterations(self) -> QuerySet[LifecycleIteration]:
|
||||
return self.lifecycleiteration_set.filter(
|
||||
opened_on__lte=timezone.now() - timedelta_from_string(self.grace_period),
|
||||
opened_on__lt=start_of_day(
|
||||
timezone.now() + timedelta(days=1) - timedelta_from_string(self.grace_period)
|
||||
),
|
||||
state=ReviewState.PENDING,
|
||||
)
|
||||
|
||||
@@ -106,7 +109,9 @@ class LifecycleRule(SerializerModel):
|
||||
recent_iteration_ids = LifecycleIteration.objects.filter(
|
||||
content_type=self.content_type,
|
||||
object_id__isnull=False,
|
||||
opened_on__gte=timezone.now() - timedelta_from_string(self.interval),
|
||||
opened_on__gte=start_of_day(
|
||||
timezone.now() + timedelta(days=1) - timedelta_from_string(self.interval)
|
||||
),
|
||||
).values_list(Cast("object_id", output_field=self._get_pk_field()), flat=True)
|
||||
|
||||
return self.get_objects().exclude(pk__in=recent_iteration_ids)
|
||||
@@ -186,7 +191,7 @@ class LifecycleIteration(SerializerModel, ManagedModel):
|
||||
rule = models.ForeignKey(LifecycleRule, null=True, on_delete=models.SET_NULL)
|
||||
|
||||
state = models.CharField(max_length=10, choices=ReviewState, default=ReviewState.PENDING)
|
||||
opened_on = models.DateField(auto_now_add=True)
|
||||
opened_on = models.DateTimeField(auto_now_add=True)
|
||||
|
||||
class Meta:
|
||||
indexes = [models.Index(fields=["content_type", "opened_on"])]
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import datetime as dt
|
||||
from datetime import timedelta
|
||||
from unittest.mock import patch
|
||||
|
||||
@@ -319,7 +320,7 @@ class TestLifecycleModels(TestCase):
|
||||
content_type=content_type, object_id=str(app_one.pk), rule=rule_overdue
|
||||
)
|
||||
LifecycleIteration.objects.filter(pk=iteration.pk).update(
|
||||
opened_on=(timezone.now().date() - timedelta(days=20))
|
||||
opened_on=(timezone.now() - timedelta(days=20))
|
||||
)
|
||||
|
||||
# Apply again to trigger overdue logic
|
||||
@@ -383,7 +384,7 @@ class TestLifecycleModels(TestCase):
|
||||
content_type=content_type, object_id=str(app_overdue.pk), rule=rule_overdue
|
||||
)
|
||||
LifecycleIteration.objects.filter(pk=overdue_iteration.pk).update(
|
||||
opened_on=(timezone.now().date() - timedelta(days=20))
|
||||
opened_on=(timezone.now() - timedelta(days=20))
|
||||
)
|
||||
|
||||
# Apply overdue rule to mark iteration as overdue
|
||||
@@ -667,3 +668,178 @@ class TestLifecycleModels(TestCase):
|
||||
reviewers = list(rule.get_reviewers())
|
||||
self.assertIn(explicit_reviewer, reviewers)
|
||||
self.assertIn(group_member, reviewers)
|
||||
|
||||
|
||||
class TestLifecycleDateBoundaries(TestCase):
|
||||
"""Verify that start_of_day normalization ensures correct overdue/due
|
||||
detection regardless of exact task execution time within a day.
|
||||
|
||||
The daily task may run at any point during the day. The start_of_day
|
||||
normalization in _get_newly_overdue_iterations and _get_newly_due_objects
|
||||
ensures that the boundary is always at midnight, so millisecond variations
|
||||
in task execution time do not affect results."""
|
||||
|
||||
def _create_rule_and_iteration(self, grace_period="days=1", interval="days=365"):
|
||||
app = Application.objects.create(name=generate_id(), slug=generate_id())
|
||||
content_type = ContentType.objects.get_for_model(Application)
|
||||
rule = LifecycleRule.objects.create(
|
||||
name=generate_id(),
|
||||
content_type=content_type,
|
||||
object_id=str(app.pk),
|
||||
interval=interval,
|
||||
grace_period=grace_period,
|
||||
)
|
||||
iteration = LifecycleIteration.objects.get(
|
||||
content_type=content_type, object_id=str(app.pk), rule=rule
|
||||
)
|
||||
return app, rule, iteration
|
||||
|
||||
def test_overdue_iteration_opened_yesterday(self):
|
||||
"""grace_period=1 day: iteration opened yesterday at any time is overdue today."""
|
||||
_, rule, iteration = self._create_rule_and_iteration(grace_period="days=1")
|
||||
fixed_now = dt.datetime(2025, 6, 15, 14, 30, 0, tzinfo=dt.UTC)
|
||||
for opened_on in [
|
||||
dt.datetime(2025, 6, 14, 0, 0, 0, tzinfo=dt.UTC),
|
||||
dt.datetime(2025, 6, 14, 12, 0, 0, tzinfo=dt.UTC),
|
||||
dt.datetime(2025, 6, 14, 23, 59, 59, 999999, tzinfo=dt.UTC),
|
||||
]:
|
||||
with self.subTest(opened_on=opened_on):
|
||||
LifecycleIteration.objects.filter(pk=iteration.pk).update(
|
||||
opened_on=opened_on, state=ReviewState.PENDING
|
||||
)
|
||||
with patch("django.utils.timezone.now", return_value=fixed_now):
|
||||
self.assertIn(iteration, list(rule._get_newly_overdue_iterations()))
|
||||
|
||||
def test_not_overdue_iteration_opened_today(self):
|
||||
"""grace_period=1 day: iteration opened today at any time is NOT overdue."""
|
||||
_, rule, iteration = self._create_rule_and_iteration(grace_period="days=1")
|
||||
fixed_now = dt.datetime(2025, 6, 15, 14, 30, 0, tzinfo=dt.UTC)
|
||||
for opened_on in [
|
||||
dt.datetime(2025, 6, 15, 0, 0, 0, tzinfo=dt.UTC),
|
||||
dt.datetime(2025, 6, 15, 14, 30, 0, tzinfo=dt.UTC),
|
||||
dt.datetime(2025, 6, 15, 23, 59, 59, 999999, tzinfo=dt.UTC),
|
||||
]:
|
||||
with self.subTest(opened_on=opened_on):
|
||||
LifecycleIteration.objects.filter(pk=iteration.pk).update(
|
||||
opened_on=opened_on, state=ReviewState.PENDING
|
||||
)
|
||||
with patch("django.utils.timezone.now", return_value=fixed_now):
|
||||
self.assertNotIn(iteration, list(rule._get_newly_overdue_iterations()))
|
||||
|
||||
def test_overdue_independent_of_task_execution_time(self):
|
||||
"""Overdue detection gives the same result whether the task runs at 00:00:01 or 23:59:59."""
|
||||
_, rule, iteration = self._create_rule_and_iteration(grace_period="days=1")
|
||||
opened_on = dt.datetime(2025, 6, 14, 18, 0, 0, tzinfo=dt.UTC)
|
||||
LifecycleIteration.objects.filter(pk=iteration.pk).update(
|
||||
opened_on=opened_on, state=ReviewState.PENDING
|
||||
)
|
||||
for task_time in [
|
||||
dt.datetime(2025, 6, 15, 0, 0, 1, tzinfo=dt.UTC),
|
||||
dt.datetime(2025, 6, 15, 12, 0, 0, tzinfo=dt.UTC),
|
||||
dt.datetime(2025, 6, 15, 23, 59, 59, tzinfo=dt.UTC),
|
||||
]:
|
||||
with self.subTest(task_time=task_time):
|
||||
with patch("django.utils.timezone.now", return_value=task_time):
|
||||
self.assertIn(iteration, list(rule._get_newly_overdue_iterations()))
|
||||
|
||||
def test_overdue_boundary_multi_day_grace_period(self):
|
||||
"""grace_period=30 days: overdue after 30 full days, not after 29."""
|
||||
_, rule, iteration = self._create_rule_and_iteration(grace_period="days=30")
|
||||
fixed_now = dt.datetime(2025, 6, 15, 14, 30, 0, tzinfo=dt.UTC)
|
||||
|
||||
# Opened 30 days ago (May 16), should go overdue
|
||||
LifecycleIteration.objects.filter(pk=iteration.pk).update(
|
||||
opened_on=dt.datetime(2025, 5, 16, 12, 0, 0, tzinfo=dt.UTC),
|
||||
state=ReviewState.PENDING,
|
||||
)
|
||||
with patch("django.utils.timezone.now", return_value=fixed_now):
|
||||
self.assertIn(iteration, list(rule._get_newly_overdue_iterations()))
|
||||
|
||||
# Opened 29 days ago (May 17), should NOT go overdue
|
||||
LifecycleIteration.objects.filter(pk=iteration.pk).update(
|
||||
opened_on=dt.datetime(2025, 5, 17, 12, 0, 0, tzinfo=dt.UTC),
|
||||
state=ReviewState.PENDING,
|
||||
)
|
||||
with patch("django.utils.timezone.now", return_value=fixed_now):
|
||||
self.assertNotIn(iteration, list(rule._get_newly_overdue_iterations()))
|
||||
|
||||
def test_due_object_iteration_opened_yesterday(self):
|
||||
"""interval=1 day: object with iteration opened yesterday is due for a new review."""
|
||||
app, rule, iteration = self._create_rule_and_iteration(interval="days=1")
|
||||
fixed_now = dt.datetime(2025, 6, 15, 14, 30, 0, tzinfo=dt.UTC)
|
||||
for opened_on in [
|
||||
dt.datetime(2025, 6, 14, 0, 0, 0, tzinfo=dt.UTC),
|
||||
dt.datetime(2025, 6, 14, 12, 0, 0, tzinfo=dt.UTC),
|
||||
dt.datetime(2025, 6, 14, 23, 59, 59, 999999, tzinfo=dt.UTC),
|
||||
]:
|
||||
with self.subTest(opened_on=opened_on):
|
||||
LifecycleIteration.objects.filter(pk=iteration.pk).update(opened_on=opened_on)
|
||||
with patch("django.utils.timezone.now", return_value=fixed_now):
|
||||
self.assertIn(app, list(rule._get_newly_due_objects()))
|
||||
|
||||
def test_not_due_object_iteration_opened_today(self):
|
||||
"""interval=1 day: object with iteration opened today is NOT due."""
|
||||
app, rule, iteration = self._create_rule_and_iteration(interval="days=1")
|
||||
fixed_now = dt.datetime(2025, 6, 15, 14, 30, 0, tzinfo=dt.UTC)
|
||||
for opened_on in [
|
||||
dt.datetime(2025, 6, 15, 0, 0, 0, tzinfo=dt.UTC),
|
||||
dt.datetime(2025, 6, 15, 14, 30, 0, tzinfo=dt.UTC),
|
||||
dt.datetime(2025, 6, 15, 23, 59, 59, 999999, tzinfo=dt.UTC),
|
||||
]:
|
||||
with self.subTest(opened_on=opened_on):
|
||||
LifecycleIteration.objects.filter(pk=iteration.pk).update(opened_on=opened_on)
|
||||
with patch("django.utils.timezone.now", return_value=fixed_now):
|
||||
self.assertNotIn(app, list(rule._get_newly_due_objects()))
|
||||
|
||||
def test_due_independent_of_task_execution_time(self):
|
||||
"""Due detection gives the same result whether the task runs at 00:00:01 or 23:59:59."""
|
||||
app, rule, iteration = self._create_rule_and_iteration(interval="days=1")
|
||||
opened_on = dt.datetime(2025, 6, 14, 18, 0, 0, tzinfo=dt.UTC)
|
||||
LifecycleIteration.objects.filter(pk=iteration.pk).update(opened_on=opened_on)
|
||||
for task_time in [
|
||||
dt.datetime(2025, 6, 15, 0, 0, 1, tzinfo=dt.UTC),
|
||||
dt.datetime(2025, 6, 15, 12, 0, 0, tzinfo=dt.UTC),
|
||||
dt.datetime(2025, 6, 15, 23, 59, 59, tzinfo=dt.UTC),
|
||||
]:
|
||||
with self.subTest(task_time=task_time):
|
||||
with patch("django.utils.timezone.now", return_value=task_time):
|
||||
self.assertIn(app, list(rule._get_newly_due_objects()))
|
||||
|
||||
def test_due_boundary_multi_day_interval(self):
|
||||
"""interval=30 days: due after 30 full days, not after 29."""
|
||||
app, rule, iteration = self._create_rule_and_iteration(interval="days=30")
|
||||
fixed_now = dt.datetime(2025, 6, 15, 14, 30, 0, tzinfo=dt.UTC)
|
||||
|
||||
# Previous review opened 30 days ago (May 16), review is due for the object
|
||||
LifecycleIteration.objects.filter(pk=iteration.pk).update(
|
||||
opened_on=dt.datetime(2025, 5, 16, 12, 0, 0, tzinfo=dt.UTC)
|
||||
)
|
||||
with patch("django.utils.timezone.now", return_value=fixed_now):
|
||||
self.assertIn(app, list(rule._get_newly_due_objects()))
|
||||
|
||||
# Previous review opened 29 days ago (May 17), new review is NOT due
|
||||
LifecycleIteration.objects.filter(pk=iteration.pk).update(
|
||||
opened_on=dt.datetime(2025, 5, 17, 12, 0, 0, tzinfo=dt.UTC)
|
||||
)
|
||||
with patch("django.utils.timezone.now", return_value=fixed_now):
|
||||
self.assertNotIn(app, list(rule._get_newly_due_objects()))
|
||||
|
||||
def test_apply_overdue_at_boundary(self):
|
||||
"""apply() marks iteration overdue when grace period just expired,
|
||||
regardless of what time the daily task runs."""
|
||||
_, rule, iteration = self._create_rule_and_iteration(
|
||||
grace_period="days=1", interval="days=365"
|
||||
)
|
||||
opened_on = dt.datetime(2025, 6, 14, 20, 0, 0, tzinfo=dt.UTC)
|
||||
for task_time in [
|
||||
dt.datetime(2025, 6, 15, 0, 0, 1, tzinfo=dt.UTC),
|
||||
dt.datetime(2025, 6, 15, 23, 59, 59, tzinfo=dt.UTC),
|
||||
]:
|
||||
with self.subTest(task_time=task_time):
|
||||
LifecycleIteration.objects.filter(pk=iteration.pk).update(
|
||||
opened_on=opened_on, state=ReviewState.PENDING
|
||||
)
|
||||
with patch("django.utils.timezone.now", return_value=task_time):
|
||||
rule.apply()
|
||||
iteration.refresh_from_db()
|
||||
self.assertEqual(iteration.state, ReviewState.OVERDUE)
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
from datetime import datetime
|
||||
from urllib import parse
|
||||
|
||||
from django.contrib.contenttypes.models import ContentType
|
||||
@@ -39,6 +40,10 @@ def link_for_model(model: Model) -> str:
|
||||
return f"{reverse("authentik_core:if-admin")}#{admin_link_for_model(model)}"
|
||||
|
||||
|
||||
def start_of_day(dt: datetime) -> datetime:
|
||||
return dt.replace(hour=0, minute=0, second=0, microsecond=0)
|
||||
|
||||
|
||||
class ContentTypeField(ChoiceField):
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(choices=model_choices(), **kwargs)
|
||||
|
||||
@@ -331,7 +331,7 @@ class GoogleWorkspaceGroupTests(TestCase):
|
||||
).exists()
|
||||
)
|
||||
self.assertFalse(Event.objects.filter(action=EventAction.SYSTEM_EXCEPTION).exists())
|
||||
self.assertEqual(len(http.requests()), 5)
|
||||
self.assertEqual(len(http.requests()), 7)
|
||||
|
||||
def test_sync_discover_multiple(self):
|
||||
"""Test group discovery"""
|
||||
@@ -372,7 +372,7 @@ class GoogleWorkspaceGroupTests(TestCase):
|
||||
).exists()
|
||||
)
|
||||
self.assertFalse(Event.objects.filter(action=EventAction.SYSTEM_EXCEPTION).exists())
|
||||
self.assertEqual(len(http.requests()), 5)
|
||||
self.assertEqual(len(http.requests()), 7)
|
||||
# Change response to trigger update
|
||||
http.add_response(
|
||||
f"https://admin.googleapis.com/admin/directory/v1/groups?customer=my_customer&maxResults=500&orderBy=email&key={self.api_key}&alt=json",
|
||||
|
||||
@@ -309,7 +309,7 @@ class GoogleWorkspaceUserTests(TestCase):
|
||||
).exists()
|
||||
)
|
||||
self.assertFalse(Event.objects.filter(action=EventAction.SYSTEM_EXCEPTION).exists())
|
||||
self.assertEqual(len(http.requests()), 5)
|
||||
self.assertEqual(len(http.requests()), 7)
|
||||
|
||||
def test_sync_discover_multiple(self):
|
||||
"""Test user discovery, running multiple times"""
|
||||
@@ -352,7 +352,7 @@ class GoogleWorkspaceUserTests(TestCase):
|
||||
).exists()
|
||||
)
|
||||
self.assertFalse(Event.objects.filter(action=EventAction.SYSTEM_EXCEPTION).exists())
|
||||
self.assertEqual(len(http.requests()), 5)
|
||||
self.assertEqual(len(http.requests()), 7)
|
||||
# Change response, which will trigger a discovery update
|
||||
http.add_response(
|
||||
f"https://admin.googleapis.com/admin/directory/v1/users?customer=my_customer&maxResults=500&orderBy=email&key={self.api_key}&alt=json",
|
||||
|
||||
@@ -78,7 +78,8 @@ class MicrosoftEntraUserClient(MicrosoftEntraSyncClient[User, MicrosoftEntraProv
|
||||
def create(self, user: User):
|
||||
"""Create user from scratch and create a connection object"""
|
||||
microsoft_user = self.to_schema(user, None)
|
||||
self.check_email_valid(microsoft_user.user_principal_name)
|
||||
if microsoft_user.user_principal_name:
|
||||
self.check_email_valid(microsoft_user.user_principal_name)
|
||||
with transaction.atomic():
|
||||
try:
|
||||
response = self._request(self.client.users.post(microsoft_user))
|
||||
@@ -118,7 +119,8 @@ class MicrosoftEntraUserClient(MicrosoftEntraSyncClient[User, MicrosoftEntraProv
|
||||
def update(self, user: User, connection: MicrosoftEntraProviderUser):
|
||||
"""Update existing user"""
|
||||
microsoft_user = self.to_schema(user, connection)
|
||||
self.check_email_valid(microsoft_user.user_principal_name)
|
||||
if microsoft_user.user_principal_name:
|
||||
self.check_email_valid(microsoft_user.user_principal_name)
|
||||
response = self._request(
|
||||
self.client.users.by_user_id(connection.microsoft_id).patch(microsoft_user)
|
||||
)
|
||||
|
||||
@@ -5,6 +5,7 @@ from django.urls import reverse
|
||||
from rest_framework.fields import CharField, SerializerMethodField, URLField
|
||||
|
||||
from authentik.core.api.providers import ProviderSerializer
|
||||
from authentik.core.models import Provider
|
||||
from authentik.enterprise.api import EnterpriseRequiredMixin
|
||||
from authentik.enterprise.providers.ws_federation.models import WSFederationProvider
|
||||
from authentik.enterprise.providers.ws_federation.processors.metadata import MetadataProcessor
|
||||
@@ -18,6 +19,29 @@ class WSFederationProviderSerializer(EnterpriseRequiredMixin, SAMLProviderSerial
|
||||
wtrealm = CharField(source="audience")
|
||||
url_wsfed = SerializerMethodField()
|
||||
|
||||
def get_url_download_metadata(self, instance: WSFederationProvider) -> str:
|
||||
"""Get metadata download URL"""
|
||||
if "request" not in self._context:
|
||||
return ""
|
||||
request: HttpRequest = self._context["request"]._request
|
||||
try:
|
||||
return request.build_absolute_uri(
|
||||
reverse(
|
||||
"authentik_providers_ws_federation:metadata-download",
|
||||
kwargs={"application_slug": instance.application.slug},
|
||||
)
|
||||
)
|
||||
except Provider.application.RelatedObjectDoesNotExist:
|
||||
return request.build_absolute_uri(
|
||||
reverse(
|
||||
"authentik_api:wsfederationprovider-metadata",
|
||||
kwargs={
|
||||
"pk": instance.pk,
|
||||
},
|
||||
)
|
||||
+ "?download"
|
||||
)
|
||||
|
||||
def get_url_wsfed(self, instance: WSFederationProvider) -> str:
|
||||
"""Get WS-Fed url"""
|
||||
if "request" not in self._context:
|
||||
|
||||
@@ -81,6 +81,8 @@ class SignInProcessor:
|
||||
self.sign_in_request = sign_in_request
|
||||
self.saml_processor = AssertionProcessor(self.provider, self.request, AuthNRequest())
|
||||
self.saml_processor.provider.audience = self.sign_in_request.wtrealm
|
||||
if self.provider.signing_kp:
|
||||
self.saml_processor.provider.sign_assertion = True
|
||||
|
||||
def create_response_token(self):
|
||||
root = Element(f"{{{NS_WS_FED_TRUST}}}RequestSecurityTokenResponse", nsmap=NS_MAP)
|
||||
@@ -148,7 +150,8 @@ class SignInProcessor:
|
||||
def response(self) -> dict[str, str]:
|
||||
root = self.create_response_token()
|
||||
assertion = root.xpath("//saml:Assertion", namespaces=NS_MAP)[0]
|
||||
self.saml_processor._sign(assertion)
|
||||
if self.provider.signing_kp:
|
||||
self.saml_processor._sign(assertion)
|
||||
str_token = etree.tostring(root).decode("utf-8") # nosec
|
||||
return delete_none_values(
|
||||
{
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
from django.urls import path
|
||||
|
||||
from authentik.enterprise.providers.ws_federation.api.providers import WSFederationProviderViewSet
|
||||
from authentik.enterprise.providers.ws_federation.views import WSFedEntryView
|
||||
from authentik.enterprise.providers.ws_federation.views import MetadataDownload, WSFedEntryView
|
||||
|
||||
urlpatterns = [
|
||||
path(
|
||||
@@ -11,6 +11,12 @@ urlpatterns = [
|
||||
WSFedEntryView.as_view(),
|
||||
name="wsfed",
|
||||
),
|
||||
# Metadata
|
||||
path(
|
||||
"<slug:application_slug>/metadata/",
|
||||
MetadataDownload.as_view(),
|
||||
name="metadata-download",
|
||||
),
|
||||
]
|
||||
|
||||
api_urlpatterns = [
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
from django.http import Http404, HttpRequest, HttpResponse
|
||||
from django.shortcuts import get_object_or_404
|
||||
from django.shortcuts import get_object_or_404, redirect
|
||||
from django.urls import reverse
|
||||
from django.utils.translation import gettext as _
|
||||
from django.views import View
|
||||
from structlog.stdlib import get_logger
|
||||
|
||||
from authentik.core.models import Application, AuthenticatedSession
|
||||
@@ -160,3 +162,24 @@ class WSFedFlowFinalView(ChallengeStageView):
|
||||
"attrs": response,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
class MetadataDownload(View):
|
||||
"""Redirect to metadata download"""
|
||||
|
||||
def dispatch(self, request: HttpRequest, application_slug: str) -> HttpResponse:
|
||||
app = Application.objects.filter(slug=application_slug).with_provider().first()
|
||||
if not app:
|
||||
raise Http404
|
||||
provider = app.get_provider()
|
||||
if not provider:
|
||||
raise Http404
|
||||
return redirect(
|
||||
reverse(
|
||||
"authentik_api:wsfederationprovider-metadata",
|
||||
kwargs={
|
||||
"pk": provider.pk,
|
||||
},
|
||||
)
|
||||
+ "?download"
|
||||
)
|
||||
|
||||
@@ -93,11 +93,13 @@ def on_login_failed(
|
||||
credentials: dict[str, str],
|
||||
request: HttpRequest,
|
||||
stage: Stage | None = None,
|
||||
context: dict[str, Any] | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""Failed Login, authentik custom event"""
|
||||
user = User.objects.filter(username=credentials.get("username")).first()
|
||||
Event.new(EventAction.LOGIN_FAILED, **credentials, stage=stage, **kwargs).from_http(
|
||||
context = context or {}
|
||||
Event.new(EventAction.LOGIN_FAILED, **credentials, stage=stage, **context).from_http(
|
||||
request, user
|
||||
)
|
||||
|
||||
|
||||
@@ -207,3 +207,9 @@ class TestEvents(TestCase):
|
||||
"username": user.username,
|
||||
},
|
||||
)
|
||||
|
||||
def test_invalid_string(self):
|
||||
"""Test creating an event with invalid unicode string data"""
|
||||
event = Event.new("unittest", foo="foo bar \u0000 baz")
|
||||
event.save()
|
||||
self.assertEqual(event.context["foo"], "foo bar baz")
|
||||
|
||||
@@ -36,6 +36,10 @@ ALLOWED_SPECIAL_KEYS = re.compile(
|
||||
)
|
||||
|
||||
|
||||
def cleanse_str(raw: Any) -> str:
|
||||
return str(raw).replace("\u0000", "")
|
||||
|
||||
|
||||
def cleanse_item(key: str, value: Any) -> Any:
|
||||
"""Cleanse a single item"""
|
||||
if isinstance(value, dict):
|
||||
@@ -66,7 +70,7 @@ def cleanse_dict(source: dict[Any, Any]) -> dict[Any, Any]:
|
||||
|
||||
def model_to_dict(model: Model) -> dict[str, Any]:
|
||||
"""Convert model to dict"""
|
||||
name = str(model)
|
||||
name = cleanse_str(model)
|
||||
if hasattr(model, "name"):
|
||||
name = model.name
|
||||
return {
|
||||
@@ -133,11 +137,11 @@ def sanitize_item(value: Any) -> Any: # noqa: PLR0911, PLR0912
|
||||
if isinstance(value, ASN):
|
||||
return ASN_CONTEXT_PROCESSOR.asn_to_dict(value)
|
||||
if isinstance(value, Path):
|
||||
return str(value)
|
||||
return cleanse_str(value)
|
||||
if isinstance(value, Exception):
|
||||
return str(value)
|
||||
return cleanse_str(value)
|
||||
if isinstance(value, YAMLTag):
|
||||
return str(value)
|
||||
return cleanse_str(value)
|
||||
if isinstance(value, Enum):
|
||||
return value.value
|
||||
if isinstance(value, type):
|
||||
@@ -161,7 +165,7 @@ def sanitize_item(value: Any) -> Any: # noqa: PLR0911, PLR0912
|
||||
raise ValueError("JSON can't represent timezone-aware times.")
|
||||
return value.isoformat()
|
||||
if isinstance(value, timedelta):
|
||||
return str(value.total_seconds())
|
||||
return cleanse_str(value.total_seconds())
|
||||
if callable(value):
|
||||
return {
|
||||
"type": "callable",
|
||||
@@ -174,8 +178,8 @@ def sanitize_item(value: Any) -> Any: # noqa: PLR0911, PLR0912
|
||||
try:
|
||||
return DjangoJSONEncoder().default(value)
|
||||
except TypeError:
|
||||
return str(value)
|
||||
return str(value)
|
||||
return cleanse_str(value)
|
||||
return cleanse_str(value)
|
||||
|
||||
|
||||
def sanitize_dict(source: dict[Any, Any]) -> dict[Any, Any]:
|
||||
|
||||
@@ -29,6 +29,12 @@ class RefreshOtherFlowsAfterAuthentication(Flag[bool], key="flows_refresh_others
|
||||
visibility = "public"
|
||||
|
||||
|
||||
class ContinuousLogin(Flag[bool], key="flows_continuous_login"):
|
||||
|
||||
default = False
|
||||
visibility = "public"
|
||||
|
||||
|
||||
class AuthentikFlowsConfig(ManagedAppConfig):
|
||||
"""authentik flows app config"""
|
||||
|
||||
|
||||
@@ -9,7 +9,15 @@
|
||||
{{ block.super }}
|
||||
<link rel="prefetch" href="{{ flow_background_url }}" />
|
||||
{% if flow.compatibility_mode and not inspector %}
|
||||
<script data-id="shady-dom">ShadyDOM = { force: true };</script>
|
||||
{% comment %}
|
||||
@see {@link web/types/webcomponents.d.ts} for type definitions.
|
||||
{% endcomment %}
|
||||
<script data-id="shady-dom">
|
||||
"use strict";
|
||||
|
||||
window.ShadyDOM = window.ShadyDOM || {}
|
||||
window.ShadyDOM.force = true
|
||||
</script>
|
||||
{% endif %}
|
||||
{% include "base/header_js.html" %}
|
||||
<script data-id="flow-config">
|
||||
@@ -45,16 +53,11 @@
|
||||
slug="{{ flow.slug }}"
|
||||
class="pf-c-login"
|
||||
data-layout="{{ flow.layout|default:'stacked' }}"
|
||||
loading
|
||||
>
|
||||
{% include "base/placeholder.html" %}
|
||||
|
||||
<ak-brand-links
|
||||
slot="footer"
|
||||
exportparts="list:brand-links-list, list-item:brand-links-list-item"
|
||||
role="contentinfo"
|
||||
aria-label="{% trans 'Site footer' %}"
|
||||
class="pf-c-login__footer {% if flow.layout == 'stacked' %}pf-m-dark{% endif %}"
|
||||
></ak-brand-links>
|
||||
<ak-brand-links name="flow-links" slot="footer"></ak-brand-links>
|
||||
</ak-flow-executor>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
@@ -141,6 +141,10 @@ web:
|
||||
# workers: 2
|
||||
threads: 4
|
||||
path: /
|
||||
timeout_http_read_header: 5s
|
||||
timeout_http_read: 30s
|
||||
timeout_http_write: 60s
|
||||
timeout_http_idle: 120s
|
||||
|
||||
worker:
|
||||
processes: 1
|
||||
@@ -178,3 +182,5 @@ storage:
|
||||
# backend: file # or s3
|
||||
# file: {}
|
||||
# s3: {}
|
||||
|
||||
skip_migrations: false
|
||||
|
||||
@@ -42,7 +42,7 @@ ARG_SANITIZE = re.compile(r"[:.-]")
|
||||
|
||||
|
||||
def sanitize_arg(arg_name: str) -> str:
|
||||
return re.sub(ARG_SANITIZE, "_", arg_name)
|
||||
return re.sub(ARG_SANITIZE, "_", slugify(arg_name))
|
||||
|
||||
|
||||
class BaseEvaluator:
|
||||
@@ -311,7 +311,9 @@ class BaseEvaluator:
|
||||
|
||||
def wrap_expression(self, expression: str) -> str:
|
||||
"""Wrap expression in a function, call it, and save the result as `result`"""
|
||||
handler_signature = ",".join(sanitize_arg(x) for x in self._context.keys())
|
||||
handler_signature = ",".join(
|
||||
[x for x in [sanitize_arg(x) for x in self._context.keys()] if x]
|
||||
)
|
||||
full_expression = ""
|
||||
full_expression += f"def handler({handler_signature}):\n"
|
||||
full_expression += indent(expression, " ")
|
||||
|
||||
@@ -103,6 +103,7 @@ class SyncTasks:
|
||||
)
|
||||
users_tasks.run().wait(timeout=provider.get_object_sync_time_limit_ms(User))
|
||||
group_tasks.run().wait(timeout=provider.get_object_sync_time_limit_ms(Group))
|
||||
self._sync_cleanup(provider, task)
|
||||
except TransientSyncException as exc:
|
||||
self.logger.warning("transient sync exception", exc=exc)
|
||||
task.warning("Sync encountered a transient exception. Retrying", exc=exc)
|
||||
@@ -111,6 +112,35 @@ class SyncTasks:
|
||||
task.error(exc)
|
||||
return
|
||||
|
||||
def _sync_cleanup(self, provider: OutgoingSyncProvider, task: Task):
|
||||
"""Delete remote objects that are no longer in scope"""
|
||||
for object_type in (User, Group):
|
||||
try:
|
||||
client = provider.client_for_model(object_type)
|
||||
except TransientSyncException:
|
||||
continue
|
||||
in_scope_pks = set(provider.get_object_qs(object_type).values_list("pk", flat=True))
|
||||
stale = client.connection_type.objects.filter(provider=provider).exclude(
|
||||
**{f"{client.connection_type_query}__pk__in": in_scope_pks}
|
||||
)
|
||||
for connection in stale:
|
||||
try:
|
||||
client.delete(connection.scim_id)
|
||||
task.info(
|
||||
f"Deleted out-of-scope {object_type._meta.verbose_name}",
|
||||
scim_id=connection.scim_id,
|
||||
)
|
||||
except NotFoundSyncException:
|
||||
pass
|
||||
except TransientSyncException as exc:
|
||||
self.logger.warning("transient error during cleanup", exc=exc)
|
||||
self.logger.warning(
|
||||
"Cleanup encountered a transient exception. Retrying", exc=exc
|
||||
)
|
||||
raise Retry() from exc
|
||||
except DryRunRejected as exc:
|
||||
self.logger.info("Rejected dry-run cleanup event", exc=exc)
|
||||
|
||||
def sync_objects(
|
||||
self,
|
||||
object_type: str,
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
"""Test Evaluator base functions"""
|
||||
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
from django.test import RequestFactory, TestCase
|
||||
@@ -353,3 +354,18 @@ class TestEvaluator(TestCase):
|
||||
self.assertEqual(message.to, ["to@example.com"])
|
||||
self.assertEqual(message.cc, ["cc1@example.com", "cc2@example.com"])
|
||||
self.assertEqual(message.bcc, ["bcc1@example.com", "bcc2@example.com"])
|
||||
|
||||
def test_expr_arg_escape(self):
|
||||
"""Test escaping of arguments"""
|
||||
eval = BaseEvaluator()
|
||||
eval._context = {
|
||||
'z=getattr(getattr(__import__("os"), "popen")("id > /tmp/test"), "read")()': "bar",
|
||||
"@@": "baz",
|
||||
"{{": "baz",
|
||||
"aa@@": "baz",
|
||||
}
|
||||
res = eval.evaluate("return locals()")
|
||||
self.assertEqual(
|
||||
res, {"zgetattrgetattr__import__os_popenid_tmptest_read": "bar", "aa": "baz"}
|
||||
)
|
||||
self.assertFalse(Path("/tmp/test").exists())
|
||||
|
||||
119
authentik/lib/tests/test_utils_inheritance.py
Normal file
119
authentik/lib/tests/test_utils_inheritance.py
Normal file
@@ -0,0 +1,119 @@
|
||||
"""Tests for inheritance helpers."""
|
||||
|
||||
from contextlib import contextmanager
|
||||
|
||||
from django.db import connection, models
|
||||
from django.test import TransactionTestCase
|
||||
from django.test.utils import isolate_apps
|
||||
|
||||
from authentik.lib.utils.inheritance import get_deepest_child
|
||||
|
||||
|
||||
@contextmanager
|
||||
def temporary_inheritance_models():
|
||||
"""Create a temporary multi-table inheritance graph for testing."""
|
||||
with isolate_apps("authentik.lib.tests"):
|
||||
|
||||
class GrandParent(models.Model):
|
||||
class Meta:
|
||||
app_label = "tests"
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"GrandParent({self.pk})"
|
||||
|
||||
class Parent(GrandParent):
|
||||
class Meta:
|
||||
app_label = "tests"
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"Parent({self.pk})"
|
||||
|
||||
class Child(Parent):
|
||||
class Meta:
|
||||
app_label = "tests"
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"Child({self.pk})"
|
||||
|
||||
class GrandChild(Child):
|
||||
class Meta:
|
||||
app_label = "tests"
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"GrandChild({self.pk})"
|
||||
|
||||
with connection.schema_editor() as schema_editor:
|
||||
schema_editor.create_model(GrandParent)
|
||||
schema_editor.create_model(Parent)
|
||||
schema_editor.create_model(Child)
|
||||
schema_editor.create_model(GrandChild)
|
||||
|
||||
try:
|
||||
yield GrandParent, Parent, Child, GrandChild
|
||||
finally:
|
||||
with connection.schema_editor() as schema_editor:
|
||||
schema_editor.delete_model(GrandChild)
|
||||
schema_editor.delete_model(Child)
|
||||
schema_editor.delete_model(Parent)
|
||||
schema_editor.delete_model(GrandParent)
|
||||
|
||||
|
||||
class TestInheritanceUtils(TransactionTestCase):
|
||||
"""Tests for helper functions in authentik.lib.utils.inheritance."""
|
||||
|
||||
def test_get_deepest_child_grandparent_to_parent(self):
|
||||
"""GrandParent -> Parent."""
|
||||
with temporary_inheritance_models() as (GrandParent, Parent, _Child, _GrandChild):
|
||||
parent = Parent.objects.create()
|
||||
grandparent = GrandParent.objects.get(pk=parent.pk)
|
||||
|
||||
resolved = get_deepest_child(grandparent)
|
||||
|
||||
self.assertIsInstance(resolved, Parent)
|
||||
self.assertEqual(resolved.pk, parent.pk)
|
||||
|
||||
def test_get_deepest_child_grandparent_to_child(self):
|
||||
"""GrandParent -> Child."""
|
||||
with temporary_inheritance_models() as (GrandParent, _Parent, Child, _GrandChild):
|
||||
child = Child.objects.create()
|
||||
grandparent = GrandParent.objects.get(pk=child.pk)
|
||||
|
||||
resolved = get_deepest_child(grandparent)
|
||||
|
||||
self.assertIsInstance(resolved, Child)
|
||||
self.assertEqual(resolved.pk, child.pk)
|
||||
|
||||
def test_get_deepest_child_grandparent_to_grandchild(self):
|
||||
"""GrandParent -> GrandChild."""
|
||||
with temporary_inheritance_models() as (GrandParent, _Parent, _Child, GrandChild):
|
||||
grandchild = GrandChild.objects.create()
|
||||
grandparent = GrandParent.objects.get(pk=grandchild.pk)
|
||||
|
||||
resolved = get_deepest_child(grandparent)
|
||||
|
||||
self.assertIsInstance(resolved, GrandChild)
|
||||
self.assertEqual(resolved.pk, grandchild.pk)
|
||||
|
||||
def test_get_deepest_child_parent_to_child(self):
|
||||
"""Parent -> Child (start from non-root)."""
|
||||
with temporary_inheritance_models() as (_GrandParent, Parent, Child, _GrandChild):
|
||||
child = Child.objects.create()
|
||||
parent = Parent.objects.get(pk=child.pk)
|
||||
|
||||
resolved = get_deepest_child(parent)
|
||||
|
||||
self.assertIsInstance(resolved, Child)
|
||||
self.assertEqual(resolved.pk, child.pk)
|
||||
|
||||
def test_get_deepest_child_no_queries_with_preloaded_relations(self):
|
||||
"""No extra queries when the inheritance chain is fully select_related."""
|
||||
with temporary_inheritance_models() as (GrandParent, _Parent, _Child, GrandChild):
|
||||
grandchild = GrandChild.objects.create()
|
||||
grandparent = GrandParent.objects.select_related("parent__child__grandchild").get(
|
||||
pk=grandchild.pk
|
||||
)
|
||||
|
||||
with self.assertNumQueries(0):
|
||||
resolved = get_deepest_child(grandparent)
|
||||
|
||||
self.assertIsInstance(resolved, GrandChild)
|
||||
41
authentik/lib/utils/inheritance.py
Normal file
41
authentik/lib/utils/inheritance.py
Normal file
@@ -0,0 +1,41 @@
|
||||
from django.db.models import Model, OneToOneField, OneToOneRel
|
||||
|
||||
|
||||
def get_deepest_child(parent: Model) -> Model:
|
||||
"""
|
||||
In multiple table inheritance, given any ancestor object, get the deepest child object.
|
||||
See https://docs.djangoproject.com/en/dev/topics/db/models/#multi-table-inheritance
|
||||
|
||||
This function does not query the database if `select_related` has been performed on all
|
||||
subclasses of `parent`'s model.
|
||||
"""
|
||||
|
||||
# Almost verbatim copy from django-model-utils, see
|
||||
# https://github.com/jazzband/django-model-utils/blob/5.0.0/model_utils/managers.py#L132
|
||||
one_to_one_rels = [
|
||||
field for field in parent._meta.get_fields() if isinstance(field, OneToOneRel)
|
||||
]
|
||||
|
||||
submodel_fields = [
|
||||
rel
|
||||
for rel in one_to_one_rels
|
||||
if isinstance(rel.field, OneToOneField)
|
||||
and issubclass(rel.field.model, parent._meta.model)
|
||||
and parent._meta.model is not rel.field.model
|
||||
and rel.parent_link
|
||||
]
|
||||
|
||||
submodel_accessors = [submodel_field.get_accessor_name() for submodel_field in submodel_fields]
|
||||
# End Copy
|
||||
|
||||
child = None
|
||||
for submodel in submodel_accessors:
|
||||
try:
|
||||
child = getattr(parent, submodel)
|
||||
break
|
||||
except AttributeError:
|
||||
continue
|
||||
|
||||
if not child:
|
||||
return parent
|
||||
return get_deepest_child(child)
|
||||
@@ -185,8 +185,10 @@ class KubernetesObjectReconciler[T]:
|
||||
|
||||
patch = self.get_patch()
|
||||
if patch is not None:
|
||||
current_json = ApiClient().sanitize_for_serialization(current)
|
||||
|
||||
try:
|
||||
current_json = ApiClient().sanitize_for_serialization(current)
|
||||
except AttributeError:
|
||||
current_json = asdict(current)
|
||||
try:
|
||||
if apply_patch(current_json, patch) != current_json:
|
||||
raise NeedsUpdate()
|
||||
|
||||
@@ -163,4 +163,5 @@ def outpost_pre_delete_cleanup(sender, instance: Outpost, **_):
|
||||
@receiver(pre_delete, sender=AuthenticatedSession)
|
||||
def outpost_logout_revoke(sender: type[AuthenticatedSession], instance: AuthenticatedSession, **_):
|
||||
"""Catch logout by expiring sessions being deleted"""
|
||||
outpost_session_end.send(instance.session.session_key)
|
||||
if Outpost.objects.exists():
|
||||
outpost_session_end.send(instance.session.session_key)
|
||||
|
||||
@@ -7,7 +7,6 @@ from socket import gethostname
|
||||
from typing import Any
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from asgiref.sync import async_to_sync
|
||||
from channels.layers import get_channel_layer
|
||||
from django.core.cache import cache
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
@@ -159,7 +158,7 @@ def outpost_send_update(pk: Any):
|
||||
layer = get_channel_layer()
|
||||
group = build_outpost_group(outpost.pk)
|
||||
LOGGER.debug("sending update", channel=group, outpost=outpost)
|
||||
async_to_sync(layer.group_send)(group, {"type": "event.update"})
|
||||
layer.group_send_blocking(group, {"type": "event.update"})
|
||||
|
||||
|
||||
@actor(description=_("Checks the local environment and create Service connections."))
|
||||
@@ -210,7 +209,7 @@ def outpost_session_end(session_id: str):
|
||||
for outpost in Outpost.objects.all():
|
||||
LOGGER.info("Sending session end signal to outpost", outpost=outpost)
|
||||
group = build_outpost_group(outpost.pk)
|
||||
async_to_sync(layer.group_send)(
|
||||
layer.group_send_blocking(
|
||||
group,
|
||||
{
|
||||
"type": "event.session.end",
|
||||
|
||||
@@ -132,9 +132,14 @@ class PolicyEngine:
|
||||
# If we didn't find any static bindings, do nothing
|
||||
return
|
||||
self.logger.debug("P_ENG: Found static bindings", **matched_bindings)
|
||||
if matched_bindings.get("passing", 0) > 0:
|
||||
# Any passing static binding -> passing
|
||||
passing = True
|
||||
if self.mode == PolicyEngineMode.MODE_ANY:
|
||||
if matched_bindings.get("passing", 0) > 0:
|
||||
# Any passing static binding -> passing
|
||||
passing = True
|
||||
elif self.mode == PolicyEngineMode.MODE_ALL:
|
||||
if matched_bindings.get("passing", 0) == matched_bindings["total"]:
|
||||
# All static bindings are passing -> passing
|
||||
passing = True
|
||||
elif matched_bindings["total"] > 0 and matched_bindings.get("passing", 0) < 1:
|
||||
# No matching static bindings but at least one is configured -> not passing
|
||||
passing = False
|
||||
@@ -185,6 +190,16 @@ class PolicyEngine:
|
||||
# Only call .recv() if no result is saved, otherwise we just deadlock here
|
||||
if not proc_info.result:
|
||||
proc_info.result = proc_info.connection.recv()
|
||||
if proc_info.result and proc_info.result._exec_time:
|
||||
HIST_POLICIES_EXECUTION_TIME.labels(
|
||||
binding_order=proc_info.binding.order,
|
||||
binding_target_type=proc_info.binding.target_type,
|
||||
binding_target_name=proc_info.binding.target_name,
|
||||
object_type=(
|
||||
class_to_path(self.request.obj.__class__) if self.request.obj else ""
|
||||
),
|
||||
mode="execute_process",
|
||||
).observe(proc_info.result._exec_time)
|
||||
return self
|
||||
|
||||
@property
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
|
||||
from multiprocessing import get_context
|
||||
from multiprocessing.connection import Connection
|
||||
from time import perf_counter
|
||||
|
||||
from django.core.cache import cache
|
||||
from sentry_sdk import start_span
|
||||
@@ -11,8 +12,6 @@ from structlog.stdlib import get_logger
|
||||
from authentik.events.models import Event, EventAction
|
||||
from authentik.lib.config import CONFIG
|
||||
from authentik.lib.utils.errors import exception_to_dict
|
||||
from authentik.lib.utils.reflection import class_to_path
|
||||
from authentik.policies.apps import HIST_POLICIES_EXECUTION_TIME
|
||||
from authentik.policies.exceptions import PolicyException
|
||||
from authentik.policies.models import PolicyBinding
|
||||
from authentik.policies.types import CACHE_PREFIX, PolicyRequest, PolicyResult
|
||||
@@ -123,18 +122,9 @@ class PolicyProcess(PROCESS_CLASS):
|
||||
|
||||
def profiling_wrapper(self):
|
||||
"""Run with profiling enabled"""
|
||||
with (
|
||||
start_span(
|
||||
op="authentik.policy.process.execute",
|
||||
) as span,
|
||||
HIST_POLICIES_EXECUTION_TIME.labels(
|
||||
binding_order=self.binding.order,
|
||||
binding_target_type=self.binding.target_type,
|
||||
binding_target_name=self.binding.target_name,
|
||||
object_type=class_to_path(self.request.obj.__class__) if self.request.obj else "",
|
||||
mode="execute_process",
|
||||
).time(),
|
||||
):
|
||||
with start_span(
|
||||
op="authentik.policy.process.execute",
|
||||
) as span:
|
||||
span: Span
|
||||
span.set_data("policy", self.binding.policy)
|
||||
span.set_data("request", self.request)
|
||||
@@ -142,8 +132,14 @@ class PolicyProcess(PROCESS_CLASS):
|
||||
|
||||
def run(self): # pragma: no cover
|
||||
"""Task wrapper to run policy checking"""
|
||||
result = None
|
||||
try:
|
||||
self.connection.send(self.profiling_wrapper())
|
||||
start = perf_counter()
|
||||
result = self.profiling_wrapper()
|
||||
end = perf_counter()
|
||||
result._exec_time = max((end - start), 0)
|
||||
except Exception as exc: # noqa
|
||||
LOGGER.warning("Policy failed to run", exc=exc)
|
||||
self.connection.send(PolicyResult(False, str(exc)))
|
||||
result = PolicyResult(False, str(exc))
|
||||
finally:
|
||||
self.connection.send(result)
|
||||
|
||||
@@ -33,6 +33,9 @@ class TestPolicyEngine(TestCase):
|
||||
self.policy_raises = ExpressionPolicy.objects.create(
|
||||
name=generate_id(), expression="{{ 0/0 }}"
|
||||
)
|
||||
self.group_member = Group.objects.create(name=generate_id())
|
||||
self.user.groups.add(self.group_member)
|
||||
self.group_non_member = Group.objects.create(name=generate_id())
|
||||
|
||||
def test_engine_empty(self):
|
||||
"""Ensure empty policy list passes"""
|
||||
@@ -51,7 +54,7 @@ class TestPolicyEngine(TestCase):
|
||||
self.assertEqual(result.passing, True)
|
||||
self.assertEqual(result.messages, ("dummy",))
|
||||
|
||||
def test_engine_mode_all(self):
|
||||
def test_engine_mode_all_dyn(self):
|
||||
"""Ensure all policies passes with AND mode (false and true -> false)"""
|
||||
pbm = PolicyBindingModel.objects.create(policy_engine_mode=PolicyEngineMode.MODE_ALL)
|
||||
PolicyBinding.objects.create(target=pbm, policy=self.policy_false, order=0)
|
||||
@@ -67,7 +70,7 @@ class TestPolicyEngine(TestCase):
|
||||
),
|
||||
)
|
||||
|
||||
def test_engine_mode_any(self):
|
||||
def test_engine_mode_any_dyn(self):
|
||||
"""Ensure all policies passes with OR mode (false and true -> true)"""
|
||||
pbm = PolicyBindingModel.objects.create(policy_engine_mode=PolicyEngineMode.MODE_ANY)
|
||||
PolicyBinding.objects.create(target=pbm, policy=self.policy_false, order=0)
|
||||
@@ -83,6 +86,26 @@ class TestPolicyEngine(TestCase):
|
||||
),
|
||||
)
|
||||
|
||||
def test_engine_mode_all_static(self):
|
||||
"""Ensure all policies passes with OR mode (false and true -> true)"""
|
||||
pbm = PolicyBindingModel.objects.create(policy_engine_mode=PolicyEngineMode.MODE_ALL)
|
||||
PolicyBinding.objects.create(target=pbm, group=self.group_member, order=0)
|
||||
PolicyBinding.objects.create(target=pbm, group=self.group_non_member, order=1)
|
||||
engine = PolicyEngine(pbm, self.user)
|
||||
result = engine.build().result
|
||||
self.assertEqual(result.passing, False)
|
||||
self.assertEqual(result.messages, ())
|
||||
|
||||
def test_engine_mode_any_static(self):
|
||||
"""Ensure all policies passes with OR mode (false and true -> true)"""
|
||||
pbm = PolicyBindingModel.objects.create(policy_engine_mode=PolicyEngineMode.MODE_ANY)
|
||||
PolicyBinding.objects.create(target=pbm, group=self.group_member, order=0)
|
||||
PolicyBinding.objects.create(target=pbm, group=self.group_non_member, order=1)
|
||||
engine = PolicyEngine(pbm, self.user)
|
||||
result = engine.build().result
|
||||
self.assertEqual(result.passing, True)
|
||||
self.assertEqual(result.messages, ())
|
||||
|
||||
def test_engine_negate(self):
|
||||
"""Test negate flag"""
|
||||
pbm = PolicyBindingModel.objects.create()
|
||||
|
||||
@@ -77,6 +77,8 @@ class PolicyResult:
|
||||
|
||||
log_messages: list[LogEvent] | None
|
||||
|
||||
_exec_time: int | None
|
||||
|
||||
def __init__(self, passing: bool, *messages: str):
|
||||
self.passing = passing
|
||||
self.messages = messages
|
||||
@@ -84,6 +86,7 @@ class PolicyResult:
|
||||
self.source_binding = None
|
||||
self.source_results = []
|
||||
self.log_messages = []
|
||||
self._exec_time = None
|
||||
|
||||
def __repr__(self):
|
||||
return self.__str__()
|
||||
|
||||
@@ -68,6 +68,8 @@ class IDToken:
|
||||
at_hash: str | None = None
|
||||
# Session ID, https://openid.net/specs/openid-connect-frontchannel-1_0.html#ClaimsContents
|
||||
sid: str | None = None
|
||||
# JWT ID, https://www.rfc-editor.org/rfc/rfc7519.html#section-4.1.7
|
||||
jti: str | None = None
|
||||
|
||||
claims: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
@@ -81,6 +83,7 @@ class IDToken:
|
||||
(token.expires if token.expires is not None else default_token_duration()).timestamp()
|
||||
)
|
||||
id_token.iss = provider.get_issuer(request)
|
||||
id_token.jti = generate_id()
|
||||
id_token.aud = provider.client_id
|
||||
id_token.claims = {}
|
||||
|
||||
|
||||
@@ -5,6 +5,7 @@ from urllib.parse import parse_qs, urlparse
|
||||
|
||||
from django.test import RequestFactory
|
||||
from django.urls import reverse
|
||||
from django.utils import translation
|
||||
from django.utils.timezone import now
|
||||
|
||||
from authentik.blueprints.tests import apply_blueprint
|
||||
@@ -690,18 +691,21 @@ class TestAuthorize(OAuthTestCase):
|
||||
Application.objects.create(name="app", slug="app", provider=provider)
|
||||
state = generate_id()
|
||||
self.client.logout()
|
||||
response = self.client.get(
|
||||
reverse("authentik_providers_oauth2:authorize"),
|
||||
data={
|
||||
"response_type": "code",
|
||||
"client_id": "test",
|
||||
"state": state,
|
||||
"redirect_uri": "foo://localhost",
|
||||
"ui_locales": "invalid fr",
|
||||
},
|
||||
)
|
||||
parsed = parse_qs(urlparse(response.url).query)
|
||||
self.assertEqual(parsed["locale"], ["fr"])
|
||||
try:
|
||||
response = self.client.get(
|
||||
reverse("authentik_providers_oauth2:authorize"),
|
||||
data={
|
||||
"response_type": "code",
|
||||
"client_id": "test",
|
||||
"state": state,
|
||||
"redirect_uri": "foo://localhost",
|
||||
"ui_locales": "invalid fr",
|
||||
},
|
||||
)
|
||||
parsed = parse_qs(urlparse(response.url).query)
|
||||
self.assertEqual(parsed["locale"], ["fr"])
|
||||
finally:
|
||||
translation.deactivate()
|
||||
|
||||
@apply_blueprint("default/flow-default-authentication-flow.yaml")
|
||||
def test_ui_locales_invalid(self):
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
"""Device backchannel tests"""
|
||||
|
||||
from base64 import b64encode
|
||||
from json import loads
|
||||
from urllib.parse import quote
|
||||
|
||||
from django.urls import reverse
|
||||
|
||||
@@ -26,7 +28,7 @@ class TesOAuth2DeviceBackchannel(OAuthTestCase):
|
||||
provider=self.provider,
|
||||
)
|
||||
|
||||
def test_backchannel_invalid(self):
|
||||
def test_backchannel_invalid_client_id_via_post_body(self):
|
||||
"""Test backchannel"""
|
||||
res = self.client.post(
|
||||
reverse("authentik_providers_oauth2:device"),
|
||||
@@ -50,7 +52,7 @@ class TesOAuth2DeviceBackchannel(OAuthTestCase):
|
||||
)
|
||||
self.assertEqual(res.status_code, 400)
|
||||
|
||||
def test_backchannel(self):
|
||||
def test_backchannel_client_id_via_post_body(self):
|
||||
"""Test backchannel"""
|
||||
res = self.client.post(
|
||||
reverse("authentik_providers_oauth2:device"),
|
||||
@@ -61,3 +63,50 @@ class TesOAuth2DeviceBackchannel(OAuthTestCase):
|
||||
self.assertEqual(res.status_code, 200)
|
||||
body = loads(res.content.decode())
|
||||
self.assertEqual(body["expires_in"], 60)
|
||||
|
||||
def test_backchannel_invalid_client_id_via_auth_header(self):
|
||||
"""Test backchannel"""
|
||||
creds = b64encode(b"foo:").decode()
|
||||
res = self.client.post(
|
||||
reverse("authentik_providers_oauth2:device"),
|
||||
HTTP_AUTHORIZATION=f"Basic {creds}",
|
||||
)
|
||||
self.assertEqual(res.status_code, 400)
|
||||
res = self.client.post(
|
||||
reverse("authentik_providers_oauth2:device"),
|
||||
)
|
||||
self.assertEqual(res.status_code, 400)
|
||||
# test without application
|
||||
self.application.provider = None
|
||||
self.application.save()
|
||||
res = self.client.post(
|
||||
reverse("authentik_providers_oauth2:device"),
|
||||
data={
|
||||
"client_id": "test",
|
||||
},
|
||||
)
|
||||
self.assertEqual(res.status_code, 400)
|
||||
|
||||
def test_backchannel_client_id_via_auth_header(self):
|
||||
"""Test backchannel"""
|
||||
creds = b64encode(f"{self.provider.client_id}:".encode()).decode()
|
||||
res = self.client.post(
|
||||
reverse("authentik_providers_oauth2:device"),
|
||||
HTTP_AUTHORIZATION=f"Basic {creds}",
|
||||
)
|
||||
self.assertEqual(res.status_code, 200)
|
||||
body = loads(res.content.decode())
|
||||
self.assertEqual(body["expires_in"], 60)
|
||||
|
||||
def test_backchannel_client_id_via_auth_header_urlencoded(self):
|
||||
"""Test URL-encoded client IDs in Basic auth"""
|
||||
self.provider.client_id = "test/client+id"
|
||||
self.provider.save()
|
||||
creds = b64encode(f"{quote(self.provider.client_id, safe='')}:".encode()).decode()
|
||||
res = self.client.post(
|
||||
reverse("authentik_providers_oauth2:device"),
|
||||
HTTP_AUTHORIZATION=f"Basic {creds}",
|
||||
)
|
||||
self.assertEqual(res.status_code, 200)
|
||||
body = loads(res.content.decode())
|
||||
self.assertEqual(body["expires_in"], 60)
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
|
||||
from base64 import b64encode
|
||||
from json import dumps
|
||||
from urllib.parse import quote
|
||||
|
||||
from django.test import RequestFactory
|
||||
from django.urls import reverse
|
||||
@@ -28,6 +29,7 @@ from authentik.providers.oauth2.models import (
|
||||
ScopeMapping,
|
||||
)
|
||||
from authentik.providers.oauth2.tests.utils import OAuthTestCase
|
||||
from authentik.providers.oauth2.utils import extract_client_auth
|
||||
from authentik.providers.oauth2.views.token import TokenParams
|
||||
|
||||
|
||||
@@ -115,6 +117,20 @@ class TestToken(OAuthTestCase):
|
||||
params = TokenParams.parse(request, provider, provider.client_id, provider.client_secret)
|
||||
self.assertEqual(params.provider, provider)
|
||||
|
||||
def test_extract_client_auth_basic_auth_percent_decodes(self):
|
||||
"""test percent-decoding of client credentials in Basic auth"""
|
||||
header = b64encode(
|
||||
f"{quote('client/id', safe='')}:{quote('secret+/==', safe='')}".encode()
|
||||
).decode()
|
||||
request = self.factory.post("/", HTTP_AUTHORIZATION=f"Basic {header}")
|
||||
self.assertEqual(extract_client_auth(request), ("client/id", "secret+/=="))
|
||||
|
||||
def test_extract_client_auth_basic_auth_preserves_raw_plus(self):
|
||||
"""test compatibility with clients that still send raw plus characters"""
|
||||
header = b64encode(b"client:secret+plus").decode()
|
||||
request = self.factory.post("/", HTTP_AUTHORIZATION=f"Basic {header}")
|
||||
self.assertEqual(extract_client_auth(request), ("client", "secret+plus"))
|
||||
|
||||
def test_auth_code_view(self):
|
||||
"""test request param"""
|
||||
provider = OAuth2Provider.objects.create(
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
|
||||
from base64 import b64encode
|
||||
from json import loads
|
||||
from urllib.parse import quote
|
||||
|
||||
from django.test import RequestFactory
|
||||
from django.urls import reverse
|
||||
@@ -178,6 +179,41 @@ class TestTokenClientCredentialsStandardCompat(OAuthTestCase):
|
||||
self.assertEqual(jwt["given_name"], self.user.name)
|
||||
self.assertEqual(jwt["preferred_username"], self.user.username)
|
||||
|
||||
def test_successful_basic_auth_urlencoded_client_secret(self):
|
||||
"""test successful with URL-encoded Basic auth credentials"""
|
||||
client_secret = b64encode(f"sa:{self.token.key}".encode()).decode()
|
||||
header = b64encode(
|
||||
f"{quote(self.provider.client_id, safe='')}:{quote(client_secret, safe='')}".encode()
|
||||
).decode()
|
||||
response = self.client.post(
|
||||
reverse("authentik_providers_oauth2:token"),
|
||||
{
|
||||
"grant_type": GRANT_TYPE_CLIENT_CREDENTIALS,
|
||||
"scope": f"{SCOPE_OPENID} {SCOPE_OPENID_EMAIL} {SCOPE_OPENID_PROFILE}",
|
||||
},
|
||||
HTTP_AUTHORIZATION=f"Basic {header}",
|
||||
)
|
||||
self.assertEqual(response.status_code, 200)
|
||||
body = loads(response.content.decode())
|
||||
self.assertEqual(body["token_type"], TOKEN_TYPE)
|
||||
_, alg = self.provider.jwt_key
|
||||
jwt = decode(
|
||||
body["access_token"],
|
||||
key=self.provider.signing_key.public_key,
|
||||
algorithms=[alg],
|
||||
audience=self.provider.client_id,
|
||||
)
|
||||
self.assertEqual(jwt["given_name"], self.user.name)
|
||||
self.assertEqual(jwt["preferred_username"], self.user.username)
|
||||
jwt = decode(
|
||||
body["id_token"],
|
||||
key=self.provider.signing_key.public_key,
|
||||
algorithms=[alg],
|
||||
audience=self.provider.client_id,
|
||||
)
|
||||
self.assertEqual(jwt["given_name"], self.user.name)
|
||||
self.assertEqual(jwt["preferred_username"], self.user.username)
|
||||
|
||||
def test_successful_password(self):
|
||||
"""test successful (password grant)"""
|
||||
response = self.client.post(
|
||||
|
||||
@@ -7,7 +7,7 @@ from binascii import Error
|
||||
from hashlib import sha256
|
||||
from hmac import compare_digest
|
||||
from typing import Any
|
||||
from urllib.parse import urlparse
|
||||
from urllib.parse import unquote, urlparse
|
||||
|
||||
from django.http import HttpRequest, HttpResponse, JsonResponse
|
||||
from django.http.response import HttpResponseRedirect
|
||||
@@ -122,6 +122,10 @@ def extract_client_auth(request: HttpRequest) -> tuple[str, str]:
|
||||
try:
|
||||
user_pass = b64decode(b64_user_pass).decode("utf-8").partition(":")
|
||||
client_id, _, client_secret = user_pass
|
||||
# RFC 6749 requires client credentials in Basic auth to be form-encoded first.
|
||||
# We only percent-decode here so raw `+` characters keep their previous meaning.
|
||||
client_id = unquote(client_id)
|
||||
client_secret = unquote(client_secret)
|
||||
except ValueError, Error:
|
||||
client_id = client_secret = "" # nosec
|
||||
else:
|
||||
|
||||
@@ -16,7 +16,7 @@ from authentik.lib.config import CONFIG
|
||||
from authentik.lib.utils.time import timedelta_from_string
|
||||
from authentik.providers.oauth2.errors import DeviceCodeError
|
||||
from authentik.providers.oauth2.models import DeviceToken, OAuth2Provider
|
||||
from authentik.providers.oauth2.utils import TokenResponse
|
||||
from authentik.providers.oauth2.utils import TokenResponse, extract_client_auth
|
||||
from authentik.providers.oauth2.views.device_init import QS_KEY_CODE
|
||||
|
||||
LOGGER = get_logger()
|
||||
@@ -32,7 +32,7 @@ class DeviceView(View):
|
||||
|
||||
def parse_request(self):
|
||||
"""Parse incoming request"""
|
||||
client_id = self.request.POST.get("client_id", None)
|
||||
client_id, _ = extract_client_auth(self.request)
|
||||
if not client_id:
|
||||
raise DeviceCodeError("invalid_client")
|
||||
provider = OAuth2Provider.objects.filter(client_id=client_id).first()
|
||||
|
||||
@@ -27,6 +27,8 @@ class TraefikMiddlewareSpecForwardAuth:
|
||||
|
||||
trustForwardHeader: bool = field(default=True)
|
||||
|
||||
maxResponseBodySize: int = field(default=1024 * 1024 * 4)
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class TraefikMiddlewareSpec:
|
||||
@@ -140,6 +142,7 @@ class Traefik3MiddlewareReconciler(KubernetesObjectReconciler[TraefikMiddleware]
|
||||
],
|
||||
authResponseHeadersRegex="",
|
||||
trustForwardHeader=True,
|
||||
maxResponseBodySize=1024 * 1024 * 4,
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
@@ -1,13 +0,0 @@
|
||||
"""Proxy provider signals"""
|
||||
|
||||
from django.db.models.signals import pre_delete
|
||||
from django.dispatch import receiver
|
||||
|
||||
from authentik.core.models import AuthenticatedSession
|
||||
from authentik.providers.proxy.tasks import proxy_on_logout
|
||||
|
||||
|
||||
@receiver(pre_delete, sender=AuthenticatedSession)
|
||||
def logout_proxy_revoke(sender: type[AuthenticatedSession], instance: AuthenticatedSession, **_):
|
||||
"""Catch logout by expiring sessions being deleted"""
|
||||
proxy_on_logout.send(instance.session.session_key)
|
||||
@@ -1,26 +0,0 @@
|
||||
"""proxy provider tasks"""
|
||||
|
||||
from asgiref.sync import async_to_sync
|
||||
from channels.layers import get_channel_layer
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
from dramatiq.actor import actor
|
||||
|
||||
from authentik.outposts.consumer import build_outpost_group
|
||||
from authentik.outposts.models import Outpost, OutpostType
|
||||
from authentik.providers.oauth2.id_token import hash_session_key
|
||||
|
||||
|
||||
@actor(description=_("Terminate session on Proxy outpost."))
|
||||
def proxy_on_logout(session_id: str):
|
||||
layer = get_channel_layer()
|
||||
hashed_session_id = hash_session_key(session_id)
|
||||
for outpost in Outpost.objects.filter(type=OutpostType.PROXY):
|
||||
group = build_outpost_group(outpost.pk)
|
||||
async_to_sync(layer.group_send)(
|
||||
group,
|
||||
{
|
||||
"type": "event.provider.specific",
|
||||
"sub_type": "logout",
|
||||
"session_id": hashed_session_id,
|
||||
},
|
||||
)
|
||||
@@ -232,7 +232,7 @@ class SAMLMetadataSerializer(PassiveSerializer):
|
||||
"""SAML Provider Metadata serializer"""
|
||||
|
||||
metadata = CharField(read_only=True)
|
||||
download_url = CharField(read_only=True, required=False)
|
||||
download_url = CharField(read_only=True, required=False, allow_null=True)
|
||||
|
||||
|
||||
class SAMLProviderImportSerializer(PassiveSerializer):
|
||||
|
||||
@@ -10,6 +10,7 @@ from authentik.blueprints.tests import apply_blueprint
|
||||
from authentik.core.models import Application, Group, User
|
||||
from authentik.lib.generators import generate_id
|
||||
from authentik.providers.scim.models import SCIMMapping, SCIMProvider, SCIMProviderGroup
|
||||
from authentik.providers.scim.tasks import scim_sync
|
||||
|
||||
|
||||
class SCIMGroupTests(TestCase):
|
||||
@@ -205,3 +206,80 @@ class SCIMGroupTests(TestCase):
|
||||
self.assertEqual(mock.request_history[1].method, "POST")
|
||||
self.assertEqual(mock.request_history[2].method, "GET")
|
||||
self.assertNotIn("PUT", [req.method for req in mock.request_history])
|
||||
|
||||
def _create_stale_provider_group(self, scim_id: str) -> Group:
|
||||
"""Create a group that is outside the provider's scope (via group_filters) with an
|
||||
existing SCIMProviderGroup, simulating a previously synced group now out of scope."""
|
||||
self.app.backchannel_providers.remove(self.provider)
|
||||
anchor = Group.objects.create(name=generate_id())
|
||||
stale = Group.objects.create(name=generate_id())
|
||||
self.app.backchannel_providers.add(self.provider)
|
||||
|
||||
self.provider.group_filters.set([anchor])
|
||||
SCIMProviderGroup.objects.create(provider=self.provider, group=stale, scim_id=scim_id)
|
||||
return stale
|
||||
|
||||
@Mocker()
|
||||
def test_sync_cleanup_stale_group_delete(self, mock: Mocker):
|
||||
"""Stale (out-of-scope) groups are deleted during full sync cleanup"""
|
||||
scim_id = generate_id()
|
||||
mock.get("https://localhost/ServiceProviderConfig", json={})
|
||||
|
||||
mock.post("https://localhost/Groups", json={"id": generate_id()})
|
||||
mock.delete(f"https://localhost/Groups/{scim_id}", status_code=204)
|
||||
self._create_stale_provider_group(scim_id)
|
||||
|
||||
scim_sync.send(self.provider.pk).get_result()
|
||||
|
||||
delete_reqs = [r for r in mock.request_history if r.method == "DELETE"]
|
||||
self.assertEqual(len(delete_reqs), 1)
|
||||
self.assertEqual(delete_reqs[0].url, f"https://localhost/Groups/{scim_id}")
|
||||
self.assertFalse(
|
||||
SCIMProviderGroup.objects.filter(provider=self.provider, scim_id=scim_id).exists()
|
||||
)
|
||||
|
||||
@Mocker()
|
||||
def test_sync_cleanup_stale_group_not_found(self, mock: Mocker):
|
||||
"""Stale group cleanup handles 404 from the remote gracefully"""
|
||||
scim_id = generate_id()
|
||||
mock.get("https://localhost/ServiceProviderConfig", json={})
|
||||
mock.post("https://localhost/Groups", json={"id": generate_id()})
|
||||
mock.delete(f"https://localhost/Groups/{scim_id}", status_code=404)
|
||||
self._create_stale_provider_group(scim_id)
|
||||
|
||||
scim_sync.send(self.provider.pk).get_result()
|
||||
|
||||
delete_reqs = [r for r in mock.request_history if r.method == "DELETE"]
|
||||
self.assertEqual(len(delete_reqs), 1)
|
||||
|
||||
self.assertFalse(
|
||||
SCIMProviderGroup.objects.filter(provider=self.provider, scim_id=scim_id).exists()
|
||||
)
|
||||
|
||||
@Mocker()
|
||||
def test_sync_cleanup_stale_group_transient_error(self, mock: Mocker):
|
||||
"""Stale group cleanup logs and retries on transient HTTP errors"""
|
||||
scim_id = generate_id()
|
||||
mock.get("https://localhost/ServiceProviderConfig", json={})
|
||||
mock.post("https://localhost/Groups", json={"id": generate_id()})
|
||||
mock.delete(f"https://localhost/Groups/{scim_id}", status_code=429)
|
||||
self._create_stale_provider_group(scim_id)
|
||||
|
||||
scim_sync.send(self.provider.pk)
|
||||
|
||||
delete_reqs = [r for r in mock.request_history if r.method == "DELETE"]
|
||||
self.assertEqual(len(delete_reqs), 1)
|
||||
|
||||
@Mocker()
|
||||
def test_sync_cleanup_stale_group_dry_run(self, mock: Mocker):
|
||||
"""Stale group cleanup skips HTTP DELETE in dry_run mode"""
|
||||
self.provider.dry_run = True
|
||||
self.provider.save()
|
||||
scim_id = generate_id()
|
||||
mock.get("https://localhost/ServiceProviderConfig", json={})
|
||||
self._create_stale_provider_group(scim_id)
|
||||
|
||||
scim_sync.send(self.provider.pk)
|
||||
|
||||
delete_reqs = [r for r in mock.request_history if r.method == "DELETE"]
|
||||
self.assertEqual(len(delete_reqs), 0)
|
||||
|
||||
@@ -1,17 +1,19 @@
|
||||
"""SCIM User tests"""
|
||||
|
||||
from json import loads
|
||||
from unittest.mock import patch
|
||||
|
||||
from django.test import TestCase
|
||||
from jsonschema import validate
|
||||
from requests_mock import Mocker
|
||||
|
||||
from authentik.blueprints.tests import apply_blueprint
|
||||
from authentik.core.models import Application, Group, User
|
||||
from authentik.core.models import Application, Group, User, UserTypes
|
||||
from authentik.lib.generators import generate_id
|
||||
from authentik.lib.sync.outgoing.base import SAFE_METHODS
|
||||
from authentik.lib.sync.outgoing.exceptions import TransientSyncException
|
||||
from authentik.providers.scim.models import SCIMMapping, SCIMProvider, SCIMProviderUser
|
||||
from authentik.providers.scim.tasks import scim_sync, scim_sync_objects
|
||||
from authentik.providers.scim.tasks import scim_sync, scim_sync_objects, sync_tasks
|
||||
from authentik.tasks.models import Task
|
||||
from authentik.tenants.models import Tenant
|
||||
|
||||
@@ -537,3 +539,104 @@ class SCIMUserTests(TestCase):
|
||||
self.assertEqual(mock.call_count, 2)
|
||||
self.assertEqual(mock.request_history[0].method, "GET")
|
||||
self.assertEqual(mock.request_history[1].method, "POST")
|
||||
|
||||
def _create_stale_provider_user(self, scim_id: str, uid: str) -> User:
|
||||
"""Create a service-account user (excluded from provider scope) with an existing
|
||||
SCIMProviderUser, simulating a previously synced user that is now out of scope."""
|
||||
user = User.objects.create(
|
||||
username=uid,
|
||||
name=f"{uid} {uid}",
|
||||
email=f"{uid}@goauthentik.io",
|
||||
type=UserTypes.SERVICE_ACCOUNT,
|
||||
)
|
||||
SCIMProviderUser.objects.create(provider=self.provider, user=user, scim_id=scim_id)
|
||||
return user
|
||||
|
||||
@Mocker()
|
||||
def test_sync_cleanup_stale_user_delete(self, mock: Mocker):
|
||||
"""Stale (out-of-scope) users are deleted during full sync cleanup"""
|
||||
scim_id = generate_id()
|
||||
uid = generate_id()
|
||||
mock.get("https://localhost/ServiceProviderConfig", json={})
|
||||
mock.delete(f"https://localhost/Users/{scim_id}", status_code=204)
|
||||
self._create_stale_provider_user(scim_id, uid)
|
||||
|
||||
scim_sync.send(self.provider.pk).get_result()
|
||||
|
||||
delete_reqs = [r for r in mock.request_history if r.method == "DELETE"]
|
||||
self.assertEqual(len(delete_reqs), 1)
|
||||
self.assertEqual(delete_reqs[0].url, f"https://localhost/Users/{scim_id}")
|
||||
self.assertFalse(
|
||||
SCIMProviderUser.objects.filter(provider=self.provider, scim_id=scim_id).exists()
|
||||
)
|
||||
|
||||
@Mocker()
|
||||
def test_sync_cleanup_stale_user_not_found(self, mock: Mocker):
|
||||
"""Stale user cleanup handles 404 from the remote gracefully"""
|
||||
scim_id = generate_id()
|
||||
uid = generate_id()
|
||||
mock.get("https://localhost/ServiceProviderConfig", json={})
|
||||
mock.delete(f"https://localhost/Users/{scim_id}", status_code=404)
|
||||
self._create_stale_provider_user(scim_id, uid)
|
||||
|
||||
scim_sync.send(self.provider.pk).get_result()
|
||||
|
||||
delete_reqs = [r for r in mock.request_history if r.method == "DELETE"]
|
||||
self.assertEqual(len(delete_reqs), 1)
|
||||
|
||||
self.assertFalse(
|
||||
SCIMProviderUser.objects.filter(provider=self.provider, scim_id=scim_id).exists()
|
||||
)
|
||||
|
||||
@Mocker()
|
||||
def test_sync_cleanup_stale_user_transient_error(self, mock: Mocker):
|
||||
"""Stale user cleanup logs and retries on transient HTTP errors"""
|
||||
scim_id = generate_id()
|
||||
uid = generate_id()
|
||||
mock.get("https://localhost/ServiceProviderConfig", json={})
|
||||
mock.delete(f"https://localhost/Users/{scim_id}", status_code=429)
|
||||
self._create_stale_provider_user(scim_id, uid)
|
||||
|
||||
scim_sync.send(self.provider.pk)
|
||||
|
||||
delete_reqs = [r for r in mock.request_history if r.method == "DELETE"]
|
||||
self.assertEqual(len(delete_reqs), 1)
|
||||
|
||||
@Mocker()
|
||||
def test_sync_cleanup_stale_user_dry_run(self, mock: Mocker):
|
||||
"""Stale user cleanup skips HTTP DELETE in dry_run mode"""
|
||||
self.provider.dry_run = True
|
||||
self.provider.save()
|
||||
scim_id = generate_id()
|
||||
uid = generate_id()
|
||||
mock.get("https://localhost/ServiceProviderConfig", json={})
|
||||
self._create_stale_provider_user(scim_id, uid)
|
||||
|
||||
scim_sync.send(self.provider.pk)
|
||||
|
||||
delete_reqs = [r for r in mock.request_history if r.method == "DELETE"]
|
||||
self.assertEqual(len(delete_reqs), 0)
|
||||
|
||||
def test_sync_cleanup_client_for_model_transient(self):
|
||||
"""Cleanup silently skips an object type when client_for_model raises
|
||||
TransientSyncException"""
|
||||
with Mocker() as mock:
|
||||
mock.get("https://localhost/ServiceProviderConfig", json={})
|
||||
with patch.object(
|
||||
SCIMProvider,
|
||||
"client_for_model",
|
||||
side_effect=TransientSyncException("connection failed"),
|
||||
):
|
||||
scim_sync.send(self.provider.pk).get_result()
|
||||
|
||||
def test_sync_transient_exception(self):
|
||||
"""TransientSyncException in _sync_cleanup is caught by sync() which then
|
||||
schedules a retry"""
|
||||
with Mocker() as mock:
|
||||
mock.get("https://localhost/ServiceProviderConfig", json={})
|
||||
with patch.object(
|
||||
sync_tasks,
|
||||
"_sync_cleanup",
|
||||
side_effect=TransientSyncException("connection failed"),
|
||||
):
|
||||
scim_sync.send(self.provider.pk)
|
||||
|
||||
@@ -60,11 +60,7 @@ class LDAPSourceSerializer(SourceSerializer):
|
||||
sources = sources.exclude(pk=self.instance.pk)
|
||||
if sources.exists():
|
||||
raise ValidationError(
|
||||
{
|
||||
"sync_users_password": _(
|
||||
"Only a single LDAP Source with password synchronization is allowed"
|
||||
)
|
||||
}
|
||||
_("Only a single LDAP Source with password synchronization is allowed")
|
||||
)
|
||||
return sync_users_password
|
||||
|
||||
@@ -221,7 +217,7 @@ class LDAPSourceViewSet(UsedByMixin, ModelViewSet):
|
||||
for sync_class in SYNC_CLASSES:
|
||||
class_name = sync_class.name()
|
||||
all_objects.setdefault(class_name, [])
|
||||
for page in sync_class(source).get_objects(size_limit=10):
|
||||
for page in sync_class(source, Task()).get_objects(size_limit=10):
|
||||
for obj in page:
|
||||
obj: dict
|
||||
obj.pop("raw_attributes", None)
|
||||
|
||||
@@ -14,6 +14,7 @@ from django.utils.translation import gettext_lazy as _
|
||||
from ldap3 import ALL, NONE, RANDOM, Connection, Server, ServerPool, Tls
|
||||
from ldap3.core.exceptions import LDAPException, LDAPInsufficientAccessRightsResult, LDAPSchemaError
|
||||
from rest_framework.serializers import Serializer
|
||||
from structlog.stdlib import get_logger
|
||||
|
||||
from authentik.core.models import (
|
||||
Group,
|
||||
@@ -31,6 +32,7 @@ from authentik.tasks.schedules.common import ScheduleSpec
|
||||
LDAP_TIMEOUT = 15
|
||||
LDAP_UNIQUENESS = "ldap_uniq"
|
||||
LDAP_DISTINGUISHED_NAME = "distinguishedName"
|
||||
LOGGER = get_logger()
|
||||
|
||||
|
||||
def flatten(value: Any) -> Any:
|
||||
@@ -268,6 +270,7 @@ class LDAPSource(IncomingSyncSource):
|
||||
)
|
||||
|
||||
if self.start_tls:
|
||||
LOGGER.debug("Connection StartTLS", source=self)
|
||||
conn.start_tls(read_server_info=False)
|
||||
try:
|
||||
successful = conn.bind()
|
||||
@@ -278,7 +281,9 @@ class LDAPSource(IncomingSyncSource):
|
||||
# See https://github.com/goauthentik/authentik/issues/4590
|
||||
# See also https://github.com/goauthentik/authentik/issues/3399
|
||||
if server_kwargs.get("get_info", ALL) == NONE:
|
||||
LOGGER.warning("Failed to connect after schema downgrade", source=self, exc=exc)
|
||||
raise exc
|
||||
LOGGER.warning("Downgrading connection to no schema info", source=self, exc=exc)
|
||||
server_kwargs["get_info"] = NONE
|
||||
return self.connection(server, server_kwargs, connection_kwargs)
|
||||
finally:
|
||||
|
||||
@@ -27,7 +27,7 @@ from authentik.tasks.middleware import CurrentTask
|
||||
from authentik.tasks.models import Task
|
||||
|
||||
LOGGER = get_logger()
|
||||
SYNC_CLASSES = [
|
||||
SYNC_CLASSES: list[type[BaseLDAPSynchronizer]] = [
|
||||
UserLDAPSynchronizer,
|
||||
GroupLDAPSynchronizer,
|
||||
MembershipLDAPSynchronizer,
|
||||
|
||||
@@ -1,10 +1,19 @@
|
||||
"""LDAP Source API tests"""
|
||||
|
||||
from json import loads
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from django.db.models import Q
|
||||
from django.urls import reverse
|
||||
from rest_framework.exceptions import ErrorDetail
|
||||
from rest_framework.test import APITestCase
|
||||
|
||||
from authentik.lib.generators import generate_key
|
||||
from authentik.blueprints.tests import apply_blueprint
|
||||
from authentik.core.tests.utils import create_test_admin_user
|
||||
from authentik.lib.generators import generate_id, generate_key
|
||||
from authentik.sources.ldap.api import LDAPSourceSerializer
|
||||
from authentik.sources.ldap.models import LDAPSource
|
||||
from authentik.sources.ldap.models import LDAPSource, LDAPSourcePropertyMapping
|
||||
from authentik.sources.ldap.tests.mock_ad import mock_ad_connection
|
||||
|
||||
LDAP_PASSWORD = generate_key()
|
||||
|
||||
@@ -26,12 +35,13 @@ class LDAPAPITests(APITestCase):
|
||||
}
|
||||
)
|
||||
self.assertTrue(serializer.is_valid())
|
||||
self.assertEqual(serializer.errors, {})
|
||||
|
||||
def test_sync_users_password_invalid(self):
|
||||
"""Ensure only a single source with password sync can be created"""
|
||||
LDAPSource.objects.create(
|
||||
name="foo",
|
||||
slug="foo",
|
||||
slug=generate_id(),
|
||||
server_uri="ldaps://1.2.3.4",
|
||||
bind_cn="",
|
||||
bind_password=LDAP_PASSWORD,
|
||||
@@ -41,15 +51,26 @@ class LDAPAPITests(APITestCase):
|
||||
serializer = LDAPSourceSerializer(
|
||||
data={
|
||||
"name": "foo",
|
||||
"slug": " foo",
|
||||
"slug": generate_id(),
|
||||
"server_uri": "ldaps://1.2.3.4",
|
||||
"bind_cn": "",
|
||||
"bind_password": LDAP_PASSWORD,
|
||||
"base_dn": "dc=foo",
|
||||
"sync_users_password": False,
|
||||
"sync_users_password": True,
|
||||
}
|
||||
)
|
||||
self.assertFalse(serializer.is_valid())
|
||||
self.assertEqual(
|
||||
serializer.errors,
|
||||
{
|
||||
"sync_users_password": [
|
||||
ErrorDetail(
|
||||
string="Only a single LDAP Source with password synchronization is allowed",
|
||||
code="invalid",
|
||||
)
|
||||
]
|
||||
},
|
||||
)
|
||||
|
||||
def test_sync_users_mapping_empty(self):
|
||||
"""Check that when sync_users is enabled, property mappings must be set"""
|
||||
@@ -82,3 +103,32 @@ class LDAPAPITests(APITestCase):
|
||||
}
|
||||
)
|
||||
self.assertFalse(serializer.is_valid())
|
||||
|
||||
@apply_blueprint("system/sources-ldap.yaml")
|
||||
def test_sync_debug(self):
|
||||
user = create_test_admin_user()
|
||||
self.client.force_login(user)
|
||||
|
||||
source: LDAPSource = LDAPSource.objects.create(
|
||||
name=generate_id(),
|
||||
slug=generate_id(),
|
||||
base_dn="dc=goauthentik,dc=io",
|
||||
additional_user_dn="ou=users",
|
||||
additional_group_dn="ou=groups",
|
||||
)
|
||||
source.user_property_mappings.set(
|
||||
LDAPSourcePropertyMapping.objects.filter(
|
||||
Q(managed__startswith="goauthentik.io/sources/ldap/default")
|
||||
| Q(managed__startswith="goauthentik.io/sources/ldap/ms")
|
||||
)
|
||||
)
|
||||
connection = MagicMock(return_value=mock_ad_connection(LDAP_PASSWORD))
|
||||
with patch("authentik.sources.ldap.models.LDAPSource.connection", connection):
|
||||
res = self.client.get(
|
||||
reverse("authentik_api:ldapsource-debug", kwargs={"slug": source.slug})
|
||||
)
|
||||
self.assertEqual(res.status_code, 200)
|
||||
body = loads(res.content.decode())
|
||||
self.assertIn("users", body)
|
||||
self.assertIn("groups", body)
|
||||
self.assertIn("membership", body)
|
||||
|
||||
@@ -59,7 +59,11 @@ class OAuthSourceSerializer(SourceSerializer):
|
||||
|
||||
def validate(self, attrs: dict) -> dict:
|
||||
session = get_http_session()
|
||||
source_type = registry.find_type(attrs["provider_type"])
|
||||
provider_type_name = attrs.get(
|
||||
"provider_type",
|
||||
self.instance.provider_type if self.instance else None,
|
||||
)
|
||||
source_type = registry.find_type(provider_type_name)
|
||||
|
||||
well_known = attrs.get("oidc_well_known_url") or source_type.oidc_well_known_url
|
||||
inferred_oidc_jwks_url = None
|
||||
@@ -101,16 +105,15 @@ class OAuthSourceSerializer(SourceSerializer):
|
||||
config = jwks_config.json()
|
||||
attrs["oidc_jwks"] = config
|
||||
|
||||
provider_type = registry.find_type(attrs.get("provider_type", ""))
|
||||
for url in [
|
||||
"authorization_url",
|
||||
"access_token_url",
|
||||
"profile_url",
|
||||
]:
|
||||
if getattr(provider_type, url, None) is None:
|
||||
if getattr(source_type, url, None) is None:
|
||||
if url not in attrs:
|
||||
raise ValidationError(
|
||||
f"{url} is required for provider {provider_type.verbose_name}"
|
||||
f"{url} is required for provider {source_type.verbose_name}"
|
||||
)
|
||||
return attrs
|
||||
|
||||
|
||||
31
authentik/sources/oauth/tests/test_api.py
Normal file
31
authentik/sources/oauth/tests/test_api.py
Normal file
@@ -0,0 +1,31 @@
|
||||
from django.urls import reverse
|
||||
from rest_framework.test import APITestCase
|
||||
|
||||
from authentik.core.tests.utils import create_test_admin_user
|
||||
from authentik.lib.generators import generate_id
|
||||
from authentik.sources.oauth.models import OAuthSource
|
||||
|
||||
|
||||
class TestOAuthSourceAPI(APITestCase):
|
||||
def setUp(self):
|
||||
self.source = OAuthSource.objects.create(
|
||||
name=generate_id(),
|
||||
slug=generate_id(),
|
||||
provider_type="openidconnect",
|
||||
authorization_url="",
|
||||
profile_url="",
|
||||
consumer_key=generate_id(),
|
||||
)
|
||||
self.user = create_test_admin_user()
|
||||
|
||||
def test_patch_no_type(self):
|
||||
self.client.force_login(self.user)
|
||||
res = self.client.patch(
|
||||
reverse("authentik_api:oauthsource-detail", kwargs={"slug": self.source.slug}),
|
||||
{
|
||||
"authorization_url": f"https://{generate_id()}",
|
||||
"profile_url": f"https://{generate_id()}",
|
||||
"access_token_url": f"https://{generate_id()}",
|
||||
},
|
||||
)
|
||||
self.assertEqual(res.status_code, 200)
|
||||
@@ -7,6 +7,7 @@ from django.http import HttpRequest
|
||||
from django.templatetags.static import static
|
||||
from django.urls import reverse
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
from lxml.etree import _Element # nosec
|
||||
from rest_framework.serializers import Serializer
|
||||
|
||||
from authentik.common.saml.constants import (
|
||||
@@ -217,9 +218,8 @@ class SAMLSource(Source):
|
||||
def property_mapping_type(self) -> type[PropertyMapping]:
|
||||
return SAMLSourcePropertyMapping
|
||||
|
||||
def get_base_user_properties(self, root: Any, name_id: Any, **kwargs):
|
||||
def get_base_user_properties(self, root: _Element, assertion: _Element, name_id: Any, **kwargs):
|
||||
attributes = {}
|
||||
assertion = root.find(f"{{{NS_SAML_ASSERTION}}}Assertion")
|
||||
if assertion is None:
|
||||
raise ValueError("Assertion element not found")
|
||||
attribute_statement = assertion.find(f"{{{NS_SAML_ASSERTION}}}AttributeStatement")
|
||||
|
||||
@@ -66,6 +66,8 @@ class ResponseProcessor:
|
||||
|
||||
_http_request: HttpRequest
|
||||
|
||||
_assertion: _Element | None = None
|
||||
|
||||
def __init__(self, source: SAMLSource, request: HttpRequest):
|
||||
self._source = source
|
||||
self._http_request = request
|
||||
@@ -122,6 +124,7 @@ class ResponseProcessor:
|
||||
index_of,
|
||||
decrypted_assertion,
|
||||
)
|
||||
self._assertion = decrypted_assertion
|
||||
|
||||
def _verify_signature(self, signature_node: _Element):
|
||||
"""Verify a single signature node"""
|
||||
@@ -162,6 +165,10 @@ class ResponseProcessor:
|
||||
raise InvalidSignature("No Signature exists in the Assertion element.")
|
||||
|
||||
self._verify_signature(signature_nodes[0])
|
||||
parent = signature_nodes[0].getparent()
|
||||
if parent is None or parent.tag != f"{{{NS_SAML_ASSERTION}}}Assertion":
|
||||
raise InvalidSignature("No Signature exists in the Assertion element.")
|
||||
self._assertion = parent
|
||||
|
||||
def _verify_request_id(self):
|
||||
if self._source.allow_idp_initiated:
|
||||
@@ -239,14 +246,21 @@ class ResponseProcessor:
|
||||
identifier=str(name_id.text),
|
||||
user_info={
|
||||
"root": self._root,
|
||||
"assertion": self.get_assertion(),
|
||||
"name_id": name_id,
|
||||
},
|
||||
policy_context={},
|
||||
)
|
||||
|
||||
def get_assertion(self) -> Element | None:
|
||||
"""Get assertion element, if we have a signed assertion"""
|
||||
if self._assertion is not None:
|
||||
return self._assertion
|
||||
return self._root.find(f"{{{NS_SAML_ASSERTION}}}Assertion")
|
||||
|
||||
def _get_name_id(self) -> Element:
|
||||
"""Get NameID Element"""
|
||||
assertion = self._root.find(f"{{{NS_SAML_ASSERTION}}}Assertion")
|
||||
assertion = self.get_assertion()
|
||||
if assertion is None:
|
||||
raise ValueError("Assertion element not found")
|
||||
subject = assertion.find(f"{{{NS_SAML_ASSERTION}}}Subject")
|
||||
@@ -299,6 +313,7 @@ class ResponseProcessor:
|
||||
identifier=str(name_id.text),
|
||||
user_info={
|
||||
"root": self._root,
|
||||
"assertion": self.get_assertion(),
|
||||
"name_id": name_id,
|
||||
},
|
||||
policy_context={
|
||||
|
||||
68
authentik/sources/saml/tests/fixtures/response_signed_assertion_dup.xml
vendored
Normal file
68
authentik/sources/saml/tests/fixtures/response_signed_assertion_dup.xml
vendored
Normal file
@@ -0,0 +1,68 @@
|
||||
<samlp:Response xmlns:samlp="urn:oasis:names:tc:SAML:2.0:protocol" xmlns:saml="urn:oasis:names:tc:SAML:2.0:assertion" ID="_8e8dc5f69a98cc4c1ff3427e5ce34606fd672f91e6" Version="2.0" IssueInstant="2014-07-17T01:01:48Z" Destination="http://sp.example.com/demo1/index.php?acs" InResponseTo="ONELOGIN_4fee3b046395c4e751011e97f8900b5273d56685">
|
||||
<saml:Issuer>http://idp.example.com/metadata.php</saml:Issuer>
|
||||
<samlp:Status>
|
||||
<samlp:StatusCode Value="urn:oasis:names:tc:SAML:2.0:status:Success"/>
|
||||
</samlp:Status>
|
||||
<saml:Assertion xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xmlns:xs="http://www.w3.org/2001/XMLSchema" ID="_other_id_pfxa06693ef-cec7-f4a6-cb7f-ad074445a1a3" Version="2.0" IssueInstant="2014-07-17T01:01:48Z">
|
||||
<saml:Issuer>http://idp.example.com/metadata.php</saml:Issuer>
|
||||
<saml:Subject>
|
||||
<saml:NameID SPNameQualifier="http://sp.example.com/demo1/metadata.php" Format="urn:oasis:names:tc:SAML:2.0:nameid-format:transient">bad</saml:NameID>
|
||||
<saml:SubjectConfirmation Method="urn:oasis:names:tc:SAML:2.0:cm:bearer">
|
||||
<saml:SubjectConfirmationData NotOnOrAfter="2024-01-18T06:21:48Z" Recipient="http://sp.example.com/demo1/index.php?acs" InResponseTo="ONELOGIN_4fee3b046395c4e751011e97f8900b5273d56685"/>
|
||||
</saml:SubjectConfirmation>
|
||||
</saml:Subject>
|
||||
<saml:Conditions NotBefore="2014-07-17T01:01:18Z" NotOnOrAfter="2024-01-18T06:21:48Z">
|
||||
<saml:AudienceRestriction>
|
||||
<saml:Audience>http://sp.example.com/demo1/metadata.php</saml:Audience>
|
||||
</saml:AudienceRestriction>
|
||||
</saml:Conditions>
|
||||
<saml:AuthnStatement AuthnInstant="2014-07-17T01:01:48Z" SessionNotOnOrAfter="2024-07-17T09:01:48Z" SessionIndex="_be9967abd904ddcae3c0eb4189adbe3f71e327cf93">
|
||||
<saml:AuthnContext>
|
||||
<saml:AuthnContextClassRef>urn:oasis:names:tc:SAML:2.0:ac:classes:Password</saml:AuthnContextClassRef>
|
||||
</saml:AuthnContext>
|
||||
</saml:AuthnStatement>
|
||||
<saml:AttributeStatement>
|
||||
<saml:Attribute Name="uid" NameFormat="urn:oasis:names:tc:SAML:2.0:attrname-format:basic">
|
||||
<saml:AttributeValue xsi:type="xs:string">bad</saml:AttributeValue>
|
||||
</saml:Attribute>
|
||||
<saml:Attribute Name="mail" NameFormat="urn:oasis:names:tc:SAML:2.0:attrname-format:basic">
|
||||
<saml:AttributeValue xsi:type="xs:string">bad</saml:AttributeValue>
|
||||
</saml:Attribute>
|
||||
</saml:AttributeStatement>
|
||||
</saml:Assertion>
|
||||
<saml:Assertion xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xmlns:xs="http://www.w3.org/2001/XMLSchema" ID="pfxa06693ef-cec7-f4a6-cb7f-ad074445a1a3" Version="2.0" IssueInstant="2014-07-17T01:01:48Z">
|
||||
<saml:Issuer>http://idp.example.com/metadata.php</saml:Issuer><ds:Signature xmlns:ds="http://www.w3.org/2000/09/xmldsig#">
|
||||
<ds:SignedInfo><ds:CanonicalizationMethod Algorithm="http://www.w3.org/2001/10/xml-exc-c14n#"/>
|
||||
<ds:SignatureMethod Algorithm="http://www.w3.org/2000/09/xmldsig#rsa-sha1"/>
|
||||
<ds:Reference URI="#pfxa06693ef-cec7-f4a6-cb7f-ad074445a1a3"><ds:Transforms><ds:Transform Algorithm="http://www.w3.org/2000/09/xmldsig#enveloped-signature"/><ds:Transform Algorithm="http://www.w3.org/2001/10/xml-exc-c14n#"/></ds:Transforms><ds:DigestMethod Algorithm="http://www.w3.org/2000/09/xmldsig#sha1"/><ds:DigestValue>zNDuGxwP4gVkv/Dzt7kiKo/4gzk=</ds:DigestValue></ds:Reference></ds:SignedInfo><ds:SignatureValue>GLP/vE8uxerB0uDpPslUgLPBL6ePQB619MoQ0I2Y5lAtFE6CB1zh8BnzChRx/bFjNy4byfOe8mFfM0r7WUi1PJOFWyUPoatdLl7wHHBIRTnPpYmu3Tb2Gz0sOP0F8wW7JkBft5gJfVw49nk5si9/3Q3o52jnJZ7dPtqfIOh8uNeopikK0HLF6sU05qCCtjcXfniEnLQFNBFMo9uY5GQqmR5n3nqPz1wYyyfFOAbVmGgBIoO2PfGX2GVLQhltc9qf2JMhks4jgZsZ8iLUIiH1lcLGWZEEs94k8k0P6gSv1uZ7Vbhksd/N9Jq9pCVuEJ/jRPcAdVjzbxqKQAj6ELwr8O6fepTzA+CAdwEolBnx/C6TmSbVZ+IWk6QUGe4x4+IAukC+0hkKENlO0ELOScksvyhpgHbxNA4rp+DhGupCaO/I2RrsQkmvavbqm+wSEspK7scK112SDunjDvqPHsPYgukD33T/97PxTLorg2kKP9HHJwPJKoXXeyOGcA6vwK+RqrAlZ2dLGAgcXo+sJcdCLuvxDNz9VXofBjBZIKVKdmYhm0QJaPYHtuQsAyFavQhdOBOmGHb7QX3YE3Xy4dX4LymtT+Jlb1I4FJSht/9HUIHW1FdhfDak4f7gUgjuMamMddLD0jVgeESupSREzFv/gj2IrctkbgjAO0iuuiBgKMg=</ds:SignatureValue>
|
||||
<ds:KeyInfo><ds:X509Data><ds:X509Certificate>MIIFUzCCAzugAwIBAgIRAL6tbNcE9Ej9gNlbGKswfFMwDQYJKoZIhvcNAQELBQAwHTEbMBkGA1UEAwwSYXV0aGVudGlrIDIwMjUuNi4zMB4XDTI1MDcxNTE4MDQzNloXDTI2MDcxNjE4MDQzNlowVjEqMCgGA1UEAwwhYXV0aGVudGlrIFNlbGYtc2lnbmVkIENlcnRpZmljYXRlMRIwEAYDVQQKDAlhdXRoZW50aWsxFDASBgNVBAsMC1NlbGYtc2lnbmVkMIICIjANBgkqhkiG9w0BAQEFAAOCAg8AMIICCgKCAgEAjmut/+bBRLlyrbf+WIfg8ZTw9t6VnsiU1n04nPTulpRAz4nBOoOHNRIruSpZyFeFa6x9jwn4Ma5EFUH7HqnRvhoujm8U17OglXWZt0DLCZ6S5xPmdMogFXjJDmg9okIcI/cb9VbR6I8uvm1oiaOWCr36RTiqZ6rmdjQcuUPLr1+V/LxWQI463S+5QA2HZxAGalp45MJAz2sa9iczktKMgyYlfjj1cruFARxxeheu5qIK7aQWfyPj1QlMb9mi4VQaxUwGrAui4Tq614ivRJY2SkZb0Aq/LLSQoQWYHtYyQIasrOXJm0JuPDqhINPBDowyhu8DihC3uzOpmTXLKc5UoIQk+Q1h5iH74A3/kxOJUw13FXzRiDxC/yGthPYLyFHsDiJolscMKSCqlDvEMcpM4mxFeud9sKUb71SZr8sqmJl3qtvZmKpkR4y8pN2c00p10t0htqONmr5kyPxmhz0HCrosiPYB4olNjaydKviNTtPJ7TtnPyeA3iXGzCP1e80XzUoJrDqON5/GcpYgqsP/kGj8Qvqesa4Fez+1+5pAGHN2VzQbkHAgK3s4YRXrGLTs7wg27F9T0RE28Mm0RYBkYpdp4/5PuTTulthB9mkUBSJMgENmQAYkapvonFDsJkTi39qnsddbZusOLT4z3hsA38eFEwRqnbNZVUGPIp/O1SsCAwEAAaNVMFMwUQYDVR0RAQH/BEcwRYJDRUZ4QXVLRzV6SlVUTWpWNTJoMkRJMUQ5MXdLblZKaXFwNmpwRTRTTy5zZWxmLXNpZ25lZC5nb2F1dGhlbnRpay5pbzANBgkqhkiG9w0BAQsFAAOCAgEAYLThxDVpA1OIAVK/buueRJExIWr6y4s6NtpuR8UQEcfq5hfoc4zMFGHR5+u1WFIb5siK25xh/OnS7bLdLic6AkjZSrx91+0v2Jn9gfUqbs5AJ040XzAAdx/Mb4s0+537yhB+/JXPylR1QxhGbO7koXQ5JDhAXWKCw2O1C+80mN8dbhQvDkEtsXrHrtXclcqf2TT89XAzc5HAC8NmP4SF+FafAREQB1KdaG4QAbc/gnjsX2YJD89SDL+3jMp6F7R1Ym+bWt5oWqx2tkm6HGXd3fbpfQlnfrRN60tMjjLmw1cDMhOhpdragY5zokniEUL2pKVtrxFp7V1ZpoMI0Kt5MKkOXrezi542NWSgkGehlsDLD9wtuCNem2arR0mNnMLdYkMG7G0dpAq3Tl32dgfMfyKnNyE2O/6/EeEuzUH2NfTU1p7AUQfLrf4rtNcJEs9OAPuC9vy7w9YEpF997T+FhR2Ub1C423NQj4bwlS/9f7MIBkSi1EgnQuiSGB5epxAKI3oOVrmzOpTuvr6wZXV9pM3zdfbcoGuFWP6Ix7W8G5vg+0WvoSjc2fwGXYlidEK3xlQSMAaQ4CMClpPsKLScRq1nrQGzPYoiL1DYubsOWx9ohll6+jNjKI6f79WwbHYrW4EeRIOz38+m46EDjAWZBMgrE7J/3DhgeLEVJYBA5K0=</ds:X509Certificate></ds:X509Data></ds:KeyInfo></ds:Signature>
|
||||
<saml:Subject>
|
||||
<saml:NameID SPNameQualifier="http://sp.example.com/demo1/metadata.php" Format="urn:oasis:names:tc:SAML:2.0:nameid-format:transient">_ce3d2948b4cf20146dee0a0b3dd6f69b6cf86f62d7</saml:NameID>
|
||||
<saml:SubjectConfirmation Method="urn:oasis:names:tc:SAML:2.0:cm:bearer">
|
||||
<saml:SubjectConfirmationData NotOnOrAfter="2024-01-18T06:21:48Z" Recipient="http://sp.example.com/demo1/index.php?acs" InResponseTo="ONELOGIN_4fee3b046395c4e751011e97f8900b5273d56685"/>
|
||||
</saml:SubjectConfirmation>
|
||||
</saml:Subject>
|
||||
<saml:Conditions NotBefore="2014-07-17T01:01:18Z" NotOnOrAfter="2024-01-18T06:21:48Z">
|
||||
<saml:AudienceRestriction>
|
||||
<saml:Audience>http://sp.example.com/demo1/metadata.php</saml:Audience>
|
||||
</saml:AudienceRestriction>
|
||||
</saml:Conditions>
|
||||
<saml:AuthnStatement AuthnInstant="2014-07-17T01:01:48Z" SessionNotOnOrAfter="2024-07-17T09:01:48Z" SessionIndex="_be9967abd904ddcae3c0eb4189adbe3f71e327cf93">
|
||||
<saml:AuthnContext>
|
||||
<saml:AuthnContextClassRef>urn:oasis:names:tc:SAML:2.0:ac:classes:Password</saml:AuthnContextClassRef>
|
||||
</saml:AuthnContext>
|
||||
</saml:AuthnStatement>
|
||||
<saml:AttributeStatement>
|
||||
<saml:Attribute Name="uid" NameFormat="urn:oasis:names:tc:SAML:2.0:attrname-format:basic">
|
||||
<saml:AttributeValue xsi:type="xs:string">test</saml:AttributeValue>
|
||||
</saml:Attribute>
|
||||
<saml:Attribute Name="mail" NameFormat="urn:oasis:names:tc:SAML:2.0:attrname-format:basic">
|
||||
<saml:AttributeValue xsi:type="xs:string">test@example.com</saml:AttributeValue>
|
||||
</saml:Attribute>
|
||||
<saml:Attribute Name="eduPersonAffiliation" NameFormat="urn:oasis:names:tc:SAML:2.0:attrname-format:basic">
|
||||
<saml:AttributeValue xsi:type="xs:string">users</saml:AttributeValue>
|
||||
<saml:AttributeValue xsi:type="xs:string">examplerole1</saml:AttributeValue>
|
||||
</saml:Attribute>
|
||||
</saml:AttributeStatement>
|
||||
</saml:Assertion>
|
||||
</samlp:Response>
|
||||
@@ -36,7 +36,9 @@ class TestPropertyMappings(TestCase):
|
||||
|
||||
def test_user_base_properties(self):
|
||||
"""Test user base properties"""
|
||||
properties = self.source.get_base_user_properties(root=ROOT, name_id=NAME_ID)
|
||||
properties = self.source.get_base_user_properties(
|
||||
root=ROOT, assertion=ROOT.find(f"{{{NS_SAML_ASSERTION}}}Assertion"), name_id=NAME_ID
|
||||
)
|
||||
self.assertEqual(
|
||||
properties,
|
||||
{
|
||||
@@ -49,7 +51,11 @@ class TestPropertyMappings(TestCase):
|
||||
|
||||
def test_group_base_properties(self):
|
||||
"""Test group base properties"""
|
||||
properties = self.source.get_base_user_properties(root=ROOT_GROUPS, name_id=NAME_ID)
|
||||
properties = self.source.get_base_user_properties(
|
||||
root=ROOT_GROUPS,
|
||||
assertion=ROOT_GROUPS.find(f"{{{NS_SAML_ASSERTION}}}Assertion"),
|
||||
name_id=NAME_ID,
|
||||
)
|
||||
self.assertEqual(properties["groups"], ["group 1", "group 2"])
|
||||
for group_id in ["group 1", "group 2"]:
|
||||
properties = self.source.get_base_group_properties(root=ROOT, group_id=group_id)
|
||||
|
||||
@@ -164,6 +164,31 @@ class TestResponseProcessor(TestCase):
|
||||
parser = ResponseProcessor(self.source, request)
|
||||
parser.parse()
|
||||
|
||||
def test_verification_assertion_duplicate(self):
|
||||
"""Test verifying signature inside assertion, where the response has another assertion
|
||||
before our signed assertion"""
|
||||
key = load_fixture("fixtures/signature_cert.pem")
|
||||
kp = CertificateKeyPair.objects.create(
|
||||
name=generate_id(),
|
||||
certificate_data=key,
|
||||
)
|
||||
self.source.verification_kp = kp
|
||||
self.source.signed_assertion = True
|
||||
self.source.signed_response = False
|
||||
request = self.factory.post(
|
||||
"/",
|
||||
data={
|
||||
"SAMLResponse": b64encode(
|
||||
load_fixture("fixtures/response_signed_assertion_dup.xml").encode()
|
||||
).decode()
|
||||
},
|
||||
)
|
||||
|
||||
parser = ResponseProcessor(self.source, request)
|
||||
parser.parse()
|
||||
self.assertNotEqual(parser._get_name_id().text, "bad")
|
||||
self.assertEqual(parser._get_name_id().text, "_ce3d2948b4cf20146dee0a0b3dd6f69b6cf86f62d7")
|
||||
|
||||
def test_verification_response(self):
|
||||
"""Test verifying signature inside response"""
|
||||
key = load_fixture("fixtures/signature_cert.pem")
|
||||
|
||||
@@ -38,6 +38,7 @@ from authentik.stages.authenticator_validate.models import AuthenticatorValidate
|
||||
from authentik.stages.authenticator_webauthn.models import UserVerification, WebAuthnDevice
|
||||
from authentik.stages.authenticator_webauthn.stage import PLAN_CONTEXT_WEBAUTHN_CHALLENGE
|
||||
from authentik.stages.authenticator_webauthn.utils import get_origin, get_rp_id
|
||||
from authentik.stages.password.stage import PLAN_CONTEXT_METHOD_ARGS
|
||||
|
||||
LOGGER = get_logger()
|
||||
if TYPE_CHECKING:
|
||||
@@ -143,7 +144,11 @@ def validate_challenge_code(code: str, stage_view: StageView, user: User) -> Dev
|
||||
credentials={"username": user.username},
|
||||
request=stage_view.request,
|
||||
stage=stage_view.executor.current_stage,
|
||||
device_class=DeviceClasses.TOTP.value,
|
||||
context={
|
||||
PLAN_CONTEXT_METHOD_ARGS: {
|
||||
"device_class": DeviceClasses.TOTP.value,
|
||||
}
|
||||
},
|
||||
)
|
||||
raise ValidationError(
|
||||
_("Invalid Token. Please ensure the time on your device is accurate and try again.")
|
||||
@@ -215,9 +220,13 @@ def validate_challenge_webauthn(
|
||||
credentials={"username": user.username},
|
||||
request=stage_view.request,
|
||||
stage=stage_view.executor.current_stage,
|
||||
device=device,
|
||||
device_class=DeviceClasses.WEBAUTHN.value,
|
||||
device_type=device.device_type,
|
||||
context={
|
||||
PLAN_CONTEXT_METHOD_ARGS: {
|
||||
"device": device,
|
||||
"device_class": DeviceClasses.WEBAUTHN.value,
|
||||
"device_type": device.device_type,
|
||||
},
|
||||
},
|
||||
)
|
||||
raise ValidationError("Assertion failed") from exc
|
||||
|
||||
@@ -267,8 +276,12 @@ def validate_challenge_duo(device_pk: int, stage_view: StageView, user: User) ->
|
||||
credentials={"username": user.username},
|
||||
request=stage_view.request,
|
||||
stage=stage_view.executor.current_stage,
|
||||
device_class=DeviceClasses.DUO.value,
|
||||
duo_response=response,
|
||||
context={
|
||||
PLAN_CONTEXT_METHOD_ARGS: {
|
||||
"device_class": DeviceClasses.DUO.value,
|
||||
"duo_response": response,
|
||||
}
|
||||
},
|
||||
)
|
||||
raise ValidationError("Duo denied access", code="denied")
|
||||
return device
|
||||
|
||||
@@ -99,6 +99,7 @@ class IdentificationChallenge(Challenge):
|
||||
password_fields = BooleanField()
|
||||
allow_show_password = BooleanField(default=False)
|
||||
application_pre = CharField(required=False)
|
||||
application_pre_launch = CharField(required=False)
|
||||
flow_designation = ChoiceField(FlowDesignation.choices)
|
||||
captcha_stage = CaptchaChallenge(required=False, allow_null=True)
|
||||
|
||||
@@ -348,9 +349,12 @@ class IdentificationStageView(ChallengeStageView):
|
||||
# If the user has been redirected to us whilst trying to access an
|
||||
# application, PLAN_CONTEXT_APPLICATION is set in the flow plan
|
||||
if PLAN_CONTEXT_APPLICATION in self.executor.plan.context:
|
||||
challenge.initial_data["application_pre"] = self.executor.plan.context.get(
|
||||
app: Application = self.executor.plan.context.get(
|
||||
PLAN_CONTEXT_APPLICATION, Application()
|
||||
).name
|
||||
)
|
||||
challenge.initial_data["application_pre"] = app.name
|
||||
if launch_url := app.get_launch_url():
|
||||
challenge.initial_data["application_pre_launch"] = launch_url
|
||||
if (
|
||||
PLAN_CONTEXT_DEVICE in self.executor.plan.context
|
||||
and PLAN_CONTEXT_DEVICE_AUTH_TOKEN in self.executor.plan.context
|
||||
|
||||
@@ -6,6 +6,7 @@ from django.contrib.auth.views import redirect_to_login
|
||||
from django.http.request import HttpRequest
|
||||
from structlog.stdlib import get_logger
|
||||
|
||||
from authentik.core.middleware import get_user
|
||||
from authentik.core.models import Session
|
||||
from authentik.events.context_processors.asn import ASN_CONTEXT_PROCESSOR
|
||||
from authentik.events.context_processors.geoip import GEOIP_CONTEXT_PROCESSOR
|
||||
@@ -54,11 +55,13 @@ class SessionBindingBroken(SentryIgnoredException):
|
||||
|
||||
def logout_extra(request: HttpRequest, exc: SessionBindingBroken):
|
||||
"""Similar to django's logout method, but able to carry more info to the signal"""
|
||||
# Dispatch the signal before the user is logged out so the receivers have a
|
||||
# chance to find out *who* logged out.
|
||||
user = getattr(request, "user", None)
|
||||
# Since this middleware runs before the AuthenticationMiddleware, we can't use `request.user`
|
||||
# as it hasn't been populated yet.
|
||||
user = get_user(request)
|
||||
if not getattr(user, "is_authenticated", True):
|
||||
user = None
|
||||
# Dispatch the signal before the user is logged out so the receivers have a
|
||||
# chance to find out *who* logged out.
|
||||
user_logged_out.send(
|
||||
sender=user.__class__, request=request, user=user, event_extra=exc.to_event()
|
||||
)
|
||||
|
||||
@@ -10,6 +10,8 @@ from django.utils.timezone import now
|
||||
from authentik.blueprints.tests import apply_blueprint
|
||||
from authentik.core.models import AuthenticatedSession, Session
|
||||
from authentik.core.tests.utils import create_test_flow, create_test_user
|
||||
from authentik.events.models import Event, EventAction
|
||||
from authentik.events.utils import get_user
|
||||
from authentik.flows.markers import StageMarker
|
||||
from authentik.flows.models import FlowDesignation, FlowStageBinding
|
||||
from authentik.flows.planner import PLAN_CONTEXT_PENDING_USER, FlowPlan
|
||||
@@ -270,6 +272,7 @@ class TestUserLoginStage(FlowTestCase):
|
||||
|
||||
def test_session_binding_broken(self):
|
||||
"""Test session binding"""
|
||||
Event.objects.all().delete()
|
||||
self.client.force_login(self.user)
|
||||
session = self.client.session
|
||||
session[Session.Keys.LAST_IP] = "192.0.2.1"
|
||||
@@ -285,3 +288,5 @@ class TestUserLoginStage(FlowTestCase):
|
||||
)
|
||||
+ f"?{NEXT_ARG_NAME}={reverse("authentik_api:user-me")}",
|
||||
)
|
||||
event = Event.objects.filter(action=EventAction.LOGOUT).first()
|
||||
self.assertEqual(event.user, get_user(self.user))
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
"$schema": "http://json-schema.org/draft-07/schema",
|
||||
"$id": "https://goauthentik.io/blueprints/schema.json",
|
||||
"type": "object",
|
||||
"title": "authentik 2026.2.0-rc1 Blueprint schema",
|
||||
"title": "authentik 2026.2.2-rc1 Blueprint schema",
|
||||
"required": [
|
||||
"version",
|
||||
"entries"
|
||||
|
||||
@@ -29,7 +29,7 @@ entries:
|
||||
password=request.user.password
|
||||
)
|
||||
# ...otherwise we set an immutable ID based on the user's UID
|
||||
user["on_premises_immutable_id"] = request.user.uid,
|
||||
user["on_premises_immutable_id"] = request.user.uid
|
||||
return user
|
||||
- identifiers:
|
||||
managed: goauthentik.io/providers/microsoft_entra/group
|
||||
|
||||
@@ -104,7 +104,11 @@ type OutpostConfig struct {
|
||||
}
|
||||
|
||||
type WebConfig struct {
|
||||
Path string `yaml:"path" env:"PATH, overwrite"`
|
||||
Path string `yaml:"path" env:"PATH, overwrite"`
|
||||
TimeoutHttpReadHeader string `yaml:"timeout_http_read_header" env:"TIMEOUT_HTTP_READ_HEADER, overwrite"`
|
||||
TimeoutHttpRead string `yaml:"timeout_http_read" env:"TIMEOUT_HTTP_READ, overwrite"`
|
||||
TimeoutHttpWrite string `yaml:"timeout_http_write" env:"TIMEOUT_HTTP_WRITE, overwrite"`
|
||||
TimeoutHttpIdle string `yaml:"timeout_http_idle" env:"TIMEOUT_HTTP_IDLE, overwrite"`
|
||||
}
|
||||
|
||||
type LogConfig struct {
|
||||
|
||||
@@ -1 +1 @@
|
||||
2026.2.0-rc1
|
||||
2026.2.2-rc1
|
||||
@@ -83,10 +83,6 @@ func (ds *DirectSearcher) Search(req *search.Request) (ldap.ServerSearchResult,
|
||||
|
||||
entries := make([]*ldap.Entry, 0)
|
||||
|
||||
// Create a custom client to set additional headers
|
||||
c := api.NewAPIClient(ds.si.GetAPIClient().GetConfig())
|
||||
c.GetConfig().AddDefaultHeader("X-authentik-outpost-ldap-query", req.Filter)
|
||||
|
||||
scope := req.Scope
|
||||
needUsers, needGroups := ds.si.GetNeededObjects(scope, req.BaseDN, req.FilterObjectClass)
|
||||
|
||||
@@ -113,7 +109,7 @@ func (ds *DirectSearcher) Search(req *search.Request) (ldap.ServerSearchResult,
|
||||
errs.Go(func() error {
|
||||
if flags.CanSearch {
|
||||
uapisp := sentry.StartSpan(errCtx, "authentik.providers.ldap.search.api_user")
|
||||
searchReq, skip := utils.ParseFilterForUser(c.CoreAPI.CoreUsersList(uapisp.Context()).IncludeGroups(true), parsedFilter, false)
|
||||
searchReq, skip := utils.ParseFilterForUser(ds.si.GetAPIClient().CoreAPI.CoreUsersList(uapisp.Context()).IncludeGroups(true), parsedFilter, false)
|
||||
|
||||
if skip {
|
||||
req.Log().Trace("Skip backend request")
|
||||
@@ -132,7 +128,7 @@ func (ds *DirectSearcher) Search(req *search.Request) (ldap.ServerSearchResult,
|
||||
} else {
|
||||
if flags.UserInfo == nil {
|
||||
uapisp := sentry.StartSpan(errCtx, "authentik.providers.ldap.search.api_user")
|
||||
u, _, err := c.CoreAPI.CoreUsersRetrieve(uapisp.Context(), flags.UserPk).Execute()
|
||||
u, _, err := ds.si.GetAPIClient().CoreAPI.CoreUsersRetrieve(uapisp.Context(), flags.UserPk).Execute()
|
||||
uapisp.Finish()
|
||||
|
||||
if err != nil {
|
||||
@@ -155,7 +151,7 @@ func (ds *DirectSearcher) Search(req *search.Request) (ldap.ServerSearchResult,
|
||||
if needGroups {
|
||||
errs.Go(func() error {
|
||||
gapisp := sentry.StartSpan(errCtx, "authentik.providers.ldap.search.api_group")
|
||||
searchReq, skip := utils.ParseFilterForGroup(c.CoreAPI.CoreGroupsList(gapisp.Context()).IncludeUsers(true).IncludeChildren(true).IncludeParents(true), parsedFilter, false)
|
||||
searchReq, skip := utils.ParseFilterForGroup(ds.si.GetAPIClient().CoreAPI.CoreGroupsList(gapisp.Context()).IncludeUsers(true).IncludeChildren(true).IncludeParents(true), parsedFilter, false)
|
||||
if skip {
|
||||
req.Log().Trace("Skip backend request")
|
||||
return nil
|
||||
|
||||
@@ -26,7 +26,6 @@ import (
|
||||
"goauthentik.io/api/v3"
|
||||
"goauthentik.io/internal/config"
|
||||
"goauthentik.io/internal/outpost/ak"
|
||||
"goauthentik.io/internal/outpost/proxyv2/constants"
|
||||
"goauthentik.io/internal/outpost/proxyv2/hs256"
|
||||
"goauthentik.io/internal/outpost/proxyv2/metrics"
|
||||
"goauthentik.io/internal/outpost/proxyv2/templates"
|
||||
@@ -294,22 +293,16 @@ func (a *Application) Stop() {
|
||||
|
||||
func (a *Application) handleSignOut(rw http.ResponseWriter, r *http.Request) {
|
||||
redirect := a.endpoint.EndSessionEndpoint
|
||||
s, err := a.sessions.Get(r, a.SessionName())
|
||||
if err != nil {
|
||||
cc := a.getClaimsFromSession(rw, r)
|
||||
if cc == nil {
|
||||
a.redirectToStart(rw, r)
|
||||
return
|
||||
}
|
||||
c, exists := s.Values[constants.SessionClaims]
|
||||
if c == nil && !exists {
|
||||
a.redirectToStart(rw, r)
|
||||
return
|
||||
}
|
||||
cc := c.(types.Claims)
|
||||
uv := url.Values{
|
||||
"id_token_hint": []string{cc.RawToken},
|
||||
}
|
||||
redirect += "?" + uv.Encode()
|
||||
err = a.Logout(r.Context(), func(c types.Claims) bool {
|
||||
err := a.Logout(r.Context(), func(c types.Claims) bool {
|
||||
return c.Sub == cc.Sub
|
||||
})
|
||||
if err != nil {
|
||||
|
||||
@@ -76,7 +76,7 @@ func (a *Application) redirectToStart(rw http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
}
|
||||
|
||||
redirectUrl := urlJoin(a.proxyConfig.ExternalHost, r.URL.Path)
|
||||
redirectUrl := urlJoin(a.proxyConfig.ExternalHost, r.URL.EscapedPath())
|
||||
|
||||
if a.Mode() == api.PROXYMODE_FORWARD_DOMAIN {
|
||||
dom := strings.TrimPrefix(*a.proxyConfig.CookieDomain, ".")
|
||||
|
||||
@@ -27,6 +27,24 @@ func TestRedirectToStart_Proxy(t *testing.T) {
|
||||
assert.Equal(t, "https://test.goauthentik.io/foo/bar/baz", s.Values[constants.SessionRedirect])
|
||||
}
|
||||
|
||||
func TestRedirectToStart_Proxy_EncodedSlash(t *testing.T) {
|
||||
a := newTestApplication()
|
||||
a.proxyConfig.Mode = api.PROXYMODE_PROXY.Ptr()
|
||||
a.proxyConfig.ExternalHost = "https://test.goauthentik.io"
|
||||
// %2F is a URL-encoded forward slash, used by apps like RabbitMQ in queue paths
|
||||
req, _ := http.NewRequest("GET", "/api/queues/%2F/MYChannelCreated", nil)
|
||||
|
||||
rr := httptest.NewRecorder()
|
||||
a.redirectToStart(rr, req)
|
||||
|
||||
assert.Equal(t, http.StatusFound, rr.Code)
|
||||
loc, _ := rr.Result().Location()
|
||||
assert.Contains(t, loc.String(), "%252F", "encoded slash %2F must be preserved in redirect URL")
|
||||
|
||||
s, _ := a.sessions.Get(req, a.SessionName())
|
||||
assert.Contains(t, s.Values[constants.SessionRedirect].(string), "%2F", "encoded slash %2F must be preserved in session redirect")
|
||||
}
|
||||
|
||||
func TestRedirectToStart_Forward(t *testing.T) {
|
||||
a := newTestApplication()
|
||||
a.proxyConfig.Mode = api.PROXYMODE_FORWARD_SINGLE.Ptr()
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user