mirror of
https://github.com/goauthentik/authentik
synced 2026-05-15 03:16:22 +02:00
Compare commits
116 Commits
sdko/remov
...
version/20
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
0d617e4ad1 | ||
|
|
4adc0eaf8e | ||
|
|
7de405db6d | ||
|
|
50b291d6c4 | ||
|
|
14005fe781 | ||
|
|
591153b6cd | ||
|
|
864856733e | ||
|
|
1b66803a31 | ||
|
|
d8579b02ed | ||
|
|
f98d464323 | ||
|
|
7828facc41 | ||
|
|
ffe2bde51f | ||
|
|
f6dcdd059c | ||
|
|
2629759293 | ||
|
|
1b9bd8d4af | ||
|
|
c0e5ac3127 | ||
|
|
53f4bd613f | ||
|
|
83e41efe07 | ||
|
|
ad569be1d5 | ||
|
|
064866ccc7 | ||
|
|
36593d4700 | ||
|
|
2857e4df95 | ||
|
|
28b4a927ef | ||
|
|
7a20845a03 | ||
|
|
76ca2fbf77 | ||
|
|
e4e8bc57f1 | ||
|
|
15380dee37 | ||
|
|
b4844f8800 | ||
|
|
e9ff4f79ca | ||
|
|
92fb2f0f2b | ||
|
|
f80ce9dd6c | ||
|
|
a233feec29 | ||
|
|
bc9215a2ff | ||
|
|
263a2bca6d | ||
|
|
4cc71ef161 | ||
|
|
f66c535ae0 | ||
|
|
893325a7b7 | ||
|
|
a62c73d6f1 | ||
|
|
483710a59c | ||
|
|
b8b7584e8e | ||
|
|
2fedc3d0a0 | ||
|
|
7f0b45f921 | ||
|
|
3905c281ad | ||
|
|
e6099d43f5 | ||
|
|
a91145bc7b | ||
|
|
3f38d5c7d9 | ||
|
|
c00df0573c | ||
|
|
c3a0edee00 | ||
|
|
8b81ca36ea | ||
|
|
698de68a36 | ||
|
|
db35593b24 | ||
|
|
445fa31b57 | ||
|
|
a9aa1bf2c2 | ||
|
|
d018f0381c | ||
|
|
7dd1cd5c59 | ||
|
|
c219a6804a | ||
|
|
d9310d04b0 | ||
|
|
f471ef0e2e | ||
|
|
31a010c108 | ||
|
|
96e6ab291e | ||
|
|
ebf68311c2 | ||
|
|
fd365b2a09 | ||
|
|
41104da41f | ||
|
|
7edebdec03 | ||
|
|
fb56a54eb1 | ||
|
|
31cd6eb8ce | ||
|
|
092c5eb33c | ||
|
|
3e41bba54d | ||
|
|
9f8fd6eabe | ||
|
|
35fb55da15 | ||
|
|
b1d571a5af | ||
|
|
fb589592b5 | ||
|
|
6468bb5707 | ||
|
|
70406664dc | ||
|
|
c58c194180 | ||
|
|
fad87741e7 | ||
|
|
f6679895e5 | ||
|
|
a573a72ecb | ||
|
|
b72709ebbc | ||
|
|
449742fbc0 | ||
|
|
1b02cc0dae | ||
|
|
b0945ee7e9 | ||
|
|
6682136af1 | ||
|
|
24cb5ae4c1 | ||
|
|
9e272c7121 | ||
|
|
5dc7b7cdae | ||
|
|
2e2c52e49c | ||
|
|
38f1ef0506 | ||
|
|
3517562549 | ||
|
|
cdbe40143d | ||
|
|
5816f0d17c | ||
|
|
907ea8b2e9 | ||
|
|
b38af89960 | ||
|
|
d52db187bf | ||
|
|
2093e0e63f | ||
|
|
2791d87ceb | ||
|
|
fdc3d95b59 | ||
|
|
de7a61cee0 | ||
|
|
f2805b9b8a | ||
|
|
f48a91fbf4 | ||
|
|
f056c0808d | ||
|
|
06a6d45139 | ||
|
|
0e12642f12 | ||
|
|
01406d364e | ||
|
|
b9b16dba59 | ||
|
|
1ef83f3295 | ||
|
|
343506d104 | ||
|
|
aeb4e1057e | ||
|
|
0bcd1c268c | ||
|
|
ecba1ffe94 | ||
|
|
b7d303936c | ||
|
|
c1bc2a4565 | ||
|
|
1422c3aff3 | ||
|
|
d4a77583ea | ||
|
|
78d270bf25 | ||
|
|
6d1c7f90e2 |
4
.github/dependabot.yml
vendored
4
.github/dependabot.yml
vendored
@@ -142,7 +142,9 @@ updates:
|
||||
labels:
|
||||
- dependencies
|
||||
- package-ecosystem: docker
|
||||
directory: "/"
|
||||
directories:
|
||||
- /
|
||||
- /website
|
||||
schedule:
|
||||
interval: daily
|
||||
time: "04:00"
|
||||
|
||||
1
.github/workflows/api-ts-publish.yml
vendored
1
.github/workflows/api-ts-publish.yml
vendored
@@ -15,7 +15,6 @@ permissions:
|
||||
|
||||
jobs:
|
||||
build:
|
||||
if: ${{ github.repository != 'goauthentik/authentik-internal' }}
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- id: generate_token
|
||||
|
||||
1
.github/workflows/ci-docs-source.yml
vendored
1
.github/workflows/ci-docs-source.yml
vendored
@@ -13,7 +13,6 @@ env:
|
||||
|
||||
jobs:
|
||||
publish-source-docs:
|
||||
if: ${{ github.repository != 'goauthentik/authentik-internal' }}
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 120
|
||||
steps:
|
||||
|
||||
2
.github/workflows/ci-docs.yml
vendored
2
.github/workflows/ci-docs.yml
vendored
@@ -61,7 +61,6 @@ jobs:
|
||||
working-directory: website/
|
||||
run: npm run build -w integrations
|
||||
build-container:
|
||||
if: ${{ github.repository != 'goauthentik/authentik-internal' }}
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
# Needed to upload container images to ghcr.io
|
||||
@@ -121,4 +120,3 @@ jobs:
|
||||
- uses: re-actors/alls-green@release/v1
|
||||
with:
|
||||
jobs: ${{ toJSON(needs) }}
|
||||
allowed-skips: ${{ github.repository == 'goauthentik/authentik-internal' && 'build-container' || '[]' }}
|
||||
|
||||
1
.github/workflows/ci-main-daily.yml
vendored
1
.github/workflows/ci-main-daily.yml
vendored
@@ -9,7 +9,6 @@ on:
|
||||
|
||||
jobs:
|
||||
test-container:
|
||||
if: ${{ github.repository != 'goauthentik/authentik-internal' }}
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
fail-fast: false
|
||||
|
||||
10
.github/workflows/ci-main.yml
vendored
10
.github/workflows/ci-main.yml
vendored
@@ -80,7 +80,15 @@ jobs:
|
||||
cp authentik/lib/default.yml local.env.yml
|
||||
cp -R .github ..
|
||||
cp -R scripts ..
|
||||
git checkout $(git tag --sort=version:refname | grep '^version/' | grep -vE -- '-rc[0-9]+$' | tail -n1)
|
||||
# Previous stable tag
|
||||
prev_stable=$(git tag --sort=version:refname | grep '^version/' | grep -vE -- '-rc[0-9]+$' | tail -n1)
|
||||
# Current version family based on
|
||||
current_version_family=$(python -c "from authentik import VERSION; print(VERSION)" | grep -vE -- 'rc[0-9]+$')
|
||||
if [[ -n $current_version_family ]]; then
|
||||
prev_stable=$current_version_family
|
||||
fi
|
||||
echo "::notice::Checking out ${prev_stable} as stable version..."
|
||||
git checkout $(prev_stable)
|
||||
rm -rf .github/ scripts/
|
||||
mv ../.github ../scripts .
|
||||
- name: Setup authentik env (stable)
|
||||
|
||||
1
.github/workflows/ci-outpost.yml
vendored
1
.github/workflows/ci-outpost.yml
vendored
@@ -67,7 +67,6 @@ jobs:
|
||||
with:
|
||||
jobs: ${{ toJSON(needs) }}
|
||||
build-container:
|
||||
if: ${{ github.repository != 'goauthentik/authentik-internal' }}
|
||||
timeout-minutes: 120
|
||||
needs:
|
||||
- ci-outpost-mark
|
||||
|
||||
@@ -13,7 +13,6 @@ env:
|
||||
|
||||
jobs:
|
||||
build:
|
||||
if: ${{ github.repository != 'goauthentik/authentik-internal' }}
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- id: generate_token
|
||||
|
||||
15
.github/workflows/gh-ghcr-retention.yml
vendored
15
.github/workflows/gh-ghcr-retention.yml
vendored
@@ -5,10 +5,13 @@ on:
|
||||
# schedule:
|
||||
# - cron: "0 0 * * *" # every day at midnight
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
dry-run:
|
||||
type: boolean
|
||||
description: Enable dry-run mode
|
||||
|
||||
jobs:
|
||||
clean-ghcr:
|
||||
if: ${{ github.repository != 'goauthentik/authentik-internal' }}
|
||||
name: Delete old unused container images
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
@@ -18,12 +21,12 @@ jobs:
|
||||
app-id: ${{ secrets.GH_APP_ID }}
|
||||
private-key: ${{ secrets.GH_APP_PRIVATE_KEY }}
|
||||
- name: Delete 'dev' containers older than a week
|
||||
uses: snok/container-retention-policy@3b0972b2276b171b212f8c4efbca59ebba26eceb # v2
|
||||
uses: snok/container-retention-policy@3b0972b2276b171b212f8c4efbca59ebba26eceb # v3.0.1
|
||||
with:
|
||||
image-names: dev-server,dev-ldap,dev-proxy
|
||||
image-tags: "!gh-next,!gh-main"
|
||||
cut-off: One week ago UTC
|
||||
account-type: org
|
||||
org-name: goauthentik
|
||||
untagged-only: false
|
||||
account: goauthentik
|
||||
tag-selection: untagged
|
||||
token: ${{ steps.generate_token.outputs.token }}
|
||||
skip-tags: gh-next,gh-main
|
||||
dry-run: ${{ inputs.dry-run }}
|
||||
|
||||
1
.github/workflows/packages-npm-publish.yml
vendored
1
.github/workflows/packages-npm-publish.yml
vendored
@@ -19,7 +19,6 @@ permissions:
|
||||
|
||||
jobs:
|
||||
publish:
|
||||
if: ${{ github.repository != 'goauthentik/authentik-internal' }}
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
fail-fast: false
|
||||
|
||||
1
.github/workflows/release-next-branch.yml
vendored
1
.github/workflows/release-next-branch.yml
vendored
@@ -12,7 +12,6 @@ permissions:
|
||||
|
||||
jobs:
|
||||
update-next:
|
||||
if: ${{ github.repository != 'goauthentik/authentik-internal' }}
|
||||
runs-on: ubuntu-latest
|
||||
environment: internal-production
|
||||
steps:
|
||||
|
||||
2
.github/workflows/release-tag.yml
vendored
2
.github/workflows/release-tag.yml
vendored
@@ -87,7 +87,7 @@ jobs:
|
||||
git tag "version/${{ inputs.version }}" HEAD -m "version/${{ inputs.version }}"
|
||||
git push --follow-tags
|
||||
- name: Create Release
|
||||
uses: softprops/action-gh-release@6da8fa9354ddfdc4aeace5fc48d7f679b5214090 # v2
|
||||
uses: softprops/action-gh-release@5be0e66d93ac7ed76da52eca8bb058f665c3a5fe # v2.4.2
|
||||
with:
|
||||
token: "${{ steps.app-token.outputs.token }}"
|
||||
tag_name: "version/${{ inputs.version }}"
|
||||
|
||||
22
.github/workflows/repo-mirror-cleanup.yml
vendored
22
.github/workflows/repo-mirror-cleanup.yml
vendored
@@ -1,22 +0,0 @@
|
||||
---
|
||||
name: Repo - Cleanup internal mirror
|
||||
|
||||
on:
|
||||
workflow_dispatch:
|
||||
|
||||
jobs:
|
||||
to_internal:
|
||||
if: ${{ github.repository != 'goauthentik/authentik-internal' }}
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5
|
||||
with:
|
||||
fetch-depth: 0
|
||||
- if: ${{ env.MIRROR_KEY != '' }}
|
||||
uses: BeryJu/repository-mirroring-action@5cf300935bc2e068f73ea69bcc411a8a997208eb # 5cf300935bc2e068f73ea69bcc411a8a997208eb
|
||||
with:
|
||||
target_repo_url: git@github.com:goauthentik/authentik-internal.git
|
||||
ssh_private_key: ${{ secrets.GH_MIRROR_KEY }}
|
||||
args: --tags --force --prune
|
||||
env:
|
||||
MIRROR_KEY: ${{ secrets.GH_MIRROR_KEY }}
|
||||
21
.github/workflows/repo-mirror.yml
vendored
21
.github/workflows/repo-mirror.yml
vendored
@@ -1,21 +0,0 @@
|
||||
---
|
||||
name: Repo - Mirror to internal
|
||||
|
||||
on: [push, delete]
|
||||
|
||||
jobs:
|
||||
to_internal:
|
||||
if: ${{ github.repository != 'goauthentik/authentik-internal' }}
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5
|
||||
with:
|
||||
fetch-depth: 0
|
||||
- if: ${{ env.MIRROR_KEY != '' }}
|
||||
uses: BeryJu/repository-mirroring-action@5cf300935bc2e068f73ea69bcc411a8a997208eb # 5cf300935bc2e068f73ea69bcc411a8a997208eb
|
||||
with:
|
||||
target_repo_url: git@github.com:goauthentik/authentik-internal.git
|
||||
ssh_private_key: ${{ secrets.GH_MIRROR_KEY }}
|
||||
args: --tags --force
|
||||
env:
|
||||
MIRROR_KEY: ${{ secrets.GH_MIRROR_KEY }}
|
||||
1
.github/workflows/repo-stale.yml
vendored
1
.github/workflows/repo-stale.yml
vendored
@@ -12,7 +12,6 @@ permissions:
|
||||
|
||||
jobs:
|
||||
stale:
|
||||
if: ${{ github.repository != 'goauthentik/authentik-internal' }}
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- id: generate_token
|
||||
|
||||
@@ -17,7 +17,6 @@ env:
|
||||
|
||||
jobs:
|
||||
compile:
|
||||
if: ${{ github.repository != 'goauthentik/authentik-internal' }}
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- id: generate_token
|
||||
|
||||
11
Dockerfile
11
Dockerfile
@@ -1,7 +1,7 @@
|
||||
# syntax=docker/dockerfile:1
|
||||
|
||||
# Stage 1: Build webui
|
||||
FROM --platform=${BUILDPLATFORM} docker.io/library/node:24-slim AS node-builder
|
||||
FROM --platform=${BUILDPLATFORM} docker.io/library/node:24-trixie-slim@sha256:45babd1b4ce0349fb12c4e24bf017b90b96d52806db32e001e3013f341bef0fe AS node-builder
|
||||
|
||||
ARG GIT_BUILD_HASH
|
||||
ENV GIT_BUILD_HASH=$GIT_BUILD_HASH
|
||||
@@ -26,7 +26,7 @@ RUN npm run build && \
|
||||
npm run build:sfe
|
||||
|
||||
# Stage 2: Build go proxy
|
||||
FROM --platform=${BUILDPLATFORM} docker.io/library/golang:1.25.3-bookworm AS go-builder
|
||||
FROM --platform=${BUILDPLATFORM} docker.io/library/golang:1.25.3-trixie@sha256:7534a6264850325fcce93e47b87a0e3fddd96b308440245e6ab1325fa8a44c91 AS go-builder
|
||||
|
||||
ARG TARGETOS
|
||||
ARG TARGETARCH
|
||||
@@ -63,7 +63,7 @@ RUN --mount=type=cache,sharing=locked,target=/go/pkg/mod \
|
||||
go build -o /go/authentik ./cmd/server
|
||||
|
||||
# Stage 3: MaxMind GeoIP
|
||||
FROM --platform=${BUILDPLATFORM} ghcr.io/maxmind/geoipupdate:v7.1.1 AS geoip
|
||||
FROM --platform=${BUILDPLATFORM} ghcr.io/maxmind/geoipupdate:v7.1.1@sha256:faecdca22579730ab0b7dea5aa9af350bb3c93cb9d39845c173639ead30346d2 AS geoip
|
||||
|
||||
ENV GEOIPUPDATE_EDITION_IDS="GeoLite2-City GeoLite2-ASN"
|
||||
ENV GEOIPUPDATE_VERBOSE="1"
|
||||
@@ -76,9 +76,9 @@ RUN --mount=type=secret,id=GEOIPUPDATE_ACCOUNT_ID \
|
||||
/bin/sh -c "GEOIPUPDATE_LICENSE_KEY_FILE=/run/secrets/GEOIPUPDATE_LICENSE_KEY /usr/bin/entry.sh || echo 'Failed to get GeoIP database, disabling'; exit 0"
|
||||
|
||||
# Stage 4: Download uv
|
||||
FROM ghcr.io/astral-sh/uv:0.9.4 AS uv
|
||||
FROM ghcr.io/astral-sh/uv:0.9.6@sha256:4b96ee9429583983fd172c33a02ecac5242d63fb46bc27804748e38c1cc9ad0d AS uv
|
||||
# Stage 5: Base python image
|
||||
FROM ghcr.io/goauthentik/fips-python:3.13.9-slim-trixie-fips AS python-base
|
||||
FROM ghcr.io/goauthentik/fips-python:3.13.9-slim-trixie-fips@sha256:700fc8c1e290bd14e5eaca50b1d8e8c748c820010559cbfb4c4f8dfbe2c4c9ff AS python-base
|
||||
|
||||
ENV VENV_PATH="/ak-root/.venv" \
|
||||
PATH="/lifecycle:/ak-root/.venv/bin:$PATH" \
|
||||
@@ -139,6 +139,7 @@ ARG GIT_BUILD_HASH
|
||||
ENV GIT_BUILD_HASH=$GIT_BUILD_HASH
|
||||
|
||||
LABEL org.opencontainers.image.authors="Authentik Security Inc." \
|
||||
org.opencontainers.image.source="https://github.com/goauthentik/authentik" \
|
||||
org.opencontainers.image.description="goauthentik.io Main server image, see https://goauthentik.io for more info." \
|
||||
org.opencontainers.image.documentation="https://docs.goauthentik.io" \
|
||||
org.opencontainers.image.licenses="https://github.com/goauthentik/authentik/blob/main/LICENSE" \
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
from functools import lru_cache
|
||||
from os import environ
|
||||
|
||||
VERSION = "2025.10.0-rc1"
|
||||
VERSION = "2025.10.3"
|
||||
ENV_GIT_HASH_KEY = "GIT_BUILD_HASH"
|
||||
|
||||
|
||||
|
||||
@@ -1,8 +1,11 @@
|
||||
"""Test brands"""
|
||||
|
||||
from json import loads
|
||||
|
||||
from django.urls import reverse
|
||||
from rest_framework.test import APITestCase
|
||||
|
||||
from authentik.blueprints.tests import apply_blueprint
|
||||
from authentik.brands.api import Themes
|
||||
from authentik.brands.models import Brand
|
||||
from authentik.core.models import Application
|
||||
@@ -23,6 +26,7 @@ class TestBrands(APITestCase):
|
||||
_flag = flag()
|
||||
if _flag.visibility == "public":
|
||||
self.default_flags[_flag.key] = _flag.get()
|
||||
Brand.objects.all().delete()
|
||||
|
||||
def test_current_brand(self):
|
||||
"""Test Current brand API"""
|
||||
@@ -44,7 +48,6 @@ class TestBrands(APITestCase):
|
||||
|
||||
def test_brand_subdomain(self):
|
||||
"""Test Current brand API"""
|
||||
Brand.objects.all().delete()
|
||||
Brand.objects.create(domain="bar.baz", branding_title="custom")
|
||||
self.assertJSONEqual(
|
||||
self.client.get(
|
||||
@@ -65,7 +68,6 @@ class TestBrands(APITestCase):
|
||||
|
||||
def test_fallback(self):
|
||||
"""Test fallback brand"""
|
||||
Brand.objects.all().delete()
|
||||
self.assertJSONEqual(
|
||||
self.client.get(reverse("authentik_api:brand-current")).content.decode(),
|
||||
{
|
||||
@@ -81,6 +83,109 @@ class TestBrands(APITestCase):
|
||||
},
|
||||
)
|
||||
|
||||
@apply_blueprint("default/default-brand.yaml")
|
||||
def test_blueprint(self):
|
||||
"""Test Current brand API"""
|
||||
response = loads(self.client.get(reverse("authentik_api:brand-current")).content.decode())
|
||||
response.pop("flow_authentication", None)
|
||||
response.pop("flow_invalidation", None)
|
||||
response.pop("flow_user_settings", None)
|
||||
self.assertEqual(
|
||||
response,
|
||||
{
|
||||
"branding_logo": "/static/dist/assets/icons/icon_left_brand.svg",
|
||||
"branding_favicon": "/static/dist/assets/icons/icon.png",
|
||||
"branding_title": "authentik",
|
||||
"branding_custom_css": "",
|
||||
"matched_domain": "authentik-default",
|
||||
"ui_footer_links": [],
|
||||
"ui_theme": Themes.AUTOMATIC,
|
||||
"default_locale": "",
|
||||
"flags": self.default_flags,
|
||||
},
|
||||
)
|
||||
|
||||
@apply_blueprint("default/default-brand.yaml")
|
||||
def test_blueprint_with_other_brand(self):
|
||||
"""Test Current brand API"""
|
||||
Brand.objects.create(domain="bar.baz", branding_title="custom")
|
||||
response = loads(self.client.get(reverse("authentik_api:brand-current")).content.decode())
|
||||
response.pop("flow_authentication", None)
|
||||
response.pop("flow_invalidation", None)
|
||||
response.pop("flow_user_settings", None)
|
||||
self.assertEqual(
|
||||
response,
|
||||
{
|
||||
"branding_logo": "/static/dist/assets/icons/icon_left_brand.svg",
|
||||
"branding_favicon": "/static/dist/assets/icons/icon.png",
|
||||
"branding_title": "authentik",
|
||||
"branding_custom_css": "",
|
||||
"matched_domain": "authentik-default",
|
||||
"ui_footer_links": [],
|
||||
"ui_theme": Themes.AUTOMATIC,
|
||||
"default_locale": "",
|
||||
"flags": self.default_flags,
|
||||
},
|
||||
)
|
||||
self.assertJSONEqual(
|
||||
self.client.get(
|
||||
reverse("authentik_api:brand-current"), HTTP_HOST="foo.bar.baz"
|
||||
).content.decode(),
|
||||
{
|
||||
"branding_logo": "/static/dist/assets/icons/icon_left_brand.svg",
|
||||
"branding_favicon": "/static/dist/assets/icons/icon.png",
|
||||
"branding_title": "custom",
|
||||
"branding_custom_css": "",
|
||||
"matched_domain": "bar.baz",
|
||||
"ui_footer_links": [],
|
||||
"ui_theme": Themes.AUTOMATIC,
|
||||
"default_locale": "",
|
||||
"flags": self.default_flags,
|
||||
},
|
||||
)
|
||||
|
||||
def test_brand_subdomain_same_suffix(self):
|
||||
"""Test Current brand API"""
|
||||
Brand.objects.create(domain="bar.baz", branding_title="custom-weak")
|
||||
Brand.objects.create(domain="foo.bar.baz", branding_title="custom-strong")
|
||||
self.assertJSONEqual(
|
||||
self.client.get(
|
||||
reverse("authentik_api:brand-current"), HTTP_HOST="foo.bar.baz"
|
||||
).content.decode(),
|
||||
{
|
||||
"branding_logo": "/static/dist/assets/icons/icon_left_brand.svg",
|
||||
"branding_favicon": "/static/dist/assets/icons/icon.png",
|
||||
"branding_title": "custom-strong",
|
||||
"branding_custom_css": "",
|
||||
"matched_domain": "foo.bar.baz",
|
||||
"ui_footer_links": [],
|
||||
"ui_theme": Themes.AUTOMATIC,
|
||||
"default_locale": "",
|
||||
"flags": self.default_flags,
|
||||
},
|
||||
)
|
||||
|
||||
def test_brand_subdomain_other_suffix(self):
|
||||
"""Test Current brand API"""
|
||||
Brand.objects.create(domain="bar.baz", branding_title="custom-weak")
|
||||
Brand.objects.create(domain="foo.bar.baz", branding_title="custom-strong")
|
||||
self.assertJSONEqual(
|
||||
self.client.get(
|
||||
reverse("authentik_api:brand-current"), HTTP_HOST="other.bar.baz"
|
||||
).content.decode(),
|
||||
{
|
||||
"branding_logo": "/static/dist/assets/icons/icon_left_brand.svg",
|
||||
"branding_favicon": "/static/dist/assets/icons/icon.png",
|
||||
"branding_title": "custom-weak",
|
||||
"branding_custom_css": "",
|
||||
"matched_domain": "bar.baz",
|
||||
"ui_footer_links": [],
|
||||
"ui_theme": Themes.AUTOMATIC,
|
||||
"default_locale": "",
|
||||
"flags": self.default_flags,
|
||||
},
|
||||
)
|
||||
|
||||
def test_create_default_multiple(self):
|
||||
"""Test attempted creation of multiple default brands"""
|
||||
Brand.objects.create(
|
||||
|
||||
@@ -2,8 +2,8 @@
|
||||
|
||||
from typing import Any
|
||||
|
||||
from django.db.models import F, Q
|
||||
from django.db.models import Value as V
|
||||
from django.db.models import Case, F, IntegerField, Q, Value, When
|
||||
from django.db.models.functions import Length
|
||||
from django.http.request import HttpRequest
|
||||
from django.utils.html import _json_script_escapes
|
||||
from django.utils.safestring import mark_safe
|
||||
@@ -19,15 +19,36 @@ DEFAULT_BRAND = Brand(domain="fallback")
|
||||
|
||||
def get_brand_for_request(request: HttpRequest) -> Brand:
|
||||
"""Get brand object for current request"""
|
||||
db_brands = (
|
||||
Brand.objects.annotate(host_domain=V(request.get_host()))
|
||||
.filter(Q(host_domain__iendswith=F("domain")) | _q_default)
|
||||
.order_by("default")
|
||||
|
||||
brand = (
|
||||
Brand.objects.annotate(
|
||||
host_domain=Value(request.get_host()),
|
||||
domain_length=Length("domain"),
|
||||
match_priority=Case(
|
||||
When(
|
||||
condition=Q(host_domain__iendswith=F("domain")),
|
||||
then=F("domain_length"),
|
||||
),
|
||||
default=Value(-1),
|
||||
output_field=IntegerField(),
|
||||
),
|
||||
is_default_fallback=Case(
|
||||
When(
|
||||
condition=Q(default=True),
|
||||
then=Value(0),
|
||||
),
|
||||
default=Value(-2),
|
||||
output_field=IntegerField(),
|
||||
),
|
||||
)
|
||||
.filter(Q(match_priority__gt=-1) | Q(default=True))
|
||||
.order_by("-match_priority", "-is_default_fallback")
|
||||
.first()
|
||||
)
|
||||
brands = list(db_brands.all())
|
||||
if len(brands) < 1:
|
||||
|
||||
if brand is None:
|
||||
return DEFAULT_BRAND
|
||||
return brands[0]
|
||||
return brand
|
||||
|
||||
|
||||
def context_processor(request: HttpRequest) -> dict[str, Any]:
|
||||
|
||||
@@ -4,7 +4,8 @@ from collections.abc import Iterator
|
||||
from copy import copy
|
||||
|
||||
from django.core.cache import cache
|
||||
from django.db.models import QuerySet
|
||||
from django.db.models import Case, QuerySet
|
||||
from django.db.models.expressions import When
|
||||
from django.shortcuts import get_object_or_404
|
||||
from django.utils.translation import gettext as _
|
||||
from drf_spectacular.types import OpenApiTypes
|
||||
@@ -23,6 +24,7 @@ from authentik.api.pagination import Pagination
|
||||
from authentik.blueprints.v1.importer import SERIALIZER_CONTEXT_BLUEPRINT
|
||||
from authentik.core.api.providers import ProviderSerializer
|
||||
from authentik.core.api.used_by import UsedByMixin
|
||||
from authentik.core.api.users import UserSerializer
|
||||
from authentik.core.api.utils import ModelSerializer
|
||||
from authentik.core.models import Application, User
|
||||
from authentik.events.logs import LogEventSerializer, capture_logs
|
||||
@@ -63,9 +65,21 @@ class ApplicationSerializer(ModelSerializer):
|
||||
def get_launch_url(self, app: Application) -> str | None:
|
||||
"""Allow formatting of launch URL"""
|
||||
user = None
|
||||
user_data = None
|
||||
|
||||
if "request" in self.context:
|
||||
user = self.context["request"].user
|
||||
return app.get_launch_url(user)
|
||||
|
||||
# Cache serialized user data to avoid N+1 when formatting launch URLs
|
||||
# for multiple applications. UserSerializer accesses user.ak_groups which
|
||||
# would otherwise trigger a query for each application.
|
||||
if user is not None:
|
||||
if "_cached_user_data" not in self.context:
|
||||
# Prefetch groups to avoid N+1
|
||||
self.context["_cached_user_data"] = UserSerializer(instance=user).data
|
||||
user_data = self.context["_cached_user_data"]
|
||||
|
||||
return app.get_launch_url(user, user_data=user_data)
|
||||
|
||||
def validate_slug(self, slug: str) -> str:
|
||||
if slug in Application.reserved_slugs:
|
||||
@@ -158,8 +172,23 @@ class ApplicationViewSet(UsedByMixin, ModelViewSet):
|
||||
applications.append(application)
|
||||
return applications
|
||||
|
||||
def _expand_applications(self, applications: list[Application]) -> QuerySet[Application]:
|
||||
"""
|
||||
Re-fetch with proper prefetching for serialization
|
||||
Cached applications don't have prefetched relationships, causing N+1 queries
|
||||
during serialization when get_provider() is called
|
||||
"""
|
||||
if not applications:
|
||||
return self.get_queryset().none()
|
||||
pks = [app.pk for app in applications]
|
||||
return (
|
||||
self.get_queryset()
|
||||
.filter(pk__in=pks)
|
||||
.order_by(Case(*[When(pk=pk, then=pos) for pos, pk in enumerate(pks)]))
|
||||
)
|
||||
|
||||
def _filter_applications_with_launch_url(
|
||||
self, paginated_apps: Iterator[Application]
|
||||
self, paginated_apps: QuerySet[Application]
|
||||
) -> list[Application]:
|
||||
applications = []
|
||||
for app in paginated_apps:
|
||||
@@ -262,6 +291,8 @@ class ApplicationViewSet(UsedByMixin, ModelViewSet):
|
||||
except ValueError as exc:
|
||||
raise ValidationError from exc
|
||||
allowed_applications = self._get_allowed_applications(paginated_apps, user=for_user)
|
||||
allowed_applications = self._expand_applications(allowed_applications)
|
||||
|
||||
serializer = self.get_serializer(allowed_applications, many=True)
|
||||
return self.get_paginated_response(serializer.data)
|
||||
|
||||
@@ -280,6 +311,7 @@ class ApplicationViewSet(UsedByMixin, ModelViewSet):
|
||||
allowed_applications,
|
||||
timeout=86400,
|
||||
)
|
||||
allowed_applications = self._expand_applications(allowed_applications)
|
||||
|
||||
if only_with_launch_url == "true":
|
||||
allowed_applications = self._filter_applications_with_launch_url(allowed_applications)
|
||||
|
||||
@@ -15,7 +15,7 @@ from django.db import models
|
||||
from django.db.models import Q, QuerySet, options
|
||||
from django.db.models.constants import LOOKUP_SEP
|
||||
from django.http import HttpRequest
|
||||
from django.utils.functional import SimpleLazyObject, cached_property
|
||||
from django.utils.functional import cached_property
|
||||
from django.utils.timezone import now
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
from django_cte import CTE, with_cte
|
||||
@@ -524,6 +524,10 @@ class ApplicationQuerySet(QuerySet):
|
||||
qs = self.select_related("provider")
|
||||
for subclass in Provider.objects.get_queryset()._get_subclasses_recurse(Provider):
|
||||
qs = qs.select_related(f"provider__{subclass}")
|
||||
# Also prefetch/select through each subclass path to ensure casted instances have access
|
||||
qs = qs.prefetch_related(f"provider__{subclass}__property_mappings")
|
||||
qs = qs.select_related(f"provider__{subclass}__application")
|
||||
qs = qs.select_related(f"provider__{subclass}__backchannel_application")
|
||||
return qs
|
||||
|
||||
|
||||
@@ -583,20 +587,28 @@ class Application(SerializerModel, PolicyBindingModel):
|
||||
return CONFIG.get("web.path", "/")[:-1] + self.meta_icon.name
|
||||
return self.meta_icon.url
|
||||
|
||||
def get_launch_url(self, user: Optional["User"] = None) -> str | None:
|
||||
"""Get launch URL if set, otherwise attempt to get launch URL based on provider."""
|
||||
def get_launch_url(
|
||||
self, user: Optional["User"] = None, user_data: dict | None = None
|
||||
) -> str | None:
|
||||
"""Get launch URL if set, otherwise attempt to get launch URL based on provider.
|
||||
|
||||
Args:
|
||||
user: User instance for formatting the URL
|
||||
user_data: Pre-serialized user data to avoid re-serialization (performance optimization)
|
||||
"""
|
||||
from authentik.core.api.users import UserSerializer
|
||||
|
||||
url = None
|
||||
if self.meta_launch_url:
|
||||
url = self.meta_launch_url
|
||||
elif provider := self.get_provider():
|
||||
url = provider.launch_url
|
||||
if user and url:
|
||||
if isinstance(user, SimpleLazyObject):
|
||||
user._setup()
|
||||
user = user._wrapped
|
||||
try:
|
||||
return url % user.__dict__
|
||||
|
||||
# Use pre-serialized data if available, otherwise serialize now
|
||||
if user_data is None:
|
||||
user_data = UserSerializer(instance=user).data
|
||||
return url % user_data
|
||||
except Exception as exc: # noqa
|
||||
LOGGER.warning("Failed to format launch url", exc=exc)
|
||||
return url
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
"""authentik core signals"""
|
||||
|
||||
from asgiref.sync import async_to_sync
|
||||
from channels.layers import get_channel_layer
|
||||
from django.contrib.auth.signals import user_logged_in
|
||||
from django.core.cache import cache
|
||||
from django.db.models import Model
|
||||
@@ -17,6 +19,8 @@ from authentik.core.models import (
|
||||
User,
|
||||
default_token_duration,
|
||||
)
|
||||
from authentik.flows.apps import RefreshOtherFlowsAfterAuthentication
|
||||
from authentik.root.ws.consumer import build_device_group
|
||||
|
||||
# Arguments: user: User, password: str
|
||||
password_changed = Signal()
|
||||
@@ -47,6 +51,16 @@ def user_logged_in_session(sender, request: HttpRequest, user: User, **_):
|
||||
if session:
|
||||
session.save()
|
||||
|
||||
if not RefreshOtherFlowsAfterAuthentication().get():
|
||||
return
|
||||
layer = get_channel_layer()
|
||||
device_cookie = request.COOKIES.get("authentik_device")
|
||||
if device_cookie:
|
||||
async_to_sync(layer.group_send)(
|
||||
build_device_group(device_cookie),
|
||||
{"type": "event.session.authenticated"},
|
||||
)
|
||||
|
||||
|
||||
@receiver(post_delete, sender=AuthenticatedSession)
|
||||
def authenticated_session_delete(sender: type[Model], instance: "AuthenticatedSession", **_):
|
||||
|
||||
@@ -28,8 +28,8 @@ from authentik.core.views.interface import (
|
||||
)
|
||||
from authentik.flows.views.interface import FlowInterfaceView
|
||||
from authentik.root.asgi_middleware import AuthMiddlewareStack
|
||||
from authentik.root.messages.consumer import MessageConsumer
|
||||
from authentik.root.middleware import ChannelsLoggingMiddleware
|
||||
from authentik.root.ws.consumer import MessageConsumer
|
||||
from authentik.tenants.channels import TenantsAwareMiddleware
|
||||
|
||||
urlpatterns = [
|
||||
|
||||
@@ -2,6 +2,8 @@
|
||||
|
||||
from binascii import hexlify
|
||||
from hashlib import md5
|
||||
from ssl import PEM_FOOTER, PEM_HEADER
|
||||
from textwrap import wrap
|
||||
from uuid import uuid4
|
||||
|
||||
from cryptography.hazmat.backends import default_backend
|
||||
@@ -20,6 +22,11 @@ from authentik.lib.models import CreatedUpdatedModel, SerializerModel
|
||||
LOGGER = get_logger()
|
||||
|
||||
|
||||
def format_cert(raw_pam: str) -> str:
|
||||
"""Format a PEM certificate that is either missing its header/footer or is in a single line"""
|
||||
return "\n".join([PEM_HEADER, *wrap(raw_pam.replace("\n", ""), 64), PEM_FOOTER])
|
||||
|
||||
|
||||
def fingerprint_sha256(cert: Certificate) -> str:
|
||||
"""Get SHA256 Fingerprint of certificate"""
|
||||
return hexlify(cert.fingerprint(hashes.SHA256()), ":").decode("utf-8")
|
||||
|
||||
@@ -1,11 +1,21 @@
|
||||
"""Enterprise app config"""
|
||||
|
||||
from django.conf import settings
|
||||
from prometheus_client import Gauge
|
||||
|
||||
from authentik.blueprints.apps import ManagedAppConfig
|
||||
from authentik.lib.utils.time import fqdn_rand
|
||||
from authentik.tasks.schedules.common import ScheduleSpec
|
||||
|
||||
GAUGE_LICENSE_USAGE = Gauge(
|
||||
"authentik_enterprise_license_usage",
|
||||
"Enterprise license usage (percentage per user type).",
|
||||
["user_type"],
|
||||
)
|
||||
GAUGE_LICENSE_EXPIRY = Gauge(
|
||||
"authentik_enterprise_license_expiry_seconds", "Duration until license expires, in seconds."
|
||||
)
|
||||
|
||||
|
||||
class EnterpriseConfig(ManagedAppConfig):
|
||||
"""Base app config for all enterprise apps"""
|
||||
|
||||
@@ -217,7 +217,7 @@ class LicenseKey:
|
||||
def summary(self) -> LicenseSummary:
|
||||
"""Summary of license status"""
|
||||
status = self.status()
|
||||
latest_valid = datetime.fromtimestamp(self.exp)
|
||||
latest_valid = datetime.fromtimestamp(self.exp).replace(tzinfo=UTC)
|
||||
return LicenseSummary(
|
||||
latest_valid=latest_valid,
|
||||
internal_users=self.internal_users,
|
||||
|
||||
@@ -42,6 +42,8 @@ def send_ssf_events(
|
||||
for stream in Stream.objects.filter(**stream_filter):
|
||||
event_data = stream.prepare_event_payload(event_type, data, **extra_data)
|
||||
events_data[stream.uuid] = event_data
|
||||
if not events_data:
|
||||
return
|
||||
ssf_events_dispatch.send(events_data)
|
||||
|
||||
|
||||
|
||||
@@ -1,18 +1,41 @@
|
||||
"""Enterprise signals"""
|
||||
|
||||
from datetime import datetime
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from django.core.cache import cache
|
||||
from django.db.models.signals import post_delete, post_save, pre_save
|
||||
from django.dispatch import receiver
|
||||
from django.utils.timezone import get_current_timezone
|
||||
from django.utils.timezone import get_current_timezone, now
|
||||
|
||||
from authentik.enterprise.license import CACHE_KEY_ENTERPRISE_LICENSE
|
||||
from authentik.enterprise.models import License
|
||||
from authentik.enterprise.apps import GAUGE_LICENSE_EXPIRY, GAUGE_LICENSE_USAGE
|
||||
from authentik.enterprise.license import CACHE_KEY_ENTERPRISE_LICENSE, LicenseKey
|
||||
from authentik.enterprise.models import License, LicenseUsageStatus
|
||||
from authentik.enterprise.tasks import enterprise_update_usage
|
||||
from authentik.root.monitoring import monitoring_set
|
||||
from authentik.tasks.schedules.models import Schedule
|
||||
|
||||
|
||||
@receiver(monitoring_set)
|
||||
def monitoring_set_enterprise(sender, **kwargs):
|
||||
"""set enterprise gauges"""
|
||||
summary = LicenseKey.cached_summary()
|
||||
if summary.status == LicenseUsageStatus.UNLICENSED:
|
||||
return
|
||||
percentage_internal = (
|
||||
0
|
||||
if summary.internal_users <= 0
|
||||
else LicenseKey.get_internal_user_count() / (summary.internal_users / 100)
|
||||
)
|
||||
percentage_external = (
|
||||
0
|
||||
if summary.external_users <= 0
|
||||
else LicenseKey.get_external_user_count() / (summary.external_users / 100)
|
||||
)
|
||||
GAUGE_LICENSE_USAGE.labels(user_type="internal").set(percentage_internal)
|
||||
GAUGE_LICENSE_USAGE.labels(user_type="external").set(percentage_external)
|
||||
GAUGE_LICENSE_EXPIRY.set((summary.latest_valid.replace(tzinfo=UTC) - now()).total_seconds())
|
||||
|
||||
|
||||
@receiver(pre_save, sender=License)
|
||||
def pre_save_license(sender: type[License], instance: License, **_):
|
||||
"""Extract data from license jwt and save it into model"""
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from binascii import hexlify
|
||||
from enum import IntFlag, auto
|
||||
from urllib.parse import unquote_plus
|
||||
|
||||
from cryptography.exceptions import InvalidSignature
|
||||
@@ -17,7 +18,7 @@ from django.utils.translation import gettext_lazy as _
|
||||
|
||||
from authentik.brands.models import Brand
|
||||
from authentik.core.models import User
|
||||
from authentik.crypto.models import CertificateKeyPair, fingerprint_sha256
|
||||
from authentik.crypto.models import CertificateKeyPair, fingerprint_sha256, format_cert
|
||||
from authentik.enterprise.stages.mtls.models import (
|
||||
CertAttributes,
|
||||
MutualTLSStage,
|
||||
@@ -43,14 +44,28 @@ HEADER_OUTPOST_FORWARDED = "X-Authentik-Outpost-Certificate"
|
||||
PLAN_CONTEXT_CERTIFICATE = "certificate"
|
||||
|
||||
|
||||
class ParseOptions(IntFlag):
|
||||
|
||||
# URL unquote the string
|
||||
UNQUOTE = auto()
|
||||
# Re-add PEM Header & footer, and chunk it into 64 character lines
|
||||
FORMAT = auto()
|
||||
|
||||
|
||||
class MTLSStageView(ChallengeStageView):
|
||||
|
||||
def __parse_single_cert(self, raw: str | None) -> list[Certificate]:
|
||||
def __parse_single_cert(self, raw: str | None, *options: ParseOptions) -> list[Certificate]:
|
||||
"""Helper to parse a single certificate"""
|
||||
if not raw:
|
||||
return []
|
||||
for opt in options:
|
||||
match opt:
|
||||
case ParseOptions.FORMAT:
|
||||
raw = format_cert(raw)
|
||||
case ParseOptions.UNQUOTE:
|
||||
raw = unquote_plus(raw)
|
||||
try:
|
||||
cert = load_pem_x509_certificate(unquote_plus(raw).encode())
|
||||
cert = load_pem_x509_certificate(raw.encode())
|
||||
return [cert]
|
||||
except ValueError as exc:
|
||||
self.logger.info("Failed to parse certificate", exc=exc)
|
||||
@@ -59,6 +74,7 @@ class MTLSStageView(ChallengeStageView):
|
||||
def _parse_cert_xfcc(self) -> list[Certificate]:
|
||||
"""Parse certificates in the format given to us in
|
||||
the format of the authentik router/envoy"""
|
||||
# https://www.envoyproxy.io/docs/envoy/latest/configuration/http/http_conn_man/headers#x-forwarded-client-cert
|
||||
xfcc_raw = self.request.headers.get(HEADER_PROXY_FORWARDED)
|
||||
if not xfcc_raw:
|
||||
return []
|
||||
@@ -68,18 +84,26 @@ class MTLSStageView(ChallengeStageView):
|
||||
raw_cert = {k.split("=")[0]: k.split("=")[1] for k in el}
|
||||
if "Cert" not in raw_cert:
|
||||
continue
|
||||
certs.extend(self.__parse_single_cert(raw_cert["Cert"]))
|
||||
certs.extend(self.__parse_single_cert(raw_cert["Cert"], ParseOptions.UNQUOTE))
|
||||
return certs
|
||||
|
||||
def _parse_cert_nginx(self) -> list[Certificate]:
|
||||
"""Parse certificates in the format nginx-ingress gives to us"""
|
||||
# https://kubernetes.github.io/ingress-nginx/user-guide/nginx-configuration/annotations/#client-certificate-authentication
|
||||
# https://github.com/kubernetes/ingress-nginx/blob/78f593b24494a0674b362faf551079f06d71b5a9/rootfs/etc/nginx/template/nginx.tmpl#L1096
|
||||
sslcc_raw = self.request.headers.get(HEADER_NGINX_FORWARDED)
|
||||
return self.__parse_single_cert(sslcc_raw)
|
||||
return self.__parse_single_cert(sslcc_raw, ParseOptions.UNQUOTE)
|
||||
|
||||
def _parse_cert_traefik(self) -> list[Certificate]:
|
||||
"""Parse certificates in the format traefik gives to us"""
|
||||
# https://doc.traefik.io/traefik/reference/routing-configuration/http/middlewares/passtlsclientcert/
|
||||
ftcc_raw = self.request.headers.get(HEADER_TRAEFIK_FORWARDED)
|
||||
return self.__parse_single_cert(ftcc_raw)
|
||||
if not ftcc_raw:
|
||||
return []
|
||||
certs = []
|
||||
for cert in ftcc_raw.split(","):
|
||||
certs.extend(self.__parse_single_cert(cert, ParseOptions.UNQUOTE, ParseOptions.FORMAT))
|
||||
return certs
|
||||
|
||||
def _parse_cert_outpost(self) -> list[Certificate]:
|
||||
"""Parse certificates in the format outposts give to us. Also authenticates
|
||||
@@ -92,7 +116,7 @@ class MTLSStageView(ChallengeStageView):
|
||||
) and not user.has_perm("authentik_stages_mtls.pass_outpost_certificate"):
|
||||
return []
|
||||
outpost_raw = self.request.headers.get(HEADER_OUTPOST_FORWARDED)
|
||||
return self.__parse_single_cert(outpost_raw)
|
||||
return self.__parse_single_cert(outpost_raw, ParseOptions.UNQUOTE)
|
||||
|
||||
def get_authorities(self) -> list[CertificateKeyPair] | None:
|
||||
# We can't access `certificate_authorities` on `self.executor.current_stage`, as that would
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
from ssl import PEM_FOOTER, PEM_HEADER
|
||||
from unittest.mock import MagicMock, patch
|
||||
from urllib.parse import quote_plus
|
||||
|
||||
@@ -51,6 +52,10 @@ class MTLSStageTests(FlowTestCase):
|
||||
User.objects.filter(username="client").delete()
|
||||
self.cert_user = create_test_user(username="client")
|
||||
|
||||
def _format_traefik(self, cert: str | None = None):
|
||||
cert = cert if cert else self.client_cert
|
||||
return quote_plus(cert.replace(PEM_HEADER, "").replace(PEM_FOOTER, "").replace("\n", ""))
|
||||
|
||||
def test_parse_xfcc(self):
|
||||
"""Test authentik Proxy/Envoy's XFCC format"""
|
||||
with self.assertFlowFinishes() as plan:
|
||||
@@ -78,7 +83,7 @@ class MTLSStageTests(FlowTestCase):
|
||||
with self.assertFlowFinishes() as plan:
|
||||
res = self.client.get(
|
||||
reverse("authentik_api:flow-executor", kwargs={"flow_slug": self.flow.slug}),
|
||||
headers={"X-Forwarded-TLS-Client-Cert": quote_plus(self.client_cert)},
|
||||
headers={"X-Forwarded-TLS-Client-Cert": self._format_traefik()},
|
||||
)
|
||||
self.assertEqual(res.status_code, 200)
|
||||
self.assertStageRedirects(res, reverse("authentik_core:root-redirect"))
|
||||
@@ -138,7 +143,9 @@ class MTLSStageTests(FlowTestCase):
|
||||
with self.assertFlowFinishes() as plan:
|
||||
res = self.client.get(
|
||||
reverse("authentik_api:flow-executor", kwargs={"flow_slug": self.flow.slug}),
|
||||
headers={"X-Forwarded-TLS-Client-Cert": quote_plus(cert.certificate_data)},
|
||||
headers={
|
||||
"X-Forwarded-TLS-Client-Cert": self._format_traefik(cert.certificate_data)
|
||||
},
|
||||
)
|
||||
self.assertEqual(res.status_code, 200)
|
||||
self.assertStageResponse(res, self.flow, component="ak-stage-access-denied")
|
||||
@@ -149,7 +156,7 @@ class MTLSStageTests(FlowTestCase):
|
||||
User.objects.filter(username="client").delete()
|
||||
res = self.client.get(
|
||||
reverse("authentik_api:flow-executor", kwargs={"flow_slug": self.flow.slug}),
|
||||
headers={"X-Forwarded-TLS-Client-Cert": quote_plus(self.client_cert)},
|
||||
headers={"X-Forwarded-TLS-Client-Cert": self._format_traefik()},
|
||||
)
|
||||
self.assertEqual(res.status_code, 200)
|
||||
self.assertStageResponse(res, self.flow, component="ak-stage-access-denied")
|
||||
@@ -163,7 +170,7 @@ class MTLSStageTests(FlowTestCase):
|
||||
with self.assertFlowFinishes() as plan:
|
||||
res = self.client.get(
|
||||
reverse("authentik_api:flow-executor", kwargs={"flow_slug": self.flow.slug}),
|
||||
headers={"X-Forwarded-TLS-Client-Cert": quote_plus(self.client_cert)},
|
||||
headers={"X-Forwarded-TLS-Client-Cert": self._format_traefik()},
|
||||
)
|
||||
self.assertEqual(res.status_code, 200)
|
||||
self.assertStageRedirects(res, reverse("authentik_core:root-redirect"))
|
||||
@@ -176,7 +183,7 @@ class MTLSStageTests(FlowTestCase):
|
||||
self.stage.save()
|
||||
res = self.client.get(
|
||||
reverse("authentik_api:flow-executor", kwargs={"flow_slug": self.flow.slug}),
|
||||
headers={"X-Forwarded-TLS-Client-Cert": quote_plus(self.client_cert)},
|
||||
headers={"X-Forwarded-TLS-Client-Cert": self._format_traefik()},
|
||||
)
|
||||
self.assertEqual(res.status_code, 200)
|
||||
self.assertStageRedirects(res, reverse("authentik_core:root-redirect"))
|
||||
@@ -187,7 +194,7 @@ class MTLSStageTests(FlowTestCase):
|
||||
self.stage.save()
|
||||
res = self.client.get(
|
||||
reverse("authentik_api:flow-executor", kwargs={"flow_slug": self.flow.slug}),
|
||||
headers={"X-Forwarded-TLS-Client-Cert": quote_plus(self.client_cert)},
|
||||
headers={"X-Forwarded-TLS-Client-Cert": self._format_traefik()},
|
||||
)
|
||||
self.assertEqual(res.status_code, 200)
|
||||
self.assertStageResponse(res, self.flow, component="ak-stage-access-denied")
|
||||
@@ -209,7 +216,7 @@ class MTLSStageTests(FlowTestCase):
|
||||
with self.assertFlowFinishes() as plan:
|
||||
res = self.client.get(
|
||||
reverse("authentik_api:flow-executor", kwargs={"flow_slug": self.flow.slug}),
|
||||
headers={"X-Forwarded-TLS-Client-Cert": quote_plus(self.client_cert)},
|
||||
headers={"X-Forwarded-TLS-Client-Cert": self._format_traefik()},
|
||||
)
|
||||
self.assertEqual(res.status_code, 200)
|
||||
self.assertStageRedirects(res, reverse("authentik_core:root-redirect"))
|
||||
|
||||
49
authentik/enterprise/tests/test_metrics.py
Normal file
49
authentik/enterprise/tests/test_metrics.py
Normal file
@@ -0,0 +1,49 @@
|
||||
"""Enterprise metrics tests"""
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from django.test import TestCase
|
||||
from prometheus_client import REGISTRY
|
||||
|
||||
from authentik.core.models import User
|
||||
from authentik.core.tests.utils import create_test_user
|
||||
from authentik.enterprise.license import LicenseKey
|
||||
from authentik.enterprise.models import License
|
||||
from authentik.enterprise.tests.test_license import expiry_valid
|
||||
from authentik.lib.generators import generate_id
|
||||
from authentik.root.monitoring import monitoring_set
|
||||
|
||||
|
||||
class TestEnterpriseMetrics(TestCase):
|
||||
"""Enterprise metrics tests"""
|
||||
|
||||
@patch(
|
||||
"authentik.enterprise.license.LicenseKey.validate",
|
||||
MagicMock(
|
||||
return_value=LicenseKey(
|
||||
aud="",
|
||||
exp=expiry_valid,
|
||||
name=generate_id(),
|
||||
internal_users=100,
|
||||
external_users=100,
|
||||
)
|
||||
),
|
||||
)
|
||||
def test_usage_empty(self):
|
||||
"""Test usage (no users)"""
|
||||
License.objects.create(key=generate_id())
|
||||
User.objects.all().delete()
|
||||
create_test_user()
|
||||
monitoring_set.send_robust(self)
|
||||
self.assertEqual(
|
||||
REGISTRY.get_sample_value(
|
||||
"authentik_enterprise_license_usage", {"user_type": "internal"}
|
||||
),
|
||||
1.0,
|
||||
)
|
||||
self.assertEqual(
|
||||
REGISTRY.get_sample_value(
|
||||
"authentik_enterprise_license_usage", {"user_type": "external"}
|
||||
),
|
||||
0,
|
||||
)
|
||||
@@ -1,7 +1,7 @@
|
||||
from collections.abc import Generator
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any
|
||||
|
||||
from django.utils.timezone import now
|
||||
@@ -28,7 +28,7 @@ class LogEvent:
|
||||
def from_event_dict(item: EventDict) -> "LogEvent":
|
||||
event = item.pop("event")
|
||||
log_level = item.pop("level").lower()
|
||||
timestamp = datetime.fromisoformat(item.pop("timestamp"))
|
||||
timestamp = datetime.fromisoformat(item.pop("timestamp")).replace(tzinfo=UTC)
|
||||
item.pop("pid", None)
|
||||
# Sometimes log entries have both `level` and `log_level` set, but `level` is always set
|
||||
item.pop("log_level", None)
|
||||
|
||||
@@ -4,6 +4,7 @@ from prometheus_client import Gauge, Histogram
|
||||
|
||||
from authentik.blueprints.apps import ManagedAppConfig
|
||||
from authentik.lib.utils.reflection import all_subclasses
|
||||
from authentik.tenants.flags import Flag
|
||||
|
||||
GAUGE_FLOWS_CACHED = Gauge(
|
||||
"authentik_flows_cached",
|
||||
@@ -22,6 +23,12 @@ HIST_FLOWS_PLAN_TIME = Histogram(
|
||||
)
|
||||
|
||||
|
||||
class RefreshOtherFlowsAfterAuthentication(Flag[bool], key="flows_refresh_others"):
|
||||
|
||||
default = False
|
||||
visibility = "public"
|
||||
|
||||
|
||||
class AuthentikFlowsConfig(ManagedAppConfig):
|
||||
"""authentik flows app config"""
|
||||
|
||||
|
||||
@@ -145,7 +145,6 @@ worker:
|
||||
consumer_listen_timeout: "seconds=30"
|
||||
task_max_retries: 5
|
||||
task_default_time_limit: "minutes=10"
|
||||
lock_purge_interval: "minutes=1"
|
||||
task_purge_interval: "days=1"
|
||||
task_expiration: "days=30"
|
||||
scheduler_interval: "seconds=60"
|
||||
|
||||
@@ -28,6 +28,8 @@ def register_signals(
|
||||
# This primarily happens during user login
|
||||
if sender == User and update_fields == {"last_login"}:
|
||||
return
|
||||
if not provider_type.objects.exists():
|
||||
return
|
||||
task_sync_direct_dispatch.send(
|
||||
class_to_path(instance.__class__),
|
||||
instance.pk,
|
||||
@@ -39,6 +41,8 @@ def register_signals(
|
||||
|
||||
def model_pre_delete(sender: type[Model], instance: User | Group, **_):
|
||||
"""Pre-delete handler"""
|
||||
if not provider_type.objects.exists():
|
||||
return
|
||||
task_sync_direct_dispatch.send(
|
||||
class_to_path(instance.__class__),
|
||||
instance.pk,
|
||||
@@ -54,6 +58,8 @@ def register_signals(
|
||||
"""Sync group membership"""
|
||||
if action not in ["post_add", "post_remove"]:
|
||||
return
|
||||
if not provider_type.objects.exists():
|
||||
return
|
||||
task_sync_m2m_dispatch.send(instance.pk, action, list(pk_set), reverse)
|
||||
|
||||
m2m_changed.connect(model_m2m_changed, User.ak_groups.through, dispatch_uid=uid, weak=False)
|
||||
|
||||
@@ -203,6 +203,12 @@ class DockerController(BaseController):
|
||||
"labels": self._get_labels(),
|
||||
"restart_policy": {"Name": "unless-stopped"},
|
||||
"network": self.outpost.config.docker_network,
|
||||
"healthcheck": {
|
||||
"test": ["CMD", f"/{self.outpost.type}", "healthcheck"],
|
||||
"interval": 5 * 1_000 * 1_000_000,
|
||||
"retries": 20,
|
||||
"start_period": 3 * 1_000 * 1_000_000,
|
||||
},
|
||||
}
|
||||
if self.outpost.config.docker_map_ports:
|
||||
container_args["ports"] = {
|
||||
|
||||
@@ -49,6 +49,9 @@ def outpost_m2m_changed(sender, instance: Outpost | Provider, action: str, **_):
|
||||
if action not in ["post_add", "post_remove", "post_clear"]:
|
||||
return
|
||||
if isinstance(instance, Outpost):
|
||||
# Rebuild permissions when providers change
|
||||
LOGGER.debug("Rebuilding outpost service account permissions", outpost=instance)
|
||||
instance.build_user_permissions(instance.user)
|
||||
outpost_controller.send_with_options(
|
||||
args=(instance.pk,),
|
||||
rel_obj=instance.service_connection,
|
||||
@@ -92,6 +95,15 @@ def outpost_post_save(sender, instance: Outpost, created: bool, **_):
|
||||
|
||||
def outpost_related_post_save(sender, instance: OutpostServiceConnection | OutpostModel, **_):
|
||||
for outpost in instance.outpost_set.all():
|
||||
# Rebuild permissions in case provider's required objects changed
|
||||
if isinstance(instance, OutpostModel):
|
||||
LOGGER.info(
|
||||
"Provider changed, rebuilding permissions and sending update",
|
||||
outpost=outpost.name,
|
||||
provider=instance.name if hasattr(instance, "name") else str(instance),
|
||||
)
|
||||
outpost.build_user_permissions(outpost.user)
|
||||
LOGGER.debug("Sending update to outpost", outpost=outpost.name, trigger="provider_change")
|
||||
outpost_send_update.send_with_options(
|
||||
args=(outpost.pk,),
|
||||
rel_obj=outpost,
|
||||
|
||||
@@ -109,7 +109,7 @@ def user_session_deleted_oauth_backchannel_logout_and_tokens_removal(
|
||||
"""Revoke tokens upon user logout"""
|
||||
LOGGER.debug("Sending back-channel logout notifications signal!", session=instance)
|
||||
|
||||
access_tokens = AccessToken.objects.filter(
|
||||
access_tokens = AccessToken.objects.select_related("provider").filter(
|
||||
user=instance.user,
|
||||
session__session__session_key=instance.session.session_key,
|
||||
)
|
||||
@@ -128,7 +128,8 @@ def user_session_deleted_oauth_backchannel_logout_and_tokens_removal(
|
||||
and token.provider.logout_method == OAuth2LogoutMethod.BACKCHANNEL
|
||||
]
|
||||
|
||||
backchannel_logout_notification_dispatch.send(revocations=backchannel_tokens)
|
||||
if backchannel_tokens:
|
||||
backchannel_logout_notification_dispatch.send(revocations=backchannel_tokens)
|
||||
|
||||
access_tokens.delete()
|
||||
|
||||
|
||||
@@ -126,6 +126,30 @@ class TestTokenClientCredentialsUserNamePassword(OAuthTestCase):
|
||||
},
|
||||
)
|
||||
|
||||
def test_deactivate(self):
|
||||
"""test deactivated user"""
|
||||
self.user.is_active = False
|
||||
self.user.save()
|
||||
response = self.client.post(
|
||||
reverse("authentik_providers_oauth2:token"),
|
||||
{
|
||||
"grant_type": GRANT_TYPE_CLIENT_CREDENTIALS,
|
||||
"scope": SCOPE_OPENID,
|
||||
"client_id": self.provider.client_id,
|
||||
"username": "sa",
|
||||
"password": self.token.key,
|
||||
},
|
||||
)
|
||||
self.assertEqual(response.status_code, 400)
|
||||
self.assertJSONEqual(
|
||||
response.content.decode(),
|
||||
{
|
||||
"error": "invalid_grant",
|
||||
"error_description": TokenError.errors["invalid_grant"],
|
||||
"request_id": response.headers["X-authentik-id"],
|
||||
},
|
||||
)
|
||||
|
||||
def test_permission_denied(self):
|
||||
"""test permission denied"""
|
||||
group = Group.objects.create(name="foo")
|
||||
|
||||
@@ -336,7 +336,7 @@ class TokenParams:
|
||||
self, request: HttpRequest, username: str, password: str
|
||||
):
|
||||
# Authenticate user based on credentials
|
||||
user = User.objects.filter(username=username).first()
|
||||
user = User.objects.filter(username=username, is_active=True).first()
|
||||
if not user:
|
||||
raise TokenError("invalid_grant")
|
||||
token: Token = Token.filter_not_expired(
|
||||
@@ -378,9 +378,11 @@ class TokenParams:
|
||||
except (PyJWTError, ValueError, TypeError, AttributeError) as exc:
|
||||
LOGGER.warning("failed to parse JWT for kid lookup", exc=exc)
|
||||
raise TokenError("invalid_grant") from None
|
||||
expected_kid = decode_unvalidated["header"]["kid"]
|
||||
fallback_alg = decode_unvalidated["header"]["alg"]
|
||||
expected_kid = decode_unvalidated["header"].get("kid")
|
||||
fallback_alg = decode_unvalidated["header"].get("alg")
|
||||
token = source = None
|
||||
if not expected_kid or not fallback_alg:
|
||||
return None, None
|
||||
for source in self.provider.jwt_federation_sources.filter(
|
||||
oidc_jwks__keys__contains=[{"kid": expected_kid}]
|
||||
):
|
||||
|
||||
@@ -9,10 +9,9 @@ from defusedxml.lxml import fromstring
|
||||
from lxml import etree # nosec
|
||||
from structlog.stdlib import get_logger
|
||||
|
||||
from authentik.crypto.models import CertificateKeyPair
|
||||
from authentik.crypto.models import CertificateKeyPair, format_cert
|
||||
from authentik.flows.models import Flow
|
||||
from authentik.providers.saml.models import SAMLBindings, SAMLPropertyMapping, SAMLProvider
|
||||
from authentik.providers.saml.utils.encoding import PEM_FOOTER, PEM_HEADER
|
||||
from authentik.sources.saml.models import SAMLNameIDPolicy
|
||||
from authentik.sources.saml.processors.constants import (
|
||||
NS_MAP,
|
||||
@@ -24,18 +23,6 @@ from authentik.sources.saml.processors.constants import (
|
||||
LOGGER = get_logger()
|
||||
|
||||
|
||||
def format_pem_certificate(unformatted_cert: str) -> str:
|
||||
"""Format single, inline certificate into PEM Format"""
|
||||
# Ensure that all linebreaks are gone
|
||||
unformatted_cert = unformatted_cert.replace("\n", "")
|
||||
chunks, chunk_size = len(unformatted_cert), 64
|
||||
lines = [PEM_HEADER]
|
||||
for i in range(0, chunks, chunk_size):
|
||||
lines.append(unformatted_cert[i : i + chunk_size])
|
||||
lines.append(PEM_FOOTER)
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class ServiceProviderMetadata:
|
||||
"""SP Metadata Dataclass"""
|
||||
@@ -87,7 +74,7 @@ class ServiceProviderMetadataParser:
|
||||
)
|
||||
if len(signing_certs) < 1:
|
||||
return None
|
||||
raw_cert = format_pem_certificate(signing_certs[0])
|
||||
raw_cert = format_cert(signing_certs[0])
|
||||
# sanity check, make sure the certificate is valid.
|
||||
load_pem_x509_certificate(raw_cert.encode("utf-8"), default_backend())
|
||||
return CertificateKeyPair(
|
||||
|
||||
@@ -2,9 +2,7 @@
|
||||
|
||||
import base64
|
||||
import zlib
|
||||
|
||||
PEM_HEADER = "-----BEGIN CERTIFICATE-----"
|
||||
PEM_FOOTER = "-----END CERTIFICATE-----"
|
||||
from ssl import PEM_FOOTER, PEM_HEADER
|
||||
|
||||
|
||||
def decode_base64_and_inflate(encoded: str, encoding="utf-8") -> str:
|
||||
|
||||
@@ -1,12 +1,15 @@
|
||||
"""Group client"""
|
||||
|
||||
from itertools import batched
|
||||
from typing import Any
|
||||
|
||||
from django.db import transaction
|
||||
from orjson import dumps
|
||||
from pydantic import ValidationError
|
||||
from pydanticscim.group import GroupMember
|
||||
|
||||
from authentik.core.models import Group
|
||||
from authentik.lib.merge import MERGE_LIST_UNIQUE
|
||||
from authentik.lib.sync.mapper import PropertyMappingManager
|
||||
from authentik.lib.sync.outgoing.base import Direction
|
||||
from authentik.lib.sync.outgoing.exceptions import (
|
||||
@@ -113,10 +116,23 @@ class SCIMGroupClient(SCIMClient[Group, SCIMProviderGroup, SCIMGroupSchema]):
|
||||
self._patch_add_users(connection, users)
|
||||
return connection
|
||||
|
||||
def diff(self, local_created: dict[str, Any], connection: SCIMProviderUser):
|
||||
"""Check if a group is different than what we last wrote to the remote system.
|
||||
Returns true if there is a difference in data."""
|
||||
local_known = connection.attributes
|
||||
local_updated = {}
|
||||
MERGE_LIST_UNIQUE.merge(local_updated, local_known)
|
||||
MERGE_LIST_UNIQUE.merge(local_updated, local_created)
|
||||
return dumps(local_updated) != dumps(local_known)
|
||||
|
||||
def update(self, group: Group, connection: SCIMProviderGroup):
|
||||
"""Update existing group"""
|
||||
scim_group = self.to_schema(group, connection)
|
||||
scim_group.id = connection.scim_id
|
||||
payload = scim_group.model_dump(mode="json", exclude_unset=True)
|
||||
if not self.diff(payload, connection):
|
||||
self.logger.debug("Skipping group write as data has not changed")
|
||||
return self.patch_compare_users(group)
|
||||
try:
|
||||
if self._config.patch.supported:
|
||||
return self._update_patch(group, scim_group, connection)
|
||||
|
||||
@@ -83,7 +83,7 @@ class EnterpriseUser(BaseModel):
|
||||
class User(BaseUser):
|
||||
"""Modified User schema with added externalId field"""
|
||||
|
||||
model_config = ConfigDict(serialize_by_alias=True)
|
||||
model_config = ConfigDict(serialize_by_alias=True, extra="allow")
|
||||
|
||||
id: str | int | None = None
|
||||
schemas: list[str] = [SCIM_USER_SCHEMA]
|
||||
@@ -106,6 +106,8 @@ class User(BaseUser):
|
||||
class Group(BaseGroup):
|
||||
"""Modified Group schema with added externalId field"""
|
||||
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
id: str | int | None = None
|
||||
schemas: list[str] = [SCIM_GROUP_SCHEMA]
|
||||
externalId: str | None = None
|
||||
|
||||
@@ -1,10 +1,14 @@
|
||||
"""User client"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from django.db import transaction
|
||||
from django.utils.http import urlencode
|
||||
from orjson import dumps
|
||||
from pydantic import ValidationError
|
||||
|
||||
from authentik.core.models import User
|
||||
from authentik.lib.merge import MERGE_LIST_UNIQUE
|
||||
from authentik.lib.sync.mapper import PropertyMappingManager
|
||||
from authentik.lib.sync.outgoing.exceptions import ObjectExistsSyncException, StopSync
|
||||
from authentik.policies.utils import delete_none_values
|
||||
@@ -92,17 +96,30 @@ class SCIMUserClient(SCIMClient[User, SCIMProviderUser, SCIMUserSchema]):
|
||||
provider=self.provider, user=user, scim_id=scim_id, attributes=response
|
||||
)
|
||||
|
||||
def diff(self, local_created: dict[str, Any], connection: SCIMProviderUser):
|
||||
"""Check if a user is different than what we last wrote to the remote system.
|
||||
Returns true if there is a difference in data."""
|
||||
local_known = connection.attributes
|
||||
local_updated = {}
|
||||
MERGE_LIST_UNIQUE.merge(local_updated, local_known)
|
||||
MERGE_LIST_UNIQUE.merge(local_updated, local_created)
|
||||
return dumps(local_updated) != dumps(local_known)
|
||||
|
||||
def update(self, user: User, connection: SCIMProviderUser):
|
||||
"""Update existing user"""
|
||||
scim_user = self.to_schema(user, connection)
|
||||
scim_user.id = connection.scim_id
|
||||
payload = scim_user.model_dump(
|
||||
mode="json",
|
||||
exclude_unset=True,
|
||||
)
|
||||
if not self.diff(payload, connection):
|
||||
self.logger.debug("Skipping user write as data has not changed")
|
||||
return
|
||||
response = self._request(
|
||||
"PUT",
|
||||
f"/Users/{connection.scim_id}",
|
||||
json=scim_user.model_dump(
|
||||
mode="json",
|
||||
exclude_unset=True,
|
||||
),
|
||||
json=payload,
|
||||
)
|
||||
connection.attributes = response
|
||||
connection.save()
|
||||
|
||||
@@ -9,7 +9,7 @@ from requests_mock import Mocker
|
||||
from authentik.blueprints.tests import apply_blueprint
|
||||
from authentik.core.models import Application, Group, User
|
||||
from authentik.lib.generators import generate_id
|
||||
from authentik.providers.scim.models import SCIMMapping, SCIMProvider
|
||||
from authentik.providers.scim.models import SCIMMapping, SCIMProvider, SCIMProviderGroup
|
||||
|
||||
|
||||
class SCIMGroupTests(TestCase):
|
||||
@@ -106,6 +106,7 @@ class SCIMGroupTests(TestCase):
|
||||
"displayName": group.name,
|
||||
},
|
||||
)
|
||||
group.name = generate_id()
|
||||
group.save()
|
||||
self.assertEqual(mock.call_count, 4)
|
||||
self.assertEqual(mock.request_history[0].method, "GET")
|
||||
@@ -148,3 +149,56 @@ class SCIMGroupTests(TestCase):
|
||||
self.assertEqual(mock.request_history[0].method, "GET")
|
||||
self.assertEqual(mock.request_history[3].method, "DELETE")
|
||||
self.assertEqual(mock.request_history[3].url, f"https://localhost/Groups/{scim_id}")
|
||||
|
||||
@Mocker()
|
||||
def test_group_create_update_noop(self, mock: Mocker):
|
||||
"""Test group creation and update"""
|
||||
scim_id = generate_id()
|
||||
mock.get(
|
||||
"https://localhost/ServiceProviderConfig",
|
||||
json={},
|
||||
)
|
||||
mock.post(
|
||||
"https://localhost/Groups",
|
||||
json={
|
||||
"id": scim_id,
|
||||
},
|
||||
)
|
||||
mock.put(
|
||||
"https://localhost/Groups",
|
||||
json={
|
||||
"id": scim_id,
|
||||
},
|
||||
)
|
||||
uid = generate_id()
|
||||
group = Group.objects.create(
|
||||
name=uid,
|
||||
)
|
||||
self.assertEqual(mock.call_count, 2)
|
||||
self.assertEqual(mock.request_history[0].method, "GET")
|
||||
self.assertEqual(mock.request_history[1].method, "POST")
|
||||
body = loads(mock.request_history[1].body)
|
||||
with open("schemas/scim-group.schema.json", encoding="utf-8") as schema:
|
||||
validate(body, loads(schema.read()))
|
||||
self.assertEqual(
|
||||
body,
|
||||
{
|
||||
"schemas": ["urn:ietf:params:scim:schemas:core:2.0:Group"],
|
||||
"externalId": str(group.pk),
|
||||
"displayName": group.name,
|
||||
},
|
||||
)
|
||||
conn = SCIMProviderGroup.objects.filter(group=group).first()
|
||||
conn.attributes = {
|
||||
"id": scim_id,
|
||||
"schemas": ["urn:ietf:params:scim:schemas:core:2.0:Group"],
|
||||
"externalId": str(group.pk),
|
||||
"displayName": group.name,
|
||||
}
|
||||
conn.save()
|
||||
group.save()
|
||||
self.assertEqual(mock.call_count, 4)
|
||||
self.assertEqual(mock.request_history[0].method, "GET")
|
||||
self.assertEqual(mock.request_history[1].method, "POST")
|
||||
self.assertEqual(mock.request_history[2].method, "GET")
|
||||
self.assertEqual(mock.request_history[2].method, "GET")
|
||||
|
||||
@@ -10,7 +10,7 @@ from authentik.blueprints.tests import apply_blueprint
|
||||
from authentik.core.models import Application, Group, User
|
||||
from authentik.lib.generators import generate_id
|
||||
from authentik.lib.sync.outgoing.base import SAFE_METHODS
|
||||
from authentik.providers.scim.models import SCIMMapping, SCIMProvider
|
||||
from authentik.providers.scim.models import SCIMMapping, SCIMProvider, SCIMProviderUser
|
||||
from authentik.providers.scim.tasks import scim_sync, scim_sync_objects
|
||||
from authentik.tasks.models import Task
|
||||
from authentik.tenants.models import Tenant
|
||||
@@ -95,7 +95,12 @@ class SCIMUserTests(TestCase):
|
||||
"""Test user creation with custom schema"""
|
||||
schema = SCIMMapping.objects.create(
|
||||
name="custom_schema",
|
||||
expression="""return {"schemas": ["foo"]}""",
|
||||
expression="""return {
|
||||
"schemas": ["urn:ietf:params:scim:schemas:extension:slack:profile:2.0:User"],
|
||||
"urn:ietf:params:scim:schemas:extension:slack:profile:2.0:User": {
|
||||
"startDate": "2024-04-10T00:00:00+0000",
|
||||
},
|
||||
}""",
|
||||
)
|
||||
self.provider.property_mappings.add(schema)
|
||||
scim_id = generate_id()
|
||||
@@ -121,7 +126,10 @@ class SCIMUserTests(TestCase):
|
||||
self.assertJSONEqual(
|
||||
mock.request_history[1].body,
|
||||
{
|
||||
"schemas": ["urn:ietf:params:scim:schemas:core:2.0:User", "foo"],
|
||||
"schemas": [
|
||||
"urn:ietf:params:scim:schemas:core:2.0:User",
|
||||
"urn:ietf:params:scim:schemas:extension:slack:profile:2.0:User",
|
||||
],
|
||||
"active": True,
|
||||
"emails": [
|
||||
{
|
||||
@@ -138,6 +146,9 @@ class SCIMUserTests(TestCase):
|
||||
},
|
||||
"displayName": f"{uid} {uid}",
|
||||
"userName": uid,
|
||||
"urn:ietf:params:scim:schemas:extension:slack:profile:2.0:User": {
|
||||
"startDate": "2024-04-10T00:00:00+0000",
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
@@ -262,6 +273,8 @@ class SCIMUserTests(TestCase):
|
||||
"userName": uid,
|
||||
},
|
||||
)
|
||||
# Update user
|
||||
user.name = "foo bar"
|
||||
user.save()
|
||||
self.assertEqual(mock.call_count, 4)
|
||||
self.assertEqual(mock.request_history[0].method, "GET")
|
||||
@@ -444,3 +457,85 @@ class SCIMUserTests(TestCase):
|
||||
self.assertIsNotNone(log.attributes["url"])
|
||||
self.assertIsNotNone(log.attributes["body"])
|
||||
self.assertIsNotNone(log.attributes["method"])
|
||||
|
||||
@Mocker()
|
||||
def test_user_create_update_noop(self, mock: Mocker):
|
||||
"""Test user creation and update"""
|
||||
scim_id = generate_id()
|
||||
mock: Mocker
|
||||
mock.get(
|
||||
"https://localhost/ServiceProviderConfig",
|
||||
json={},
|
||||
)
|
||||
mock.post(
|
||||
"https://localhost/Users",
|
||||
json={
|
||||
"id": scim_id,
|
||||
},
|
||||
)
|
||||
mock.put(
|
||||
"https://localhost/Users",
|
||||
json={
|
||||
"id": scim_id,
|
||||
},
|
||||
)
|
||||
uid = generate_id()
|
||||
user = User.objects.create(
|
||||
username=uid,
|
||||
name=f"{uid} {uid}",
|
||||
email=f"{uid}@goauthentik.io",
|
||||
)
|
||||
self.assertEqual(mock.call_count, 2)
|
||||
self.assertEqual(mock.request_history[0].method, "GET")
|
||||
self.assertEqual(mock.request_history[1].method, "POST")
|
||||
body = loads(mock.request_history[1].body)
|
||||
self.assertEqual(
|
||||
body,
|
||||
{
|
||||
"schemas": ["urn:ietf:params:scim:schemas:core:2.0:User"],
|
||||
"active": True,
|
||||
"emails": [
|
||||
{
|
||||
"primary": True,
|
||||
"type": "other",
|
||||
"value": f"{uid}@goauthentik.io",
|
||||
}
|
||||
],
|
||||
"displayName": f"{uid} {uid}",
|
||||
"externalId": user.uid,
|
||||
"name": {
|
||||
"familyName": uid,
|
||||
"formatted": f"{uid} {uid}",
|
||||
"givenName": uid,
|
||||
},
|
||||
"userName": uid,
|
||||
},
|
||||
)
|
||||
conn = SCIMProviderUser.objects.filter(user=user).first()
|
||||
conn.attributes = {
|
||||
"schemas": ["urn:ietf:params:scim:schemas:core:2.0:User"],
|
||||
"active": True,
|
||||
"emails": [
|
||||
{
|
||||
"primary": True,
|
||||
"type": "other",
|
||||
"value": f"{uid}@goauthentik.io",
|
||||
}
|
||||
],
|
||||
"displayName": f"{uid} {uid}",
|
||||
"externalId": user.uid,
|
||||
"name": {
|
||||
"familyName": uid,
|
||||
"formatted": f"{uid} {uid}",
|
||||
"givenName": uid,
|
||||
},
|
||||
"userName": uid,
|
||||
"id": scim_id,
|
||||
}
|
||||
conn.save()
|
||||
user.save()
|
||||
self.assertEqual(mock.call_count, 3)
|
||||
self.assertEqual(mock.request_history[0].method, "GET")
|
||||
self.assertEqual(mock.request_history[1].method, "POST")
|
||||
self.assertEqual(mock.request_history[2].method, "GET")
|
||||
# No PUT request
|
||||
|
||||
@@ -1,27 +0,0 @@
|
||||
"""websocket Message consumer"""
|
||||
|
||||
from channels.generic.websocket import JsonWebsocketConsumer
|
||||
from django.core.cache import cache
|
||||
|
||||
from authentik.root.messages.storage import CACHE_PREFIX
|
||||
|
||||
|
||||
class MessageConsumer(JsonWebsocketConsumer):
|
||||
"""Consumer which sends django.contrib.messages Messages over WS.
|
||||
channel_name is saved into cache with user_id, and when a add_message is called"""
|
||||
|
||||
session_key: str
|
||||
|
||||
def connect(self):
|
||||
self.accept()
|
||||
self.session_key = self.scope["session"].session_key
|
||||
if not self.session_key:
|
||||
return
|
||||
cache.set(f"{CACHE_PREFIX}{self.session_key}_messages_{self.channel_name}", True, None)
|
||||
|
||||
def disconnect(self, code):
|
||||
cache.delete(f"{CACHE_PREFIX}{self.session_key}_messages_{self.channel_name}")
|
||||
|
||||
def event_update(self, event: dict):
|
||||
"""Event handler which is called by Messages Storage backend"""
|
||||
self.send_json(event)
|
||||
@@ -6,6 +6,7 @@ from hashlib import sha512
|
||||
from pathlib import Path
|
||||
|
||||
import orjson
|
||||
from django.http import response as http_response
|
||||
from sentry_sdk import set_tag
|
||||
from xmlsec import enable_debug_trace
|
||||
|
||||
@@ -248,7 +249,7 @@ SESSION_COOKIE_AGE = timedelta_from_string(
|
||||
).total_seconds()
|
||||
SESSION_EXPIRE_AT_BROWSER_CLOSE = True
|
||||
|
||||
MESSAGE_STORAGE = "authentik.root.messages.storage.ChannelsStorage"
|
||||
MESSAGE_STORAGE = "authentik.root.ws.storage.ChannelsStorage"
|
||||
|
||||
MIDDLEWARE_FIRST = [
|
||||
"django_prometheus.middleware.PrometheusBeforeMiddleware",
|
||||
@@ -379,9 +380,6 @@ DRAMATIQ = {
|
||||
"broker_class": "authentik.tasks.broker.Broker",
|
||||
"channel_prefix": "authentik",
|
||||
"task_model": "authentik.tasks.models.Task",
|
||||
"lock_purge_interval": timedelta_from_string(
|
||||
CONFIG.get("worker.lock_purge_interval")
|
||||
).total_seconds(),
|
||||
"task_purge_interval": timedelta_from_string(
|
||||
CONFIG.get("worker.task_purge_interval")
|
||||
).total_seconds(),
|
||||
@@ -429,6 +427,7 @@ DRAMATIQ = {
|
||||
},
|
||||
),
|
||||
("dramatiq.results.middleware.Results", {"store_results": True}),
|
||||
("authentik.tasks.middleware.StartupSignalsMiddleware", {}),
|
||||
("authentik.tasks.middleware.CurrentTask", {}),
|
||||
("authentik.tasks.middleware.TenantMiddleware", {}),
|
||||
("authentik.tasks.middleware.ModelDataMiddleware", {}),
|
||||
@@ -471,6 +470,12 @@ STORAGES = {
|
||||
},
|
||||
}
|
||||
|
||||
# Django 5.2.8 and CVE-2025-64458 added a strong enforcement of 2048 characters
|
||||
# as the maximum for a URL to redirect to, mostly for running on windows.
|
||||
# However our URLs can easily exceed that with OAuth/SAML Query parameters or hash values
|
||||
# 8192 should cover most cases..
|
||||
http_response.MAX_URL_LENGTH = http_response.MAX_URL_LENGTH * 4
|
||||
|
||||
|
||||
# Media files
|
||||
if CONFIG.get("storage.media.backend", "file") == "s3":
|
||||
|
||||
57
authentik/root/ws/consumer.py
Normal file
57
authentik/root/ws/consumer.py
Normal file
@@ -0,0 +1,57 @@
|
||||
"""websocket Message consumer"""
|
||||
|
||||
from hashlib import sha256
|
||||
|
||||
from asgiref.sync import async_to_sync
|
||||
from channels.generic.websocket import JsonWebsocketConsumer
|
||||
from django.core.cache import cache
|
||||
from django.db import connection
|
||||
|
||||
from authentik.root.ws.storage import CACHE_PREFIX
|
||||
|
||||
|
||||
def build_session_group(session_key: str):
|
||||
return sha256(
|
||||
f"{connection.schema_name}/group_client_session_{str(session_key)}".encode()
|
||||
).hexdigest()
|
||||
|
||||
|
||||
def build_device_group(session_key: str):
|
||||
return sha256(
|
||||
f"{connection.schema_name}/group_client_device_{str(session_key)}".encode()
|
||||
).hexdigest()
|
||||
|
||||
|
||||
class MessageConsumer(JsonWebsocketConsumer):
|
||||
"""Consumer which sends django.contrib.messages Messages over WS.
|
||||
channel_name is saved into cache with user_id, and when a add_message is called"""
|
||||
|
||||
session_key: str
|
||||
device_cookie: str | None = None
|
||||
|
||||
def connect(self):
|
||||
self.accept()
|
||||
self.session_key = self.scope["session"].session_key
|
||||
if self.session_key:
|
||||
cache.set(f"{CACHE_PREFIX}{self.session_key}_messages_{self.channel_name}", True, None)
|
||||
if device_cookie := self.scope["cookies"].get("authentik_device", None):
|
||||
self.device_cookie = device_cookie
|
||||
async_to_sync(self.channel_layer.group_add)(
|
||||
build_device_group(self.device_cookie), self.channel_name
|
||||
)
|
||||
|
||||
def disconnect(self, code):
|
||||
if self.session_key:
|
||||
cache.delete(f"{CACHE_PREFIX}{self.session_key}_messages_{self.channel_name}")
|
||||
if self.device_cookie:
|
||||
async_to_sync(self.channel_layer.group_discard)(
|
||||
build_device_group(self.device_cookie), self.channel_name
|
||||
)
|
||||
|
||||
def event_message(self, event: dict):
|
||||
"""Event handler which is called by Messages Storage backend"""
|
||||
self.send_json(event)
|
||||
|
||||
def event_session_authenticated(self, event: dict):
|
||||
"""Event handler post user authentication"""
|
||||
self.send_json({"message_type": "session.authenticated"})
|
||||
@@ -31,7 +31,7 @@ class ChannelsStorage(SessionStorage):
|
||||
async_to_sync(self.channel.send)(
|
||||
uid,
|
||||
{
|
||||
"type": "event.update",
|
||||
"type": "event.message",
|
||||
"message_type": "message",
|
||||
"level": message.level_tag,
|
||||
"tags": message.tags,
|
||||
@@ -298,6 +298,16 @@ class LDAPSource(ScheduledModel, Source):
|
||||
side_effect=pglock.Return,
|
||||
)
|
||||
|
||||
def get_ldap_server_info(self, srv: Server) -> dict[str, str]:
|
||||
info = {
|
||||
"vendor": _("N/A"),
|
||||
"version": _("N/A"),
|
||||
}
|
||||
if srv.info:
|
||||
info["vendor"] = str(flatten(srv.info.vendor_name))
|
||||
info["version"] = str(flatten(srv.info.vendor_version))
|
||||
return info
|
||||
|
||||
def check_connection(self) -> dict[str, dict[str, str]]:
|
||||
"""Check LDAP Connection"""
|
||||
servers = self.server()
|
||||
@@ -308,9 +318,8 @@ class LDAPSource(ScheduledModel, Source):
|
||||
try:
|
||||
conn = self.connection(server=server)
|
||||
server_info[server.host] = {
|
||||
"vendor": str(flatten(conn.server.info.vendor_name)),
|
||||
"version": str(flatten(conn.server.info.vendor_version)),
|
||||
"status": "ok",
|
||||
**self.get_ldap_server_info(conn.server),
|
||||
}
|
||||
except LDAPException as exc:
|
||||
server_info[server.host] = {
|
||||
@@ -320,9 +329,8 @@ class LDAPSource(ScheduledModel, Source):
|
||||
try:
|
||||
conn = self.connection()
|
||||
server_info["__all__"] = {
|
||||
"vendor": str(flatten(conn.server.info.vendor_name)),
|
||||
"version": str(flatten(conn.server.info.vendor_version)),
|
||||
"status": "ok",
|
||||
**self.get_ldap_server_info(conn.server),
|
||||
}
|
||||
except LDAPException as exc:
|
||||
server_info["__all__"] = {
|
||||
|
||||
@@ -143,7 +143,7 @@ class OAuth2Client(BaseOAuthClient):
|
||||
if self.source.source_type.urls_customizable and self.source.pkce:
|
||||
pkce_mode = self.source.pkce
|
||||
if pkce_mode != PKCEMethod.NONE:
|
||||
verifier = generate_id()
|
||||
verifier = generate_id(length=128)
|
||||
self.request.session[SESSION_KEY_OAUTH_PKCE] = verifier
|
||||
# https://datatracker.ietf.org/doc/html/rfc7636#section-4.2
|
||||
if pkce_mode == PKCEMethod.PLAIN:
|
||||
|
||||
@@ -205,6 +205,7 @@ class TestOAuthSource(APITestCase):
|
||||
session = self.client.session
|
||||
state = session[f"oauth-client-{self.source.name}-request-state"]
|
||||
verifier = session[SESSION_KEY_OAUTH_PKCE]
|
||||
self.assertEqual(len(verifier), 128)
|
||||
challenge = pkce_s256_challenge(verifier)
|
||||
|
||||
self.assertEqual(qs["redirect_uri"], ["http://testserver/source/oauth/callback/test/"])
|
||||
|
||||
@@ -11,10 +11,10 @@ from authentik.stages.invitation.models import Invitation, InvitationStage
|
||||
from authentik.stages.invitation.signals import invitation_used
|
||||
from authentik.stages.prompt.stage import PLAN_CONTEXT_PROMPT
|
||||
|
||||
INVITATION_TOKEN_KEY_CONTEXT = "token" # nosec
|
||||
INVITATION_TOKEN_KEY = "itoken" # nosec
|
||||
INVITATION_IN_EFFECT = "invitation_in_effect"
|
||||
INVITATION = "invitation"
|
||||
QS_INVITATION_TOKEN_KEY = "itoken" # nosec
|
||||
PLAN_CONTEXT_INVITATION_TOKEN = "token" # nosec
|
||||
PLAN_CONTEXT_INVITATION_IN_EFFECT = "invitation_in_effect"
|
||||
PLAN_CONTEXT_INVITATION = "invitation"
|
||||
|
||||
|
||||
class InvitationStageView(StageView):
|
||||
@@ -23,13 +23,13 @@ class InvitationStageView(StageView):
|
||||
def get_token(self) -> str | None:
|
||||
"""Get token from saved get-arguments or prompt_data"""
|
||||
# Check for ?token= and ?itoken=
|
||||
if INVITATION_TOKEN_KEY in self.request.session.get(SESSION_KEY_GET, {}):
|
||||
return self.request.session[SESSION_KEY_GET][INVITATION_TOKEN_KEY]
|
||||
if INVITATION_TOKEN_KEY_CONTEXT in self.request.session.get(SESSION_KEY_GET, {}):
|
||||
return self.request.session[SESSION_KEY_GET][INVITATION_TOKEN_KEY_CONTEXT]
|
||||
if QS_INVITATION_TOKEN_KEY in self.request.session.get(SESSION_KEY_GET, {}):
|
||||
return self.request.session[SESSION_KEY_GET][QS_INVITATION_TOKEN_KEY]
|
||||
if PLAN_CONTEXT_INVITATION_TOKEN in self.request.session.get(SESSION_KEY_GET, {}):
|
||||
return self.request.session[SESSION_KEY_GET][PLAN_CONTEXT_INVITATION_TOKEN]
|
||||
# Check for {'token': ''} in the context
|
||||
if INVITATION_TOKEN_KEY_CONTEXT in self.executor.plan.context.get(PLAN_CONTEXT_PROMPT, {}):
|
||||
return self.executor.plan.context[PLAN_CONTEXT_PROMPT][INVITATION_TOKEN_KEY_CONTEXT]
|
||||
if PLAN_CONTEXT_INVITATION_TOKEN in self.executor.plan.context.get(PLAN_CONTEXT_PROMPT, {}):
|
||||
return self.executor.plan.context[PLAN_CONTEXT_PROMPT][PLAN_CONTEXT_INVITATION_TOKEN]
|
||||
return None
|
||||
|
||||
def get_invite(self) -> Invitation | None:
|
||||
@@ -38,7 +38,7 @@ class InvitationStageView(StageView):
|
||||
if not token:
|
||||
return None
|
||||
try:
|
||||
invite: Invitation = Invitation.objects.filter(pk=token).first()
|
||||
invite: Invitation | None = Invitation.filter_not_expired(pk=token).first()
|
||||
except ValidationError:
|
||||
self.logger.debug("invalid invitation", token=token)
|
||||
return None
|
||||
@@ -60,8 +60,8 @@ class InvitationStageView(StageView):
|
||||
return self.executor.stage_ok()
|
||||
return self.executor.stage_invalid(_("Invalid invite/invite not found"))
|
||||
|
||||
self.executor.plan.context[INVITATION_IN_EFFECT] = True
|
||||
self.executor.plan.context[INVITATION] = invite
|
||||
self.executor.plan.context[PLAN_CONTEXT_INVITATION_IN_EFFECT] = True
|
||||
self.executor.plan.context[PLAN_CONTEXT_INVITATION] = invite
|
||||
|
||||
context = {}
|
||||
always_merger.merge(context, self.executor.plan.context.get(PLAN_CONTEXT_PROMPT, {}))
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
"""invitation tests"""
|
||||
|
||||
from datetime import timedelta
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from django.urls import reverse
|
||||
from django.utils.http import urlencode
|
||||
from django.utils.timezone import now
|
||||
from guardian.shortcuts import get_anonymous_user
|
||||
from rest_framework.test import APITestCase
|
||||
|
||||
@@ -16,9 +18,9 @@ from authentik.flows.tests.test_executor import TO_STAGE_RESPONSE_MOCK
|
||||
from authentik.flows.views.executor import SESSION_KEY_PLAN
|
||||
from authentik.stages.invitation.models import Invitation, InvitationStage
|
||||
from authentik.stages.invitation.stage import (
|
||||
INVITATION_TOKEN_KEY,
|
||||
INVITATION_TOKEN_KEY_CONTEXT,
|
||||
PLAN_CONTEXT_INVITATION_TOKEN,
|
||||
PLAN_CONTEXT_PROMPT,
|
||||
QS_INVITATION_TOKEN_KEY,
|
||||
)
|
||||
from authentik.stages.password import BACKEND_INBUILT
|
||||
from authentik.stages.password.stage import PLAN_CONTEXT_AUTHENTICATION_BACKEND
|
||||
@@ -77,6 +79,31 @@ class TestInvitationStage(FlowTestCase):
|
||||
self.stage.continue_flow_without_invitation = False
|
||||
self.stage.save()
|
||||
|
||||
def test_with_invitation_expired(self):
|
||||
"""Test with invitation, expired"""
|
||||
plan = FlowPlan(flow_pk=self.flow.pk.hex, bindings=[self.binding], markers=[StageMarker()])
|
||||
session = self.client.session
|
||||
session[SESSION_KEY_PLAN] = plan
|
||||
session.save()
|
||||
|
||||
data = {"foo": "bar"}
|
||||
invite = Invitation.objects.create(
|
||||
created_by=get_anonymous_user(),
|
||||
fixed_data=data,
|
||||
expires=now() - timedelta(hours=1),
|
||||
)
|
||||
|
||||
base_url = reverse("authentik_api:flow-executor", kwargs={"flow_slug": self.flow.slug})
|
||||
args = urlencode({QS_INVITATION_TOKEN_KEY: invite.pk.hex})
|
||||
response = self.client.get(base_url + f"?query={args}")
|
||||
|
||||
self.assertEqual(response.status_code, 200)
|
||||
self.assertStageResponse(
|
||||
response,
|
||||
flow=self.flow,
|
||||
component="ak-stage-access-denied",
|
||||
)
|
||||
|
||||
def test_with_invitation_get(self):
|
||||
"""Test with invitation, check data in session"""
|
||||
plan = FlowPlan(flow_pk=self.flow.pk.hex, bindings=[self.binding], markers=[StageMarker()])
|
||||
@@ -89,7 +116,7 @@ class TestInvitationStage(FlowTestCase):
|
||||
|
||||
with patch("authentik.flows.views.executor.FlowExecutorView.cancel", MagicMock()):
|
||||
base_url = reverse("authentik_api:flow-executor", kwargs={"flow_slug": self.flow.slug})
|
||||
args = urlencode({INVITATION_TOKEN_KEY: invite.pk.hex})
|
||||
args = urlencode({QS_INVITATION_TOKEN_KEY: invite.pk.hex})
|
||||
response = self.client.get(base_url + f"?query={args}")
|
||||
|
||||
session = self.client.session
|
||||
@@ -114,7 +141,7 @@ class TestInvitationStage(FlowTestCase):
|
||||
|
||||
with patch("authentik.flows.views.executor.FlowExecutorView.cancel", MagicMock()):
|
||||
base_url = reverse("authentik_api:flow-executor", kwargs={"flow_slug": self.flow.slug})
|
||||
args = urlencode({INVITATION_TOKEN_KEY: invite.pk.hex})
|
||||
args = urlencode({QS_INVITATION_TOKEN_KEY: invite.pk.hex})
|
||||
response = self.client.get(base_url + f"?query={args}")
|
||||
|
||||
session = self.client.session
|
||||
@@ -134,7 +161,7 @@ class TestInvitationStage(FlowTestCase):
|
||||
)
|
||||
|
||||
plan = FlowPlan(flow_pk=self.flow.pk.hex, bindings=[self.binding], markers=[StageMarker()])
|
||||
plan.context[PLAN_CONTEXT_PROMPT] = {INVITATION_TOKEN_KEY_CONTEXT: invite.pk.hex}
|
||||
plan.context[PLAN_CONTEXT_PROMPT] = {PLAN_CONTEXT_INVITATION_TOKEN: invite.pk.hex}
|
||||
session = self.client.session
|
||||
session[SESSION_KEY_PLAN] = plan
|
||||
session.save()
|
||||
|
||||
@@ -261,7 +261,9 @@ class Prompt(SerializerModel):
|
||||
|
||||
return value
|
||||
|
||||
def field(self, default: Any | None, choices: list[Any] | None = None) -> CharField:
|
||||
def field( # noqa PLR0915
|
||||
self, default: Any | None, choices: list[Any] | None = None
|
||||
) -> CharField:
|
||||
"""Get field type for Challenge and response. Choices are only valid for CHOICE_FIELDS."""
|
||||
field_class = CharField
|
||||
kwargs = {
|
||||
@@ -275,6 +277,7 @@ class Prompt(SerializerModel):
|
||||
field_class = ReadOnlyField
|
||||
# required can't be set for ReadOnlyField
|
||||
kwargs["required"] = False
|
||||
kwargs["allow_blank"] = True
|
||||
case FieldTypes.EMAIL:
|
||||
field_class = EmailField
|
||||
kwargs["allow_blank"] = not self.required
|
||||
@@ -306,7 +309,14 @@ class Prompt(SerializerModel):
|
||||
|
||||
if self.type in CHOICE_FIELDS:
|
||||
field_class = ChoiceField
|
||||
kwargs["choices"] = choices or []
|
||||
kwargs["choices"] = []
|
||||
if choices:
|
||||
for choice in choices:
|
||||
label, value = choice, choice
|
||||
if isinstance(choice, dict):
|
||||
label = choice.get("label", "")
|
||||
value = choice.get("value", "")
|
||||
kwargs["choices"].append((value, label))
|
||||
|
||||
if default:
|
||||
kwargs["default"] = default
|
||||
|
||||
@@ -23,6 +23,7 @@ from authentik import authentik_full_version
|
||||
from authentik.events.models import Event, EventAction
|
||||
from authentik.lib.sentry import should_ignore_exception
|
||||
from authentik.lib.utils.reflection import class_to_path
|
||||
from authentik.root.signals import post_startup, pre_startup, startup
|
||||
from authentik.tasks.models import Task, TaskLog, TaskStatus, WorkerStatus
|
||||
from authentik.tenants.models import Tenant
|
||||
from authentik.tenants.utils import get_current_tenant
|
||||
@@ -32,6 +33,14 @@ HEALTHCHECK_LOGGER = get_logger("authentik.worker").bind()
|
||||
DB_ERRORS = (OperationalError, Error)
|
||||
|
||||
|
||||
class StartupSignalsMiddleware(Middleware):
|
||||
def after_process_boot(self, broker: Broker):
|
||||
_startup_sender = type("WorkerStartup", (object,), {})
|
||||
pre_startup.send(sender=_startup_sender)
|
||||
startup.send(sender=_startup_sender)
|
||||
post_startup.send(sender=_startup_sender)
|
||||
|
||||
|
||||
class CurrentTask(BaseCurrentTask):
|
||||
@classmethod
|
||||
def get_task(cls) -> Task:
|
||||
|
||||
@@ -9,6 +9,7 @@ from django.utils.translation import gettext_lazy as _
|
||||
from django_dramatiq_postgres.models import TaskBase, TaskState
|
||||
|
||||
from authentik.events.logs import LogEvent
|
||||
from authentik.events.utils import sanitize_item
|
||||
from authentik.lib.models import SerializerModel
|
||||
from authentik.lib.utils.errors import exception_to_dict
|
||||
from authentik.tenants.models import Tenant
|
||||
@@ -174,7 +175,7 @@ class TaskLog(models.Model):
|
||||
log_level=log_event.log_level,
|
||||
logger=log_event.logger,
|
||||
timestamp=log_event.timestamp,
|
||||
attributes=log_event.attributes,
|
||||
attributes=sanitize_item(log_event.attributes),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -193,7 +194,7 @@ class TaskLog(models.Model):
|
||||
log_level=log_event.log_level,
|
||||
logger=log_event.logger,
|
||||
timestamp=log_event.timestamp,
|
||||
attributes=log_event.attributes,
|
||||
attributes=sanitize_item(log_event.attributes),
|
||||
)
|
||||
for log_event in log_events
|
||||
]
|
||||
|
||||
@@ -2,9 +2,10 @@ import pickle # nosec
|
||||
from collections.abc import Iterable
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from uuid import UUID
|
||||
|
||||
from dramatiq.actor import Actor
|
||||
from psqlextra.query import ConflictAction
|
||||
from psqlextra.types import ConflictAction
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from authentik.tasks.schedules.models import Schedule
|
||||
@@ -15,7 +16,7 @@ class ScheduleSpec:
|
||||
actor: Actor
|
||||
crontab: str
|
||||
paused: bool = False
|
||||
identifier: str | None = None
|
||||
identifier: str | UUID | None = None
|
||||
uid: str | None = None
|
||||
|
||||
args: Iterable[Any] = field(default_factory=tuple)
|
||||
@@ -41,6 +42,8 @@ class ScheduleSpec:
|
||||
return pickle.dumps(options)
|
||||
|
||||
def update_or_create(self) -> "Schedule":
|
||||
from django.contrib.contenttypes.models import ContentType
|
||||
|
||||
from authentik.tasks.schedules.models import Schedule
|
||||
|
||||
update_values = {
|
||||
@@ -50,10 +53,12 @@ class ScheduleSpec:
|
||||
"kwargs": self.get_kwargs(),
|
||||
"options": self.get_options(),
|
||||
}
|
||||
if self.rel_obj is not None:
|
||||
update_values["rel_obj_content_type"] = ContentType.objects.get_for_model(self.rel_obj)
|
||||
update_values["rel_obj_id"] = str(self.rel_obj.pk)
|
||||
create_values = {
|
||||
**update_values,
|
||||
"crontab": self.crontab,
|
||||
"rel_obj": self.rel_obj,
|
||||
}
|
||||
|
||||
schedule = Schedule.objects.on_conflict(
|
||||
@@ -62,7 +67,7 @@ class ScheduleSpec:
|
||||
update_values=update_values,
|
||||
).insert_and_get(
|
||||
actor_name=self.actor.actor_name,
|
||||
identifier=self.identifier,
|
||||
identifier=str(self.identifier),
|
||||
**create_values,
|
||||
)
|
||||
|
||||
|
||||
@@ -13,6 +13,7 @@ def post_save_scheduled_model(sender, instance, **_):
|
||||
return
|
||||
for spec in instance.schedule_specs:
|
||||
spec.rel_obj = instance
|
||||
spec.identifier = instance.pk
|
||||
schedule = spec.update_or_create()
|
||||
if spec.send_on_save:
|
||||
schedule.send()
|
||||
|
||||
@@ -5,10 +5,3 @@ setup()
|
||||
import django # noqa: E402
|
||||
|
||||
django.setup()
|
||||
|
||||
from authentik.root.signals import post_startup, pre_startup, startup # noqa: E402
|
||||
|
||||
_startup_sender = type("WorkerStartup", (object,), {})
|
||||
pre_startup.send(sender=_startup_sender)
|
||||
startup.send(sender=_startup_sender)
|
||||
post_startup.send(sender=_startup_sender)
|
||||
|
||||
@@ -5,7 +5,8 @@ from json import loads
|
||||
from django.urls import reverse
|
||||
from django_tenants.utils import get_public_schema_name
|
||||
|
||||
from authentik.core.models import Token, TokenIntents, User
|
||||
from authentik.core.models import Token, TokenIntents
|
||||
from authentik.core.tests.utils import create_test_user
|
||||
from authentik.lib.config import CONFIG
|
||||
from authentik.lib.generators import generate_id
|
||||
from authentik.tenants.models import Tenant
|
||||
@@ -21,7 +22,7 @@ class TestRecovery(TenantAPITestCase):
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self.tenant = Tenant.objects.get(schema_name=get_public_schema_name())
|
||||
self.user: User = User.objects.create_user(username="recovery-test-user")
|
||||
self.user = create_test_user()
|
||||
|
||||
@CONFIG.patch("outposts.disable_embedded_outpost", True)
|
||||
@CONFIG.patch("tenants.enabled", True)
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
"$schema": "http://json-schema.org/draft-07/schema",
|
||||
"$id": "https://goauthentik.io/blueprints/schema.json",
|
||||
"type": "object",
|
||||
"title": "authentik 2025.10.0-rc1 Blueprint schema",
|
||||
"title": "authentik 2025.10.3 Blueprint schema",
|
||||
"required": [
|
||||
"version",
|
||||
"entries"
|
||||
|
||||
@@ -60,22 +60,6 @@ func checkServer() int {
|
||||
return 0
|
||||
}
|
||||
|
||||
func splitHostPort(address string) (host, port string) {
|
||||
lastColon := strings.LastIndex(address, ":")
|
||||
if lastColon == -1 {
|
||||
return address, ""
|
||||
}
|
||||
|
||||
host = address[:lastColon]
|
||||
port = address[lastColon+1:]
|
||||
|
||||
if strings.HasPrefix(host, "[") && strings.HasSuffix(host, "]") {
|
||||
host = host[1 : len(host)-1]
|
||||
}
|
||||
|
||||
return host, port
|
||||
}
|
||||
|
||||
func checkWorker() int {
|
||||
pidB, err := os.ReadFile(workerPidFile)
|
||||
if err != nil {
|
||||
@@ -98,41 +82,6 @@ func checkWorker() int {
|
||||
log.WithError(err).Warning("failed to signal worker process")
|
||||
return 1
|
||||
}
|
||||
h := &http.Client{
|
||||
Transport: web.NewUserAgentTransport("goauthentik.io/healthcheck", http.DefaultTransport),
|
||||
}
|
||||
|
||||
host, port := splitHostPort(config.Get().Listen.HTTP)
|
||||
|
||||
if host == "0.0.0.0" || host == "::" {
|
||||
url := fmt.Sprintf("http://%s:%s/-/health/ready/", "::1", port)
|
||||
_, err := h.Head(url)
|
||||
if err != nil {
|
||||
log.WithError(err).WithField("url", url).Warning("failed to send healthcheck request")
|
||||
url := fmt.Sprintf("http://%s:%s/-/health/ready/", "127.0.0.1", port)
|
||||
res, err := h.Head(url)
|
||||
if err != nil {
|
||||
log.WithError(err).WithField("url", url).Warning("failed to send healthcheck request")
|
||||
return 1
|
||||
}
|
||||
if res.StatusCode >= 400 {
|
||||
log.WithField("status", res.StatusCode).Warning("unhealthy status code")
|
||||
return 1
|
||||
}
|
||||
}
|
||||
} else {
|
||||
url := fmt.Sprintf("http://%s:%s/-/health/ready/", host, port)
|
||||
res, err := h.Head(url)
|
||||
if err != nil {
|
||||
log.WithError(err).Warning("failed to send healthcheck request")
|
||||
return 1
|
||||
}
|
||||
if res.StatusCode >= 400 {
|
||||
log.WithField("status", res.StatusCode).Warning("unhealthy status code")
|
||||
return 1
|
||||
}
|
||||
}
|
||||
|
||||
log.Info("successfully checked health")
|
||||
return 0
|
||||
}
|
||||
|
||||
@@ -31,7 +31,7 @@ services:
|
||||
AUTHENTIK_POSTGRESQL__PASSWORD: ${PG_PASS}
|
||||
AUTHENTIK_POSTGRESQL__USER: ${PG_USER:-authentik}
|
||||
AUTHENTIK_SECRET_KEY: ${AUTHENTIK_SECRET_KEY:?secret key required}
|
||||
image: ${AUTHENTIK_IMAGE:-ghcr.io/goauthentik/server}:${AUTHENTIK_TAG:-2025.10.0-rc1}
|
||||
image: ${AUTHENTIK_IMAGE:-ghcr.io/goauthentik/server}:${AUTHENTIK_TAG:-2025.10.3}
|
||||
ports:
|
||||
- ${COMPOSE_PORT_HTTP:-9000}:9000
|
||||
- ${COMPOSE_PORT_HTTPS:-9443}:9443
|
||||
@@ -52,7 +52,7 @@ services:
|
||||
AUTHENTIK_POSTGRESQL__PASSWORD: ${PG_PASS}
|
||||
AUTHENTIK_POSTGRESQL__USER: ${PG_USER:-authentik}
|
||||
AUTHENTIK_SECRET_KEY: ${AUTHENTIK_SECRET_KEY:?secret key required}
|
||||
image: ${AUTHENTIK_IMAGE:-ghcr.io/goauthentik/server}:${AUTHENTIK_TAG:-2025.10.0-rc1}
|
||||
image: ${AUTHENTIK_IMAGE:-ghcr.io/goauthentik/server}:${AUTHENTIK_TAG:-2025.10.3}
|
||||
restart: unless-stopped
|
||||
user: root
|
||||
volumes:
|
||||
|
||||
@@ -1 +1 @@
|
||||
2025.10.0-rc1
|
||||
2025.10.3
|
||||
@@ -6,7 +6,7 @@ import (
|
||||
)
|
||||
|
||||
func OpensslVersion() string {
|
||||
cmd := exec.Command("openssl", "version")
|
||||
cmd := exec.Command("/usr/bin/openssl", "version")
|
||||
var out bytes.Buffer
|
||||
cmd.Stdout = &out
|
||||
err := cmd.Run()
|
||||
|
||||
@@ -93,7 +93,7 @@ func NewAPIController(akURL url.URL, token string) *APIController {
|
||||
}),
|
||||
)
|
||||
if len(outposts.Results) < 1 {
|
||||
log.Panic("No outposts found with given token, ensure the given token corresponds to an authenitk Outpost")
|
||||
log.Panic("No outposts found with given token, ensure the given token corresponds to an authentik Outpost")
|
||||
}
|
||||
outpost := outposts.Results[0]
|
||||
|
||||
@@ -122,6 +122,7 @@ func NewAPIController(akURL url.URL, token string) *APIController {
|
||||
eventHandlers: []EventHandler{},
|
||||
refreshHandlers: make([]func(), 0),
|
||||
}
|
||||
ac.logger.WithField("embedded", ac.IsEmbedded()).Info("Outpost mode")
|
||||
ac.logger.WithField("offset", ac.reloadOffset.String()).Debug("HA Reload offset")
|
||||
err = ac.initEvent(akURL, outpost.Pk)
|
||||
if err != nil {
|
||||
@@ -135,6 +136,13 @@ func (a *APIController) Log() *log.Entry {
|
||||
return a.logger
|
||||
}
|
||||
|
||||
func (a *APIController) IsEmbedded() bool {
|
||||
if m := a.Outpost.Managed.Get(); m != nil {
|
||||
return *m == "goauthentik.io/outposts/embedded"
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// Start Starts all handlers, non-blocking
|
||||
func (a *APIController) Start() error {
|
||||
err := a.Server.Refresh()
|
||||
|
||||
@@ -66,6 +66,7 @@ type Server interface {
|
||||
API() *ak.APIController
|
||||
Apps() []*Application
|
||||
CryptoStore() *ak.CryptoStore
|
||||
SessionBackend() string
|
||||
}
|
||||
|
||||
func init() {
|
||||
@@ -94,10 +95,7 @@ func NewApplication(p api.ProxyOutpostConfig, c *http.Client, server Server, old
|
||||
CallbackSignature: []string{"true"},
|
||||
}.Encode()
|
||||
|
||||
isEmbedded := false
|
||||
if m := server.API().Outpost.Managed.Get(); m != nil {
|
||||
isEmbedded = *m == "goauthentik.io/outposts/embedded"
|
||||
}
|
||||
isEmbedded := server.API().IsEmbedded()
|
||||
// Configure an OpenID Connect aware OAuth2 client.
|
||||
endpoint := GetOIDCEndpoint(
|
||||
p,
|
||||
@@ -153,6 +151,7 @@ func NewApplication(p api.ProxyOutpostConfig, c *http.Client, server Server, old
|
||||
go a.authHeaderCache.Start()
|
||||
if oldApp != nil && oldApp.sessions != nil {
|
||||
a.sessions = oldApp.sessions
|
||||
muxLogger.Debug("reusing existing session store")
|
||||
} else {
|
||||
sess, err := a.getStore(p, externalHost)
|
||||
if err != nil {
|
||||
@@ -161,7 +160,7 @@ func NewApplication(p api.ProxyOutpostConfig, c *http.Client, server Server, old
|
||||
a.sessions = sess
|
||||
}
|
||||
mux.Use(web.NewLoggingHandler(muxLogger, func(l *log.Entry, r *http.Request) *log.Entry {
|
||||
c := a.getClaimsFromSession(r)
|
||||
c := a.getClaimsFromSession(nil, r)
|
||||
if c == nil {
|
||||
return l
|
||||
}
|
||||
@@ -172,7 +171,7 @@ func NewApplication(p api.ProxyOutpostConfig, c *http.Client, server Server, old
|
||||
}))
|
||||
mux.Use(func(inner http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
|
||||
c := a.getClaimsFromSession(r)
|
||||
c := a.getClaimsFromSession(nil, r)
|
||||
user := ""
|
||||
if c != nil {
|
||||
user = c.PreferredUsername
|
||||
|
||||
@@ -13,7 +13,7 @@ import (
|
||||
// checkAuth Get claims which are currently in session
|
||||
// Returns an error if the session can't be loaded or the claims can't be parsed/type-cast
|
||||
func (a *Application) checkAuth(rw http.ResponseWriter, r *http.Request) (*types.Claims, error) {
|
||||
c := a.getClaimsFromSession(r)
|
||||
c := a.getClaimsFromSession(rw, r)
|
||||
if c != nil {
|
||||
return c, nil
|
||||
}
|
||||
@@ -50,10 +50,17 @@ func (a *Application) checkAuth(rw http.ResponseWriter, r *http.Request) (*types
|
||||
return nil, fmt.Errorf("failed to get claims from session")
|
||||
}
|
||||
|
||||
func (a *Application) getClaimsFromSession(r *http.Request) *types.Claims {
|
||||
func (a *Application) getClaimsFromSession(rw http.ResponseWriter, r *http.Request) *types.Claims {
|
||||
s, err := a.sessions.Get(r, a.SessionName())
|
||||
if err != nil {
|
||||
// err == user has no session/session is not valid, reject
|
||||
// err == user has no session/session is not valid
|
||||
// Delete the stale session cookie if it exists
|
||||
if rw != nil {
|
||||
s.Options.MaxAge = -1
|
||||
if saveErr := s.Save(r, rw); saveErr != nil {
|
||||
a.log.WithError(saveErr).Warning("failed to delete stale session cookie")
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
claims, ok := s.Values[constants.SessionClaims]
|
||||
@@ -64,7 +71,7 @@ func (a *Application) getClaimsFromSession(r *http.Request) *types.Claims {
|
||||
|
||||
// Claims are always stored as types.Claims but may be deserialized differently:
|
||||
// - Filesystem store (gob): preserves struct type as types.Claims
|
||||
// - PostgreSQL store (JSON): deserializes as map[string]interface{}
|
||||
// - PostgreSQL store (JSON): deserializes as map[string]any
|
||||
|
||||
// Handle struct type (filesystem store)
|
||||
if c, ok := claims.(types.Claims); ok {
|
||||
@@ -72,7 +79,7 @@ func (a *Application) getClaimsFromSession(r *http.Request) *types.Claims {
|
||||
}
|
||||
|
||||
// Handle map type (PostgreSQL store)
|
||||
if claimsMap, ok := claims.(map[string]interface{}); ok {
|
||||
if claimsMap, ok := claims.(map[string]any); ok {
|
||||
var c types.Claims
|
||||
if err := mapstructure.Decode(claimsMap, &c); err != nil {
|
||||
return nil
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"testing"
|
||||
|
||||
"github.com/gorilla/sessions"
|
||||
"github.com/mitchellh/mapstructure"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
@@ -27,7 +28,7 @@ func TestClaimsJSONSerialization(t *testing.T) {
|
||||
Entitlements: []string{"read", "write"},
|
||||
Sid: "session-id-456",
|
||||
Proxy: &types.ProxyClaims{
|
||||
UserAttributes: map[string]interface{}{
|
||||
UserAttributes: map[string]any{
|
||||
"custom_field": "custom_value",
|
||||
"department": "engineering",
|
||||
},
|
||||
@@ -70,35 +71,33 @@ func TestClaimsJSONSerialization(t *testing.T) {
|
||||
assert.Equal(t, "engineering", parsedClaims.Proxy.UserAttributes["department"])
|
||||
}
|
||||
|
||||
// TestClaimsMapSerialization tests that Claims stored as map[string]interface{} can be converted back
|
||||
// TestClaimsMapSerialization tests that Claims stored as map[string]any can be converted back
|
||||
func TestClaimsMapSerialization(t *testing.T) {
|
||||
// Simulate how claims are stored in session as map (like from PostgreSQL JSONB)
|
||||
claimsMap := map[string]interface{}{
|
||||
claimsMap := map[string]any{
|
||||
"sub": "user-id-123",
|
||||
"exp": float64(1234567890), // json numbers become float64
|
||||
"email": "test@example.com",
|
||||
"email_verified": true,
|
||||
"name": "Test User",
|
||||
"preferred_username": "testuser",
|
||||
"groups": []interface{}{"admin", "user"},
|
||||
"entitlements": []interface{}{"read", "write"},
|
||||
"groups": []any{"admin", "user"},
|
||||
"entitlements": []any{"read", "write"},
|
||||
"sid": "session-id-456",
|
||||
"ak_proxy": map[string]interface{}{
|
||||
"user_attributes": map[string]interface{}{
|
||||
"ak_proxy": map[string]any{
|
||||
"user_attributes": map[string]any{
|
||||
"custom_field": "custom_value",
|
||||
},
|
||||
"backend_override": "custom-backend",
|
||||
"host_header": "example.com",
|
||||
"is_superuser": true,
|
||||
},
|
||||
"raw_token": "not-a-real-token",
|
||||
}
|
||||
|
||||
// Convert map to Claims using JSON marshaling (like getClaimsFromSession does)
|
||||
jsonData, err := json.Marshal(claimsMap)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Convert map to Claims using mapstructure marshaling (like getClaimsFromSession does)
|
||||
var claims types.Claims
|
||||
err = json.Unmarshal(jsonData, &claims)
|
||||
err := mapstructure.Decode(claimsMap, &claims)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify fields
|
||||
@@ -111,6 +110,7 @@ func TestClaimsMapSerialization(t *testing.T) {
|
||||
assert.Equal(t, []string{"admin", "user"}, claims.Groups)
|
||||
assert.Equal(t, []string{"read", "write"}, claims.Entitlements)
|
||||
assert.Equal(t, "session-id-456", claims.Sid)
|
||||
assert.Equal(t, "not-a-real-token", claims.RawToken)
|
||||
|
||||
// Verify proxy claims
|
||||
require.NotNil(t, claims.Proxy)
|
||||
@@ -122,7 +122,7 @@ func TestClaimsMapSerialization(t *testing.T) {
|
||||
|
||||
// TestClaimsMinimalFields tests that Claims work with minimal required fields
|
||||
func TestClaimsMinimalFields(t *testing.T) {
|
||||
claimsMap := map[string]interface{}{
|
||||
claimsMap := map[string]any{
|
||||
"sub": "user-id-123",
|
||||
"exp": float64(1234567890),
|
||||
}
|
||||
@@ -144,11 +144,11 @@ func TestClaimsMinimalFields(t *testing.T) {
|
||||
|
||||
// TestClaimsWithEmptyArrays tests that empty arrays are handled correctly
|
||||
func TestClaimsWithEmptyArrays(t *testing.T) {
|
||||
claimsMap := map[string]interface{}{
|
||||
claimsMap := map[string]any{
|
||||
"sub": "user-id-123",
|
||||
"exp": float64(1234567890),
|
||||
"groups": []interface{}{},
|
||||
"entitlements": []interface{}{},
|
||||
"groups": []any{},
|
||||
"entitlements": []any{},
|
||||
}
|
||||
|
||||
jsonData, err := json.Marshal(claimsMap)
|
||||
@@ -167,7 +167,7 @@ func TestClaimsWithEmptyArrays(t *testing.T) {
|
||||
|
||||
// TestClaimsWithNullProxyClaims tests that null proxy claims don't cause issues
|
||||
func TestClaimsWithNullProxyClaims(t *testing.T) {
|
||||
claimsMap := map[string]interface{}{
|
||||
claimsMap := map[string]any{
|
||||
"sub": "user-id-123",
|
||||
"exp": float64(1234567890),
|
||||
"ak_proxy": nil,
|
||||
@@ -185,18 +185,18 @@ func TestClaimsWithNullProxyClaims(t *testing.T) {
|
||||
}
|
||||
|
||||
// TestGetClaimsFromSession_Success tests successful retrieval of claims from session
|
||||
// uses a mock session that returns claims as map[string]interface{} to simulate
|
||||
// uses a mock session that returns claims as map[string]any to simulate
|
||||
// how PostgreSQL storage deserializes JSONB data
|
||||
func TestGetClaimsFromSession_Success(t *testing.T) {
|
||||
// Create a custom mock store that returns claims as map
|
||||
store := &mockMapSessionStore{
|
||||
claimsMap: map[string]interface{}{
|
||||
claimsMap: map[string]any{
|
||||
"sub": "user-id-123",
|
||||
"exp": float64(1234567890),
|
||||
"email": "test@example.com",
|
||||
"email_verified": true,
|
||||
"preferred_username": "testuser",
|
||||
"groups": []interface{}{"admin", "user"},
|
||||
"groups": []any{"admin", "user"},
|
||||
},
|
||||
}
|
||||
|
||||
@@ -207,7 +207,7 @@ func TestGetClaimsFromSession_Success(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
|
||||
// Test getClaimsFromSession
|
||||
claims := app.getClaimsFromSession(req)
|
||||
claims := app.getClaimsFromSession(nil, req)
|
||||
require.NotNil(t, claims)
|
||||
assert.Equal(t, "user-id-123", claims.Sub)
|
||||
assert.Equal(t, 1234567890, claims.Exp)
|
||||
@@ -217,9 +217,9 @@ func TestGetClaimsFromSession_Success(t *testing.T) {
|
||||
assert.Equal(t, []string{"admin", "user"}, claims.Groups)
|
||||
}
|
||||
|
||||
// mockMapSessionStore is a mock session store that returns claims as map[string]interface{}
|
||||
// mockMapSessionStore is a mock session store that returns claims as map[string]any
|
||||
type mockMapSessionStore struct {
|
||||
claimsMap map[string]interface{}
|
||||
claimsMap map[string]any
|
||||
}
|
||||
|
||||
func (m *mockMapSessionStore) Get(r *http.Request, name string) (*sessions.Session, error) {
|
||||
@@ -250,7 +250,7 @@ func TestGetClaimsFromSession_NoSession(t *testing.T) {
|
||||
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
|
||||
claims := app.getClaimsFromSession(req)
|
||||
claims := app.getClaimsFromSession(nil, req)
|
||||
assert.Nil(t, claims)
|
||||
}
|
||||
|
||||
@@ -266,7 +266,7 @@ func TestGetClaimsFromSession_NoClaims(t *testing.T) {
|
||||
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
|
||||
claims := app.getClaimsFromSession(req)
|
||||
claims := app.getClaimsFromSession(nil, req)
|
||||
assert.Nil(t, claims)
|
||||
}
|
||||
|
||||
@@ -280,7 +280,7 @@ func TestGetClaimsFromSession_InvalidClaimsType(t *testing.T) {
|
||||
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
|
||||
claims := app.getClaimsFromSession(req)
|
||||
claims := app.getClaimsFromSession(nil, req)
|
||||
assert.Nil(t, claims)
|
||||
}
|
||||
|
||||
@@ -314,7 +314,7 @@ func TestClaimsRoundTrip(t *testing.T) {
|
||||
Entitlements: []string{"ent1", "ent2"},
|
||||
Sid: "session-789",
|
||||
Proxy: &types.ProxyClaims{
|
||||
UserAttributes: map[string]interface{}{
|
||||
UserAttributes: map[string]any{
|
||||
"attr1": "value1",
|
||||
"attr2": float64(42),
|
||||
"attr3": true,
|
||||
@@ -329,8 +329,8 @@ func TestClaimsRoundTrip(t *testing.T) {
|
||||
jsonData, err := json.Marshal(originalClaims)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Step 2: Deserialize to map[string]interface{} (simulating PostgreSQL load)
|
||||
var claimsMap map[string]interface{}
|
||||
// Step 2: Deserialize to map[string]any (simulating PostgreSQL load)
|
||||
var claimsMap map[string]any
|
||||
err = json.Unmarshal(jsonData, &claimsMap)
|
||||
require.NoError(t, err)
|
||||
|
||||
|
||||
@@ -14,62 +14,83 @@ import (
|
||||
"goauthentik.io/internal/outpost/proxyv2/types"
|
||||
)
|
||||
|
||||
func (a *Application) addHeaders(headers http.Header, c *types.Claims) {
|
||||
nh := a.getHeaders(c)
|
||||
for key, val := range nh {
|
||||
headers.Set(key, val)
|
||||
}
|
||||
a.removeDuplicateUnderscoreHeader(headers)
|
||||
}
|
||||
|
||||
func (a *Application) removeDuplicateUnderscoreHeader(h http.Header) {
|
||||
for key := range h {
|
||||
ush := strings.ReplaceAll(key, "_", "-")
|
||||
if _, ok := h[ush]; !ok {
|
||||
h.Del(key)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (a *Application) getHeaders(c *types.Claims) map[string]string {
|
||||
headers := map[string]string{}
|
||||
// https://docs.goauthentik.io/add-secure-apps/providers/proxy
|
||||
headers["X-authentik-username"] = c.PreferredUsername
|
||||
headers["X-authentik-groups"] = strings.Join(c.Groups, "|")
|
||||
headers["X-authentik-entitlements"] = strings.Join(c.Entitlements, "|")
|
||||
headers["X-authentik-email"] = c.Email
|
||||
headers["X-authentik-name"] = c.Name
|
||||
headers["X-authentik-uid"] = c.Sub
|
||||
headers["X-authentik-jwt"] = c.RawToken
|
||||
|
||||
// System headers
|
||||
headers["X-authentik-meta-jwks"] = a.endpoint.JwksUri
|
||||
headers["X-authentik-meta-outpost"] = a.outpostName
|
||||
headers["X-authentik-meta-provider"] = a.proxyConfig.Name
|
||||
headers["X-authentik-meta-app"] = a.proxyConfig.AssignedApplicationSlug
|
||||
headers["X-authentik-meta-version"] = constants.UserAgentOutpost()
|
||||
|
||||
if c.Proxy == nil {
|
||||
return headers
|
||||
}
|
||||
if authz := a.setAuthorizationHeader(c); authz != "" {
|
||||
headers["Authorization"] = authz
|
||||
}
|
||||
// Check if user has additional headers set that we should sent
|
||||
userAttributes := c.Proxy.UserAttributes
|
||||
if additionalHeaders, ok := userAttributes["additionalHeaders"]; ok {
|
||||
a.log.WithField("headers", additionalHeaders).Trace("setting additional headers")
|
||||
if additionalHeaders == nil {
|
||||
return headers
|
||||
}
|
||||
for key, value := range additionalHeaders.(map[string]interface{}) {
|
||||
headers[key] = toString(value)
|
||||
}
|
||||
}
|
||||
return headers
|
||||
}
|
||||
|
||||
// Attempt to set basic auth based on user's attributes
|
||||
func (a *Application) setAuthorizationHeader(headers http.Header, c *types.Claims) {
|
||||
func (a *Application) setAuthorizationHeader(c *types.Claims) string {
|
||||
if !*a.proxyConfig.BasicAuthEnabled {
|
||||
return
|
||||
return ""
|
||||
}
|
||||
userAttributes := c.Proxy.UserAttributes
|
||||
var ok bool
|
||||
var username string
|
||||
var password string
|
||||
if password, ok = userAttributes[*a.proxyConfig.BasicAuthPasswordAttribute].(string); !ok {
|
||||
password = ""
|
||||
}
|
||||
// Check if we should use email or a custom attribute as username
|
||||
var username string
|
||||
if username, ok = userAttributes[*a.proxyConfig.BasicAuthUserAttribute].(string); !ok {
|
||||
username = c.Email
|
||||
}
|
||||
if username == "" && password == "" {
|
||||
return
|
||||
if password == "" {
|
||||
return ""
|
||||
}
|
||||
authVal := base64.StdEncoding.EncodeToString([]byte(username + ":" + password))
|
||||
a.log.WithField("username", username).Trace("setting http basic auth")
|
||||
headers.Set("Authorization", fmt.Sprintf("Basic %s", authVal))
|
||||
}
|
||||
|
||||
func (a *Application) addHeaders(headers http.Header, c *types.Claims) {
|
||||
// https://docs.goauthentik.io/add-secure-apps/providers/proxy
|
||||
headers.Set("X-authentik-username", c.PreferredUsername)
|
||||
headers.Set("X-authentik-groups", strings.Join(c.Groups, "|"))
|
||||
headers.Set("X-authentik-entitlements", strings.Join(c.Entitlements, "|"))
|
||||
headers.Set("X-authentik-email", c.Email)
|
||||
headers.Set("X-authentik-name", c.Name)
|
||||
headers.Set("X-authentik-uid", c.Sub)
|
||||
headers.Set("X-authentik-jwt", c.RawToken)
|
||||
|
||||
// System headers
|
||||
headers.Set("X-authentik-meta-jwks", a.endpoint.JwksUri)
|
||||
headers.Set("X-authentik-meta-outpost", a.outpostName)
|
||||
headers.Set("X-authentik-meta-provider", a.proxyConfig.Name)
|
||||
headers.Set("X-authentik-meta-app", a.proxyConfig.AssignedApplicationSlug)
|
||||
headers.Set("X-authentik-meta-version", constants.UserAgentOutpost())
|
||||
|
||||
if c.Proxy == nil {
|
||||
return
|
||||
}
|
||||
userAttributes := c.Proxy.UserAttributes
|
||||
a.setAuthorizationHeader(headers, c)
|
||||
// Check if user has additional headers set that we should sent
|
||||
if additionalHeaders, ok := userAttributes["additionalHeaders"]; ok {
|
||||
a.log.WithField("headers", additionalHeaders).Trace("setting additional headers")
|
||||
if additionalHeaders == nil {
|
||||
return
|
||||
}
|
||||
for key, value := range additionalHeaders.(map[string]interface{}) {
|
||||
headers.Set(key, toString(value))
|
||||
}
|
||||
}
|
||||
return fmt.Sprintf("Basic %s", authVal)
|
||||
}
|
||||
|
||||
// getTraefikForwardUrl See https://doc.traefik.io/traefik/middlewares/forwardauth/
|
||||
|
||||
@@ -1,12 +1,15 @@
|
||||
package application
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/url"
|
||||
"regexp"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"goauthentik.io/api/v3"
|
||||
"goauthentik.io/internal/constants"
|
||||
"goauthentik.io/internal/outpost/proxyv2/types"
|
||||
)
|
||||
|
||||
func urlMustParse(u string) *url.URL {
|
||||
@@ -48,3 +51,135 @@ func TestIsAllowlisted_Proxy_Domain(t *testing.T) {
|
||||
assert.Equal(t, false, a.IsAllowlisted(urlMustParse("https://health.domain.tld/")))
|
||||
assert.Equal(t, true, a.IsAllowlisted(urlMustParse("https://health.domain.tld/ping/qq")))
|
||||
}
|
||||
|
||||
func TestAdHeaders_Standard(t *testing.T) {
|
||||
a := newTestApplication()
|
||||
h := http.Header{}
|
||||
a.addHeaders(h, &types.Claims{
|
||||
PreferredUsername: "foo",
|
||||
Groups: []string{"foo", "bar"},
|
||||
Entitlements: []string{"bar", "quox"},
|
||||
Email: "bar@authentik.company",
|
||||
Name: "foo",
|
||||
Sub: "bar",
|
||||
RawToken: "baz",
|
||||
})
|
||||
assert.Equal(t, http.Header{
|
||||
"X-Authentik-Email": []string{"bar@authentik.company"},
|
||||
"X-Authentik-Entitlements": []string{"bar|quox"},
|
||||
"X-Authentik-Groups": []string{"foo|bar"},
|
||||
"X-Authentik-Jwt": []string{"baz"},
|
||||
"X-Authentik-Meta-App": []string{""},
|
||||
"X-Authentik-Meta-Jwks": []string{""},
|
||||
"X-Authentik-Meta-Outpost": []string{""},
|
||||
"X-Authentik-Meta-Provider": []string{a.proxyConfig.Name},
|
||||
"X-Authentik-Meta-Version": []string{constants.UserAgentOutpost()},
|
||||
"X-Authentik-Name": []string{"foo"},
|
||||
"X-Authentik-Uid": []string{"bar"},
|
||||
"X-Authentik-Username": []string{"foo"},
|
||||
}, h)
|
||||
}
|
||||
|
||||
func TestAdHeaders_BasicAuth(t *testing.T) {
|
||||
a := newTestApplication()
|
||||
a.proxyConfig.BasicAuthEnabled = api.PtrBool(true)
|
||||
a.proxyConfig.BasicAuthUserAttribute = api.PtrString("user")
|
||||
a.proxyConfig.BasicAuthPasswordAttribute = api.PtrString("pass")
|
||||
h := http.Header{}
|
||||
a.addHeaders(h, &types.Claims{
|
||||
PreferredUsername: "foo",
|
||||
Groups: []string{"foo", "bar"},
|
||||
Entitlements: []string{"bar", "quox"},
|
||||
Email: "bar@authentik.company",
|
||||
Name: "foo",
|
||||
Sub: "bar",
|
||||
RawToken: "baz",
|
||||
Proxy: &types.ProxyClaims{
|
||||
UserAttributes: map[string]any{
|
||||
"user": "foo",
|
||||
"pass": "baz",
|
||||
},
|
||||
},
|
||||
})
|
||||
assert.Equal(t, http.Header{
|
||||
"Authorization": []string{"Basic Zm9vOmJheg=="},
|
||||
"X-Authentik-Email": []string{"bar@authentik.company"},
|
||||
"X-Authentik-Entitlements": []string{"bar|quox"},
|
||||
"X-Authentik-Groups": []string{"foo|bar"},
|
||||
"X-Authentik-Jwt": []string{"baz"},
|
||||
"X-Authentik-Meta-App": []string{""},
|
||||
"X-Authentik-Meta-Jwks": []string{""},
|
||||
"X-Authentik-Meta-Outpost": []string{""},
|
||||
"X-Authentik-Meta-Provider": []string{a.proxyConfig.Name},
|
||||
"X-Authentik-Meta-Version": []string{constants.UserAgentOutpost()},
|
||||
"X-Authentik-Name": []string{"foo"},
|
||||
"X-Authentik-Uid": []string{"bar"},
|
||||
"X-Authentik-Username": []string{"foo"},
|
||||
}, h)
|
||||
}
|
||||
|
||||
func TestAdHeaders_Extra(t *testing.T) {
|
||||
a := newTestApplication()
|
||||
h := http.Header{}
|
||||
a.addHeaders(h, &types.Claims{
|
||||
PreferredUsername: "foo",
|
||||
Groups: []string{"foo", "bar"},
|
||||
Entitlements: []string{"bar", "quox"},
|
||||
Email: "bar@authentik.company",
|
||||
Name: "foo",
|
||||
Sub: "bar",
|
||||
RawToken: "baz",
|
||||
Proxy: &types.ProxyClaims{
|
||||
UserAttributes: map[string]any{
|
||||
"additionalHeaders": map[string]any{
|
||||
"foo": "bar",
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
assert.Equal(t, http.Header{
|
||||
"Foo": []string{"bar"},
|
||||
"X-Authentik-Email": []string{"bar@authentik.company"},
|
||||
"X-Authentik-Entitlements": []string{"bar|quox"},
|
||||
"X-Authentik-Groups": []string{"foo|bar"},
|
||||
"X-Authentik-Jwt": []string{"baz"},
|
||||
"X-Authentik-Meta-App": []string{""},
|
||||
"X-Authentik-Meta-Jwks": []string{""},
|
||||
"X-Authentik-Meta-Outpost": []string{""},
|
||||
"X-Authentik-Meta-Provider": []string{a.proxyConfig.Name},
|
||||
"X-Authentik-Meta-Version": []string{constants.UserAgentOutpost()},
|
||||
"X-Authentik-Name": []string{"foo"},
|
||||
"X-Authentik-Uid": []string{"bar"},
|
||||
"X-Authentik-Username": []string{"foo"},
|
||||
}, h)
|
||||
}
|
||||
|
||||
func TestAdHeaders_UnderscoreInitial(t *testing.T) {
|
||||
a := newTestApplication()
|
||||
h := http.Header{}
|
||||
h.Set("X_AUTHENTIK_USERNAME", "another user")
|
||||
h.Set("X-Authentik_username", "another user")
|
||||
a.addHeaders(h, &types.Claims{
|
||||
PreferredUsername: "foo",
|
||||
Groups: []string{"foo", "bar"},
|
||||
Entitlements: []string{"bar", "quox"},
|
||||
Email: "bar@authentik.company",
|
||||
Name: "foo",
|
||||
Sub: "bar",
|
||||
RawToken: "baz",
|
||||
})
|
||||
assert.Equal(t, http.Header{
|
||||
"X-Authentik-Email": []string{"bar@authentik.company"},
|
||||
"X-Authentik-Entitlements": []string{"bar|quox"},
|
||||
"X-Authentik-Groups": []string{"foo|bar"},
|
||||
"X-Authentik-Jwt": []string{"baz"},
|
||||
"X-Authentik-Meta-App": []string{""},
|
||||
"X-Authentik-Meta-Jwks": []string{""},
|
||||
"X-Authentik-Meta-Outpost": []string{""},
|
||||
"X-Authentik-Meta-Provider": []string{a.proxyConfig.Name},
|
||||
"X-Authentik-Meta-Version": []string{constants.UserAgentOutpost()},
|
||||
"X-Authentik-Name": []string{"foo"},
|
||||
"X-Authentik-Uid": []string{"bar"},
|
||||
"X-Authentik-Username": []string{"foo"},
|
||||
}, h)
|
||||
}
|
||||
|
||||
@@ -73,7 +73,7 @@ func (a *Application) proxyModifyRequest(ou *url.URL) func(req *http.Request) {
|
||||
r.Header.Set("X-Forwarded-Host", r.Host)
|
||||
r.URL.Scheme = ou.Scheme
|
||||
r.URL.Host = ou.Host
|
||||
claims := a.getClaimsFromSession(r)
|
||||
claims := a.getClaimsFromSession(nil, r)
|
||||
if claims != nil && claims.Proxy != nil {
|
||||
if claims.Proxy.BackendOverride != "" {
|
||||
u, err := url.Parse(claims.Proxy.BackendOverride)
|
||||
|
||||
@@ -19,6 +19,7 @@ func (a *Application) handleAuthStart(rw http.ResponseWriter, r *http.Request, f
|
||||
state, err := a.createState(r, rw, fwd)
|
||||
if err != nil {
|
||||
a.log.WithError(err).Warning("failed to create state")
|
||||
rw.WriteHeader(400)
|
||||
return
|
||||
}
|
||||
http.Redirect(rw, r, a.oauthConfig.AuthCodeURL(state), http.StatusFound)
|
||||
|
||||
@@ -13,7 +13,7 @@ import (
|
||||
)
|
||||
|
||||
func (a *Application) handleAuthCallback(rw http.ResponseWriter, r *http.Request) {
|
||||
state := a.stateFromRequest(r)
|
||||
state := a.stateFromRequest(rw, r)
|
||||
if state == nil {
|
||||
a.log.Warning("invalid state")
|
||||
a.redirect(rw, r)
|
||||
|
||||
@@ -96,7 +96,7 @@ func (a *Application) createState(r *http.Request, w http.ResponseWriter, fwd st
|
||||
return tokenString, nil
|
||||
}
|
||||
|
||||
func (a *Application) stateFromRequest(r *http.Request) *OAuthState {
|
||||
func (a *Application) stateFromRequest(rw http.ResponseWriter, r *http.Request) *OAuthState {
|
||||
stateJwt := r.URL.Query().Get("state")
|
||||
token, err := jwt.Parse(stateJwt, func(token *jwt.Token) (interface{}, error) {
|
||||
// Don't forget to validate the alg is what you expect:
|
||||
@@ -127,6 +127,13 @@ func (a *Application) stateFromRequest(r *http.Request) *OAuthState {
|
||||
s, err := a.sessions.Get(r, a.SessionName())
|
||||
if err != nil {
|
||||
a.log.WithError(err).Warning("failed to get session")
|
||||
// Delete the stale session cookie if it exists
|
||||
if rw != nil {
|
||||
s.Options.MaxAge = -1
|
||||
if saveErr := s.Save(r, rw); saveErr != nil {
|
||||
a.log.WithError(saveErr).Warning("failed to delete stale session cookie")
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
if claims.SessionID != s.ID {
|
||||
|
||||
@@ -29,7 +29,10 @@ func (a *Application) getStore(p api.ProxyOutpostConfig, externalHost *url.URL)
|
||||
// Add one to the validity to ensure we don't have a session with indefinite length
|
||||
maxAge = int(*t) + 1
|
||||
}
|
||||
if a.isEmbedded {
|
||||
|
||||
sessionBackend := a.srv.SessionBackend()
|
||||
switch sessionBackend {
|
||||
case "postgres":
|
||||
// New PostgreSQL store
|
||||
ps, err := postgresstore.NewPostgresStore()
|
||||
if err != nil {
|
||||
@@ -46,30 +49,32 @@ func (a *Application) getStore(p api.ProxyOutpostConfig, externalHost *url.URL)
|
||||
Path: "/",
|
||||
})
|
||||
|
||||
a.log.Trace("using postgresql session backend")
|
||||
return ps, nil
|
||||
}
|
||||
dir := os.TempDir()
|
||||
cs, err := filesystemstore.GetPersistentStore(dir)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
cs.Codecs = codecs.CodecsFromPairs(maxAge, []byte(*p.CookieSecret))
|
||||
// https://github.com/markbates/goth/commit/7276be0fdf719ddff753f3574ef0f967e4a5a5f7
|
||||
// set the maxLength of the cookies stored on the disk to a larger number to prevent issues with:
|
||||
// securecookie: the value is too long
|
||||
// when using OpenID Connect, since this can contain a large amount of extra information in the id_token
|
||||
case "filesystem":
|
||||
dir := os.TempDir()
|
||||
cs, err := filesystemstore.GetPersistentStore(dir)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
cs.Codecs = codecs.CodecsFromPairs(maxAge, []byte(*p.CookieSecret))
|
||||
// https://github.com/markbates/goth/commit/7276be0fdf719ddff753f3574ef0f967e4a5a5f7
|
||||
// set the maxLength of the cookies stored on the disk to a larger number to prevent issues with:
|
||||
// securecookie: the value is too long
|
||||
// when using OpenID Connect, since this can contain a large amount of extra information in the id_token
|
||||
|
||||
// Note, when using the FilesystemStore only the session.ID is written to a browser cookie, so this is explicit for the storage on disk
|
||||
cs.MaxLength(math.MaxInt)
|
||||
cs.Options.HttpOnly = true
|
||||
cs.Options.Secure = strings.ToLower(externalHost.Scheme) == "https"
|
||||
cs.Options.Domain = *p.CookieDomain
|
||||
cs.Options.SameSite = http.SameSiteLaxMode
|
||||
cs.Options.MaxAge = maxAge
|
||||
cs.Options.Path = "/"
|
||||
a.log.WithField("dir", dir).Trace("using filesystem session backend")
|
||||
return cs, nil
|
||||
// Note, when using the FilesystemStore only the session.ID is written to a browser cookie, so this is explicit for the storage on disk
|
||||
cs.MaxLength(math.MaxInt)
|
||||
cs.Options.HttpOnly = true
|
||||
cs.Options.Secure = strings.ToLower(externalHost.Scheme) == "https"
|
||||
cs.Options.Domain = *p.CookieDomain
|
||||
cs.Options.SameSite = http.SameSiteLaxMode
|
||||
cs.Options.MaxAge = maxAge
|
||||
cs.Options.Path = "/"
|
||||
return cs, nil
|
||||
default:
|
||||
a.log.WithField("backend", sessionBackend).Panic("unknown session backend type")
|
||||
return nil, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (a *Application) SessionName() string {
|
||||
|
||||
@@ -76,3 +76,81 @@ func TestLogout(t *testing.T) {
|
||||
_, err = os.Stat(s2Name)
|
||||
assert.True(t, errors.Is(err, os.ErrNotExist))
|
||||
}
|
||||
|
||||
func TestStaleCookieDeletion(t *testing.T) {
|
||||
a := newTestApplication()
|
||||
_ = a.configureProxy()
|
||||
|
||||
// Create a request with a session cookie that references a non-existent session file
|
||||
req, _ := http.NewRequest("GET", "https://ext.t.goauthentik.io/foo", nil)
|
||||
|
||||
// Set a cookie for a session that doesn't exist (simulates pod restart)
|
||||
nonExistentSessionID := uuid.New().String()
|
||||
req.AddCookie(&http.Cookie{
|
||||
Name: a.SessionName(),
|
||||
Value: "encoded_session_data_" + nonExistentSessionID,
|
||||
Path: "/",
|
||||
})
|
||||
|
||||
rr := httptest.NewRecorder()
|
||||
|
||||
// Call getClaimsFromSession which should delete the stale cookie
|
||||
claims := a.getClaimsFromSession(rr, req)
|
||||
|
||||
// Verify no claims were returned (session doesn't exist)
|
||||
assert.Nil(t, claims)
|
||||
|
||||
// Verify the response includes a Set-Cookie header to delete the stale cookie
|
||||
cookies := rr.Result().Cookies()
|
||||
var foundDeleteCookie bool
|
||||
for _, cookie := range cookies {
|
||||
if cookie.Name == a.SessionName() && cookie.MaxAge < 0 {
|
||||
foundDeleteCookie = true
|
||||
break
|
||||
}
|
||||
}
|
||||
assert.True(t, foundDeleteCookie, "Expected stale session cookie to be deleted")
|
||||
}
|
||||
|
||||
func TestStateFromRequestDeletesStaleCookie(t *testing.T) {
|
||||
a := newTestApplication()
|
||||
_ = a.configureProxy()
|
||||
|
||||
// Create a valid state JWT (from createState)
|
||||
req, _ := http.NewRequest("GET", "https://ext.t.goauthentik.io/foo", nil)
|
||||
rr := httptest.NewRecorder()
|
||||
|
||||
state, err := a.createState(req, rr, "/redirect")
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Create a new request with the state but a stale session cookie
|
||||
req2, _ := http.NewRequest("GET", "https://ext.t.goauthentik.io/callback?state="+state, nil)
|
||||
|
||||
// Add a cookie for a non-existent session
|
||||
nonExistentSessionID := uuid.New().String()
|
||||
req2.AddCookie(&http.Cookie{
|
||||
Name: a.SessionName(),
|
||||
Value: "encoded_session_data_" + nonExistentSessionID,
|
||||
Path: "/",
|
||||
})
|
||||
|
||||
rr2 := httptest.NewRecorder()
|
||||
|
||||
// Call stateFromRequest which should fail due to missing session
|
||||
// but should also delete the stale cookie
|
||||
claims := a.stateFromRequest(rr2, req2)
|
||||
|
||||
// Verify no claims were returned
|
||||
assert.Nil(t, claims)
|
||||
|
||||
// Verify the response includes a Set-Cookie header to delete the stale cookie
|
||||
cookies := rr2.Result().Cookies()
|
||||
var foundDeleteCookie bool
|
||||
for _, cookie := range cookies {
|
||||
if cookie.Name == a.SessionName() && cookie.MaxAge < 0 {
|
||||
foundDeleteCookie = true
|
||||
break
|
||||
}
|
||||
}
|
||||
assert.True(t, foundDeleteCookie, "Expected stale session cookie to be deleted")
|
||||
}
|
||||
|
||||
@@ -41,6 +41,10 @@ func (ts *testServer) Apps() []*Application {
|
||||
return ts.apps
|
||||
}
|
||||
|
||||
func (ts *testServer) SessionBackend() string {
|
||||
return "filesystem"
|
||||
}
|
||||
|
||||
func newTestApplication() *Application {
|
||||
ts := newTestServer()
|
||||
a, _ := NewApplication(
|
||||
@@ -83,7 +87,7 @@ func (a *Application) assertState(t *testing.T, req *http.Request, response *htt
|
||||
nrq.Set("state", state)
|
||||
nr.URL.RawQuery = nrq.Encode()
|
||||
// parse state
|
||||
parsed := a.stateFromRequest(nr)
|
||||
parsed := a.stateFromRequest(nil, nr)
|
||||
if parsed == nil {
|
||||
panic("Could not parse state")
|
||||
}
|
||||
|
||||
@@ -16,7 +16,7 @@ func urlJoin(originalUrl string, newPath string) string {
|
||||
|
||||
func (a *Application) redirect(rw http.ResponseWriter, r *http.Request) {
|
||||
fallbackRedirect := a.proxyConfig.ExternalHost
|
||||
state := a.stateFromRequest(r)
|
||||
state := a.stateFromRequest(rw, r)
|
||||
if state == nil {
|
||||
rw.WriteHeader(http.StatusBadRequest)
|
||||
return
|
||||
|
||||
@@ -21,7 +21,6 @@ import (
|
||||
type RefreshableConnPool struct {
|
||||
mu sync.RWMutex
|
||||
db *sql.DB
|
||||
dsnBuilder func(config.PostgreSQLConfig) (string, error)
|
||||
log *log.Entry
|
||||
currentDSN string
|
||||
gormConfig *gorm.Config
|
||||
@@ -49,7 +48,6 @@ func NewRefreshableConnPool(initialDSN string, gormConfig *gorm.Config, maxIdleC
|
||||
|
||||
pool := &RefreshableConnPool{
|
||||
db: db,
|
||||
dsnBuilder: BuildDSN,
|
||||
log: log.WithField("logger", "authentik.outpost.proxyv2.postgresstore.connpool"),
|
||||
currentDSN: initialDSN,
|
||||
gormConfig: gormConfig,
|
||||
@@ -86,7 +84,7 @@ func (p *RefreshableConnPool) refreshCredentials(ctx context.Context) error {
|
||||
|
||||
// Get fresh config
|
||||
cfg := config.Get().RefreshPostgreSQLConfig()
|
||||
newDSN, err := p.dsnBuilder(cfg)
|
||||
newDSN, err := BuildDSN(cfg)
|
||||
if err != nil {
|
||||
p.log.WithError(err).Warn("Failed to build DSN with refreshed credentials")
|
||||
return err
|
||||
|
||||
@@ -2,17 +2,21 @@ package postgresstore
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"os"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/gorilla/sessions"
|
||||
_ "github.com/jackc/pgx/v5/stdlib"
|
||||
"github.com/jackc/pgx/v5"
|
||||
"github.com/jackc/pgx/v5/stdlib"
|
||||
"github.com/mitchellh/mapstructure"
|
||||
log "github.com/sirupsen/logrus"
|
||||
_ "gorm.io/driver/postgres"
|
||||
@@ -51,60 +55,121 @@ func (ProxySession) TableName() string {
|
||||
return "authentik_providers_proxy_proxysession"
|
||||
}
|
||||
|
||||
// BuildDSN constructs a PostgreSQL connection string
|
||||
func BuildDSN(cfg config.PostgreSQLConfig) (string, error) {
|
||||
// BuildConnConfig constructs a pgx.ConnConfig from PostgreSQL configuration.
|
||||
func BuildConnConfig(cfg config.PostgreSQLConfig) (*pgx.ConnConfig, error) {
|
||||
// Validate required fields
|
||||
if cfg.Host == "" {
|
||||
return "", fmt.Errorf("PostgreSQL host is required")
|
||||
return nil, fmt.Errorf("PostgreSQL host is required")
|
||||
}
|
||||
if cfg.User == "" {
|
||||
return "", fmt.Errorf("PostgreSQL user is required")
|
||||
return nil, fmt.Errorf("PostgreSQL user is required")
|
||||
}
|
||||
if cfg.Name == "" {
|
||||
return "", fmt.Errorf("PostgreSQL database name is required")
|
||||
return nil, fmt.Errorf("PostgreSQL database name is required")
|
||||
}
|
||||
if cfg.Port <= 0 {
|
||||
return "", fmt.Errorf("PostgreSQL port must be positive")
|
||||
return nil, fmt.Errorf("PostgreSQL port must be positive")
|
||||
}
|
||||
|
||||
// Build DSN string with all parameters
|
||||
dsnParts := []string{
|
||||
"host=" + cfg.Host,
|
||||
fmt.Sprintf("port=%d", cfg.Port),
|
||||
"user=" + cfg.User,
|
||||
"dbname=" + cfg.Name,
|
||||
// Start with a default config
|
||||
connConfig, err := pgx.ParseConfig("")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create default config: %w", err)
|
||||
}
|
||||
|
||||
if cfg.Password != "" {
|
||||
dsnParts = append(dsnParts, "password="+cfg.Password)
|
||||
}
|
||||
// Set connection parameters
|
||||
connConfig.Host = cfg.Host
|
||||
connConfig.Port = uint16(cfg.Port)
|
||||
connConfig.User = cfg.User
|
||||
connConfig.Password = cfg.Password
|
||||
connConfig.Database = cfg.Name
|
||||
|
||||
// Add SSL mode
|
||||
// Configure TLS/SSL
|
||||
if cfg.SSLMode != "" {
|
||||
dsnParts = append(dsnParts, "sslmode="+cfg.SSLMode)
|
||||
switch cfg.SSLMode {
|
||||
case "disable":
|
||||
connConfig.TLSConfig = nil
|
||||
case "require", "verify-ca", "verify-full":
|
||||
tlsConfig := &tls.Config{}
|
||||
|
||||
// Load root CA certificate if provided
|
||||
if cfg.SSLRootCert != "" {
|
||||
caCert, err := os.ReadFile(cfg.SSLRootCert)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read SSL root certificate: %w", err)
|
||||
}
|
||||
caCertPool := x509.NewCertPool()
|
||||
if !caCertPool.AppendCertsFromPEM(caCert) {
|
||||
return nil, fmt.Errorf("failed to parse SSL root certificate")
|
||||
}
|
||||
tlsConfig.RootCAs = caCertPool
|
||||
}
|
||||
|
||||
// Load client certificate and key if provided
|
||||
if cfg.SSLCert != "" && cfg.SSLKey != "" {
|
||||
cert, err := tls.LoadX509KeyPair(cfg.SSLCert, cfg.SSLKey)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to load SSL client certificate: %w", err)
|
||||
}
|
||||
tlsConfig.Certificates = []tls.Certificate{cert}
|
||||
}
|
||||
|
||||
// Set verification mode
|
||||
switch cfg.SSLMode {
|
||||
case "require":
|
||||
// Don't verify the server certificate (just encrypt)
|
||||
tlsConfig.InsecureSkipVerify = true
|
||||
case "verify-ca":
|
||||
// Verify the certificate is signed by a trusted CA
|
||||
tlsConfig.InsecureSkipVerify = false
|
||||
case "verify-full":
|
||||
// Verify the certificate and hostname
|
||||
tlsConfig.InsecureSkipVerify = false
|
||||
tlsConfig.ServerName = cfg.Host
|
||||
}
|
||||
|
||||
connConfig.TLSConfig = tlsConfig
|
||||
}
|
||||
}
|
||||
|
||||
// Add SSL certificates if provided
|
||||
if cfg.SSLRootCert != "" {
|
||||
dsnParts = append(dsnParts, "sslrootcert="+cfg.SSLRootCert)
|
||||
}
|
||||
if cfg.SSLCert != "" {
|
||||
dsnParts = append(dsnParts, "sslcert="+cfg.SSLCert)
|
||||
}
|
||||
if cfg.SSLKey != "" {
|
||||
dsnParts = append(dsnParts, "sslkey="+cfg.SSLKey)
|
||||
// Set runtime params
|
||||
if connConfig.RuntimeParams == nil {
|
||||
connConfig.RuntimeParams = make(map[string]string)
|
||||
}
|
||||
|
||||
if cfg.DefaultSchema != "" {
|
||||
dsnParts = append(dsnParts, "search_path="+cfg.DefaultSchema)
|
||||
connConfig.RuntimeParams["search_path"] = cfg.DefaultSchema
|
||||
}
|
||||
|
||||
// Add connection options if specified
|
||||
// Parse and apply connection options if specified
|
||||
if cfg.ConnOptions != "" {
|
||||
dsnParts = append(dsnParts, cfg.ConnOptions)
|
||||
// Parse key=value pairs from ConnOptions
|
||||
// Format: "key1=value1 key2=value2"
|
||||
pairs := strings.Split(cfg.ConnOptions, " ")
|
||||
for _, pair := range pairs {
|
||||
if pair == "" {
|
||||
continue
|
||||
}
|
||||
kv := strings.SplitN(pair, "=", 2)
|
||||
if len(kv) == 2 {
|
||||
connConfig.RuntimeParams[kv[0]] = kv[1]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Join parts with spaces
|
||||
return strings.Join(dsnParts, " "), nil
|
||||
return connConfig, nil
|
||||
}
|
||||
|
||||
// BuildDSN constructs a PostgreSQL connection string from a ConnConfig.
|
||||
func BuildDSN(cfg config.PostgreSQLConfig) (string, error) {
|
||||
connConfig, err := BuildConnConfig(cfg)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// Register the config and get a connection string
|
||||
// (This approach lets pgx handle all the escaping internally which is quite convenient for say spaces in the password)
|
||||
return stdlib.RegisterConnConfig(connConfig), nil
|
||||
}
|
||||
|
||||
// SetupGORMWithRefreshablePool creates a GORM DB with a refreshable connection pool.
|
||||
|
||||
@@ -2,14 +2,23 @@ package postgresstore
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"crypto/x509"
|
||||
"crypto/x509/pkix"
|
||||
"encoding/json"
|
||||
"encoding/pem"
|
||||
"fmt"
|
||||
"math/big"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/gorilla/sessions"
|
||||
"github.com/jackc/pgx/v5"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"gorm.io/gorm"
|
||||
@@ -541,11 +550,11 @@ func TestBuildDSN_Validation(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildDSN(t *testing.T) {
|
||||
func TestBuildConnConfig(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
cfg config.PostgreSQLConfig
|
||||
expected string
|
||||
validate func(*testing.T, *pgx.ConnConfig)
|
||||
}{
|
||||
{
|
||||
name: "basic configuration",
|
||||
@@ -555,10 +564,16 @@ func TestBuildDSN(t *testing.T) {
|
||||
User: "testuser",
|
||||
Name: "testdb",
|
||||
},
|
||||
expected: "host=localhost port=5432 user=testuser dbname=testdb",
|
||||
validate: func(t *testing.T, cc *pgx.ConnConfig) {
|
||||
assert.Equal(t, "localhost", cc.Host)
|
||||
assert.Equal(t, uint16(5432), cc.Port)
|
||||
assert.Equal(t, "testuser", cc.User)
|
||||
assert.Equal(t, "testdb", cc.Database)
|
||||
assert.Equal(t, "", cc.Password)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "with password",
|
||||
name: "with simple password",
|
||||
cfg: config.PostgreSQLConfig{
|
||||
Host: "localhost",
|
||||
Port: 5432,
|
||||
@@ -566,7 +581,87 @@ func TestBuildDSN(t *testing.T) {
|
||||
Password: "testpass",
|
||||
Name: "testdb",
|
||||
},
|
||||
expected: "host=localhost port=5432 user=testuser dbname=testdb password=testpass",
|
||||
validate: func(t *testing.T, cc *pgx.ConnConfig) {
|
||||
assert.Equal(t, "testpass", cc.Password)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "with password containing spaces",
|
||||
cfg: config.PostgreSQLConfig{
|
||||
Host: "localhost",
|
||||
Port: 5432,
|
||||
User: "testuser",
|
||||
Password: "my secure password",
|
||||
Name: "testdb",
|
||||
},
|
||||
validate: func(t *testing.T, cc *pgx.ConnConfig) {
|
||||
assert.Equal(t, "my secure password", cc.Password)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "with password containing single quotes",
|
||||
cfg: config.PostgreSQLConfig{
|
||||
Host: "localhost",
|
||||
Port: 5432,
|
||||
User: "testuser",
|
||||
Password: "pass'word",
|
||||
Name: "testdb",
|
||||
},
|
||||
validate: func(t *testing.T, cc *pgx.ConnConfig) {
|
||||
assert.Equal(t, "pass'word", cc.Password)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "with password containing backslashes",
|
||||
cfg: config.PostgreSQLConfig{
|
||||
Host: "localhost",
|
||||
Port: 5432,
|
||||
User: "testuser",
|
||||
Password: `pass\word`,
|
||||
Name: "testdb",
|
||||
},
|
||||
validate: func(t *testing.T, cc *pgx.ConnConfig) {
|
||||
assert.Equal(t, `pass\word`, cc.Password)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "with password containing special characters",
|
||||
cfg: config.PostgreSQLConfig{
|
||||
Host: "localhost",
|
||||
Port: 5432,
|
||||
User: "testuser",
|
||||
Password: `p@ss w0rd!#$%^&*()`,
|
||||
Name: "testdb",
|
||||
},
|
||||
validate: func(t *testing.T, cc *pgx.ConnConfig) {
|
||||
assert.Equal(t, `p@ss w0rd!#$%^&*()`, cc.Password)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "with password containing quotes and backslashes",
|
||||
cfg: config.PostgreSQLConfig{
|
||||
Host: "localhost",
|
||||
Port: 5432,
|
||||
User: "testuser",
|
||||
Password: `my'pass\word"here`,
|
||||
Name: "testdb",
|
||||
},
|
||||
validate: func(t *testing.T, cc *pgx.ConnConfig) {
|
||||
assert.Equal(t, `my'pass\word"here`, cc.Password)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "with passphrase (multiple spaces)",
|
||||
cfg: config.PostgreSQLConfig{
|
||||
Host: "localhost",
|
||||
Port: 5432,
|
||||
User: "testuser",
|
||||
Password: "the quick brown fox jumps over",
|
||||
Name: "testdb",
|
||||
},
|
||||
validate: func(t *testing.T, cc *pgx.ConnConfig) {
|
||||
assert.Equal(t, "the quick brown fox jumps over", cc.Password)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "with sslmode=disable",
|
||||
@@ -577,10 +672,12 @@ func TestBuildDSN(t *testing.T) {
|
||||
Name: "testdb",
|
||||
SSLMode: "disable",
|
||||
},
|
||||
expected: "host=localhost port=5432 user=testuser dbname=testdb sslmode=disable",
|
||||
validate: func(t *testing.T, cc *pgx.ConnConfig) {
|
||||
assert.Nil(t, cc.TLSConfig)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "with sslmode=require",
|
||||
name: "with sslmode=require (no certs)",
|
||||
cfg: config.PostgreSQLConfig{
|
||||
Host: "localhost",
|
||||
Port: 5432,
|
||||
@@ -588,32 +685,10 @@ func TestBuildDSN(t *testing.T) {
|
||||
Name: "testdb",
|
||||
SSLMode: "require",
|
||||
},
|
||||
expected: "host=localhost port=5432 user=testuser dbname=testdb sslmode=require",
|
||||
},
|
||||
{
|
||||
name: "with sslmode=prefer",
|
||||
cfg: config.PostgreSQLConfig{
|
||||
Host: "localhost",
|
||||
Port: 5432,
|
||||
User: "testuser",
|
||||
Name: "testdb",
|
||||
SSLMode: "prefer",
|
||||
validate: func(t *testing.T, cc *pgx.ConnConfig) {
|
||||
assert.NotNil(t, cc.TLSConfig)
|
||||
assert.True(t, cc.TLSConfig.InsecureSkipVerify)
|
||||
},
|
||||
expected: "host=localhost port=5432 user=testuser dbname=testdb sslmode=prefer",
|
||||
},
|
||||
{
|
||||
name: "with SSL certificates",
|
||||
cfg: config.PostgreSQLConfig{
|
||||
Host: "localhost",
|
||||
Port: 5432,
|
||||
User: "testuser",
|
||||
Name: "testdb",
|
||||
SSLMode: "verify-full",
|
||||
SSLRootCert: "/path/to/root.crt",
|
||||
SSLCert: "/path/to/client.crt",
|
||||
SSLKey: "/path/to/client.key",
|
||||
},
|
||||
expected: "host=localhost port=5432 user=testuser dbname=testdb sslmode=verify-full sslrootcert=/path/to/root.crt sslcert=/path/to/client.crt sslkey=/path/to/client.key",
|
||||
},
|
||||
{
|
||||
name: "with custom schema",
|
||||
@@ -624,7 +699,9 @@ func TestBuildDSN(t *testing.T) {
|
||||
Name: "testdb",
|
||||
DefaultSchema: "custom_schema",
|
||||
},
|
||||
expected: "host=localhost port=5432 user=testuser dbname=testdb search_path=custom_schema",
|
||||
validate: func(t *testing.T, cc *pgx.ConnConfig) {
|
||||
assert.Equal(t, "custom_schema", cc.RuntimeParams["search_path"])
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "with connection options",
|
||||
@@ -633,34 +710,192 @@ func TestBuildDSN(t *testing.T) {
|
||||
Port: 5432,
|
||||
User: "testuser",
|
||||
Name: "testdb",
|
||||
ConnOptions: "connect_timeout=10",
|
||||
ConnOptions: "connect_timeout=10 application_name=authentik",
|
||||
},
|
||||
validate: func(t *testing.T, cc *pgx.ConnConfig) {
|
||||
assert.Equal(t, "10", cc.RuntimeParams["connect_timeout"])
|
||||
assert.Equal(t, "authentik", cc.RuntimeParams["application_name"])
|
||||
},
|
||||
expected: "host=localhost port=5432 user=testuser dbname=testdb connect_timeout=10",
|
||||
},
|
||||
{
|
||||
name: "full configuration",
|
||||
name: "full configuration with special password",
|
||||
cfg: config.PostgreSQLConfig{
|
||||
Host: "db.example.com",
|
||||
Port: 5433,
|
||||
User: "admin",
|
||||
Password: "secret",
|
||||
Password: "my super secret password!@#",
|
||||
Name: "production",
|
||||
SSLMode: "verify-full",
|
||||
SSLRootCert: "/certs/root.crt",
|
||||
SSLCert: "/certs/client.crt",
|
||||
SSLKey: "/certs/client.key",
|
||||
SSLMode: "require",
|
||||
DefaultSchema: "app_schema",
|
||||
ConnOptions: "application_name=authentik",
|
||||
},
|
||||
expected: "host=db.example.com port=5433 user=admin dbname=production password=secret sslmode=verify-full sslrootcert=/certs/root.crt sslcert=/certs/client.crt sslkey=/certs/client.key search_path=app_schema application_name=authentik",
|
||||
validate: func(t *testing.T, cc *pgx.ConnConfig) {
|
||||
assert.Equal(t, "db.example.com", cc.Host)
|
||||
assert.Equal(t, uint16(5433), cc.Port)
|
||||
assert.Equal(t, "admin", cc.User)
|
||||
assert.Equal(t, "my super secret password!@#", cc.Password)
|
||||
assert.Equal(t, "production", cc.Database)
|
||||
assert.Equal(t, "app_schema", cc.RuntimeParams["search_path"])
|
||||
assert.Equal(t, "authentik", cc.RuntimeParams["application_name"])
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result, err := BuildDSN(tt.cfg)
|
||||
result, err := BuildConnConfig(tt.cfg)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
require.NotNil(t, result)
|
||||
tt.validate(t, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestBuildConnConfig_WithSSLCertificates tests SSL certificate configuration
|
||||
func TestBuildConnConfig_WithSSLCertificates(t *testing.T) {
|
||||
rootCertPath, clientCertPath, clientKeyPath, cleanup := generateTestCerts(t)
|
||||
defer cleanup()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
cfg config.PostgreSQLConfig
|
||||
validate func(*testing.T, *pgx.ConnConfig)
|
||||
}{
|
||||
{
|
||||
name: "verify-full with all certificates",
|
||||
cfg: config.PostgreSQLConfig{
|
||||
Host: "db.example.com",
|
||||
Port: 5432,
|
||||
User: "testuser",
|
||||
Password: "my secure password",
|
||||
Name: "testdb",
|
||||
SSLMode: "verify-full",
|
||||
SSLRootCert: rootCertPath,
|
||||
SSLCert: clientCertPath,
|
||||
SSLKey: clientKeyPath,
|
||||
},
|
||||
validate: func(t *testing.T, cc *pgx.ConnConfig) {
|
||||
require.NotNil(t, cc.TLSConfig)
|
||||
assert.False(t, cc.TLSConfig.InsecureSkipVerify)
|
||||
assert.Equal(t, "db.example.com", cc.TLSConfig.ServerName)
|
||||
assert.NotNil(t, cc.TLSConfig.RootCAs)
|
||||
assert.Len(t, cc.TLSConfig.Certificates, 1)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "verify-ca with root cert only",
|
||||
cfg: config.PostgreSQLConfig{
|
||||
Host: "localhost",
|
||||
Port: 5432,
|
||||
User: "testuser",
|
||||
Name: "testdb",
|
||||
SSLMode: "verify-ca",
|
||||
SSLRootCert: rootCertPath,
|
||||
},
|
||||
validate: func(t *testing.T, cc *pgx.ConnConfig) {
|
||||
require.NotNil(t, cc.TLSConfig)
|
||||
assert.False(t, cc.TLSConfig.InsecureSkipVerify)
|
||||
assert.NotNil(t, cc.TLSConfig.RootCAs)
|
||||
assert.Empty(t, cc.TLSConfig.Certificates)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "require with client cert",
|
||||
cfg: config.PostgreSQLConfig{
|
||||
Host: "localhost",
|
||||
Port: 5432,
|
||||
User: "testuser",
|
||||
Name: "testdb",
|
||||
SSLMode: "require",
|
||||
SSLCert: clientCertPath,
|
||||
SSLKey: clientKeyPath,
|
||||
},
|
||||
validate: func(t *testing.T, cc *pgx.ConnConfig) {
|
||||
require.NotNil(t, cc.TLSConfig)
|
||||
assert.True(t, cc.TLSConfig.InsecureSkipVerify)
|
||||
assert.Len(t, cc.TLSConfig.Certificates, 1)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "full configuration with SSL and special password",
|
||||
cfg: config.PostgreSQLConfig{
|
||||
Host: "db.example.com",
|
||||
Port: 5433,
|
||||
User: "admin",
|
||||
Password: "my super secret password!@#",
|
||||
Name: "production",
|
||||
SSLMode: "verify-full",
|
||||
SSLRootCert: rootCertPath,
|
||||
SSLCert: clientCertPath,
|
||||
SSLKey: clientKeyPath,
|
||||
DefaultSchema: "app_schema",
|
||||
ConnOptions: "application_name=authentik",
|
||||
},
|
||||
validate: func(t *testing.T, cc *pgx.ConnConfig) {
|
||||
assert.Equal(t, "db.example.com", cc.Host)
|
||||
assert.Equal(t, uint16(5433), cc.Port)
|
||||
assert.Equal(t, "admin", cc.User)
|
||||
assert.Equal(t, "my super secret password!@#", cc.Password)
|
||||
assert.Equal(t, "production", cc.Database)
|
||||
require.NotNil(t, cc.TLSConfig)
|
||||
assert.False(t, cc.TLSConfig.InsecureSkipVerify)
|
||||
assert.Equal(t, "db.example.com", cc.TLSConfig.ServerName)
|
||||
assert.NotNil(t, cc.TLSConfig.RootCAs)
|
||||
assert.Len(t, cc.TLSConfig.Certificates, 1)
|
||||
assert.Equal(t, "app_schema", cc.RuntimeParams["search_path"])
|
||||
assert.Equal(t, "authentik", cc.RuntimeParams["application_name"])
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result, err := BuildConnConfig(tt.cfg)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
tt.validate(t, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestBuildDSN_WithSpecialPasswords tests that BuildDSN can handle passwords with special characters
|
||||
// by verifying the DSN can actually be used to connect to a database
|
||||
func TestBuildDSN_WithSpecialPasswords(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
password string
|
||||
}{
|
||||
{"space in password", "my password"},
|
||||
{"multiple spaces", "the quick brown fox"},
|
||||
{"single quote", "pass'word"},
|
||||
{"backslash", `pass\word`},
|
||||
{"double quote", `pass"word`},
|
||||
{"special chars", `p@ss!#$%^&*()`},
|
||||
{"mixed special", `my'pass\word"here`},
|
||||
{"unicode", "pässwörd"},
|
||||
{"leading/trailing spaces", " password "},
|
||||
{"tab character", "pass\tword"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
cfg := config.PostgreSQLConfig{
|
||||
Host: "localhost",
|
||||
Port: 5432,
|
||||
User: "testuser",
|
||||
Password: tt.password,
|
||||
Name: "testdb",
|
||||
}
|
||||
|
||||
// Test that BuildDSN doesn't error
|
||||
dsn, err := BuildDSN(cfg)
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, dsn)
|
||||
|
||||
// Test that BuildConnConfig preserves the password exactly
|
||||
connConfig, err := BuildConnConfig(cfg)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, tt.password, connConfig.Password, "Password should be preserved exactly")
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -715,3 +950,89 @@ func createSessionData(t *testing.T, claims map[string]interface{}) string {
|
||||
require.NoError(t, err)
|
||||
return string(sessionDataJSON)
|
||||
}
|
||||
|
||||
// generateTestCerts creates temporary SSL certificates for testing
|
||||
func generateTestCerts(t *testing.T) (rootCertPath, clientCertPath, clientKeyPath string, cleanup func()) {
|
||||
tmpDir := t.TempDir()
|
||||
|
||||
// Generate CA certificate
|
||||
caKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
require.NoError(t, err)
|
||||
|
||||
caTemplate := &x509.Certificate{
|
||||
SerialNumber: big.NewInt(1),
|
||||
Subject: pkix.Name{
|
||||
Organization: []string{"Test CA"},
|
||||
},
|
||||
NotBefore: time.Now(),
|
||||
NotAfter: time.Now().Add(24 * time.Hour),
|
||||
KeyUsage: x509.KeyUsageCertSign | x509.KeyUsageCRLSign,
|
||||
BasicConstraintsValid: true,
|
||||
IsCA: true,
|
||||
}
|
||||
|
||||
caCertDER, err := x509.CreateCertificate(rand.Reader, caTemplate, caTemplate, &caKey.PublicKey, caKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Write CA certificate
|
||||
rootCertPath = filepath.Join(tmpDir, "root.crt")
|
||||
rootCertFile, err := os.Create(rootCertPath)
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
if closeErr := rootCertFile.Close(); closeErr != nil {
|
||||
t.Logf("failed to close root cert file: %v", closeErr)
|
||||
}
|
||||
}()
|
||||
err = pem.Encode(rootCertFile, &pem.Block{Type: "CERTIFICATE", Bytes: caCertDER})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Generate client key
|
||||
clientKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Generate client certificate
|
||||
clientTemplate := &x509.Certificate{
|
||||
SerialNumber: big.NewInt(2),
|
||||
Subject: pkix.Name{
|
||||
Organization: []string{"Test Client"},
|
||||
},
|
||||
NotBefore: time.Now(),
|
||||
NotAfter: time.Now().Add(24 * time.Hour),
|
||||
KeyUsage: x509.KeyUsageDigitalSignature,
|
||||
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth},
|
||||
}
|
||||
|
||||
clientCertDER, err := x509.CreateCertificate(rand.Reader, clientTemplate, caTemplate, &clientKey.PublicKey, caKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Write client certificate
|
||||
clientCertPath = filepath.Join(tmpDir, "client.crt")
|
||||
clientCertFile, err := os.Create(clientCertPath)
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
if closeErr := clientCertFile.Close(); closeErr != nil {
|
||||
t.Logf("failed to close client cert file: %v", closeErr)
|
||||
}
|
||||
}()
|
||||
err = pem.Encode(clientCertFile, &pem.Block{Type: "CERTIFICATE", Bytes: clientCertDER})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Write client key
|
||||
clientKeyPath = filepath.Join(tmpDir, "client.key")
|
||||
clientKeyFile, err := os.Create(clientKeyPath)
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
if closeErr := clientKeyFile.Close(); closeErr != nil {
|
||||
t.Logf("failed to close client key file: %v", closeErr)
|
||||
}
|
||||
}()
|
||||
clientKeyBytes := x509.MarshalPKCS1PrivateKey(clientKey)
|
||||
err = pem.Encode(clientKeyFile, &pem.Block{Type: "RSA PRIVATE KEY", Bytes: clientKeyBytes})
|
||||
require.NoError(t, err)
|
||||
|
||||
cleanup = func() {
|
||||
// TempDir cleanup is automatic in Go tests
|
||||
}
|
||||
|
||||
return rootCertPath, clientCertPath, clientKeyPath, cleanup
|
||||
}
|
||||
|
||||
@@ -55,6 +55,11 @@ func NewProxyServer(ac *ak.APIController) ak.Outpost {
|
||||
if ac.GlobalConfig.ErrorReporting.Enabled {
|
||||
globalMux.Use(sentryhttp.New(sentryhttp.Options{}).Handle)
|
||||
}
|
||||
if ac.IsEmbedded() {
|
||||
l.Info("using PostgreSQL session backend")
|
||||
} else {
|
||||
l.Info("using filesystem session backend")
|
||||
}
|
||||
s := &ProxyServer{
|
||||
cryptoStore: ak.NewCryptoStore(ac.Client.CryptoApi),
|
||||
apps: make(map[string]*application.Application),
|
||||
|
||||
@@ -15,7 +15,9 @@ import (
|
||||
)
|
||||
|
||||
func (ps *ProxyServer) Refresh() error {
|
||||
providers, err := ak.Paginator(ps.akAPI.Client.OutpostsApi.OutpostsProxyList(context.Background()), ak.PaginatorOptions{
|
||||
req := ps.akAPI.Client.OutpostsApi.OutpostsProxyList(context.Background())
|
||||
ps.log.WithField("outpost_pk", ps.akAPI.Outpost.Pk).Debug("Requesting providers for outpost")
|
||||
providers, err := ak.Paginator(req, ak.PaginatorOptions{
|
||||
PageSize: 100,
|
||||
Logger: ps.log,
|
||||
})
|
||||
@@ -25,6 +27,13 @@ func (ps *ProxyServer) Refresh() error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
ps.log.WithField("count", len(providers)).Debug("Fetched providers")
|
||||
if len(providers) == 0 {
|
||||
ps.log.Warning("No providers assigned to this outpost, check outpost configuration in authentik")
|
||||
}
|
||||
for i, p := range providers {
|
||||
ps.log.WithField("index", i).WithField("name", p.Name).WithField("external_host", p.ExternalHost).WithField("assigned_to_app", p.AssignedApplicationName).Debug("Provider details")
|
||||
}
|
||||
apps := make(map[string]*application.Application)
|
||||
for _, provider := range providers {
|
||||
rsp := sentry.StartSpan(context.Background(), "authentik.outposts.proxy.application_ss")
|
||||
@@ -52,6 +61,7 @@ func (ps *ProxyServer) Refresh() error {
|
||||
ps.log.WithError(err).Warning("failed to setup application")
|
||||
continue
|
||||
}
|
||||
ps.log.WithField("name", provider.Name).WithField("host", externalHost.Host).Info("Loaded application")
|
||||
apps[externalHost.Host] = a
|
||||
}
|
||||
ps.apps = apps
|
||||
@@ -70,3 +80,14 @@ func (ps *ProxyServer) CryptoStore() *ak.CryptoStore {
|
||||
func (ps *ProxyServer) Apps() []*application.Application {
|
||||
return maps.Values(ps.apps)
|
||||
}
|
||||
|
||||
func (ps *ProxyServer) SessionBackend() string {
|
||||
if ps.akAPI.IsEmbedded() {
|
||||
return "postgres"
|
||||
}
|
||||
if !ps.akAPI.IsEmbedded() {
|
||||
return "filesystem"
|
||||
}
|
||||
ps.log.Panic("failed to determine session backend type")
|
||||
return ""
|
||||
}
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
package types
|
||||
|
||||
type ProxyClaims struct {
|
||||
UserAttributes map[string]interface{} `json:"user_attributes"`
|
||||
BackendOverride string `json:"backend_override"`
|
||||
HostHeader string `json:"host_header"`
|
||||
IsSuperuser bool `json:"is_superuser"`
|
||||
UserAttributes map[string]any `json:"user_attributes" mapstructure:"user_attributes"`
|
||||
BackendOverride string `json:"backend_override" mapstructure:"backend_override"`
|
||||
HostHeader string `json:"host_header" mapstructure:"host_header"`
|
||||
IsSuperuser bool `json:"is_superuser" mapstructure:"is_superuser"`
|
||||
}
|
||||
|
||||
type Claims struct {
|
||||
@@ -19,5 +19,5 @@ type Claims struct {
|
||||
Sid string `json:"sid" mapstructure:"sid"`
|
||||
Proxy *ProxyClaims `json:"ak_proxy" mapstructure:"ak_proxy"`
|
||||
|
||||
RawToken string `mapstructure:"-"`
|
||||
RawToken string `json:"raw_token" mapstructure:"raw_token"`
|
||||
}
|
||||
|
||||
@@ -41,95 +41,92 @@ func (pi *ProviderInstance) SetEAPState(key string, state *protocol.State) {
|
||||
}
|
||||
|
||||
func (pi *ProviderInstance) GetEAPSettings() protocol.Settings {
|
||||
protocols := []protocol.ProtocolConstructor{
|
||||
identity.Protocol,
|
||||
legacy_nak.Protocol,
|
||||
settings := protocol.Settings{
|
||||
Logger: &logrusAdapter{pi.log},
|
||||
Protocols: []protocol.ProtocolConstructor{
|
||||
identity.Protocol,
|
||||
legacy_nak.Protocol,
|
||||
},
|
||||
}
|
||||
|
||||
certId := pi.certId
|
||||
if certId == "" {
|
||||
return protocol.Settings{
|
||||
Protocols: protocols,
|
||||
}
|
||||
return settings
|
||||
}
|
||||
|
||||
cert := pi.s.cryptoStore.Get(certId)
|
||||
if cert == nil {
|
||||
return protocol.Settings{
|
||||
Protocols: protocols,
|
||||
}
|
||||
return settings
|
||||
}
|
||||
|
||||
return protocol.Settings{
|
||||
Logger: &logrusAdapter{entry: pi.log},
|
||||
Protocols: append(protocols, tls.Protocol, peap.Protocol),
|
||||
ProtocolPriority: []protocol.Type{
|
||||
identity.TypeIdentity,
|
||||
tls.TypeTLS,
|
||||
},
|
||||
ProtocolSettings: map[protocol.Type]interface{}{
|
||||
tls.TypeTLS: tls.Settings{
|
||||
Config: &ttls.Config{
|
||||
Certificates: []ttls.Certificate{*cert},
|
||||
ClientAuth: ttls.RequireAnyClientCert,
|
||||
},
|
||||
HandshakeSuccessful: func(ctx protocol.Context, certs []*x509.Certificate) protocol.Status {
|
||||
ident := ctx.GetProtocolState(identity.TypeIdentity).(*identity.State).Identity
|
||||
settings.Protocols = append(settings.Protocols, tls.Protocol, peap.Protocol)
|
||||
settings.ProtocolPriority = []protocol.Type{
|
||||
identity.TypeIdentity,
|
||||
tls.TypeTLS,
|
||||
}
|
||||
settings.ProtocolSettings = map[protocol.Type]any{
|
||||
tls.TypeTLS: tls.Settings{
|
||||
Config: &ttls.Config{
|
||||
Certificates: []ttls.Certificate{*cert},
|
||||
ClientAuth: ttls.RequireAnyClientCert,
|
||||
},
|
||||
HandshakeSuccessful: func(ctx protocol.Context, certs []*x509.Certificate) protocol.Status {
|
||||
ident := ctx.GetProtocolState(identity.TypeIdentity).(*identity.State).Identity
|
||||
|
||||
ctx.Log().Debug("Starting authn flow")
|
||||
pem := pem.EncodeToMemory(&pem.Block{
|
||||
Type: "CERTIFICATE",
|
||||
Bytes: certs[0].Raw,
|
||||
ctx.Log().Debug("Starting authn flow")
|
||||
pem := pem.EncodeToMemory(&pem.Block{
|
||||
Type: "CERTIFICATE",
|
||||
Bytes: certs[0].Raw,
|
||||
})
|
||||
|
||||
fe := flow.NewFlowExecutor(context.Background(), pi.flowSlug, pi.s.ac.Client.GetConfig(), log.Fields{
|
||||
"client": utils.GetIP(ctx.Packet().RemoteAddr),
|
||||
"identity": ident,
|
||||
})
|
||||
fe.Answers[flow.StageIdentification] = ident
|
||||
fe.DelegateClientIP(utils.GetIP(ctx.Packet().RemoteAddr))
|
||||
fe.Params.Add("goauthentik.io/outpost/radius", "true")
|
||||
fe.AddHeader("X-Authentik-Outpost-Certificate", url.QueryEscape(string(pem)))
|
||||
|
||||
passed, err := fe.Execute()
|
||||
if err != nil {
|
||||
ctx.Log().Warn("failed to execute flow", "error", err)
|
||||
return protocol.StatusError
|
||||
}
|
||||
ctx.Log().Debug("Finished flow")
|
||||
if !passed {
|
||||
return protocol.StatusError
|
||||
}
|
||||
access, _, err := fe.ApiClient().OutpostsApi.OutpostsRadiusAccessCheck(context.Background(), pi.providerId).AppSlug(pi.appSlug).Execute()
|
||||
if err != nil {
|
||||
ctx.Log().Warn("failed to check access: %v", err)
|
||||
return protocol.StatusError
|
||||
}
|
||||
if !access.Access.Passing {
|
||||
ctx.Log().Info("Access denied for user")
|
||||
return protocol.StatusError
|
||||
}
|
||||
if access.HasAttributes() {
|
||||
ctx.AddResponseModifier(func(r, q *radius.Packet) error {
|
||||
rawData, err := base64.StdEncoding.DecodeString(access.GetAttributes())
|
||||
if err != nil {
|
||||
ctx.Log().Warn("failed to decode attributes from core: %v", err)
|
||||
return errors.New("attribute_decode_failed")
|
||||
}
|
||||
p, err := radius.Parse(rawData, pi.SharedSecret)
|
||||
if err != nil {
|
||||
ctx.Log().Warn("failed to parse attributes from core: %v", err)
|
||||
return errors.New("attribute_parse_failed")
|
||||
}
|
||||
for _, attr := range p.Attributes {
|
||||
r.Add(attr.Type, attr.Attribute)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
fe := flow.NewFlowExecutor(context.Background(), pi.flowSlug, pi.s.ac.Client.GetConfig(), log.Fields{
|
||||
"client": utils.GetIP(ctx.Packet().RemoteAddr),
|
||||
"identity": ident,
|
||||
})
|
||||
fe.Answers[flow.StageIdentification] = ident
|
||||
fe.DelegateClientIP(utils.GetIP(ctx.Packet().RemoteAddr))
|
||||
fe.Params.Add("goauthentik.io/outpost/radius", "true")
|
||||
fe.AddHeader("X-Authentik-Outpost-Certificate", url.QueryEscape(string(pem)))
|
||||
|
||||
passed, err := fe.Execute()
|
||||
if err != nil {
|
||||
ctx.Log().Warn("failed to execute flow", "error", err)
|
||||
return protocol.StatusError
|
||||
}
|
||||
ctx.Log().Debug("Finished flow")
|
||||
if !passed {
|
||||
return protocol.StatusError
|
||||
}
|
||||
access, _, err := fe.ApiClient().OutpostsApi.OutpostsRadiusAccessCheck(context.Background(), pi.providerId).AppSlug(pi.appSlug).Execute()
|
||||
if err != nil {
|
||||
ctx.Log().Warn("failed to check access: %v", err)
|
||||
return protocol.StatusError
|
||||
}
|
||||
if !access.Access.Passing {
|
||||
ctx.Log().Info("Access denied for user")
|
||||
return protocol.StatusError
|
||||
}
|
||||
if access.HasAttributes() {
|
||||
ctx.AddResponseModifier(func(r, q *radius.Packet) error {
|
||||
rawData, err := base64.StdEncoding.DecodeString(access.GetAttributes())
|
||||
if err != nil {
|
||||
ctx.Log().Warn("failed to decode attributes from core: %v", err)
|
||||
return errors.New("attribute_decode_failed")
|
||||
}
|
||||
p, err := radius.Parse(rawData, pi.SharedSecret)
|
||||
if err != nil {
|
||||
ctx.Log().Warn("failed to parse attributes from core: %v", err)
|
||||
return errors.New("attribute_parse_failed")
|
||||
}
|
||||
for _, attr := range p.Attributes {
|
||||
r.Add(attr.Type, attr.Attribute)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
return protocol.StatusSuccess
|
||||
},
|
||||
}
|
||||
return protocol.StatusSuccess
|
||||
},
|
||||
},
|
||||
}
|
||||
return settings
|
||||
}
|
||||
|
||||
@@ -19,9 +19,7 @@ import (
|
||||
staticWeb "goauthentik.io/web"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrAuthentikStarting = errors.New("authentik starting")
|
||||
)
|
||||
var ErrAuthentikStarting = errors.New("authentik starting")
|
||||
|
||||
const (
|
||||
maxBodyBytes = 32 * 1024 * 1024
|
||||
@@ -99,11 +97,11 @@ func (ws *WebServer) proxyErrorHandler(rw http.ResponseWriter, req *http.Request
|
||||
|
||||
if strings.Contains(accept, "application/json") {
|
||||
header.Set("Content-Type", "application/json")
|
||||
rw.WriteHeader(http.StatusServiceUnavailable)
|
||||
|
||||
err = json.NewEncoder(rw).Encode(map[string]string{
|
||||
"error": "authentik starting",
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
ws.log.WithError(err).Warning("failed to write error message")
|
||||
return
|
||||
@@ -113,21 +111,18 @@ func (ws *WebServer) proxyErrorHandler(rw http.ResponseWriter, req *http.Request
|
||||
rw.WriteHeader(http.StatusServiceUnavailable)
|
||||
|
||||
loadingSplashFile, err := staticWeb.StaticDir.Open("standalone/loading/startup.html")
|
||||
|
||||
if err != nil {
|
||||
ws.log.WithError(err).Warning("failed to open startup splash screen")
|
||||
return
|
||||
}
|
||||
|
||||
loadingSplashHTML, err := io.ReadAll(loadingSplashFile)
|
||||
|
||||
if err != nil {
|
||||
ws.log.WithError(err).Warning("failed to read startup splash screen")
|
||||
return
|
||||
}
|
||||
|
||||
_, err = rw.Write(loadingSplashHTML)
|
||||
|
||||
if err != nil {
|
||||
ws.log.WithError(err).Warning("failed to write startup splash screen")
|
||||
return
|
||||
@@ -138,7 +133,6 @@ func (ws *WebServer) proxyErrorHandler(rw http.ResponseWriter, req *http.Request
|
||||
|
||||
// Fallback to just a status message
|
||||
_, err = rw.Write([]byte("authentik starting"))
|
||||
|
||||
if err != nil {
|
||||
ws.log.WithError(err).Warning("failed to write initializing HTML")
|
||||
}
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
# syntax=docker/dockerfile:1
|
||||
|
||||
# Stage 1: Build
|
||||
FROM --platform=${BUILDPLATFORM} docker.io/library/golang:1.25.3-bookworm AS builder
|
||||
FROM --platform=${BUILDPLATFORM} docker.io/library/golang:1.25.3-trixie@sha256:7534a6264850325fcce93e47b87a0e3fddd96b308440245e6ab1325fa8a44c91 AS builder
|
||||
|
||||
ARG TARGETOS
|
||||
ARG TARGETARCH
|
||||
@@ -31,13 +31,14 @@ RUN --mount=type=cache,sharing=locked,target=/go/pkg/mod \
|
||||
go build -o /go/ldap ./cmd/ldap
|
||||
|
||||
# Stage 2: Run
|
||||
FROM ghcr.io/goauthentik/fips-debian:bookworm-slim-fips
|
||||
FROM ghcr.io/goauthentik/fips-debian:trixie-slim-fips@sha256:9b4cedf932e97194f1825124830f2eec14254d90162dad28f97e505971543115
|
||||
|
||||
ARG VERSION
|
||||
ARG GIT_BUILD_HASH
|
||||
ENV GIT_BUILD_HASH=$GIT_BUILD_HASH
|
||||
|
||||
LABEL org.opencontainers.image.authors="Authentik Security Inc." \
|
||||
org.opencontainers.image.source="https://github.com/goauthentik/authentik" \
|
||||
org.opencontainers.image.description="goauthentik.io LDAP outpost, see https://goauthentik.io for more info." \
|
||||
org.opencontainers.image.documentation="https://docs.goauthentik.io" \
|
||||
org.opencontainers.image.licenses="https://github.com/goauthentik/authentik/blob/main/LICENSE" \
|
||||
|
||||
@@ -18,7 +18,7 @@ Parameters:
|
||||
Description: authentik Docker image
|
||||
AuthentikVersion:
|
||||
Type: String
|
||||
Default: 2025.10.0-rc1
|
||||
Default: 2025.10.3
|
||||
Description: authentik Docker image tag
|
||||
AuthentikServerCPU:
|
||||
Type: Number
|
||||
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
4
package-lock.json
generated
4
package-lock.json
generated
@@ -1,12 +1,12 @@
|
||||
{
|
||||
"name": "@goauthentik/authentik",
|
||||
"version": "2025.10.0-rc1",
|
||||
"version": "2025.10.3",
|
||||
"lockfileVersion": 3,
|
||||
"requires": true,
|
||||
"packages": {
|
||||
"": {
|
||||
"name": "@goauthentik/authentik",
|
||||
"version": "2025.10.0-rc1",
|
||||
"version": "2025.10.3",
|
||||
"dependencies": {
|
||||
"@eslint/js": "^9.31.0",
|
||||
"@typescript-eslint/eslint-plugin": "^8.38.0",
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user