mirror of
https://github.com/goauthentik/authentik
synced 2026-05-06 07:02:51 +02:00
Compare commits
196 Commits
metadata-f
...
version/20
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f6f2f6ceab | ||
|
|
6585bdad4d | ||
|
|
adfab8e322 | ||
|
|
644e8e6915 | ||
|
|
f4848883fe | ||
|
|
29e23ce08c | ||
|
|
0c95d5bbe3 | ||
|
|
15c4de7c5b | ||
|
|
a4187baa10 | ||
|
|
ff42054d9d | ||
|
|
7b0a6b4282 | ||
|
|
8a55050da5 | ||
|
|
87d7ebcfdf | ||
|
|
6ad4bbefcf | ||
|
|
1538e42f3d | ||
|
|
73ac3f6336 | ||
|
|
eb127fd39d | ||
|
|
12978bd87d | ||
|
|
4d8ba745b0 | ||
|
|
90f2a01451 | ||
|
|
177ebe06b2 | ||
|
|
f6a5ddd367 | ||
|
|
dbec7ead5d | ||
|
|
a1e2a50037 | ||
|
|
f06d36e48f | ||
|
|
dd2ad94971 | ||
|
|
bfcdc9ea2f | ||
|
|
b4beb1de9c | ||
|
|
22d09744e0 | ||
|
|
e7d09e820f | ||
|
|
f47749ab60 | ||
|
|
b4f7455f21 | ||
|
|
3beef73f82 | ||
|
|
7ee1fbf267 | ||
|
|
ac0501fb06 | ||
|
|
2b1bfbbb54 | ||
|
|
e9719cf7d5 | ||
|
|
e924a37985 | ||
|
|
3f9ca19d35 | ||
|
|
50e55eea08 | ||
|
|
383d3b89f2 | ||
|
|
8cc768f973 | ||
|
|
03d21be201 | ||
|
|
7d8465bdb5 | ||
|
|
a9b46a4943 | ||
|
|
d3c052559d | ||
|
|
b73b6dcdd3 | ||
|
|
e37bdc6a1d | ||
|
|
b3f1c4736d | ||
|
|
5fd5f3d6ff | ||
|
|
5dbcf6c484 | ||
|
|
056e2c8571 | ||
|
|
0f58a567ce | ||
|
|
502e037d04 | ||
|
|
9b6fae0749 | ||
|
|
dc2332a316 | ||
|
|
c39414f558 | ||
|
|
aac1acfebd | ||
|
|
4d881bb3d2 | ||
|
|
852d392158 | ||
|
|
76b26ea288 | ||
|
|
a1f1378814 | ||
|
|
afc2be6b68 | ||
|
|
c45985e9d0 | ||
|
|
7221ed1ce6 | ||
|
|
123fd3dfb8 | ||
|
|
59c292ca21 | ||
|
|
2b247b60cf | ||
|
|
359a3b9768 | ||
|
|
2c84d73353 | ||
|
|
56ba055857 | ||
|
|
4b9775d9fe | ||
|
|
d06091e226 | ||
|
|
f715e7a537 | ||
|
|
1068dfcc28 | ||
|
|
9a6f66b23c | ||
|
|
853a367325 | ||
|
|
09cdcd1892 | ||
|
|
bed6407b52 | ||
|
|
3936a4e09a | ||
|
|
ad818a2880 | ||
|
|
f8f049f080 | ||
|
|
434e8203de | ||
|
|
7715ce1a90 | ||
|
|
c735dd67a2 | ||
|
|
1b5962be60 | ||
|
|
796d130ea4 | ||
|
|
6c8b502a5b | ||
|
|
674d681f98 | ||
|
|
8c6d3e131d | ||
|
|
b689debfed | ||
|
|
03e4297824 | ||
|
|
c4e0a02837 | ||
|
|
4586ed0735 | ||
|
|
59ef6bb6ea | ||
|
|
6ce812b01f | ||
|
|
87d08dc164 | ||
|
|
c41883b8ea | ||
|
|
6e9d510c9e | ||
|
|
d09ed8e8f0 | ||
|
|
8fe8b1e803 | ||
|
|
66438f3780 | ||
|
|
46f446fd0e | ||
|
|
f83d3a19d0 | ||
|
|
ef59ff1856 | ||
|
|
4966225282 | ||
|
|
2b8765d0aa | ||
|
|
d60d06f958 | ||
|
|
1a3f268476 | ||
|
|
515a855c40 | ||
|
|
16d65b8d12 | ||
|
|
bfe928df18 | ||
|
|
c447bbe6c8 | ||
|
|
1c0a3f95df | ||
|
|
8a6116ab79 | ||
|
|
430010fbea | ||
|
|
079b575a45 | ||
|
|
b2ca887d59 | ||
|
|
d7b30ad0d7 | ||
|
|
b084ace1dd | ||
|
|
b3e45cdf1a | ||
|
|
8132e1f7d9 | ||
|
|
149dccf244 | ||
|
|
b5e4797761 | ||
|
|
be670d6253 | ||
|
|
71060ea4e7 | ||
|
|
f60f38280c | ||
|
|
418deeb332 | ||
|
|
619c77c27e | ||
|
|
ddfddb49da | ||
|
|
dbbb1870b7 | ||
|
|
5b43301206 | ||
|
|
d915d1a94a | ||
|
|
786497790a | ||
|
|
56c899cf21 | ||
|
|
943f22e5a9 | ||
|
|
11b45689f4 | ||
|
|
87f443532f | ||
|
|
0c672a0c37 | ||
|
|
dfd11ceb57 | ||
|
|
d865b7fd87 | ||
|
|
aa8a6b9c43 | ||
|
|
fe5313f42e | ||
|
|
499f739e2b | ||
|
|
4e0e738823 | ||
|
|
24360bf306 | ||
|
|
6fad3c2bbd | ||
|
|
2cf20de7ec | ||
|
|
3d8d3bb8ce | ||
|
|
80bcbe4885 | ||
|
|
32e4782ed8 | ||
|
|
613a51bdbb | ||
|
|
1c6de43701 | ||
|
|
6771530025 | ||
|
|
5876f367bc | ||
|
|
e263af2dd9 | ||
|
|
3a59911a2b | ||
|
|
bbf31e99c3 | ||
|
|
9d5bd42f3e | ||
|
|
e721dae6da | ||
|
|
af3106b144 | ||
|
|
5b55103575 | ||
|
|
ee4ecf929f | ||
|
|
8336556a6f | ||
|
|
709aad1d3b | ||
|
|
fb7ab4937c | ||
|
|
5df1726d80 | ||
|
|
9fdb568843 | ||
|
|
8e76f56f89 | ||
|
|
05d3791577 | ||
|
|
d00dd7eb90 | ||
|
|
8d2e404017 | ||
|
|
95eb2af25e | ||
|
|
cbc00a501b | ||
|
|
480645d897 | ||
|
|
997c767c95 | ||
|
|
5a54e1dc9a | ||
|
|
49b1952566 | ||
|
|
e73edc2fce | ||
|
|
409652e874 | ||
|
|
1d3fb6431f | ||
|
|
76cfada60f | ||
|
|
ac45f80551 | ||
|
|
5ea85f086a | ||
|
|
e3f657746c | ||
|
|
001b56e2cc | ||
|
|
ecbfd2f0de | ||
|
|
45753397e1 | ||
|
|
dc6fe1dafe | ||
|
|
d5e8f2f416 | ||
|
|
d73af5a2b4 | ||
|
|
7042f2bba8 | ||
|
|
efeb260fa8 | ||
|
|
29e90092ea | ||
|
|
0abe865023 | ||
|
|
220c65a41a |
5
.github/actions/setup/action.yml
vendored
5
.github/actions/setup/action.yml
vendored
@@ -12,13 +12,14 @@ inputs:
|
||||
runs:
|
||||
using: "composite"
|
||||
steps:
|
||||
- name: Install apt deps
|
||||
- name: Install apt deps & cleanup
|
||||
if: ${{ contains(inputs.dependencies, 'system') || contains(inputs.dependencies, 'python') }}
|
||||
shell: bash
|
||||
run: |
|
||||
sudo apt-get remove --purge man-db
|
||||
sudo apt-get update
|
||||
sudo apt-get install --no-install-recommends -y libpq-dev openssl libxmlsec1-dev pkg-config gettext libkrb5-dev krb5-kdc krb5-user krb5-admin-server
|
||||
sudo apt-get install --no-install-recommends -y libpq-dev openssl libxmlsec1-dev pkg-config gettext krb5-multidev libkrb5-dev heimdal-multidev libclang-dev krb5-kdc krb5-user krb5-admin-server
|
||||
sudo rm -rf /usr/local/lib/android
|
||||
- name: Install uv
|
||||
if: ${{ contains(inputs.dependencies, 'python') }}
|
||||
uses: astral-sh/setup-uv@ed21f2f24f8dd64503750218de024bcf64c7250a # v5
|
||||
|
||||
2
.github/actions/test-results/action.yml
vendored
2
.github/actions/test-results/action.yml
vendored
@@ -20,7 +20,7 @@ runs:
|
||||
- name: PostgreSQL Logs
|
||||
shell: bash
|
||||
run: |
|
||||
if [[ $ACTIONS_RUNNER_DEBUG == 'true' || $ACTIONS_STEP_DEBUG == 'true' ]]; then
|
||||
if [[ $RUNNER_DEBUG == '1' ]]; then
|
||||
docker stop setup-postgresql-1
|
||||
echo "::group::PostgreSQL Logs"
|
||||
docker logs setup-postgresql-1
|
||||
|
||||
2
.github/workflows/ci-main.yml
vendored
2
.github/workflows/ci-main.yml
vendored
@@ -84,7 +84,7 @@ jobs:
|
||||
# Current version family based on
|
||||
current_version_family=$(cat internal/constants/VERSION | grep -vE -- 'rc[0-9]+$' || true)
|
||||
if [[ -n $current_version_family ]]; then
|
||||
prev_stable=$current_version_family
|
||||
prev_stable="version/${current_version_family}"
|
||||
fi
|
||||
echo "::notice::Checking out ${prev_stable} as stable version..."
|
||||
git checkout ${prev_stable}
|
||||
|
||||
5
.github/workflows/release-publish.yml
vendored
5
.github/workflows/release-publish.yml
vendored
@@ -101,7 +101,7 @@ jobs:
|
||||
- name: make empty clients
|
||||
run: |
|
||||
mkdir -p ./gen-ts-api
|
||||
mkdir -p ./gen-go-api
|
||||
make gen-client-go
|
||||
- name: Docker Login Registry
|
||||
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # v3
|
||||
with:
|
||||
@@ -160,6 +160,9 @@ jobs:
|
||||
run: |
|
||||
npm ci
|
||||
npm run build-proxy
|
||||
- name: Build API client
|
||||
run: |
|
||||
make gen-client-go
|
||||
- name: Build outpost
|
||||
run: |
|
||||
set -x
|
||||
|
||||
4
.github/workflows/release-tag.yml
vendored
4
.github/workflows/release-tag.yml
vendored
@@ -49,8 +49,12 @@ jobs:
|
||||
test:
|
||||
name: Pre-release test
|
||||
runs-on: ubuntu-latest
|
||||
needs:
|
||||
- check-inputs
|
||||
steps:
|
||||
- uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v5
|
||||
with:
|
||||
ref: "version-${{ needs.check-inputs.outputs.major_version }}"
|
||||
- run: make test-docker
|
||||
bump-authentik:
|
||||
name: Bump authentik version
|
||||
|
||||
@@ -40,7 +40,7 @@ packages/tsconfig @goauthentik/frontend
|
||||
# Web
|
||||
web/ @goauthentik/frontend
|
||||
# Locale
|
||||
locale/ @goauthentik/backend @goauthentik/frontend
|
||||
/locale/ @goauthentik/backend @goauthentik/frontend
|
||||
web/xliff/ @goauthentik/backend @goauthentik/frontend
|
||||
# Docs
|
||||
website/ @goauthentik/docs
|
||||
|
||||
@@ -114,7 +114,7 @@ RUN --mount=type=cache,id=apt-$TARGETARCH$TARGETVARIANT,sharing=locked,target=/v
|
||||
# postgresql
|
||||
libpq-dev \
|
||||
# python-kadmin-rs
|
||||
clang libkrb5-dev sccache \
|
||||
krb5-multidev libkrb5-dev heimdal-multidev libclang-dev \
|
||||
# xmlsec
|
||||
libltdl-dev && \
|
||||
curl https://sh.rustup.rs -sSf | sh -s -- -y
|
||||
@@ -156,7 +156,11 @@ WORKDIR /
|
||||
RUN apt-get update && \
|
||||
apt-get upgrade -y && \
|
||||
# Required for runtime
|
||||
apt-get install -y --no-install-recommends libpq5 libmaxminddb0 ca-certificates libkrb5-3 libkadm5clnt-mit12 libkdb5-10 libltdl7 libxslt1.1 && \
|
||||
apt-get install -y --no-install-recommends \
|
||||
libpq5 libmaxminddb0 ca-certificates \
|
||||
krb5-multidev libkrb5-3 libkdb5-10 libkadm5clnt-mit12 \
|
||||
heimdal-multidev libkadm5clnt7t64-heimdal \
|
||||
libltdl7 libxslt1.1 && \
|
||||
# Required for bootstrap & healtcheck
|
||||
apt-get install -y --no-install-recommends runit && \
|
||||
pip3 install --no-cache-dir --upgrade pip && \
|
||||
|
||||
15
Makefile
15
Makefile
@@ -188,24 +188,15 @@ gen-client-ts: gen-clean-ts ## Build and install the authentik API for Typescri
|
||||
|
||||
gen-client-py: gen-clean-py ## Build and install the authentik API for Python
|
||||
mkdir -p ${PWD}/${GEN_API_PY}
|
||||
ifeq ($(wildcard ${PWD}/${GEN_API_PY}/.*),)
|
||||
git clone --depth 1 https://github.com/goauthentik/client-python.git ${PWD}/${GEN_API_PY}
|
||||
else
|
||||
cd ${PWD}/${GEN_API_PY} && git pull
|
||||
endif
|
||||
cp ${PWD}/schema.yml ${PWD}/${GEN_API_PY}
|
||||
make -C ${PWD}/${GEN_API_PY} build version=${NPM_VERSION}
|
||||
|
||||
gen-client-go: ## Build and install the authentik API for Golang
|
||||
gen-client-go: gen-clean-go ## Build and install the authentik API for Golang
|
||||
mkdir -p ${PWD}/${GEN_API_GO}
|
||||
ifeq ($(wildcard ${PWD}/${GEN_API_GO}/.*),)
|
||||
git clone --depth 1 https://github.com/goauthentik/client-go.git ${PWD}/${GEN_API_GO}
|
||||
else
|
||||
cd ${PWD}/${GEN_API_GO} && git reset --hard
|
||||
cd ${PWD}/${GEN_API_GO} && git pull
|
||||
endif
|
||||
cp ${PWD}/schema.yml ${PWD}/${GEN_API_GO}
|
||||
make -C ${PWD}/${GEN_API_GO} build
|
||||
make -C ${PWD}/${GEN_API_GO} build version=${NPM_VERSION}
|
||||
go mod edit -replace goauthentik.io/api/v3=./${GEN_API_GO}
|
||||
|
||||
gen-dev-config: ## Generate a local development config file
|
||||
@@ -327,6 +318,6 @@ ci-pending-migrations: ci--meta-debug
|
||||
uv run ak makemigrations --check
|
||||
|
||||
ci-test: ci--meta-debug
|
||||
uv run coverage run manage.py test --keepdb --randomly-seed ${CI_TEST_SEED} authentik
|
||||
uv run coverage run manage.py test --keepdb authentik
|
||||
uv run coverage report
|
||||
uv run coverage xml
|
||||
|
||||
@@ -20,8 +20,8 @@ Even if the issue is not a CVE, we still greatly appreciate your help in hardeni
|
||||
|
||||
| Version | Supported |
|
||||
| ---------- | ---------- |
|
||||
| 2025.8.x | ✅ |
|
||||
| 2025.10.x | ✅ |
|
||||
| 2025.12.x | ✅ |
|
||||
|
||||
## Reporting a Vulnerability
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
from functools import lru_cache
|
||||
from os import environ
|
||||
|
||||
VERSION = "2025.12.0-rc1"
|
||||
VERSION = "2025.12.2"
|
||||
ENV_GIT_HASH_KEY = "GIT_BUILD_HASH"
|
||||
|
||||
|
||||
|
||||
@@ -37,7 +37,7 @@ class VersionSerializer(PassiveSerializer):
|
||||
|
||||
def get_version_latest(self, _) -> str:
|
||||
"""Get latest version from cache"""
|
||||
if get_current_tenant().schema_name == get_public_schema_name():
|
||||
if get_current_tenant().schema_name != get_public_schema_name():
|
||||
return authentik_version()
|
||||
version_in_cache = cache.get(VERSION_CACHE_KEY)
|
||||
if not version_in_cache: # pragma: no cover
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
import mimetypes
|
||||
|
||||
from django.db.models import Q
|
||||
from django.utils.translation import gettext as _
|
||||
from drf_spectacular.utils import extend_schema
|
||||
@@ -12,13 +10,14 @@ from rest_framework.request import Request
|
||||
from rest_framework.response import Response
|
||||
from rest_framework.views import APIView
|
||||
|
||||
from authentik.admin.files.backends.base import get_content_type
|
||||
from authentik.admin.files.fields import FileField as AkFileField
|
||||
from authentik.admin.files.manager import get_file_manager
|
||||
from authentik.admin.files.usage import FileApiUsage
|
||||
from authentik.admin.files.validation import validate_upload_file_name
|
||||
from authentik.api.validation import validate
|
||||
from authentik.core.api.used_by import DeleteAction, UsedBySerializer
|
||||
from authentik.core.api.utils import PassiveSerializer
|
||||
from authentik.core.api.utils import PassiveSerializer, ThemedUrlsSerializer
|
||||
from authentik.events.models import Event, EventAction
|
||||
from authentik.lib.utils.reflection import get_apps
|
||||
from authentik.rbac.permissions import HasPermission
|
||||
@@ -26,11 +25,6 @@ from authentik.rbac.permissions import HasPermission
|
||||
MAX_FILE_SIZE_BYTES = 25 * 1024 * 1024 # 25MB
|
||||
|
||||
|
||||
def get_mime_from_filename(filename: str) -> str:
|
||||
mime_type, _ = mimetypes.guess_type(filename)
|
||||
return mime_type or "application/octet-stream"
|
||||
|
||||
|
||||
class FileView(APIView):
|
||||
pagination_class = None
|
||||
parser_classes = [MultiPartParser]
|
||||
@@ -53,6 +47,7 @@ class FileView(APIView):
|
||||
name = CharField()
|
||||
mime_type = CharField()
|
||||
url = CharField()
|
||||
themed_urls = ThemedUrlsSerializer(required=False, allow_null=True)
|
||||
|
||||
@extend_schema(
|
||||
parameters=[FileListParameters],
|
||||
@@ -80,8 +75,9 @@ class FileView(APIView):
|
||||
FileView.FileListSerializer(
|
||||
data={
|
||||
"name": file,
|
||||
"url": manager.file_url(file),
|
||||
"mime_type": get_mime_from_filename(file),
|
||||
"url": manager.file_url(file, request),
|
||||
"mime_type": get_content_type(file),
|
||||
"themed_urls": manager.themed_urls(file, request),
|
||||
}
|
||||
)
|
||||
for file in files
|
||||
@@ -150,7 +146,7 @@ class FileView(APIView):
|
||||
"pk": name,
|
||||
"name": name,
|
||||
"usage": usage.value,
|
||||
"mime_type": get_mime_from_filename(name),
|
||||
"mime_type": get_content_type(name),
|
||||
},
|
||||
).from_http(request)
|
||||
|
||||
@@ -240,7 +236,9 @@ class FileUsedByView(APIView):
|
||||
for field in fields:
|
||||
q |= Q(**{field: params.get("name")})
|
||||
|
||||
objs = get_objects_for_user(request.user, f"{app}.view_{model_name}", model)
|
||||
objs = get_objects_for_user(
|
||||
request.user, f"{app}.view_{model_name}", model.objects.all()
|
||||
)
|
||||
objs = objs.filter(q)
|
||||
for obj in objs:
|
||||
serializer = UsedBySerializer(
|
||||
|
||||
@@ -1,9 +1,4 @@
|
||||
from pathlib import Path
|
||||
|
||||
from django.conf import settings
|
||||
|
||||
from authentik.blueprints.apps import ManagedAppConfig
|
||||
from authentik.lib.config import CONFIG
|
||||
|
||||
|
||||
class AuthentikFilesConfig(ManagedAppConfig):
|
||||
@@ -11,20 +6,3 @@ class AuthentikFilesConfig(ManagedAppConfig):
|
||||
label = "authentik_admin_files"
|
||||
verbose_name = "authentik Files"
|
||||
default = True
|
||||
|
||||
@ManagedAppConfig.reconcile_global
|
||||
def check_for_media_mount(self):
|
||||
if settings.TEST:
|
||||
return
|
||||
|
||||
from authentik.events.models import Event, EventAction
|
||||
|
||||
if (
|
||||
CONFIG.get("storage.media.backend", CONFIG.get("storage.backend", "file")) == "file"
|
||||
and Path("/media").exists()
|
||||
):
|
||||
Event.new(
|
||||
EventAction.CONFIGURATION_ERROR,
|
||||
message="/media has been moved to /data/media. "
|
||||
"Check the release notes for migration steps.",
|
||||
).save()
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import mimetypes
|
||||
from collections.abc import Callable, Generator, Iterator
|
||||
from typing import cast
|
||||
|
||||
@@ -10,6 +11,32 @@ from authentik.admin.files.usage import FileUsage
|
||||
CACHE_PREFIX = "goauthentik.io/admin/files"
|
||||
LOGGER = get_logger()
|
||||
|
||||
# Theme variable placeholder for theme-specific files like logo-%(theme)s.png
|
||||
THEME_VARIABLE = "%(theme)s"
|
||||
|
||||
|
||||
def get_content_type(name: str) -> str:
|
||||
"""Get MIME type for a file based on its extension."""
|
||||
content_type, _ = mimetypes.guess_type(name)
|
||||
return content_type or "application/octet-stream"
|
||||
|
||||
|
||||
def get_valid_themes() -> list[str]:
|
||||
"""Get valid themes that can be substituted for %(theme)s."""
|
||||
from authentik.brands.api import Themes
|
||||
|
||||
return [t.value for t in Themes if t != Themes.AUTOMATIC]
|
||||
|
||||
|
||||
def has_theme_variable(name: str) -> bool:
|
||||
"""Check if filename contains %(theme)s variable."""
|
||||
return THEME_VARIABLE in name
|
||||
|
||||
|
||||
def substitute_theme(name: str, theme: str) -> str:
|
||||
"""Replace %(theme)s with the given theme."""
|
||||
return name.replace(THEME_VARIABLE, theme)
|
||||
|
||||
|
||||
class Backend:
|
||||
"""
|
||||
@@ -75,6 +102,29 @@ class Backend:
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def themed_urls(
|
||||
self,
|
||||
name: str,
|
||||
request: HttpRequest | None = None,
|
||||
) -> dict[str, str] | None:
|
||||
"""
|
||||
Get URLs for each theme variant when filename contains %(theme)s.
|
||||
|
||||
Args:
|
||||
name: File path potentially containing %(theme)s
|
||||
request: Optional Django HttpRequest for URL building
|
||||
|
||||
Returns:
|
||||
Dict mapping theme to URL if %(theme)s present, None otherwise
|
||||
"""
|
||||
if not has_theme_variable(name):
|
||||
return None
|
||||
|
||||
return {
|
||||
theme: self.file_url(substitute_theme(name, theme), request, use_cache=True)
|
||||
for theme in get_valid_themes()
|
||||
}
|
||||
|
||||
|
||||
class ManageableBackend(Backend):
|
||||
"""
|
||||
|
||||
@@ -45,8 +45,13 @@ class FileBackend(ManageableBackend):
|
||||
|
||||
@property
|
||||
def manageable(self) -> bool:
|
||||
# Check _base_dir (the mount point, e.g. /data) rather than base_path
|
||||
# (which includes usage/schema subdirs, e.g. /data/media/public).
|
||||
# The subdirectories are created on first file write via mkdir(parents=True)
|
||||
# in save_file(), so requiring them to exist beforehand would prevent
|
||||
# file creation on fresh installs.
|
||||
return (
|
||||
self.base_path.exists()
|
||||
self._base_dir.exists()
|
||||
and (self._base_dir.is_mount() or (self._base_dir / self.usage.value).is_mount())
|
||||
or (settings.DEBUG or settings.TEST)
|
||||
)
|
||||
|
||||
@@ -46,3 +46,25 @@ class PassthroughBackend(Backend):
|
||||
) -> str:
|
||||
"""Return the URL as-is for passthrough files."""
|
||||
return name
|
||||
|
||||
def themed_urls(
|
||||
self,
|
||||
name: str,
|
||||
request: HttpRequest | None = None,
|
||||
) -> dict[str, str] | None:
|
||||
"""Support themed URLs for external URLs with %(theme)s placeholder.
|
||||
|
||||
If the external URL contains %(theme)s, substitute it for each theme.
|
||||
We can't verify that themed variants exist at the external location,
|
||||
but we trust the user to provide valid URLs.
|
||||
"""
|
||||
from authentik.admin.files.backends.base import (
|
||||
get_valid_themes,
|
||||
has_theme_variable,
|
||||
substitute_theme,
|
||||
)
|
||||
|
||||
if not has_theme_variable(name):
|
||||
return None
|
||||
|
||||
return {theme: substitute_theme(name, theme) for theme in get_valid_themes()}
|
||||
|
||||
@@ -9,7 +9,7 @@ from botocore.exceptions import ClientError
|
||||
from django.db import connection
|
||||
from django.http.request import HttpRequest
|
||||
|
||||
from authentik.admin.files.backends.base import ManageableBackend
|
||||
from authentik.admin.files.backends.base import ManageableBackend, get_content_type
|
||||
from authentik.admin.files.usage import FileUsage
|
||||
from authentik.lib.config import CONFIG
|
||||
from authentik.lib.utils.time import timedelta_from_string
|
||||
@@ -173,7 +173,22 @@ class S3Backend(ManageableBackend):
|
||||
if custom_domain:
|
||||
parsed = urlsplit(url)
|
||||
scheme = "https" if use_https else "http"
|
||||
url = f"{scheme}://{custom_domain}{parsed.path}?{parsed.query}"
|
||||
path = parsed.path
|
||||
|
||||
# When using path-style addressing, the presigned URL contains the bucket
|
||||
# name in the path (e.g., /bucket-name/key). Since custom_domain must
|
||||
# include the bucket name (per docs), strip it from the path to avoid
|
||||
# duplication. See: https://github.com/goauthentik/authentik/issues/19521
|
||||
# Check with trailing slash to ensure exact bucket name match
|
||||
if path.startswith(f"/{self.bucket_name}/"):
|
||||
path = path.removeprefix(f"/{self.bucket_name}")
|
||||
|
||||
# Normalize to avoid double slashes
|
||||
custom_domain = custom_domain.rstrip("/")
|
||||
if not path.startswith("/"):
|
||||
path = f"/{path}"
|
||||
|
||||
url = f"{scheme}://{custom_domain}{path}?{parsed.query}"
|
||||
|
||||
return url
|
||||
|
||||
@@ -189,6 +204,7 @@ class S3Backend(ManageableBackend):
|
||||
Key=f"{self.base_path}/{name}",
|
||||
Body=content,
|
||||
ACL="private",
|
||||
ContentType=get_content_type(name),
|
||||
)
|
||||
|
||||
@contextmanager
|
||||
@@ -204,6 +220,7 @@ class S3Backend(ManageableBackend):
|
||||
Key=f"{self.base_path}/{name}",
|
||||
ExtraArgs={
|
||||
"ACL": "private",
|
||||
"ContentType": get_content_type(name),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@@ -165,3 +165,31 @@ class TestFileBackend(FileTestFileBackendMixin, TestCase):
|
||||
def test_file_exists_false(self):
|
||||
"""Test file_exists returns False for nonexistent file"""
|
||||
self.assertFalse(self.backend.file_exists("does_not_exist.txt"))
|
||||
|
||||
def test_themed_urls_without_theme_variable(self):
|
||||
"""Test themed_urls returns None when filename has no %(theme)s"""
|
||||
file_name = "logo.png"
|
||||
result = self.backend.themed_urls(file_name)
|
||||
self.assertIsNone(result)
|
||||
|
||||
def test_themed_urls_with_theme_variable(self):
|
||||
"""Test themed_urls returns dict of URLs for each theme"""
|
||||
file_name = "logo-%(theme)s.png"
|
||||
result = self.backend.themed_urls(file_name)
|
||||
|
||||
self.assertIsInstance(result, dict)
|
||||
self.assertIn("light", result)
|
||||
self.assertIn("dark", result)
|
||||
|
||||
# Check URLs contain the substituted theme
|
||||
self.assertIn("logo-light.png", result["light"])
|
||||
self.assertIn("logo-dark.png", result["dark"])
|
||||
|
||||
def test_themed_urls_multiple_theme_variables(self):
|
||||
"""Test themed_urls with multiple %(theme)s in path"""
|
||||
file_name = "%(theme)s/logo-%(theme)s.svg"
|
||||
result = self.backend.themed_urls(file_name)
|
||||
|
||||
self.assertIsInstance(result, dict)
|
||||
self.assertIn("light/logo-light.svg", result["light"])
|
||||
self.assertIn("dark/logo-dark.svg", result["dark"])
|
||||
|
||||
@@ -1,10 +1,13 @@
|
||||
from unittest import skipUnless
|
||||
|
||||
from django.test import TestCase
|
||||
|
||||
from authentik.admin.files.tests.utils import FileTestS3BackendMixin
|
||||
from authentik.admin.files.tests.utils import FileTestS3BackendMixin, s3_test_server_available
|
||||
from authentik.admin.files.usage import FileUsage
|
||||
from authentik.lib.config import CONFIG
|
||||
|
||||
|
||||
@skipUnless(s3_test_server_available(), "S3 test server not available")
|
||||
class TestS3Backend(FileTestS3BackendMixin, TestCase):
|
||||
"""Test S3 backend functionality"""
|
||||
|
||||
@@ -107,3 +110,106 @@ class TestS3Backend(FileTestS3BackendMixin, TestCase):
|
||||
"""Test S3Backend with REPORTS usage"""
|
||||
self.assertEqual(self.reports_s3_backend.usage, FileUsage.REPORTS)
|
||||
self.assertEqual(self.reports_s3_backend.base_path, "reports/public")
|
||||
|
||||
@CONFIG.patch("storage.s3.secure_urls", True)
|
||||
@CONFIG.patch("storage.s3.addressing_style", "path")
|
||||
def test_file_url_custom_domain_with_bucket_no_duplicate(self):
|
||||
"""Test file_url doesn't duplicate bucket name when custom_domain includes bucket.
|
||||
|
||||
Regression test for https://github.com/goauthentik/authentik/issues/19521
|
||||
|
||||
When using:
|
||||
- Path-style addressing (bucket name goes in URL path, not subdomain)
|
||||
- Custom domain that includes the bucket name (e.g., s3.example.com/bucket-name)
|
||||
|
||||
The bucket name should NOT appear twice in the final URL.
|
||||
|
||||
Example of the bug:
|
||||
- custom_domain = "s3.example.com/authentik-media"
|
||||
- boto3 presigned URL = "http://s3.example.com/authentik-media/media/public/file.png?..."
|
||||
- Buggy result = "https://s3.example.com/authentik-media/authentik-media/media/public/file.png?..."
|
||||
"""
|
||||
bucket_name = self.media_s3_bucket_name
|
||||
|
||||
# Custom domain includes the bucket name
|
||||
custom_domain = f"localhost:8020/{bucket_name}"
|
||||
|
||||
with CONFIG.patch("storage.media.s3.custom_domain", custom_domain):
|
||||
url = self.media_s3_backend.file_url("application-icons/test.svg", use_cache=False)
|
||||
|
||||
# The bucket name should appear exactly once in the URL path, not twice
|
||||
bucket_occurrences = url.count(bucket_name)
|
||||
self.assertEqual(
|
||||
bucket_occurrences,
|
||||
1,
|
||||
f"Bucket name '{bucket_name}' appears {bucket_occurrences} times in URL, expected 1. "
|
||||
f"URL: {url}",
|
||||
)
|
||||
|
||||
def test_themed_urls_without_theme_variable(self):
|
||||
"""Test themed_urls returns None when filename has no %(theme)s"""
|
||||
result = self.media_s3_backend.themed_urls("logo.png")
|
||||
self.assertIsNone(result)
|
||||
|
||||
def test_themed_urls_with_theme_variable(self):
|
||||
"""Test themed_urls returns dict of presigned URLs for each theme"""
|
||||
result = self.media_s3_backend.themed_urls("logo-%(theme)s.png")
|
||||
|
||||
self.assertIsInstance(result, dict)
|
||||
self.assertIn("light", result)
|
||||
self.assertIn("dark", result)
|
||||
|
||||
# Check URLs are valid presigned URLs with correct file paths
|
||||
self.assertIn("logo-light.png", result["light"])
|
||||
self.assertIn("logo-dark.png", result["dark"])
|
||||
self.assertIn("X-Amz-Signature=", result["light"])
|
||||
self.assertIn("X-Amz-Signature=", result["dark"])
|
||||
|
||||
def test_themed_urls_multiple_theme_variables(self):
|
||||
"""Test themed_urls with multiple %(theme)s in path"""
|
||||
result = self.media_s3_backend.themed_urls("%(theme)s/logo-%(theme)s.svg")
|
||||
|
||||
self.assertIsInstance(result, dict)
|
||||
self.assertIn("light/logo-light.svg", result["light"])
|
||||
self.assertIn("dark/logo-dark.svg", result["dark"])
|
||||
|
||||
def test_save_file_sets_content_type_svg(self):
|
||||
"""Test save_file sets correct ContentType for SVG files"""
|
||||
self.media_s3_backend.save_file("test.svg", b"<svg></svg>")
|
||||
|
||||
response = self.media_s3_backend.client.head_object(
|
||||
Bucket=self.media_s3_bucket_name,
|
||||
Key="media/public/test.svg",
|
||||
)
|
||||
self.assertEqual(response["ContentType"], "image/svg+xml")
|
||||
|
||||
def test_save_file_sets_content_type_png(self):
|
||||
"""Test save_file sets correct ContentType for PNG files"""
|
||||
self.media_s3_backend.save_file("test.png", b"\x89PNG\r\n\x1a\n")
|
||||
|
||||
response = self.media_s3_backend.client.head_object(
|
||||
Bucket=self.media_s3_bucket_name,
|
||||
Key="media/public/test.png",
|
||||
)
|
||||
self.assertEqual(response["ContentType"], "image/png")
|
||||
|
||||
def test_save_file_stream_sets_content_type(self):
|
||||
"""Test save_file_stream sets correct ContentType"""
|
||||
with self.media_s3_backend.save_file_stream("test.css") as f:
|
||||
f.write(b"body { color: red; }")
|
||||
|
||||
response = self.media_s3_backend.client.head_object(
|
||||
Bucket=self.media_s3_bucket_name,
|
||||
Key="media/public/test.css",
|
||||
)
|
||||
self.assertEqual(response["ContentType"], "text/css")
|
||||
|
||||
def test_save_file_unknown_extension_octet_stream(self):
|
||||
"""Test save_file sets octet-stream for unknown extensions"""
|
||||
self.media_s3_backend.save_file("test.unknownext123", b"data")
|
||||
|
||||
response = self.media_s3_backend.client.head_object(
|
||||
Bucket=self.media_s3_bucket_name,
|
||||
Key="media/public/test.unknownext123",
|
||||
)
|
||||
self.assertEqual(response["ContentType"], "application/octet-stream")
|
||||
|
||||
@@ -88,6 +88,28 @@ class FileManager:
|
||||
LOGGER.warning(f"Could not find file backend for file: {name}")
|
||||
return ""
|
||||
|
||||
def themed_urls(
|
||||
self,
|
||||
name: str | None,
|
||||
request: HttpRequest | Request | None = None,
|
||||
) -> dict[str, str] | None:
|
||||
"""
|
||||
Get URLs for each theme variant when filename contains %(theme)s.
|
||||
|
||||
Returns dict mapping theme to URL if %(theme)s present, None otherwise.
|
||||
"""
|
||||
if not name:
|
||||
return None
|
||||
|
||||
if isinstance(request, Request):
|
||||
request = request._request
|
||||
|
||||
for backend in self.backends:
|
||||
if backend.supports_file(name):
|
||||
return backend.themed_urls(name, request)
|
||||
|
||||
return None
|
||||
|
||||
def _check_manageable(self) -> None:
|
||||
if not self.manageable:
|
||||
raise ImproperlyConfigured("No file management backend configured.")
|
||||
|
||||
@@ -5,7 +5,6 @@ from io import BytesIO
|
||||
from django.test import TestCase
|
||||
from django.urls import reverse
|
||||
|
||||
from authentik.admin.files.api import get_mime_from_filename
|
||||
from authentik.admin.files.manager import FileManager
|
||||
from authentik.admin.files.tests.utils import FileTestFileBackendMixin
|
||||
from authentik.admin.files.usage import FileUsage
|
||||
@@ -94,8 +93,9 @@ class TestFileAPI(FileTestFileBackendMixin, TestCase):
|
||||
self.assertIn(
|
||||
{
|
||||
"name": "/static/authentik/sources/ldap.png",
|
||||
"url": "/static/authentik/sources/ldap.png",
|
||||
"url": "http://testserver/static/authentik/sources/ldap.png",
|
||||
"mime_type": "image/png",
|
||||
"themed_urls": None,
|
||||
},
|
||||
response.data,
|
||||
)
|
||||
@@ -129,8 +129,9 @@ class TestFileAPI(FileTestFileBackendMixin, TestCase):
|
||||
self.assertIn(
|
||||
{
|
||||
"name": "/static/authentik/sources/ldap.png",
|
||||
"url": "/static/authentik/sources/ldap.png",
|
||||
"url": "http://testserver/static/authentik/sources/ldap.png",
|
||||
"mime_type": "image/png",
|
||||
"themed_urls": None,
|
||||
},
|
||||
response.data,
|
||||
)
|
||||
@@ -200,30 +201,64 @@ class TestFileAPI(FileTestFileBackendMixin, TestCase):
|
||||
self.assertEqual(response.status_code, 400)
|
||||
self.assertIn("field is required", str(response.data))
|
||||
|
||||
def test_list_files_includes_themed_urls_none(self):
|
||||
"""Test listing files includes themed_urls as None for non-themed files"""
|
||||
manager = FileManager(FileUsage.MEDIA)
|
||||
file_name = "test-no-theme.png"
|
||||
manager.save_file(file_name, b"test content")
|
||||
|
||||
class TestGetMimeFromFilename(TestCase):
|
||||
"""Test get_mime_from_filename function"""
|
||||
response = self.client.get(
|
||||
reverse("authentik_api:files", query={"search": file_name, "manageableOnly": "true"})
|
||||
)
|
||||
|
||||
def test_image_png(self):
|
||||
"""Test PNG image MIME type"""
|
||||
self.assertEqual(get_mime_from_filename("test.png"), "image/png")
|
||||
self.assertEqual(response.status_code, 200)
|
||||
file_entry = next((f for f in response.data if f["name"] == file_name), None)
|
||||
self.assertIsNotNone(file_entry)
|
||||
self.assertIn("themed_urls", file_entry)
|
||||
self.assertIsNone(file_entry["themed_urls"])
|
||||
|
||||
def test_image_jpeg(self):
|
||||
"""Test JPEG image MIME type"""
|
||||
self.assertEqual(get_mime_from_filename("test.jpg"), "image/jpeg")
|
||||
manager.delete_file(file_name)
|
||||
|
||||
def test_image_svg(self):
|
||||
"""Test SVG image MIME type"""
|
||||
self.assertEqual(get_mime_from_filename("test.svg"), "image/svg+xml")
|
||||
def test_list_files_includes_themed_urls_dict(self):
|
||||
"""Test listing files includes themed_urls as dict for themed files"""
|
||||
manager = FileManager(FileUsage.MEDIA)
|
||||
file_name = "logo-%(theme)s.svg"
|
||||
manager.save_file("logo-light.svg", b"<svg>light</svg>")
|
||||
manager.save_file("logo-dark.svg", b"<svg>dark</svg>")
|
||||
manager.save_file(file_name, b"<svg>placeholder</svg>")
|
||||
|
||||
def test_text_plain(self):
|
||||
"""Test text file MIME type"""
|
||||
self.assertEqual(get_mime_from_filename("test.txt"), "text/plain")
|
||||
response = self.client.get(
|
||||
reverse("authentik_api:files", query={"search": "%(theme)s", "manageableOnly": "true"})
|
||||
)
|
||||
|
||||
def test_unknown_extension(self):
|
||||
"""Test unknown extension returns octet-stream"""
|
||||
self.assertEqual(get_mime_from_filename("test.unknown"), "application/octet-stream")
|
||||
self.assertEqual(response.status_code, 200)
|
||||
file_entry = next((f for f in response.data if f["name"] == file_name), None)
|
||||
self.assertIsNotNone(file_entry)
|
||||
self.assertIn("themed_urls", file_entry)
|
||||
self.assertIsInstance(file_entry["themed_urls"], dict)
|
||||
self.assertIn("light", file_entry["themed_urls"])
|
||||
self.assertIn("dark", file_entry["themed_urls"])
|
||||
|
||||
def test_no_extension(self):
|
||||
"""Test no extension returns octet-stream"""
|
||||
self.assertEqual(get_mime_from_filename("test"), "application/octet-stream")
|
||||
manager.delete_file(file_name)
|
||||
manager.delete_file("logo-light.svg")
|
||||
manager.delete_file("logo-dark.svg")
|
||||
|
||||
def test_upload_file_with_theme_variable(self):
|
||||
"""Test uploading file with %(theme)s in name"""
|
||||
manager = FileManager(FileUsage.MEDIA)
|
||||
file_name = "brand-logo-%(theme)s.svg"
|
||||
file_content = b"<svg></svg>"
|
||||
|
||||
response = self.client.post(
|
||||
reverse("authentik_api:files"),
|
||||
{
|
||||
"file": BytesIO(file_content),
|
||||
"name": file_name,
|
||||
"usage": FileUsage.MEDIA.value,
|
||||
},
|
||||
format="multipart",
|
||||
)
|
||||
|
||||
self.assertEqual(response.status_code, 200)
|
||||
self.assertTrue(manager.file_exists(file_name))
|
||||
manager.delete_file(file_name)
|
||||
|
||||
@@ -1,10 +1,17 @@
|
||||
"""Test file service layer"""
|
||||
|
||||
from unittest import skipUnless
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from django.http import HttpRequest
|
||||
from django.test import TestCase
|
||||
|
||||
from authentik.admin.files.manager import FileManager
|
||||
from authentik.admin.files.tests.utils import FileTestFileBackendMixin, FileTestS3BackendMixin
|
||||
from authentik.admin.files.tests.utils import (
|
||||
FileTestFileBackendMixin,
|
||||
FileTestS3BackendMixin,
|
||||
s3_test_server_available,
|
||||
)
|
||||
from authentik.admin.files.usage import FileUsage
|
||||
from authentik.lib.config import CONFIG
|
||||
|
||||
@@ -81,6 +88,7 @@ class TestResolveFileUrlFileBackend(FileTestFileBackendMixin, TestCase):
|
||||
self.assertEqual(result, "http://example.com/files/media/public/test.png")
|
||||
|
||||
|
||||
@skipUnless(s3_test_server_available(), "S3 test server not available")
|
||||
class TestResolveFileUrlS3Backend(FileTestS3BackendMixin, TestCase):
|
||||
@CONFIG.patch("storage.media.s3.custom_domain", "s3.test:8080/test")
|
||||
@CONFIG.patch("storage.media.s3.secure_urls", False)
|
||||
@@ -97,3 +105,71 @@ class TestResolveFileUrlS3Backend(FileTestS3BackendMixin, TestCase):
|
||||
|
||||
# S3 URLs should be returned as-is (already absolute)
|
||||
self.assertTrue(result.startswith("http://s3.test:8080/test"))
|
||||
|
||||
|
||||
class TestThemedUrls(FileTestFileBackendMixin, TestCase):
|
||||
"""Test FileManager.themed_urls method"""
|
||||
|
||||
def test_themed_urls_none_path(self):
|
||||
"""Test themed_urls returns None for None path"""
|
||||
manager = FileManager(FileUsage.MEDIA)
|
||||
result = manager.themed_urls(None)
|
||||
self.assertIsNone(result)
|
||||
|
||||
def test_themed_urls_empty_path(self):
|
||||
"""Test themed_urls returns None for empty path"""
|
||||
manager = FileManager(FileUsage.MEDIA)
|
||||
result = manager.themed_urls("")
|
||||
self.assertIsNone(result)
|
||||
|
||||
def test_themed_urls_no_theme_variable(self):
|
||||
"""Test themed_urls returns None when no %(theme)s in path"""
|
||||
manager = FileManager(FileUsage.MEDIA)
|
||||
result = manager.themed_urls("logo.png")
|
||||
self.assertIsNone(result)
|
||||
|
||||
def test_themed_urls_with_theme_variable(self):
|
||||
"""Test themed_urls returns dict of URLs for each theme"""
|
||||
manager = FileManager(FileUsage.MEDIA)
|
||||
result = manager.themed_urls("logo-%(theme)s.png")
|
||||
|
||||
self.assertIsInstance(result, dict)
|
||||
self.assertIn("light", result)
|
||||
self.assertIn("dark", result)
|
||||
self.assertIn("logo-light.png", result["light"])
|
||||
self.assertIn("logo-dark.png", result["dark"])
|
||||
|
||||
def test_themed_urls_with_request(self):
|
||||
"""Test themed_urls builds absolute URLs with request"""
|
||||
mock_request = HttpRequest()
|
||||
mock_request.META = {
|
||||
"HTTP_HOST": "example.com",
|
||||
"SERVER_NAME": "example.com",
|
||||
}
|
||||
|
||||
manager = FileManager(FileUsage.MEDIA)
|
||||
result = manager.themed_urls("logo-%(theme)s.svg", mock_request)
|
||||
|
||||
self.assertIsInstance(result, dict)
|
||||
light_url = urlparse(result["light"])
|
||||
dark_url = urlparse(result["dark"])
|
||||
self.assertEqual(light_url.scheme, "http")
|
||||
self.assertEqual(light_url.netloc, "example.com")
|
||||
self.assertEqual(dark_url.scheme, "http")
|
||||
self.assertEqual(dark_url.netloc, "example.com")
|
||||
|
||||
def test_themed_urls_passthrough_with_theme_variable(self):
|
||||
"""Test themed_urls returns dict for passthrough URLs with %(theme)s"""
|
||||
manager = FileManager(FileUsage.MEDIA)
|
||||
# External URLs with %(theme)s should return themed URLs
|
||||
result = manager.themed_urls("https://example.com/logo-%(theme)s.png")
|
||||
self.assertIsInstance(result, dict)
|
||||
self.assertEqual(result["light"], "https://example.com/logo-light.png")
|
||||
self.assertEqual(result["dark"], "https://example.com/logo-dark.png")
|
||||
|
||||
def test_themed_urls_passthrough_without_theme_variable(self):
|
||||
"""Test themed_urls returns None for passthrough URLs without %(theme)s"""
|
||||
manager = FileManager(FileUsage.MEDIA)
|
||||
# External URLs without %(theme)s should return None
|
||||
result = manager.themed_urls("https://example.com/logo.png")
|
||||
self.assertIsNone(result)
|
||||
|
||||
@@ -62,10 +62,10 @@ class TestSanitizeFilePath(TestCase):
|
||||
"test@file.png", # @
|
||||
"test#file.png", # #
|
||||
"test$file.png", # $
|
||||
"test%file.png", # %
|
||||
"test%file.png", # % (but %(theme)s is allowed)
|
||||
"test&file.png", # &
|
||||
"test*file.png", # *
|
||||
"test(file).png", # parentheses
|
||||
"test(file).png", # parentheses (but %(theme)s is allowed)
|
||||
"test[file].png", # brackets
|
||||
"test{file}.png", # braces
|
||||
]
|
||||
@@ -108,3 +108,30 @@ class TestSanitizeFilePath(TestCase):
|
||||
|
||||
with self.assertRaises(ValidationError):
|
||||
validate_file_name(path)
|
||||
|
||||
def test_sanitize_theme_variable_valid(self):
|
||||
"""Test sanitizing filename with %(theme)s variable"""
|
||||
# These should all be valid
|
||||
validate_file_name("logo-%(theme)s.png")
|
||||
validate_file_name("brand/logo-%(theme)s.svg")
|
||||
validate_file_name("images/icon-%(theme)s.png")
|
||||
validate_file_name("%(theme)s/logo.png")
|
||||
validate_file_name("brand/%(theme)s/logo.png")
|
||||
|
||||
def test_sanitize_theme_variable_multiple(self):
|
||||
"""Test sanitizing filename with multiple %(theme)s variables"""
|
||||
validate_file_name("%(theme)s/logo-%(theme)s.png")
|
||||
|
||||
def test_sanitize_theme_variable_invalid_format(self):
|
||||
"""Test that partial or malformed theme variables are rejected"""
|
||||
invalid_paths = [
|
||||
"test%(theme.png", # missing )s
|
||||
"test%theme)s.png", # missing (
|
||||
"test%(themes).png", # wrong variable name
|
||||
"test%(THEME)s.png", # wrong case
|
||||
"test%()s.png", # empty variable name
|
||||
]
|
||||
|
||||
for path in invalid_paths:
|
||||
with self.assertRaises(ValidationError):
|
||||
validate_file_name(path)
|
||||
|
||||
@@ -1,11 +1,26 @@
|
||||
import shutil
|
||||
import socket
|
||||
from tempfile import mkdtemp
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from authentik.admin.files.backends.s3 import S3Backend
|
||||
from authentik.admin.files.usage import FileUsage
|
||||
from authentik.lib.config import CONFIG, UNSET
|
||||
from authentik.lib.generators import generate_id
|
||||
|
||||
S3_TEST_ENDPOINT = "http://localhost:8020"
|
||||
|
||||
|
||||
def s3_test_server_available() -> bool:
|
||||
"""Check if the S3 test server is reachable."""
|
||||
|
||||
parsed = urlparse(S3_TEST_ENDPOINT)
|
||||
try:
|
||||
with socket.create_connection((parsed.hostname, parsed.port), timeout=2):
|
||||
return True
|
||||
except OSError:
|
||||
return False
|
||||
|
||||
|
||||
class FileTestFileBackendMixin:
|
||||
def setUp(self):
|
||||
@@ -57,7 +72,7 @@ class FileTestS3BackendMixin:
|
||||
for key in s3_config_keys:
|
||||
self.original_media_s3_settings[key] = CONFIG.get(f"storage.media.s3.{key}", UNSET)
|
||||
self.media_s3_bucket_name = f"authentik-test-{generate_id(10)}".lower()
|
||||
CONFIG.set("storage.media.s3.endpoint", "http://localhost:8020")
|
||||
CONFIG.set("storage.media.s3.endpoint", S3_TEST_ENDPOINT)
|
||||
CONFIG.set("storage.media.s3.access_key", "accessKey1")
|
||||
CONFIG.set("storage.media.s3.secret_key", "secretKey1")
|
||||
CONFIG.set("storage.media.s3.bucket_name", self.media_s3_bucket_name)
|
||||
@@ -70,7 +85,7 @@ class FileTestS3BackendMixin:
|
||||
for key in s3_config_keys:
|
||||
self.original_reports_s3_settings[key] = CONFIG.get(f"storage.reports.s3.{key}", UNSET)
|
||||
self.reports_s3_bucket_name = f"authentik-test-{generate_id(10)}".lower()
|
||||
CONFIG.set("storage.reports.s3.endpoint", "http://localhost:8020")
|
||||
CONFIG.set("storage.reports.s3.endpoint", S3_TEST_ENDPOINT)
|
||||
CONFIG.set("storage.reports.s3.access_key", "accessKey1")
|
||||
CONFIG.set("storage.reports.s3.secret_key", "secretKey1")
|
||||
CONFIG.set("storage.reports.s3.bucket_name", self.reports_s3_bucket_name)
|
||||
|
||||
@@ -4,6 +4,7 @@ from pathlib import PurePosixPath
|
||||
from django.core.exceptions import ValidationError
|
||||
from django.utils.translation import gettext as _
|
||||
|
||||
from authentik.admin.files.backends.base import THEME_VARIABLE
|
||||
from authentik.admin.files.backends.passthrough import PassthroughBackend
|
||||
from authentik.admin.files.backends.static import StaticBackend
|
||||
from authentik.admin.files.usage import FileUsage
|
||||
@@ -39,12 +40,17 @@ def validate_upload_file_name(
|
||||
if not name:
|
||||
raise ValidationError(_("File name cannot be empty"))
|
||||
|
||||
# Same regex is used in the frontend as well
|
||||
if not re.match(r"^[a-zA-Z0-9._/-]+$", name):
|
||||
# Allow %(theme)s placeholder for theme-specific files
|
||||
# Replace with placeholder for validation, then check the result
|
||||
name_for_validation = name.replace(THEME_VARIABLE, "theme")
|
||||
|
||||
# Same regex is used in the frontend as well (with %(theme)s handling)
|
||||
if not re.match(r"^[a-zA-Z0-9._/-]+$", name_for_validation):
|
||||
raise ValidationError(
|
||||
_(
|
||||
"File name can only contain letters (a-z, A-Z), numbers (0-9), "
|
||||
"dots (.), hyphens (-), underscores (_), and forward slashes (/)"
|
||||
"dots (.), hyphens (-), underscores (_), forward slashes (/), "
|
||||
"and the placeholder %(theme)s for theme-specific files"
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@@ -15,7 +15,9 @@ class Pagination(pagination.PageNumberPagination):
|
||||
|
||||
def get_page_size(self, request):
|
||||
if self.page_size_query_param in request.query_params:
|
||||
return min(super().get_page_size(request), request.tenant.pagination_max_page_size)
|
||||
page_size = super().get_page_size(request)
|
||||
if page_size is not None:
|
||||
return min(super().get_page_size(request), request.tenant.pagination_max_page_size)
|
||||
return request.tenant.pagination_default_page_size
|
||||
|
||||
def get_paginated_response(self, data):
|
||||
|
||||
@@ -31,6 +31,7 @@ class Capabilities(models.TextChoices):
|
||||
"""Define capabilities which influence which APIs can/should be used"""
|
||||
|
||||
CAN_SAVE_MEDIA = "can_save_media"
|
||||
CAN_SAVE_REPORTS = "can_save_reports"
|
||||
CAN_GEO_IP = "can_geo_ip"
|
||||
CAN_ASN = "can_asn"
|
||||
CAN_IMPERSONATE = "can_impersonate"
|
||||
@@ -70,6 +71,8 @@ class ConfigView(APIView):
|
||||
caps = []
|
||||
if get_file_manager(FileUsage.MEDIA).manageable:
|
||||
caps.append(Capabilities.CAN_SAVE_MEDIA)
|
||||
if get_file_manager(FileUsage.REPORTS).manageable:
|
||||
caps.append(Capabilities.CAN_SAVE_REPORTS)
|
||||
for processor in get_context_processors():
|
||||
if cap := processor.capability():
|
||||
caps.append(cap)
|
||||
|
||||
@@ -8,45 +8,62 @@ metadata:
|
||||
- Application (icon)
|
||||
- Source (icon)
|
||||
- Flow (background)
|
||||
- Endpoint Enrollment token (key)
|
||||
entries:
|
||||
- model: authentik_core.token
|
||||
identifiers:
|
||||
identifier: "%(uid)s-token"
|
||||
attrs:
|
||||
key: "%(uid)s"
|
||||
user: "%(user)s"
|
||||
intent: api
|
||||
- model: authentik_core.application
|
||||
identifiers:
|
||||
slug: "%(uid)s-app"
|
||||
attrs:
|
||||
name: "%(uid)s-app"
|
||||
icon: https://goauthentik.io/img/icon.png
|
||||
- model: authentik_sources_oauth.oauthsource
|
||||
identifiers:
|
||||
slug: "%(uid)s-source"
|
||||
attrs:
|
||||
name: "%(uid)s-source"
|
||||
provider_type: azuread
|
||||
consumer_key: "%(uid)s"
|
||||
consumer_secret: "%(uid)s"
|
||||
icon: https://goauthentik.io/img/icon.png
|
||||
- model: authentik_flows.flow
|
||||
identifiers:
|
||||
slug: "%(uid)s-flow"
|
||||
attrs:
|
||||
name: "%(uid)s-flow"
|
||||
title: "%(uid)s-flow"
|
||||
designation: authentication
|
||||
background: https://goauthentik.io/img/icon.png
|
||||
- model: authentik_core.user
|
||||
identifiers:
|
||||
username: "%(uid)s"
|
||||
attrs:
|
||||
name: "%(uid)s"
|
||||
password: "%(uid)s"
|
||||
- model: authentik_core.user
|
||||
identifiers:
|
||||
username: "%(uid)s-no-password"
|
||||
attrs:
|
||||
name: "%(uid)s"
|
||||
token:
|
||||
- model: authentik_core.token
|
||||
identifiers:
|
||||
identifier: "%(uid)s-token"
|
||||
attrs:
|
||||
key: "%(uid)s"
|
||||
user: "%(user)s"
|
||||
intent: api
|
||||
app:
|
||||
- model: authentik_core.application
|
||||
identifiers:
|
||||
slug: "%(uid)s-app"
|
||||
attrs:
|
||||
name: "%(uid)s-app"
|
||||
icon: https://goauthentik.io/img/icon.png
|
||||
source:
|
||||
- model: authentik_sources_oauth.oauthsource
|
||||
identifiers:
|
||||
slug: "%(uid)s-source"
|
||||
attrs:
|
||||
name: "%(uid)s-source"
|
||||
provider_type: azuread
|
||||
consumer_key: "%(uid)s"
|
||||
consumer_secret: "%(uid)s"
|
||||
icon: https://goauthentik.io/img/icon.png
|
||||
flow:
|
||||
- model: authentik_flows.flow
|
||||
identifiers:
|
||||
slug: "%(uid)s-flow"
|
||||
attrs:
|
||||
name: "%(uid)s-flow"
|
||||
title: "%(uid)s-flow"
|
||||
designation: authentication
|
||||
background: https://goauthentik.io/img/icon.png
|
||||
user:
|
||||
- model: authentik_core.user
|
||||
identifiers:
|
||||
username: "%(uid)s"
|
||||
attrs:
|
||||
name: "%(uid)s"
|
||||
password: "%(uid)s"
|
||||
- model: authentik_core.user
|
||||
identifiers:
|
||||
username: "%(uid)s-no-password"
|
||||
attrs:
|
||||
name: "%(uid)s"
|
||||
endpoint:
|
||||
- model: authentik_endpoints_connectors_agent.agentconnector
|
||||
id: connector
|
||||
identifiers:
|
||||
name: "%(uid)s"
|
||||
- model: authentik_endpoints_connectors_agent.enrollmenttoken
|
||||
identifiers:
|
||||
name: "%(uid)s"
|
||||
attrs:
|
||||
key: "%(uid)s"
|
||||
connector: !KeyOf connector
|
||||
|
||||
@@ -5,6 +5,7 @@ from django.test import TransactionTestCase
|
||||
from authentik.blueprints.v1.importer import Importer
|
||||
from authentik.core.models import Token, User
|
||||
from authentik.core.tests.utils import create_test_admin_user
|
||||
from authentik.endpoints.connectors.agent.models import EnrollmentToken
|
||||
from authentik.lib.generators import generate_id
|
||||
from authentik.lib.tests.utils import load_fixture
|
||||
|
||||
@@ -29,12 +30,18 @@ class TestBlueprintsV1ConditionalFields(TransactionTestCase):
|
||||
|
||||
def test_user(self):
|
||||
"""Test user"""
|
||||
user: User = User.objects.filter(username=self.uid).first()
|
||||
user = User.objects.filter(username=self.uid).first()
|
||||
self.assertIsNotNone(user)
|
||||
self.assertTrue(user.check_password(self.uid))
|
||||
|
||||
def test_user_null(self):
|
||||
"""Test user"""
|
||||
user: User = User.objects.filter(username=f"{self.uid}-no-password").first()
|
||||
user = User.objects.filter(username=f"{self.uid}-no-password").first()
|
||||
self.assertIsNotNone(user)
|
||||
self.assertFalse(user.has_usable_password())
|
||||
|
||||
def test_enrollment_token(self):
|
||||
"""Test endpoint enrollment token"""
|
||||
token = EnrollmentToken.objects.filter(name=self.uid).first()
|
||||
self.assertIsNotNone(token)
|
||||
self.assertEqual(token.key, self.uid)
|
||||
|
||||
@@ -149,7 +149,7 @@ class TestBlueprintsV1Tasks(TransactionTestCase):
|
||||
instance.status,
|
||||
BlueprintInstanceStatus.UNKNOWN,
|
||||
)
|
||||
apply_blueprint(instance.pk)
|
||||
apply_blueprint.send(instance.pk).get_result(block=True)
|
||||
instance.refresh_from_db()
|
||||
self.assertEqual(instance.last_applied_hash, "")
|
||||
self.assertEqual(
|
||||
|
||||
@@ -37,14 +37,21 @@ class ApplyBlueprintMetaSerializer(PassiveSerializer):
|
||||
return super().validate(attrs)
|
||||
|
||||
def create(self, validated_data: dict) -> MetaResult:
|
||||
from authentik.blueprints.v1.tasks import apply_blueprint
|
||||
from authentik.blueprints.v1.importer import Importer
|
||||
|
||||
if not self.blueprint_instance:
|
||||
LOGGER.info("Blueprint does not exist, but not required")
|
||||
return MetaResult()
|
||||
LOGGER.debug("Applying blueprint from meta model", blueprint=self.blueprint_instance)
|
||||
|
||||
apply_blueprint(self.blueprint_instance.pk)
|
||||
# Apply blueprint directly using Importer to avoid task context requirements
|
||||
# and prevent deadlocks when called from within another blueprint task
|
||||
blueprint_content = self.blueprint_instance.retrieve()
|
||||
importer = Importer.from_string(blueprint_content, self.blueprint_instance.context)
|
||||
valid, logs = importer.validate()
|
||||
[log.log() for log in logs]
|
||||
if valid:
|
||||
importer.apply()
|
||||
return MetaResult()
|
||||
|
||||
|
||||
|
||||
@@ -12,7 +12,6 @@ from django.db import DatabaseError, InternalError, ProgrammingError
|
||||
from django.utils.text import slugify
|
||||
from django.utils.timezone import now
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
from django_dramatiq_postgres.middleware import CurrentTaskNotFound
|
||||
from dramatiq.actor import actor
|
||||
from dramatiq.middleware import Middleware
|
||||
from structlog.stdlib import get_logger
|
||||
@@ -40,7 +39,6 @@ from authentik.events.utils import sanitize_dict
|
||||
from authentik.lib.config import CONFIG
|
||||
from authentik.tasks.apps import PRIORITY_HIGH
|
||||
from authentik.tasks.middleware import CurrentTask
|
||||
from authentik.tasks.models import Task
|
||||
from authentik.tasks.schedules.models import Schedule
|
||||
from authentik.tenants.models import Tenant
|
||||
|
||||
@@ -191,10 +189,7 @@ def check_blueprint_v1_file(blueprint: BlueprintFile):
|
||||
|
||||
@actor(description=_("Apply single blueprint."))
|
||||
def apply_blueprint(instance_pk: UUID):
|
||||
try:
|
||||
self = CurrentTask.get_task()
|
||||
except CurrentTaskNotFound:
|
||||
self = Task()
|
||||
self = CurrentTask.get_task()
|
||||
self.set_uid(str(instance_pk))
|
||||
instance: BlueprintInstance | None = None
|
||||
try:
|
||||
|
||||
@@ -6,7 +6,12 @@ from django.db import models
|
||||
from drf_spectacular.utils import extend_schema, extend_schema_field
|
||||
from rest_framework.decorators import action
|
||||
from rest_framework.exceptions import ValidationError
|
||||
from rest_framework.fields import CharField, ChoiceField, ListField, SerializerMethodField
|
||||
from rest_framework.fields import (
|
||||
CharField,
|
||||
ChoiceField,
|
||||
ListField,
|
||||
SerializerMethodField,
|
||||
)
|
||||
from rest_framework.filters import OrderingFilter, SearchFilter
|
||||
from rest_framework.permissions import AllowAny
|
||||
from rest_framework.request import Request
|
||||
@@ -16,7 +21,7 @@ from rest_framework.viewsets import ModelViewSet
|
||||
|
||||
from authentik.brands.models import Brand
|
||||
from authentik.core.api.used_by import UsedByMixin
|
||||
from authentik.core.api.utils import ModelSerializer, PassiveSerializer
|
||||
from authentik.core.api.utils import ModelSerializer, PassiveSerializer, ThemedUrlsSerializer
|
||||
from authentik.rbac.filters import SecretKeyFilter
|
||||
from authentik.tenants.api.settings import FlagJSONField
|
||||
from authentik.tenants.flags import Flag
|
||||
@@ -90,7 +95,9 @@ class CurrentBrandSerializer(PassiveSerializer):
|
||||
matched_domain = CharField(source="domain")
|
||||
branding_title = CharField()
|
||||
branding_logo = CharField(source="branding_logo_url")
|
||||
branding_logo_themed_urls = ThemedUrlsSerializer(read_only=True, allow_null=True)
|
||||
branding_favicon = CharField(source="branding_favicon_url")
|
||||
branding_favicon_themed_urls = ThemedUrlsSerializer(read_only=True, allow_null=True)
|
||||
branding_custom_css = CharField()
|
||||
ui_footer_links = ListField(
|
||||
child=FooterLinkSerializer(),
|
||||
|
||||
@@ -89,14 +89,26 @@ class Brand(SerializerModel):
|
||||
"""Get branding_logo URL"""
|
||||
return get_file_manager(FileUsage.MEDIA).file_url(self.branding_logo)
|
||||
|
||||
def branding_logo_themed_urls(self) -> dict[str, str] | None:
|
||||
"""Get themed URLs for branding_logo if it contains %(theme)s"""
|
||||
return get_file_manager(FileUsage.MEDIA).themed_urls(self.branding_logo)
|
||||
|
||||
def branding_favicon_url(self) -> str:
|
||||
"""Get branding_favicon URL"""
|
||||
return get_file_manager(FileUsage.MEDIA).file_url(self.branding_favicon)
|
||||
|
||||
def branding_favicon_themed_urls(self) -> dict[str, str] | None:
|
||||
"""Get themed URLs for branding_favicon if it contains %(theme)s"""
|
||||
return get_file_manager(FileUsage.MEDIA).themed_urls(self.branding_favicon)
|
||||
|
||||
def branding_default_flow_background_url(self) -> str:
|
||||
"""Get branding_default_flow_background URL"""
|
||||
return get_file_manager(FileUsage.MEDIA).file_url(self.branding_default_flow_background)
|
||||
|
||||
def branding_default_flow_background_themed_urls(self) -> dict[str, str] | None:
|
||||
"""Get themed URLs for branding_default_flow_background if it contains %(theme)s"""
|
||||
return get_file_manager(FileUsage.MEDIA).themed_urls(self.branding_default_flow_background)
|
||||
|
||||
@property
|
||||
def serializer(self) -> type[Serializer]:
|
||||
from authentik.brands.api import BrandSerializer
|
||||
|
||||
@@ -6,7 +6,6 @@ from django.urls import reverse
|
||||
from rest_framework.test import APITestCase
|
||||
|
||||
from authentik.blueprints.tests import apply_blueprint
|
||||
from authentik.brands.api import Themes
|
||||
from authentik.brands.models import Brand
|
||||
from authentik.core.models import Application
|
||||
from authentik.core.tests.utils import create_test_admin_user, create_test_brand
|
||||
@@ -35,12 +34,14 @@ class TestBrands(APITestCase):
|
||||
self.client.get(reverse("authentik_api:brand-current")).content.decode(),
|
||||
{
|
||||
"branding_logo": "/static/dist/assets/icons/icon_left_brand.svg",
|
||||
"branding_logo_themed_urls": None,
|
||||
"branding_favicon": "/static/dist/assets/icons/icon.png",
|
||||
"branding_favicon_themed_urls": None,
|
||||
"branding_title": "authentik",
|
||||
"branding_custom_css": "",
|
||||
"matched_domain": brand.domain,
|
||||
"ui_footer_links": [],
|
||||
"ui_theme": Themes.AUTOMATIC,
|
||||
"ui_theme": "automatic",
|
||||
"default_locale": "",
|
||||
"flags": self.default_flags,
|
||||
},
|
||||
@@ -55,12 +56,14 @@ class TestBrands(APITestCase):
|
||||
).content.decode(),
|
||||
{
|
||||
"branding_logo": "/static/dist/assets/icons/icon_left_brand.svg",
|
||||
"branding_logo_themed_urls": None,
|
||||
"branding_favicon": "/static/dist/assets/icons/icon.png",
|
||||
"branding_favicon_themed_urls": None,
|
||||
"branding_title": "custom",
|
||||
"branding_custom_css": "",
|
||||
"matched_domain": "bar.baz",
|
||||
"ui_footer_links": [],
|
||||
"ui_theme": Themes.AUTOMATIC,
|
||||
"ui_theme": "automatic",
|
||||
"default_locale": "",
|
||||
"flags": self.default_flags,
|
||||
},
|
||||
@@ -72,12 +75,14 @@ class TestBrands(APITestCase):
|
||||
self.client.get(reverse("authentik_api:brand-current")).content.decode(),
|
||||
{
|
||||
"branding_logo": "/static/dist/assets/icons/icon_left_brand.svg",
|
||||
"branding_logo_themed_urls": None,
|
||||
"branding_favicon": "/static/dist/assets/icons/icon.png",
|
||||
"branding_favicon_themed_urls": None,
|
||||
"branding_title": "authentik",
|
||||
"branding_custom_css": "",
|
||||
"matched_domain": "fallback",
|
||||
"ui_footer_links": [],
|
||||
"ui_theme": Themes.AUTOMATIC,
|
||||
"ui_theme": "automatic",
|
||||
"default_locale": "",
|
||||
"flags": self.default_flags,
|
||||
},
|
||||
@@ -94,12 +99,14 @@ class TestBrands(APITestCase):
|
||||
response,
|
||||
{
|
||||
"branding_logo": "/static/dist/assets/icons/icon_left_brand.svg",
|
||||
"branding_logo_themed_urls": None,
|
||||
"branding_favicon": "/static/dist/assets/icons/icon.png",
|
||||
"branding_favicon_themed_urls": None,
|
||||
"branding_title": "authentik",
|
||||
"branding_custom_css": "",
|
||||
"matched_domain": "authentik-default",
|
||||
"ui_footer_links": [],
|
||||
"ui_theme": Themes.AUTOMATIC,
|
||||
"ui_theme": "automatic",
|
||||
"default_locale": "",
|
||||
"flags": self.default_flags,
|
||||
},
|
||||
@@ -117,12 +124,14 @@ class TestBrands(APITestCase):
|
||||
response,
|
||||
{
|
||||
"branding_logo": "/static/dist/assets/icons/icon_left_brand.svg",
|
||||
"branding_logo_themed_urls": None,
|
||||
"branding_favicon": "/static/dist/assets/icons/icon.png",
|
||||
"branding_favicon_themed_urls": None,
|
||||
"branding_title": "authentik",
|
||||
"branding_custom_css": "",
|
||||
"matched_domain": "authentik-default",
|
||||
"ui_footer_links": [],
|
||||
"ui_theme": Themes.AUTOMATIC,
|
||||
"ui_theme": "automatic",
|
||||
"default_locale": "",
|
||||
"flags": self.default_flags,
|
||||
},
|
||||
@@ -133,12 +142,14 @@ class TestBrands(APITestCase):
|
||||
).content.decode(),
|
||||
{
|
||||
"branding_logo": "/static/dist/assets/icons/icon_left_brand.svg",
|
||||
"branding_logo_themed_urls": None,
|
||||
"branding_favicon": "/static/dist/assets/icons/icon.png",
|
||||
"branding_favicon_themed_urls": None,
|
||||
"branding_title": "custom",
|
||||
"branding_custom_css": "",
|
||||
"matched_domain": "bar.baz",
|
||||
"ui_footer_links": [],
|
||||
"ui_theme": Themes.AUTOMATIC,
|
||||
"ui_theme": "automatic",
|
||||
"default_locale": "",
|
||||
"flags": self.default_flags,
|
||||
},
|
||||
@@ -154,12 +165,14 @@ class TestBrands(APITestCase):
|
||||
).content.decode(),
|
||||
{
|
||||
"branding_logo": "/static/dist/assets/icons/icon_left_brand.svg",
|
||||
"branding_logo_themed_urls": None,
|
||||
"branding_favicon": "/static/dist/assets/icons/icon.png",
|
||||
"branding_favicon_themed_urls": None,
|
||||
"branding_title": "custom-strong",
|
||||
"branding_custom_css": "",
|
||||
"matched_domain": "foo.bar.baz",
|
||||
"ui_footer_links": [],
|
||||
"ui_theme": Themes.AUTOMATIC,
|
||||
"ui_theme": "automatic",
|
||||
"default_locale": "",
|
||||
"flags": self.default_flags,
|
||||
},
|
||||
@@ -175,12 +188,14 @@ class TestBrands(APITestCase):
|
||||
).content.decode(),
|
||||
{
|
||||
"branding_logo": "/static/dist/assets/icons/icon_left_brand.svg",
|
||||
"branding_logo_themed_urls": None,
|
||||
"branding_favicon": "/static/dist/assets/icons/icon.png",
|
||||
"branding_favicon_themed_urls": None,
|
||||
"branding_title": "custom-weak",
|
||||
"branding_custom_css": "",
|
||||
"matched_domain": "bar.baz",
|
||||
"ui_footer_links": [],
|
||||
"ui_theme": Themes.AUTOMATIC,
|
||||
"ui_theme": "automatic",
|
||||
"default_locale": "",
|
||||
"flags": self.default_flags,
|
||||
},
|
||||
@@ -256,12 +271,14 @@ class TestBrands(APITestCase):
|
||||
self.client.get(reverse("authentik_api:brand-current")).content.decode(),
|
||||
{
|
||||
"branding_logo": "https://goauthentik.io/img/icon.png",
|
||||
"branding_logo_themed_urls": None,
|
||||
"branding_favicon": "https://goauthentik.io/img/icon.png",
|
||||
"branding_favicon_themed_urls": None,
|
||||
"branding_title": "authentik",
|
||||
"branding_custom_css": "",
|
||||
"matched_domain": brand.domain,
|
||||
"ui_footer_links": [],
|
||||
"ui_theme": Themes.AUTOMATIC,
|
||||
"ui_theme": "automatic",
|
||||
"default_locale": "",
|
||||
"flags": self.default_flags,
|
||||
},
|
||||
|
||||
@@ -24,7 +24,7 @@ from authentik.blueprints.v1.importer import SERIALIZER_CONTEXT_BLUEPRINT
|
||||
from authentik.core.api.providers import ProviderSerializer
|
||||
from authentik.core.api.used_by import UsedByMixin
|
||||
from authentik.core.api.users import UserSerializer
|
||||
from authentik.core.api.utils import ModelSerializer
|
||||
from authentik.core.api.utils import ModelSerializer, ThemedUrlsSerializer
|
||||
from authentik.core.models import Application, User
|
||||
from authentik.events.logs import LogEventSerializer, capture_logs
|
||||
from authentik.policies.api.exec import PolicyTestResultSerializer
|
||||
@@ -53,6 +53,9 @@ class ApplicationSerializer(ModelSerializer):
|
||||
)
|
||||
|
||||
meta_icon_url = ReadOnlyField(source="get_meta_icon")
|
||||
meta_icon_themed_urls = ThemedUrlsSerializer(
|
||||
source="get_meta_icon_themed_urls", read_only=True, allow_null=True
|
||||
)
|
||||
|
||||
def get_launch_url(self, app: Application) -> str | None:
|
||||
"""Allow formatting of launch URL"""
|
||||
@@ -102,6 +105,7 @@ class ApplicationSerializer(ModelSerializer):
|
||||
"meta_launch_url",
|
||||
"meta_icon",
|
||||
"meta_icon_url",
|
||||
"meta_icon_themed_urls",
|
||||
"meta_description",
|
||||
"meta_publisher",
|
||||
"policy_engine_mode",
|
||||
@@ -180,10 +184,10 @@ class ApplicationViewSet(UsedByMixin, ModelViewSet):
|
||||
)
|
||||
|
||||
def _filter_applications_with_launch_url(
|
||||
self, applications: QuerySet[Application]
|
||||
self, paginated_apps: QuerySet[Application]
|
||||
) -> list[Application]:
|
||||
applications = []
|
||||
for app in applications:
|
||||
for app in paginated_apps:
|
||||
if app.get_launch_url():
|
||||
applications.append(app)
|
||||
return applications
|
||||
|
||||
@@ -33,6 +33,16 @@ from authentik.endpoints.connectors.agent.auth import AgentAuth
|
||||
from authentik.rbac.api.roles import RoleSerializer
|
||||
from authentik.rbac.decorators import permission_required
|
||||
|
||||
PARTIAL_USER_SERIALIZER_MODEL_FIELDS = [
|
||||
"pk",
|
||||
"username",
|
||||
"name",
|
||||
"is_active",
|
||||
"last_login",
|
||||
"email",
|
||||
"attributes",
|
||||
]
|
||||
|
||||
|
||||
class PartialUserSerializer(ModelSerializer):
|
||||
"""Partial User Serializer, does not include child relations."""
|
||||
@@ -42,16 +52,7 @@ class PartialUserSerializer(ModelSerializer):
|
||||
|
||||
class Meta:
|
||||
model = User
|
||||
fields = [
|
||||
"pk",
|
||||
"username",
|
||||
"name",
|
||||
"is_active",
|
||||
"last_login",
|
||||
"email",
|
||||
"attributes",
|
||||
"uid",
|
||||
]
|
||||
fields = PARTIAL_USER_SERIALIZER_MODEL_FIELDS + ["uid"]
|
||||
|
||||
|
||||
class RelatedGroupSerializer(ModelSerializer):
|
||||
@@ -84,6 +85,7 @@ class GroupSerializer(ModelSerializer):
|
||||
source="roles",
|
||||
required=False,
|
||||
)
|
||||
inherited_roles_obj = SerializerMethodField(allow_null=True)
|
||||
num_pk = IntegerField(read_only=True)
|
||||
|
||||
@property
|
||||
@@ -107,6 +109,13 @@ class GroupSerializer(ModelSerializer):
|
||||
return True
|
||||
return str(request.query_params.get("include_parents", "false")).lower() == "true"
|
||||
|
||||
@property
|
||||
def _should_include_inherited_roles(self) -> bool:
|
||||
request: Request = self.context.get("request", None)
|
||||
if not request:
|
||||
return True
|
||||
return str(request.query_params.get("include_inherited_roles", "false")).lower() == "true"
|
||||
|
||||
@extend_schema_field(PartialUserSerializer(many=True))
|
||||
def get_users_obj(self, instance: Group) -> list[PartialUserSerializer] | None:
|
||||
if not self._should_include_users:
|
||||
@@ -125,6 +134,15 @@ class GroupSerializer(ModelSerializer):
|
||||
return None
|
||||
return RelatedGroupSerializer(instance.parents, many=True).data
|
||||
|
||||
@extend_schema_field(RoleSerializer(many=True))
|
||||
def get_inherited_roles_obj(self, instance: Group) -> list | None:
|
||||
"""Return only inherited roles from ancestor groups (excludes direct roles)"""
|
||||
if not self._should_include_inherited_roles:
|
||||
return None
|
||||
direct_role_pks = instance.roles.values_list("pk", flat=True)
|
||||
inherited_roles = instance.all_roles().exclude(pk__in=direct_role_pks)
|
||||
return RoleSerializer(inherited_roles, many=True).data
|
||||
|
||||
def validate_is_superuser(self, superuser: bool):
|
||||
"""Ensure that the user creating this group has permissions to set the superuser flag"""
|
||||
request: Request = self.context.get("request", None)
|
||||
@@ -166,6 +184,7 @@ class GroupSerializer(ModelSerializer):
|
||||
"attributes",
|
||||
"roles",
|
||||
"roles_obj",
|
||||
"inherited_roles_obj",
|
||||
"children",
|
||||
"children_obj",
|
||||
]
|
||||
@@ -255,14 +274,21 @@ class GroupViewSet(UsedByMixin, ModelViewSet):
|
||||
return [
|
||||
StrField(Group, "name"),
|
||||
BoolField(Group, "is_superuser", nullable=True),
|
||||
JSONSearchField(Group, "attributes", suggest_nested=False),
|
||||
JSONSearchField(Group, "attributes"),
|
||||
]
|
||||
|
||||
def get_queryset(self):
|
||||
base_qs = Group.objects.all().prefetch_related("roles")
|
||||
|
||||
if self.serializer_class(context={"request": self.request})._should_include_users:
|
||||
base_qs = base_qs.prefetch_related("users")
|
||||
# Only fetch fields needed by PartialUserSerializer to reduce DB load and instantiation
|
||||
# time
|
||||
base_qs = base_qs.prefetch_related(
|
||||
Prefetch(
|
||||
"users",
|
||||
queryset=User.objects.all().only(*PARTIAL_USER_SERIALIZER_MODEL_FIELDS),
|
||||
)
|
||||
)
|
||||
else:
|
||||
base_qs = base_qs.prefetch_related(
|
||||
Prefetch("users", queryset=User.objects.all().only("id"))
|
||||
@@ -281,6 +307,7 @@ class GroupViewSet(UsedByMixin, ModelViewSet):
|
||||
OpenApiParameter("include_users", bool, default=True),
|
||||
OpenApiParameter("include_children", bool, default=False),
|
||||
OpenApiParameter("include_parents", bool, default=False),
|
||||
OpenApiParameter("include_inherited_roles", bool, default=False),
|
||||
]
|
||||
)
|
||||
def list(self, request, *args, **kwargs):
|
||||
@@ -291,6 +318,7 @@ class GroupViewSet(UsedByMixin, ModelViewSet):
|
||||
OpenApiParameter("include_users", bool, default=True),
|
||||
OpenApiParameter("include_children", bool, default=False),
|
||||
OpenApiParameter("include_parents", bool, default=False),
|
||||
OpenApiParameter("include_inherited_roles", bool, default=False),
|
||||
]
|
||||
)
|
||||
def retrieve(self, request, *args, **kwargs):
|
||||
|
||||
@@ -18,10 +18,14 @@ from authentik.core.models import Provider
|
||||
class ProviderSerializer(ModelSerializer, MetaNameSerializer):
|
||||
"""Provider Serializer"""
|
||||
|
||||
assigned_application_slug = ReadOnlyField(source="application.slug")
|
||||
assigned_application_name = ReadOnlyField(source="application.name")
|
||||
assigned_backchannel_application_slug = ReadOnlyField(source="backchannel_application.slug")
|
||||
assigned_backchannel_application_name = ReadOnlyField(source="backchannel_application.name")
|
||||
assigned_application_slug = ReadOnlyField(source="application.slug", allow_null=True)
|
||||
assigned_application_name = ReadOnlyField(source="application.name", allow_null=True)
|
||||
assigned_backchannel_application_slug = ReadOnlyField(
|
||||
source="backchannel_application.slug", allow_null=True
|
||||
)
|
||||
assigned_backchannel_application_name = ReadOnlyField(
|
||||
source="backchannel_application.name", allow_null=True
|
||||
)
|
||||
|
||||
component = SerializerMethodField()
|
||||
|
||||
|
||||
@@ -14,7 +14,7 @@ from structlog.stdlib import get_logger
|
||||
|
||||
from authentik.core.api.object_types import TypesMixin
|
||||
from authentik.core.api.used_by import UsedByMixin
|
||||
from authentik.core.api.utils import MetaNameSerializer, ModelSerializer
|
||||
from authentik.core.api.utils import MetaNameSerializer, ModelSerializer, ThemedUrlsSerializer
|
||||
from authentik.core.models import GroupSourceConnection, Source, UserSourceConnection
|
||||
from authentik.core.types import UserSettingSerializer
|
||||
from authentik.policies.engine import PolicyEngine
|
||||
@@ -28,6 +28,7 @@ class SourceSerializer(ModelSerializer, MetaNameSerializer):
|
||||
managed = ReadOnlyField()
|
||||
component = SerializerMethodField()
|
||||
icon_url = ReadOnlyField()
|
||||
icon_themed_urls = ThemedUrlsSerializer(read_only=True, allow_null=True)
|
||||
|
||||
def get_component(self, obj: Source) -> str:
|
||||
"""Get object component so that we know how to edit the object"""
|
||||
@@ -57,6 +58,7 @@ class SourceSerializer(ModelSerializer, MetaNameSerializer):
|
||||
"user_path_template",
|
||||
"icon",
|
||||
"icon_url",
|
||||
"icon_themed_urls",
|
||||
]
|
||||
|
||||
|
||||
|
||||
@@ -518,7 +518,7 @@ class UserViewSet(
|
||||
StrField(User, "path"),
|
||||
BoolField(User, "is_active", nullable=True),
|
||||
ChoiceSearchField(User, "type"),
|
||||
JSONSearchField(User, "attributes", suggest_nested=False),
|
||||
JSONSearchField(User, "attributes"),
|
||||
]
|
||||
|
||||
def get_queryset(self):
|
||||
|
||||
@@ -127,3 +127,10 @@ class LinkSerializer(PassiveSerializer):
|
||||
"""Returns a single link"""
|
||||
|
||||
link = CharField()
|
||||
|
||||
|
||||
class ThemedUrlsSerializer(PassiveSerializer):
|
||||
"""Themed URLs - maps theme names to URLs for light and dark themes"""
|
||||
|
||||
light = CharField(required=False, allow_null=True)
|
||||
dark = CharField(required=False, allow_null=True)
|
||||
|
||||
@@ -8,7 +8,7 @@ from uuid import uuid4
|
||||
from django.contrib.auth import logout
|
||||
from django.contrib.auth.models import AnonymousUser
|
||||
from django.core.exceptions import ImproperlyConfigured
|
||||
from django.http import HttpRequest, HttpResponse
|
||||
from django.http import HttpRequest, HttpResponse, HttpResponseBadRequest
|
||||
from django.utils.deprecation import MiddlewareMixin
|
||||
from django.utils.functional import SimpleLazyObject
|
||||
from django.utils.translation import override
|
||||
@@ -47,7 +47,7 @@ async def aget_user(request):
|
||||
|
||||
|
||||
class AuthenticationMiddleware(MiddlewareMixin):
|
||||
def process_request(self, request):
|
||||
def process_request(self, request: HttpRequest) -> HttpResponseBadRequest | None:
|
||||
if not hasattr(request, "session"):
|
||||
raise ImproperlyConfigured(
|
||||
"The Django authentication middleware requires session "
|
||||
@@ -62,7 +62,8 @@ class AuthenticationMiddleware(MiddlewareMixin):
|
||||
user = request.user
|
||||
if user and user.is_authenticated and not user.is_active:
|
||||
logout(request)
|
||||
raise AssertionError()
|
||||
return HttpResponseBadRequest()
|
||||
return None
|
||||
|
||||
|
||||
class ImpersonateMiddleware:
|
||||
|
||||
@@ -18,10 +18,9 @@ def migrate_object_permissions(apps: Apps, schema_editor: BaseDatabaseSchemaEdit
|
||||
RoleModelPermission = apps.get_model("guardian", "RoleModelPermission")
|
||||
|
||||
def get_role_for_user_id(user_id: int) -> Role:
|
||||
name = f"ak-managed-role--user-{user_id}"
|
||||
name = f"ak-migrated-role--user-{user_id}"
|
||||
role, created = Role.objects.using(db_alias).get_or_create(
|
||||
name=name,
|
||||
managed=name,
|
||||
)
|
||||
if created:
|
||||
role.users.add(user_id)
|
||||
@@ -32,11 +31,10 @@ def migrate_object_permissions(apps: Apps, schema_editor: BaseDatabaseSchemaEdit
|
||||
if not role:
|
||||
# Every django group should already have a role, so this should never happen.
|
||||
# But let's be nice.
|
||||
name = f"ak-managed-role--group-{group_id}"
|
||||
name = f"ak-migrated-role--group-{group_id}"
|
||||
role, created = Role.objects.using(db_alias).get_or_create(
|
||||
group_id=group_id,
|
||||
name=name,
|
||||
managed=name,
|
||||
)
|
||||
if created:
|
||||
role.group_id = group_id
|
||||
|
||||
@@ -713,9 +713,15 @@ class Application(SerializerModel, PolicyBindingModel):
|
||||
|
||||
return get_file_manager(FileUsage.MEDIA).file_url(self.meta_icon)
|
||||
|
||||
def get_launch_url(
|
||||
self, user: Optional["User"] = None, user_data: dict | None = None
|
||||
) -> str | None:
|
||||
@property
|
||||
def get_meta_icon_themed_urls(self) -> dict[str, str] | None:
|
||||
"""Get themed URLs for meta_icon if it contains %(theme)s"""
|
||||
if not self.meta_icon:
|
||||
return None
|
||||
|
||||
return get_file_manager(FileUsage.MEDIA).themed_urls(self.meta_icon)
|
||||
|
||||
def get_launch_url(self, user: User | None = None, user_data: dict | None = None) -> str | None:
|
||||
"""Get launch URL if set, otherwise attempt to get launch URL based on provider.
|
||||
|
||||
Args:
|
||||
@@ -929,6 +935,14 @@ class Source(ManagedModel, SerializerModel, PolicyBindingModel):
|
||||
|
||||
return get_file_manager(FileUsage.MEDIA).file_url(self.icon)
|
||||
|
||||
@property
|
||||
def icon_themed_urls(self) -> dict[str, str] | None:
|
||||
"""Get themed URLs for icon if it contains %(theme)s"""
|
||||
if not self.icon:
|
||||
return None
|
||||
|
||||
return get_file_manager(FileUsage.MEDIA).themed_urls(self.icon)
|
||||
|
||||
def get_user_path(self) -> str:
|
||||
"""Get user path, fallback to default for formatting errors"""
|
||||
try:
|
||||
|
||||
@@ -66,9 +66,12 @@ class SessionStore(SessionBase):
|
||||
def decode(self, session_data):
|
||||
try:
|
||||
return pickle.loads(session_data) # nosec
|
||||
except pickle.PickleError:
|
||||
# ValueError, unpickling exceptions. If any of these happen, just return an empty
|
||||
# dictionary (an empty session)
|
||||
except (pickle.PickleError, AttributeError, TypeError):
|
||||
# PickleError, ValueError - unpickling exceptions
|
||||
# AttributeError - can happen when Django model fields (e.g., FileField) are unpickled
|
||||
# and their descriptors fail to initialize (e.g., missing storage)
|
||||
# TypeError - can happen with incompatible pickled objects
|
||||
# If any of these happen, just return an empty dictionary (an empty session)
|
||||
pass
|
||||
return {}
|
||||
|
||||
|
||||
@@ -35,8 +35,13 @@ def clean_expired_models():
|
||||
LOGGER.debug("Expired models", model=cls, amount=amount)
|
||||
self.info(f"Expired {amount} {cls._meta.verbose_name_plural}")
|
||||
clear_expired_cache()
|
||||
Message.delete_expired()
|
||||
GroupChannel.delete_expired()
|
||||
for cls in [Message, GroupChannel]:
|
||||
objects = cls.objects.all().filter(expires__lt=now())
|
||||
amount = objects.count()
|
||||
for obj in chunked_queryset(objects):
|
||||
obj.delete()
|
||||
LOGGER.debug("Expired models", model=cls, amount=amount)
|
||||
self.info(f"Expired {amount} {cls._meta.verbose_name_plural}")
|
||||
|
||||
|
||||
@actor(description=_("Remove temporary users created by SAML Sources."))
|
||||
|
||||
@@ -10,15 +10,23 @@
|
||||
{% elif ui_theme == "light" %}
|
||||
<meta name="color-scheme" content="light" />
|
||||
<meta name="theme-color" content="#ffffff">
|
||||
{% else %}
|
||||
{% else %}
|
||||
<script data-id="theme-script">
|
||||
"use strict";
|
||||
|
||||
(function () {
|
||||
try {
|
||||
/* Ignore older theme names */
|
||||
let locallyStoredTheme = window.localStorage?.getItem("theme") || null;
|
||||
if (typeof locallyStoredTheme === "string") {
|
||||
locallyStoredTheme = locallyStoredTheme.trim();
|
||||
}
|
||||
if (!(["auto", "light", "dark"].includes(locallyStoredTheme))) {
|
||||
locallyStoredTheme = null;
|
||||
}
|
||||
|
||||
const initialThemeChoice =
|
||||
new URLSearchParams(window.location.search).get("theme") ||
|
||||
window.localStorage?.getItem("theme");
|
||||
new URLSearchParams(window.location.search).get("theme") || locallyStoredTheme;
|
||||
|
||||
const themeChoice =
|
||||
initialThemeChoice || document.documentElement.dataset.themeChoice || "auto";
|
||||
|
||||
@@ -107,6 +107,8 @@ class TestApplicationsAPI(APITestCase):
|
||||
"provider_obj": {
|
||||
"assigned_application_name": "allowed",
|
||||
"assigned_application_slug": "allowed",
|
||||
"assigned_backchannel_application_name": None,
|
||||
"assigned_backchannel_application_slug": None,
|
||||
"authentication_flow": None,
|
||||
"invalidation_flow": None,
|
||||
"authorization_flow": str(self.provider.authorization_flow.pk),
|
||||
@@ -125,6 +127,7 @@ class TestApplicationsAPI(APITestCase):
|
||||
"open_in_new_tab": True,
|
||||
"meta_icon": "",
|
||||
"meta_icon_url": None,
|
||||
"meta_icon_themed_urls": None,
|
||||
"meta_description": "",
|
||||
"meta_publisher": "",
|
||||
"policy_engine_mode": "any",
|
||||
@@ -162,6 +165,8 @@ class TestApplicationsAPI(APITestCase):
|
||||
"provider_obj": {
|
||||
"assigned_application_name": "allowed",
|
||||
"assigned_application_slug": "allowed",
|
||||
"assigned_backchannel_application_name": None,
|
||||
"assigned_backchannel_application_slug": None,
|
||||
"authentication_flow": None,
|
||||
"invalidation_flow": None,
|
||||
"authorization_flow": str(self.provider.authorization_flow.pk),
|
||||
@@ -180,6 +185,7 @@ class TestApplicationsAPI(APITestCase):
|
||||
"open_in_new_tab": True,
|
||||
"meta_icon": "",
|
||||
"meta_icon_url": None,
|
||||
"meta_icon_themed_urls": None,
|
||||
"meta_description": "",
|
||||
"meta_publisher": "",
|
||||
"policy_engine_mode": "any",
|
||||
@@ -189,6 +195,7 @@ class TestApplicationsAPI(APITestCase):
|
||||
"meta_description": "",
|
||||
"meta_icon": "",
|
||||
"meta_icon_url": None,
|
||||
"meta_icon_themed_urls": None,
|
||||
"meta_launch_url": "",
|
||||
"open_in_new_tab": False,
|
||||
"meta_publisher": "",
|
||||
|
||||
@@ -3,11 +3,10 @@ from rest_framework.viewsets import ModelViewSet
|
||||
from authentik.core.api.used_by import UsedByMixin
|
||||
from authentik.endpoints.api.connectors import ConnectorSerializer
|
||||
from authentik.endpoints.models import EndpointStage
|
||||
from authentik.enterprise.api import EnterpriseRequiredMixin
|
||||
from authentik.flows.api.stages import StageSerializer
|
||||
|
||||
|
||||
class EndpointStageSerializer(EnterpriseRequiredMixin, StageSerializer):
|
||||
class EndpointStageSerializer(StageSerializer):
|
||||
"""EndpointStage Serializer"""
|
||||
|
||||
connector_obj = ConnectorSerializer(source="connector", read_only=True)
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
from django.db import models
|
||||
from rest_framework.fields import (
|
||||
BooleanField,
|
||||
CharField,
|
||||
@@ -14,6 +15,12 @@ from authentik.endpoints.models import Device
|
||||
from authentik.lib.utils.time import timedelta_from_string
|
||||
from authentik.providers.oauth2.views.jwks import JWKSView
|
||||
|
||||
try:
|
||||
from authentik.enterprise.models import LicenseUsageStatus
|
||||
except ImportError:
|
||||
|
||||
class LicenseUsageStatus(models.TextChoices): ...
|
||||
|
||||
|
||||
class AgentConfigSerializer(PassiveSerializer):
|
||||
|
||||
@@ -29,6 +36,7 @@ class AgentConfigSerializer(PassiveSerializer):
|
||||
auth_terminate_session_on_expiry = BooleanField()
|
||||
|
||||
system_config = SerializerMethodField()
|
||||
license_status = SerializerMethodField(required=False, allow_null=True)
|
||||
|
||||
def get_device_id(self, instance: AgentConnector) -> str:
|
||||
device: Device = self.context["device"]
|
||||
@@ -54,6 +62,14 @@ class AgentConfigSerializer(PassiveSerializer):
|
||||
def get_system_config(self, instance: AgentConnector) -> ConfigSerializer:
|
||||
return ConfigView.get_config(self.context["request"]).data
|
||||
|
||||
def get_license_status(self, instance: AgentConnector) -> "LicenseUsageStatus":
|
||||
try:
|
||||
from authentik.enterprise.license import LicenseKey
|
||||
|
||||
return LicenseKey.cached_summary().status
|
||||
except ModuleNotFoundError:
|
||||
return None
|
||||
|
||||
|
||||
class EnrollSerializer(PassiveSerializer):
|
||||
|
||||
|
||||
@@ -1,11 +1,13 @@
|
||||
from typing import cast
|
||||
|
||||
from django.utils.timezone import now
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
from drf_spectacular.types import OpenApiTypes
|
||||
from drf_spectacular.utils import OpenApiResponse, extend_schema
|
||||
from drf_spectacular.utils import OpenApiParameter, OpenApiResponse, extend_schema
|
||||
from rest_framework.decorators import action
|
||||
from rest_framework.exceptions import PermissionDenied, ValidationError
|
||||
from rest_framework.fields import ChoiceField
|
||||
from rest_framework.permissions import IsAuthenticated
|
||||
from rest_framework.relations import PrimaryKeyRelatedField
|
||||
from rest_framework.request import Request
|
||||
from rest_framework.response import Response
|
||||
@@ -22,6 +24,9 @@ from authentik.endpoints.connectors.agent.api.agent import (
|
||||
from authentik.endpoints.connectors.agent.auth import (
|
||||
AgentAuth,
|
||||
AgentEnrollmentAuth,
|
||||
DeviceAuthFedAuthentication,
|
||||
agent_auth_issue_token,
|
||||
check_device_policies,
|
||||
)
|
||||
from authentik.endpoints.connectors.agent.controller import MDMConfigResponseSerializer
|
||||
from authentik.endpoints.connectors.agent.models import (
|
||||
@@ -32,7 +37,10 @@ from authentik.endpoints.connectors.agent.models import (
|
||||
)
|
||||
from authentik.endpoints.facts import DeviceFacts, OSFamily
|
||||
from authentik.endpoints.models import Device
|
||||
from authentik.events.models import Event, EventAction
|
||||
from authentik.flows.planner import PLAN_CONTEXT_DEVICE
|
||||
from authentik.lib.utils.reflection import ConditionalInheritance
|
||||
from authentik.stages.password.stage import PLAN_CONTEXT_METHOD, PLAN_CONTEXT_METHOD_ARGS
|
||||
|
||||
|
||||
class AgentConnectorSerializer(ConnectorSerializer):
|
||||
@@ -163,3 +171,43 @@ class AgentConnectorViewSet(
|
||||
connection: AgentDeviceConnection = token.device
|
||||
connection.create_snapshot(data.validated_data)
|
||||
return Response(status=204)
|
||||
|
||||
@extend_schema(
|
||||
request=OpenApiTypes.NONE,
|
||||
parameters=[OpenApiParameter("device", OpenApiTypes.STR, location="query", required=True)],
|
||||
responses={
|
||||
200: AgentTokenResponseSerializer(),
|
||||
404: OpenApiResponse(description="Device not found"),
|
||||
},
|
||||
)
|
||||
@action(
|
||||
methods=["POST"],
|
||||
detail=False,
|
||||
pagination_class=None,
|
||||
filter_backends=[],
|
||||
permission_classes=[IsAuthenticated],
|
||||
authentication_classes=[DeviceAuthFedAuthentication],
|
||||
)
|
||||
def auth_fed(self, request: Request) -> Response:
|
||||
federated_token, device, connector = request.auth
|
||||
|
||||
policy_result = check_device_policies(device, federated_token.user, request._request)
|
||||
if not policy_result.passing:
|
||||
raise ValidationError(
|
||||
{"policy_result": "Policy denied access", "policy_messages": policy_result.messages}
|
||||
)
|
||||
|
||||
token, exp = agent_auth_issue_token(device, connector, federated_token.user)
|
||||
rel_exp = int((exp - now()).total_seconds())
|
||||
Event.new(
|
||||
EventAction.LOGIN,
|
||||
**{
|
||||
PLAN_CONTEXT_METHOD: "jwt",
|
||||
PLAN_CONTEXT_METHOD_ARGS: {
|
||||
"jwt": federated_token,
|
||||
"provider": federated_token.provider,
|
||||
},
|
||||
PLAN_CONTEXT_DEVICE: device,
|
||||
},
|
||||
).from_http(request, user=federated_token.user)
|
||||
return Response({"token": token, "expires_in": rel_exp})
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
from drf_spectacular.utils import OpenApiResponse, extend_schema
|
||||
from rest_framework.decorators import action
|
||||
from rest_framework.fields import CharField
|
||||
from rest_framework.request import Request
|
||||
from rest_framework.response import Response
|
||||
from rest_framework.viewsets import ModelViewSet
|
||||
|
||||
from authentik.blueprints.v1.importer import SERIALIZER_CONTEXT_BLUEPRINT
|
||||
from authentik.core.api.tokens import TokenViewSerializer
|
||||
from authentik.core.api.used_by import UsedByMixin
|
||||
from authentik.core.api.utils import ModelSerializer
|
||||
@@ -19,6 +21,11 @@ class EnrollmentTokenSerializer(ModelSerializer):
|
||||
source="device_group", read_only=True, required=False
|
||||
)
|
||||
|
||||
def __init__(self, *args, **kwargs) -> None:
|
||||
super().__init__(*args, **kwargs)
|
||||
if SERIALIZER_CONTEXT_BLUEPRINT in self.context:
|
||||
self.fields["key"] = CharField(required=False)
|
||||
|
||||
class Meta:
|
||||
model = EnrollmentToken
|
||||
fields = [
|
||||
|
||||
@@ -1,13 +1,28 @@
|
||||
from typing import Any
|
||||
|
||||
from django.http import HttpRequest
|
||||
from django.utils.timezone import now
|
||||
from drf_spectacular.extensions import OpenApiAuthenticationExtension
|
||||
from jwt import PyJWTError, decode, encode
|
||||
from rest_framework.authentication import BaseAuthentication, get_authorization_header
|
||||
from rest_framework.exceptions import PermissionDenied
|
||||
from rest_framework.request import Request
|
||||
from structlog.stdlib import get_logger
|
||||
|
||||
from authentik.api.authentication import IPCUser, validate_auth
|
||||
from authentik.core.middleware import CTX_AUTH_VIA
|
||||
from authentik.core.models import User
|
||||
from authentik.endpoints.connectors.agent.models import DeviceToken, EnrollmentToken
|
||||
from authentik.crypto.apps import MANAGED_KEY
|
||||
from authentik.crypto.models import CertificateKeyPair
|
||||
from authentik.endpoints.connectors.agent.models import AgentConnector, DeviceToken, EnrollmentToken
|
||||
from authentik.endpoints.models import Device
|
||||
from authentik.lib.utils.time import timedelta_from_string
|
||||
from authentik.policies.engine import PolicyEngine
|
||||
from authentik.policies.models import PolicyBindingModel
|
||||
from authentik.providers.oauth2.models import AccessToken, JWTAlgorithms, OAuth2Provider
|
||||
|
||||
LOGGER = get_logger()
|
||||
PLATFORM_ISSUER = "goauthentik.io/platform"
|
||||
|
||||
|
||||
class DeviceUser(IPCUser):
|
||||
@@ -40,3 +55,96 @@ class AgentAuth(BaseAuthentication):
|
||||
raise PermissionDenied()
|
||||
CTX_AUTH_VIA.set("endpoint_token")
|
||||
return (DeviceUser(), device_token)
|
||||
|
||||
|
||||
def agent_auth_issue_token(device: Device, connector: AgentConnector, user: User, **kwargs):
|
||||
kp = CertificateKeyPair.objects.filter(managed=MANAGED_KEY).first()
|
||||
if not kp:
|
||||
return None, None
|
||||
exp = now() + timedelta_from_string(connector.auth_session_duration)
|
||||
token = encode(
|
||||
{
|
||||
"iss": PLATFORM_ISSUER,
|
||||
"aud": str(device.pk),
|
||||
"iat": int(now().timestamp()),
|
||||
"exp": int(exp.timestamp()),
|
||||
"preferred_username": user.username,
|
||||
**kwargs,
|
||||
},
|
||||
kp.private_key,
|
||||
headers={
|
||||
"kid": kp.kid,
|
||||
},
|
||||
algorithm=JWTAlgorithms.from_private_key(kp.private_key),
|
||||
)
|
||||
return token, exp
|
||||
|
||||
|
||||
class DeviceAuthFedAuthentication(BaseAuthentication):
|
||||
|
||||
def authenticate(self, request):
|
||||
raw_token = validate_auth(get_authorization_header(request))
|
||||
if not raw_token:
|
||||
LOGGER.warning("Missing token")
|
||||
return None
|
||||
device = Device.filter_not_expired(name=request.query_params.get("device")).first()
|
||||
if not device:
|
||||
LOGGER.warning("Couldn't find device")
|
||||
return None
|
||||
connectors_for_device = AgentConnector.objects.filter(device__in=[device])
|
||||
connector = connectors_for_device.first()
|
||||
providers = OAuth2Provider.objects.filter(agentconnector__in=connectors_for_device)
|
||||
federated_token = AccessToken.objects.filter(
|
||||
token=raw_token, provider__in=providers
|
||||
).first()
|
||||
if not federated_token:
|
||||
LOGGER.warning("Couldn't lookup provider")
|
||||
return None
|
||||
_key, _alg = federated_token.provider.jwt_key
|
||||
try:
|
||||
decode(
|
||||
raw_token,
|
||||
_key.public_key(),
|
||||
algorithms=[_alg],
|
||||
options={
|
||||
"verify_aud": False,
|
||||
},
|
||||
)
|
||||
LOGGER.info(
|
||||
"successfully verified JWT with provider", provider=federated_token.provider.name
|
||||
)
|
||||
return (federated_token.user, (federated_token, device, connector))
|
||||
except (PyJWTError, ValueError, TypeError, AttributeError) as exc:
|
||||
LOGGER.warning("failed to verify JWT", exc=exc, provider=federated_token.provider.name)
|
||||
return None
|
||||
|
||||
|
||||
class DeviceFederationAuthSchema(OpenApiAuthenticationExtension):
|
||||
"""Auth schema"""
|
||||
|
||||
target_class = DeviceAuthFedAuthentication
|
||||
name = "device_federation"
|
||||
|
||||
def get_security_definition(self, auto_schema):
|
||||
"""Auth schema"""
|
||||
return {"type": "http", "scheme": "bearer"}
|
||||
|
||||
|
||||
def check_device_policies(device: Device, user: User, request: HttpRequest):
|
||||
"""Check policies bound to device group and device"""
|
||||
if device.access_group:
|
||||
result = check_pbm_policies(device.access_group, user, request)
|
||||
if result.passing:
|
||||
return result
|
||||
return check_pbm_policies(device, user, request)
|
||||
|
||||
|
||||
def check_pbm_policies(pbm: PolicyBindingModel, user: User, request: HttpRequest):
|
||||
policy_engine = PolicyEngine(pbm, user, request)
|
||||
policy_engine.use_cache = False
|
||||
policy_engine.empty_result = False
|
||||
policy_engine.mode = pbm.policy_engine_mode
|
||||
policy_engine.build()
|
||||
result = policy_engine.result
|
||||
LOGGER.debug("PolicyAccessView user_has_access", user=user.username, result=result, pbm=pbm.pk)
|
||||
return result
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from datetime import timedelta
|
||||
from hashlib import sha256
|
||||
from hmac import compare_digest
|
||||
|
||||
from django.http import HttpResponse
|
||||
@@ -8,7 +9,7 @@ from rest_framework.exceptions import ValidationError
|
||||
from rest_framework.fields import CharField, IntegerField
|
||||
|
||||
from authentik.crypto.models import CertificateKeyPair
|
||||
from authentik.endpoints.connectors.agent.models import DeviceToken
|
||||
from authentik.endpoints.connectors.agent.models import DeviceAuthenticationToken, DeviceToken
|
||||
from authentik.endpoints.models import Device, EndpointStage, StageMode
|
||||
from authentik.flows.challenge import (
|
||||
Challenge,
|
||||
@@ -20,6 +21,7 @@ from authentik.lib.generators import generate_id
|
||||
from authentik.lib.utils.time import timedelta_from_string
|
||||
from authentik.providers.oauth2.models import JWTAlgorithms
|
||||
|
||||
PLAN_CONTEXT_DEVICE_AUTH_TOKEN = "goauthentik.io/endpoints/device_auth_token" # nosec
|
||||
PLAN_CONTEXT_AGENT_ENDPOINT_CHALLENGE = "goauthentik.io/endpoints/connectors/agent/challenge"
|
||||
QS_CHALLENGE = "challenge"
|
||||
QS_CHALLENGE_RESPONSE = "response"
|
||||
@@ -85,12 +87,36 @@ class AuthenticatorEndpointStageView(ChallengeStageView):
|
||||
response_class = EndpointAgentChallengeResponse
|
||||
|
||||
def get(self, request, *args, **kwargs):
|
||||
# Check if we're in a device interactive auth flow, in which case we use that
|
||||
# to prove which device is being used
|
||||
if response := self.check_device_ia():
|
||||
return response
|
||||
stage: EndpointStage = self.executor.current_stage
|
||||
keypair = CertificateKeyPair.objects.filter(pk=stage.connector.challenge_key_id).first()
|
||||
if not keypair:
|
||||
return self.executor.stage_ok()
|
||||
return super().get(request, *args, **kwargs)
|
||||
|
||||
def check_device_ia(self):
|
||||
"""Check if we're in a device interactive authentication flow, and if so,
|
||||
there won't be a browser extension to talk to. However we can authenticate
|
||||
on the DTH header"""
|
||||
if PLAN_CONTEXT_DEVICE_AUTH_TOKEN not in self.executor.plan.context:
|
||||
return None
|
||||
auth_token: DeviceAuthenticationToken = self.executor.plan.context.get(
|
||||
PLAN_CONTEXT_DEVICE_AUTH_TOKEN
|
||||
)
|
||||
device_token_hash = self.request.headers.get("X-Authentik-Platform-Auth-DTH")
|
||||
if not device_token_hash:
|
||||
return None
|
||||
if not compare_digest(
|
||||
device_token_hash, sha256(auth_token.device_token.key.encode()).hexdigest()
|
||||
):
|
||||
return self.executor.stage_invalid("Invalid device token")
|
||||
self.logger.debug("Setting device based on DTH header")
|
||||
self.executor.plan.context[PLAN_CONTEXT_DEVICE] = auth_token.device
|
||||
return self.executor.stage_ok()
|
||||
|
||||
def get_challenge(self, *args, **kwargs) -> Challenge:
|
||||
stage: EndpointStage = self.executor.current_stage
|
||||
keypair = CertificateKeyPair.objects.get(pk=stage.connector.challenge_key_id)
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
from hashlib import sha256
|
||||
from json import loads
|
||||
|
||||
from django.urls import reverse
|
||||
@@ -7,10 +8,14 @@ from authentik.core.tests.utils import create_test_cert, create_test_flow
|
||||
from authentik.endpoints.connectors.agent.models import (
|
||||
AgentConnector,
|
||||
AgentDeviceConnection,
|
||||
DeviceAuthenticationToken,
|
||||
DeviceToken,
|
||||
EnrollmentToken,
|
||||
)
|
||||
from authentik.endpoints.connectors.agent.stage import PLAN_CONTEXT_AGENT_ENDPOINT_CHALLENGE
|
||||
from authentik.endpoints.connectors.agent.stage import (
|
||||
PLAN_CONTEXT_AGENT_ENDPOINT_CHALLENGE,
|
||||
PLAN_CONTEXT_DEVICE_AUTH_TOKEN,
|
||||
)
|
||||
from authentik.endpoints.models import Device, EndpointStage, StageMode
|
||||
from authentik.flows.models import FlowStageBinding
|
||||
from authentik.flows.planner import PLAN_CONTEXT_DEVICE
|
||||
@@ -35,6 +40,11 @@ class TestEndpointStage(FlowTestCase):
|
||||
device=self.connection,
|
||||
key=generate_id(),
|
||||
)
|
||||
self.device_auth_token = DeviceAuthenticationToken.objects.create(
|
||||
device=self.device,
|
||||
device_token=self.device_token,
|
||||
connector=self.connector,
|
||||
)
|
||||
|
||||
def test_endpoint_stage(self):
|
||||
flow = create_test_flow()
|
||||
@@ -194,3 +204,31 @@ class TestEndpointStage(FlowTestCase):
|
||||
"response": [{"string": "Invalid challenge response", "code": "invalid"}]
|
||||
},
|
||||
)
|
||||
|
||||
def test_endpoint_stage_ia_dth(self):
|
||||
"""Test with DTH"""
|
||||
flow = create_test_flow()
|
||||
stage = EndpointStage.objects.create(connector=self.connector)
|
||||
FlowStageBinding.objects.create(stage=stage, target=flow, order=0)
|
||||
|
||||
# Send an "invalid" request first, to populate the flow plan
|
||||
res = self.client.get(
|
||||
reverse("authentik_api:flow-executor", kwargs={"flow_slug": flow.slug}),
|
||||
)
|
||||
plan = self.get_flow_plan()
|
||||
plan.context[PLAN_CONTEXT_DEVICE_AUTH_TOKEN] = DeviceAuthenticationToken.objects.get(
|
||||
pk=self.device_auth_token.pk
|
||||
)
|
||||
self.set_flow_plan(plan)
|
||||
|
||||
with self.assertFlowFinishes() as plan:
|
||||
res = self.client.get(
|
||||
reverse("authentik_api:flow-executor", kwargs={"flow_slug": flow.slug}),
|
||||
HTTP_X_AUTHENTIK_PLATFORM_AUTH_DTH=sha256(
|
||||
self.device_token.key.encode()
|
||||
).hexdigest(),
|
||||
)
|
||||
self.assertStageRedirects(res, reverse("authentik_core:root-redirect"))
|
||||
plan = plan()
|
||||
self.assertNotIn(PLAN_CONTEXT_AGENT_ENDPOINT_CHALLENGE, plan.context)
|
||||
self.assertEqual(plan.context[PLAN_CONTEXT_DEVICE], self.device)
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
"""Enterprise API Views"""
|
||||
|
||||
from collections.abc import Callable
|
||||
from datetime import timedelta
|
||||
from functools import wraps
|
||||
|
||||
from django.utils.timezone import now
|
||||
from django.utils.translation import gettext as _
|
||||
@@ -35,6 +37,18 @@ class EnterpriseRequiredMixin:
|
||||
return super().validate(attrs)
|
||||
|
||||
|
||||
def enterprise_action(func: Callable):
|
||||
"""Check permissions for a single custom action"""
|
||||
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs) -> Response:
|
||||
if not LicenseKey.cached_summary().status.is_valid:
|
||||
raise ValidationError(_("Enterprise is required to use this endpoint."))
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
class LicenseSerializer(ModelSerializer):
|
||||
"""License Serializer"""
|
||||
|
||||
|
||||
@@ -1,31 +1,20 @@
|
||||
from django.urls import reverse
|
||||
from django.utils.timezone import now
|
||||
from drf_spectacular.types import OpenApiTypes
|
||||
from drf_spectacular.utils import OpenApiParameter, OpenApiResponse, extend_schema
|
||||
from drf_spectacular.utils import extend_schema
|
||||
from rest_framework.decorators import action
|
||||
from rest_framework.exceptions import ValidationError
|
||||
from rest_framework.permissions import IsAuthenticated
|
||||
from rest_framework.request import Request
|
||||
from rest_framework.response import Response
|
||||
from structlog.stdlib import get_logger
|
||||
|
||||
from authentik.endpoints.connectors.agent.api.agent import (
|
||||
AgentAuthenticationResponse,
|
||||
AgentTokenResponseSerializer,
|
||||
)
|
||||
from authentik.endpoints.connectors.agent.auth import AgentAuth
|
||||
from authentik.endpoints.connectors.agent.models import (
|
||||
DeviceAuthenticationToken,
|
||||
DeviceToken,
|
||||
)
|
||||
from authentik.enterprise.endpoints.connectors.agent.auth import (
|
||||
DeviceAuthFedAuthentication,
|
||||
agent_auth_issue_token,
|
||||
check_device_policies,
|
||||
)
|
||||
from authentik.events.models import Event, EventAction
|
||||
from authentik.flows.planner import PLAN_CONTEXT_DEVICE
|
||||
from authentik.stages.password.stage import PLAN_CONTEXT_METHOD, PLAN_CONTEXT_METHOD_ARGS
|
||||
from authentik.enterprise.api import enterprise_action
|
||||
|
||||
LOGGER = get_logger()
|
||||
|
||||
@@ -37,6 +26,7 @@ class AgentConnectorViewSetMixin:
|
||||
responses=AgentAuthenticationResponse(),
|
||||
)
|
||||
@action(methods=["POST"], detail=False, authentication_classes=[AgentAuth])
|
||||
@enterprise_action
|
||||
def auth_ia(self, request: Request) -> Response:
|
||||
token: DeviceToken = request.auth
|
||||
auth_token = DeviceAuthenticationToken.objects.create(
|
||||
@@ -54,43 +44,3 @@ class AgentConnectorViewSetMixin:
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
@extend_schema(
|
||||
request=OpenApiTypes.NONE,
|
||||
parameters=[OpenApiParameter("device", OpenApiTypes.STR, location="query", required=True)],
|
||||
responses={
|
||||
200: AgentTokenResponseSerializer(),
|
||||
404: OpenApiResponse(description="Device not found"),
|
||||
},
|
||||
)
|
||||
@action(
|
||||
methods=["POST"],
|
||||
detail=False,
|
||||
pagination_class=None,
|
||||
filter_backends=[],
|
||||
permission_classes=[IsAuthenticated],
|
||||
authentication_classes=[DeviceAuthFedAuthentication],
|
||||
)
|
||||
def auth_fed(self, request: Request) -> Response:
|
||||
federated_token, device, connector = request.auth
|
||||
|
||||
policy_result = check_device_policies(device, federated_token.user, request._request)
|
||||
if not policy_result.passing:
|
||||
raise ValidationError(
|
||||
{"policy_result": "Policy denied access", "policy_messages": policy_result.messages}
|
||||
)
|
||||
|
||||
token, exp = agent_auth_issue_token(device, connector, federated_token.user)
|
||||
rel_exp = int((exp - now()).total_seconds())
|
||||
Event.new(
|
||||
EventAction.LOGIN,
|
||||
**{
|
||||
PLAN_CONTEXT_METHOD: "jwt",
|
||||
PLAN_CONTEXT_METHOD_ARGS: {
|
||||
"jwt": federated_token,
|
||||
"provider": federated_token.provider,
|
||||
},
|
||||
PLAN_CONTEXT_DEVICE: device,
|
||||
},
|
||||
).from_http(request, user=federated_token.user)
|
||||
return Response({"token": token, "expires_in": rel_exp})
|
||||
|
||||
@@ -1,113 +0,0 @@
|
||||
from django.http import HttpRequest
|
||||
from django.utils.timezone import now
|
||||
from drf_spectacular.extensions import OpenApiAuthenticationExtension
|
||||
from jwt import PyJWTError, decode, encode
|
||||
from rest_framework.authentication import BaseAuthentication
|
||||
from structlog.stdlib import get_logger
|
||||
|
||||
from authentik.api.authentication import get_authorization_header, validate_auth
|
||||
from authentik.core.models import User
|
||||
from authentik.crypto.apps import MANAGED_KEY
|
||||
from authentik.crypto.models import CertificateKeyPair
|
||||
from authentik.endpoints.connectors.agent.models import AgentConnector
|
||||
from authentik.endpoints.models import Device
|
||||
from authentik.lib.utils.time import timedelta_from_string
|
||||
from authentik.policies.engine import PolicyEngine
|
||||
from authentik.policies.models import PolicyBindingModel
|
||||
from authentik.providers.oauth2.models import AccessToken, JWTAlgorithms, OAuth2Provider
|
||||
|
||||
LOGGER = get_logger()
|
||||
PLATFORM_ISSUER = "goauthentik.io/platform"
|
||||
|
||||
|
||||
def agent_auth_issue_token(device: Device, connector: AgentConnector, user: User, **kwargs):
|
||||
kp = CertificateKeyPair.objects.filter(managed=MANAGED_KEY).first()
|
||||
if not kp:
|
||||
return None, None
|
||||
exp = now() + timedelta_from_string(connector.auth_session_duration)
|
||||
token = encode(
|
||||
{
|
||||
"iss": PLATFORM_ISSUER,
|
||||
"aud": str(device.pk),
|
||||
"iat": int(now().timestamp()),
|
||||
"exp": int(exp.timestamp()),
|
||||
"preferred_username": user.username,
|
||||
**kwargs,
|
||||
},
|
||||
kp.private_key,
|
||||
headers={
|
||||
"kid": kp.kid,
|
||||
},
|
||||
algorithm=JWTAlgorithms.from_private_key(kp.private_key),
|
||||
)
|
||||
return token, exp
|
||||
|
||||
|
||||
class DeviceAuthFedAuthentication(BaseAuthentication):
|
||||
|
||||
def authenticate(self, request):
|
||||
raw_token = validate_auth(get_authorization_header(request))
|
||||
if not raw_token:
|
||||
LOGGER.warning("Missing token")
|
||||
return None
|
||||
device = Device.filter_not_expired(name=request.query_params.get("device")).first()
|
||||
if not device:
|
||||
LOGGER.warning("Couldn't find device")
|
||||
return None
|
||||
connectors_for_device = AgentConnector.objects.filter(device__in=[device])
|
||||
connector = connectors_for_device.first()
|
||||
providers = OAuth2Provider.objects.filter(agentconnector__in=connectors_for_device)
|
||||
federated_token = AccessToken.objects.filter(
|
||||
token=raw_token, provider__in=providers
|
||||
).first()
|
||||
if not federated_token:
|
||||
LOGGER.warning("Couldn't lookup provider")
|
||||
return None
|
||||
_key, _alg = federated_token.provider.jwt_key
|
||||
try:
|
||||
decode(
|
||||
raw_token,
|
||||
_key.public_key(),
|
||||
algorithms=[_alg],
|
||||
options={
|
||||
"verify_aud": False,
|
||||
},
|
||||
)
|
||||
LOGGER.info(
|
||||
"successfully verified JWT with provider", provider=federated_token.provider.name
|
||||
)
|
||||
return (federated_token.user, (federated_token, device, connector))
|
||||
except (PyJWTError, ValueError, TypeError, AttributeError) as exc:
|
||||
LOGGER.warning("failed to verify JWT", exc=exc, provider=federated_token.provider.name)
|
||||
return None
|
||||
|
||||
|
||||
class DeviceFederationAuthSchema(OpenApiAuthenticationExtension):
|
||||
"""Auth schema"""
|
||||
|
||||
target_class = DeviceAuthFedAuthentication
|
||||
name = "device_federation"
|
||||
|
||||
def get_security_definition(self, auto_schema):
|
||||
"""Auth schema"""
|
||||
return {"type": "http", "scheme": "bearer"}
|
||||
|
||||
|
||||
def check_device_policies(device: Device, user: User, request: HttpRequest):
|
||||
"""Check policies bound to device group and device"""
|
||||
if device.access_group:
|
||||
result = check_pbm_policies(device.access_group, user, request)
|
||||
if result.passing:
|
||||
return result
|
||||
return check_pbm_policies(device, user, request)
|
||||
|
||||
|
||||
def check_pbm_policies(pbm: PolicyBindingModel, user: User, request: HttpRequest):
|
||||
policy_engine = PolicyEngine(pbm, user, request)
|
||||
policy_engine.use_cache = False
|
||||
policy_engine.empty_result = False
|
||||
policy_engine.mode = pbm.policy_engine_mode
|
||||
policy_engine.build()
|
||||
result = policy_engine.result
|
||||
LOGGER.debug("PolicyAccessView user_has_access", user=user.username, result=result, pbm=pbm.pk)
|
||||
return result
|
||||
@@ -63,8 +63,21 @@ class TestConnectorAuthIA(FlowTestCase):
|
||||
)
|
||||
self.assertEqual(response.status_code, 200)
|
||||
|
||||
@patch(
|
||||
"authentik.enterprise.license.LicenseKey.validate",
|
||||
MagicMock(
|
||||
return_value=LicenseKey(
|
||||
aud="",
|
||||
exp=expiry_valid,
|
||||
name=generate_id(),
|
||||
internal_users=100,
|
||||
external_users=100,
|
||||
)
|
||||
),
|
||||
)
|
||||
@reconcile_app("authentik_crypto")
|
||||
def test_auth_ia_fulfill(self):
|
||||
License.objects.create(key=generate_id())
|
||||
self.client.force_login(self.user)
|
||||
response = self.client.post(
|
||||
reverse("authentik_api:agentconnector-auth-ia"),
|
||||
|
||||
@@ -3,12 +3,13 @@ from hmac import compare_digest
|
||||
|
||||
from django.http import Http404, HttpRequest, HttpResponse, HttpResponseBadRequest, QueryDict
|
||||
|
||||
from authentik.endpoints.connectors.agent.models import AgentConnector, DeviceAuthenticationToken
|
||||
from authentik.endpoints.models import Device
|
||||
from authentik.enterprise.endpoints.connectors.agent.auth import (
|
||||
from authentik.endpoints.connectors.agent.auth import (
|
||||
agent_auth_issue_token,
|
||||
check_device_policies,
|
||||
)
|
||||
from authentik.endpoints.connectors.agent.models import AgentConnector, DeviceAuthenticationToken
|
||||
from authentik.endpoints.connectors.agent.stage import PLAN_CONTEXT_DEVICE_AUTH_TOKEN
|
||||
from authentik.endpoints.models import Device
|
||||
from authentik.enterprise.policy import EnterprisePolicyAccessView
|
||||
from authentik.flows.exceptions import FlowNonApplicableException
|
||||
from authentik.flows.models import in_memory_stage
|
||||
@@ -16,8 +17,6 @@ from authentik.flows.planner import PLAN_CONTEXT_DEVICE, FlowPlanner
|
||||
from authentik.flows.stage import StageView
|
||||
from authentik.providers.oauth2.utils import HttpResponseRedirectScheme
|
||||
|
||||
PLAN_CONTEXT_DEVICE_AUTH_TOKEN = "goauthentik.io/endpoints/device_auth_token" # nosec
|
||||
|
||||
QS_AGENT_IA_TOKEN = "ak-auth-ia-token" # nosec
|
||||
|
||||
|
||||
|
||||
@@ -44,20 +44,13 @@ class GoogleWorkspaceGroupClient(
|
||||
email=f"{slugify(obj.name)}@{self.provider.default_group_email_domain}",
|
||||
)
|
||||
|
||||
def delete(self, obj: Group):
|
||||
def delete(self, identifier: str):
|
||||
"""Delete group"""
|
||||
google_group = GoogleWorkspaceProviderGroup.objects.filter(
|
||||
provider=self.provider, group=obj
|
||||
).first()
|
||||
if not google_group:
|
||||
self.logger.debug("Group does not exist in Google, skipping")
|
||||
return None
|
||||
with transaction.atomic():
|
||||
if self.provider.group_delete_action == OutgoingSyncDeleteAction.DELETE:
|
||||
self._request(
|
||||
self.directory_service.groups().delete(groupKey=google_group.google_id)
|
||||
)
|
||||
google_group.delete()
|
||||
GoogleWorkspaceProviderGroup.objects.filter(
|
||||
provider=self.provider, google_id=identifier
|
||||
).delete()
|
||||
if self.provider.group_delete_action == OutgoingSyncDeleteAction.DELETE:
|
||||
return self._request(self.directory_service.groups().delete(groupKey=identifier))
|
||||
|
||||
def create(self, group: Group):
|
||||
"""Create group from scratch and create a connection object"""
|
||||
|
||||
@@ -35,28 +35,17 @@ class GoogleWorkspaceUserClient(GoogleWorkspaceSyncClient[User, GoogleWorkspaceP
|
||||
"""Convert authentik user"""
|
||||
return delete_none_values(super().to_schema(obj, connection, primaryEmail=obj.email))
|
||||
|
||||
def delete(self, obj: User):
|
||||
def delete(self, identifier: str):
|
||||
"""Delete user"""
|
||||
google_user = GoogleWorkspaceProviderUser.objects.filter(
|
||||
provider=self.provider, user=obj
|
||||
).first()
|
||||
if not google_user:
|
||||
self.logger.debug("User does not exist in Google, skipping")
|
||||
return None
|
||||
with transaction.atomic():
|
||||
response = None
|
||||
if self.provider.user_delete_action == OutgoingSyncDeleteAction.DELETE:
|
||||
response = self._request(
|
||||
self.directory_service.users().delete(userKey=google_user.google_id)
|
||||
)
|
||||
elif self.provider.user_delete_action == OutgoingSyncDeleteAction.SUSPEND:
|
||||
response = self._request(
|
||||
self.directory_service.users().update(
|
||||
userKey=google_user.google_id, body={"suspended": True}
|
||||
)
|
||||
)
|
||||
google_user.delete()
|
||||
return response
|
||||
GoogleWorkspaceProviderUser.objects.filter(
|
||||
provider=self.provider, google_id=identifier
|
||||
).delete()
|
||||
if self.provider.user_delete_action == OutgoingSyncDeleteAction.DELETE:
|
||||
return self._request(self.directory_service.users().delete(userKey=identifier))
|
||||
if self.provider.user_delete_action == OutgoingSyncDeleteAction.SUSPEND:
|
||||
return self._request(
|
||||
self.directory_service.users().update(userKey=identifier, body={"suspended": True})
|
||||
)
|
||||
|
||||
def create(self, user: User):
|
||||
"""Create user from scratch and create a connection object"""
|
||||
|
||||
@@ -152,6 +152,18 @@ class GoogleWorkspaceProvider(OutgoingSyncProvider, BackchannelProvider):
|
||||
return Group.objects.all().order_by("pk")
|
||||
raise ValueError(f"Invalid type {type}")
|
||||
|
||||
@classmethod
|
||||
def get_object_mappings(cls, obj: User | Group) -> list[tuple[str, str]]:
|
||||
if isinstance(obj, User):
|
||||
return list(
|
||||
obj.googleworkspaceprovideruser_set.values_list("provider__pk", "google_id")
|
||||
)
|
||||
if isinstance(obj, Group):
|
||||
return list(
|
||||
obj.googleworkspaceprovidergroup_set.values_list("provider__pk", "google_id")
|
||||
)
|
||||
raise ValueError(f"Invalid type {type(obj)}")
|
||||
|
||||
def google_credentials(self):
|
||||
return {
|
||||
"credentials": Credentials.from_service_account_info(
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
|
||||
from authentik.enterprise.providers.google_workspace.models import GoogleWorkspaceProvider
|
||||
from authentik.enterprise.providers.google_workspace.tasks import (
|
||||
google_workspace_sync_delete_dispatch,
|
||||
google_workspace_sync_direct_dispatch,
|
||||
google_workspace_sync_m2m_dispatch,
|
||||
)
|
||||
@@ -10,5 +11,6 @@ from authentik.lib.sync.outgoing.signals import register_signals
|
||||
register_signals(
|
||||
GoogleWorkspaceProvider,
|
||||
task_sync_direct_dispatch=google_workspace_sync_direct_dispatch,
|
||||
task_sync_delete_dispatch=google_workspace_sync_delete_dispatch,
|
||||
task_sync_m2m_dispatch=google_workspace_sync_m2m_dispatch,
|
||||
)
|
||||
|
||||
@@ -25,6 +25,18 @@ def google_workspace_sync_direct(*args, **kwargs):
|
||||
return sync_tasks.sync_signal_direct(*args, **kwargs)
|
||||
|
||||
|
||||
@actor(
|
||||
description=_("Dispatch deletions for an object (user, group) for Google Workspace providers.")
|
||||
)
|
||||
def google_workspace_sync_delete_dispatch(*args, **kwargs):
|
||||
return sync_tasks.sync_signal_delete_dispatch(google_workspace_sync_delete, *args, **kwargs)
|
||||
|
||||
|
||||
@actor(description=_("Delete an object (user, group) for Google Workspace provider."))
|
||||
def google_workspace_sync_delete(*args, **kwargs):
|
||||
return sync_tasks.sync_signal_delete(*args, **kwargs)
|
||||
|
||||
|
||||
@actor(
|
||||
description=_(
|
||||
"Dispatch syncs for a direct object (user, group) for Google Workspace providers."
|
||||
|
||||
@@ -48,18 +48,13 @@ class MicrosoftEntraGroupClient(
|
||||
except TypeError as exc:
|
||||
raise StopSync(exc, obj) from exc
|
||||
|
||||
def delete(self, obj: Group):
|
||||
def delete(self, identifier: str):
|
||||
"""Delete group"""
|
||||
microsoft_group = MicrosoftEntraProviderGroup.objects.filter(
|
||||
provider=self.provider, group=obj
|
||||
).first()
|
||||
if not microsoft_group:
|
||||
self.logger.debug("Group does not exist in Microsoft, skipping")
|
||||
return None
|
||||
with transaction.atomic():
|
||||
if self.provider.group_delete_action == OutgoingSyncDeleteAction.DELETE:
|
||||
self._request(self.client.groups.by_group_id(microsoft_group.microsoft_id).delete())
|
||||
microsoft_group.delete()
|
||||
MicrosoftEntraProviderGroup.objects.filter(
|
||||
provider=self.provider, microsoft_id=identifier
|
||||
).delete()
|
||||
if self.provider.group_delete_action == OutgoingSyncDeleteAction.DELETE:
|
||||
return self._request(self.client.groups.by_group_id(identifier).delete())
|
||||
|
||||
def create(self, group: Group):
|
||||
"""Create group from scratch and create a connection object"""
|
||||
|
||||
@@ -43,28 +43,17 @@ class MicrosoftEntraUserClient(MicrosoftEntraSyncClient[User, MicrosoftEntraProv
|
||||
except TypeError as exc:
|
||||
raise StopSync(exc, obj) from exc
|
||||
|
||||
def delete(self, obj: User):
|
||||
def delete(self, identifier: str):
|
||||
"""Delete user"""
|
||||
microsoft_user = MicrosoftEntraProviderUser.objects.filter(
|
||||
provider=self.provider, user=obj
|
||||
).first()
|
||||
if not microsoft_user:
|
||||
self.logger.debug("User does not exist in Microsoft, skipping")
|
||||
return None
|
||||
with transaction.atomic():
|
||||
response = None
|
||||
if self.provider.user_delete_action == OutgoingSyncDeleteAction.DELETE:
|
||||
response = self._request(
|
||||
self.client.users.by_user_id(microsoft_user.microsoft_id).delete()
|
||||
)
|
||||
elif self.provider.user_delete_action == OutgoingSyncDeleteAction.SUSPEND:
|
||||
response = self._request(
|
||||
self.client.users.by_user_id(microsoft_user.microsoft_id).patch(
|
||||
MSUser(account_enabled=False)
|
||||
)
|
||||
)
|
||||
microsoft_user.delete()
|
||||
return response
|
||||
MicrosoftEntraProviderUser.objects.filter(
|
||||
provider=self.provider, microsoft_id=identifier
|
||||
).delete()
|
||||
if self.provider.user_delete_action == OutgoingSyncDeleteAction.DELETE:
|
||||
return self._request(self.client.users.by_user_id(identifier).delete())
|
||||
if self.provider.user_delete_action == OutgoingSyncDeleteAction.SUSPEND:
|
||||
return self._request(
|
||||
self.client.users.by_user_id(identifier).patch(MSUser(account_enabled=False))
|
||||
)
|
||||
|
||||
def get_select_fields(self) -> list[str]:
|
||||
"""All fields that should be selected when we fetch user data."""
|
||||
|
||||
@@ -141,6 +141,18 @@ class MicrosoftEntraProvider(OutgoingSyncProvider, BackchannelProvider):
|
||||
return Group.objects.all().order_by("pk")
|
||||
raise ValueError(f"Invalid type {type}")
|
||||
|
||||
@classmethod
|
||||
def get_object_mappings(cls, obj: User | Group) -> list[tuple[str, str]]:
|
||||
if isinstance(obj, User):
|
||||
return list(
|
||||
obj.microsoftentraprovideruser_set.values_list("provider__pk", "microsoft_id")
|
||||
)
|
||||
if isinstance(obj, Group):
|
||||
return list(
|
||||
obj.microsoftentraprovidergroup_set.values_list("provider__pk", "microsoft_id")
|
||||
)
|
||||
raise ValueError(f"Invalid type {type(obj)}")
|
||||
|
||||
def microsoft_credentials(self):
|
||||
return {
|
||||
"credentials": ClientSecretCredential(
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
|
||||
from authentik.enterprise.providers.microsoft_entra.models import MicrosoftEntraProvider
|
||||
from authentik.enterprise.providers.microsoft_entra.tasks import (
|
||||
microsoft_entra_sync_delete_dispatch,
|
||||
microsoft_entra_sync_direct_dispatch,
|
||||
microsoft_entra_sync_m2m_dispatch,
|
||||
)
|
||||
@@ -10,5 +11,6 @@ from authentik.lib.sync.outgoing.signals import register_signals
|
||||
register_signals(
|
||||
MicrosoftEntraProvider,
|
||||
task_sync_direct_dispatch=microsoft_entra_sync_direct_dispatch,
|
||||
task_sync_delete_dispatch=microsoft_entra_sync_delete_dispatch,
|
||||
task_sync_m2m_dispatch=microsoft_entra_sync_m2m_dispatch,
|
||||
)
|
||||
|
||||
@@ -32,6 +32,18 @@ def microsoft_entra_sync_direct_dispatch(*args, **kwargs):
|
||||
return sync_tasks.sync_signal_direct_dispatch(microsoft_entra_sync_direct, *args, **kwargs)
|
||||
|
||||
|
||||
@actor(description=_("Delete an object (user, group) for Microsoft Entra provider."))
|
||||
def microsoft_entra_sync_delete(*args, **kwargs):
|
||||
return sync_tasks.sync_signal_delete(*args, **kwargs)
|
||||
|
||||
|
||||
@actor(
|
||||
description=_("Dispatch deletions for an object (user, group) for Microsoft Entra providers.")
|
||||
)
|
||||
def microsoft_entra_sync_delete_dispatch(*args, **kwargs):
|
||||
return sync_tasks.sync_signal_delete_dispatch(microsoft_entra_sync_delete, *args, **kwargs)
|
||||
|
||||
|
||||
@actor(description=_("Sync a related object (memberships) for Microsoft Entra provider."))
|
||||
def microsoft_entra_sync_m2m(*args, **kwargs):
|
||||
return sync_tasks.sync_signal_m2m(*args, **kwargs)
|
||||
|
||||
@@ -4,37 +4,35 @@ from django.urls import reverse
|
||||
from drf_spectacular.utils import extend_schema
|
||||
from rest_framework import mixins
|
||||
from rest_framework.decorators import action
|
||||
from rest_framework.fields import CharField
|
||||
from rest_framework.fields import CharField, SerializerMethodField
|
||||
from rest_framework.permissions import BasePermission
|
||||
from rest_framework.request import Request
|
||||
from rest_framework.response import Response
|
||||
from rest_framework.viewsets import GenericViewSet
|
||||
|
||||
from authentik.core.api.groups import PartialUserSerializer
|
||||
from authentik.core.api.utils import ModelSerializer
|
||||
from authentik.core.models import User
|
||||
from authentik.enterprise.api import EnterpriseRequiredMixin
|
||||
from authentik.enterprise.reports.models import DataExport
|
||||
from authentik.enterprise.reports.tasks import generate_export
|
||||
from authentik.rbac.permissions import HasPermission
|
||||
|
||||
|
||||
class RequestedBySerializer(ModelSerializer):
|
||||
class Meta:
|
||||
model = User
|
||||
fields = ("pk", "username")
|
||||
|
||||
|
||||
class ContentTypeSerializer(ModelSerializer):
|
||||
app_label = CharField(read_only=True)
|
||||
model = CharField(read_only=True)
|
||||
verbose_name_plural = SerializerMethodField()
|
||||
|
||||
def get_verbose_name_plural(self, ct: ContentType) -> str:
|
||||
return ct.model_class()._meta.verbose_name_plural
|
||||
|
||||
class Meta:
|
||||
model = ContentType
|
||||
fields = ("id", "app_label", "model")
|
||||
fields = ("id", "app_label", "model", "verbose_name_plural")
|
||||
|
||||
|
||||
class DataExportSerializer(EnterpriseRequiredMixin, ModelSerializer):
|
||||
requested_by = RequestedBySerializer(read_only=True)
|
||||
requested_by = PartialUserSerializer(read_only=True)
|
||||
content_type = ContentTypeSerializer(read_only=True)
|
||||
|
||||
class Meta:
|
||||
|
||||
@@ -7,6 +7,7 @@ from django.db import connection
|
||||
from django.db.models import Model, Q
|
||||
from djangoql.compat import text_type
|
||||
from djangoql.schema import StrField
|
||||
from djangoql.serializers import DjangoQLSchemaSerializer
|
||||
|
||||
|
||||
class JSONSearchField(StrField):
|
||||
@@ -14,10 +15,18 @@ class JSONSearchField(StrField):
|
||||
|
||||
model: Model
|
||||
|
||||
def __init__(self, model=None, name=None, nullable=None, suggest_nested=True):
|
||||
def __init__(
|
||||
self,
|
||||
model=None,
|
||||
name=None,
|
||||
nullable=None,
|
||||
suggest_nested=False,
|
||||
fixed_structure: OrderedDict | None = None,
|
||||
):
|
||||
# Set this in the constructor to not clobber the type variable
|
||||
self.type = "relation"
|
||||
self.suggest_nested = suggest_nested
|
||||
self.fixed_structure = fixed_structure
|
||||
super().__init__(model, name, nullable)
|
||||
|
||||
def get_lookup(self, path, operator, value):
|
||||
@@ -57,11 +66,23 @@ class JSONSearchField(StrField):
|
||||
)
|
||||
return (x[0] for x in cursor.fetchall())
|
||||
|
||||
def get_nested_options(self) -> OrderedDict:
|
||||
def get_fixed_structure(self, serializer: DjangoQLSchemaSerializer) -> OrderedDict:
|
||||
new_dict = OrderedDict()
|
||||
if not self.fixed_structure:
|
||||
return new_dict
|
||||
new_dict.setdefault(self.relation(), {})
|
||||
for key, value in self.fixed_structure.items():
|
||||
new_dict[self.relation()][key] = serializer.serialize_field(value)
|
||||
if isinstance(value, JSONSearchField):
|
||||
new_dict.update(value.get_nested_options(serializer))
|
||||
return new_dict
|
||||
|
||||
def get_nested_options(self, serializer: DjangoQLSchemaSerializer) -> OrderedDict:
|
||||
"""Get keys of all nested objects to show autocomplete"""
|
||||
if not self.suggest_nested:
|
||||
if self.fixed_structure:
|
||||
return self.get_fixed_structure(serializer)
|
||||
return OrderedDict()
|
||||
base_model_name = f"{self.model._meta.app_label}.{self.model._meta.model_name}_{self.name}"
|
||||
|
||||
def recursive_function(parts: list[str], parent_parts: list[str] | None = None):
|
||||
if not parent_parts:
|
||||
@@ -87,7 +108,7 @@ class JSONSearchField(StrField):
|
||||
relation_structure = defaultdict(dict)
|
||||
|
||||
for relations in self.json_field_keys():
|
||||
result = recursive_function([base_model_name] + relations)
|
||||
result = recursive_function([self.relation()] + relations)
|
||||
for relation_key, value in result.items():
|
||||
for sub_relation_key, sub_value in value.items():
|
||||
if not relation_structure[relation_key].get(sub_relation_key, None):
|
||||
|
||||
@@ -12,7 +12,7 @@ class AKQLSchemaSerializer(DjangoQLSchemaSerializer):
|
||||
for _, field in fields.items():
|
||||
if not isinstance(field, JSONSearchField):
|
||||
continue
|
||||
serialization["models"].update(field.get_nested_options())
|
||||
serialization["models"].update(field.get_nested_options(self))
|
||||
return serialization
|
||||
|
||||
def serialize_field(self, field):
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
"""Events API Views"""
|
||||
|
||||
from collections import OrderedDict
|
||||
from datetime import timedelta
|
||||
|
||||
import django_filters
|
||||
@@ -136,7 +137,7 @@ class EventViewSet(
|
||||
filterset_class = EventsFilter
|
||||
|
||||
def get_ql_fields(self):
|
||||
from djangoql.schema import DateTimeField, StrField
|
||||
from djangoql.schema import DateTimeField, IntField, StrField
|
||||
|
||||
from authentik.enterprise.search.fields import ChoiceSearchField, JSONSearchField
|
||||
|
||||
@@ -145,9 +146,42 @@ class EventViewSet(
|
||||
StrField(Event, "event_uuid"),
|
||||
StrField(Event, "app", suggest_options=True),
|
||||
StrField(Event, "client_ip"),
|
||||
JSONSearchField(Event, "user", suggest_nested=False),
|
||||
JSONSearchField(Event, "brand", suggest_nested=False),
|
||||
JSONSearchField(Event, "context", suggest_nested=False),
|
||||
JSONSearchField(
|
||||
Event,
|
||||
"user",
|
||||
fixed_structure=OrderedDict(
|
||||
pk=IntField(),
|
||||
username=StrField(),
|
||||
email=StrField(),
|
||||
),
|
||||
),
|
||||
JSONSearchField(
|
||||
Event,
|
||||
"brand",
|
||||
fixed_structure=OrderedDict(
|
||||
pk=StrField(),
|
||||
app=StrField(),
|
||||
name=StrField(),
|
||||
model_name=StrField(),
|
||||
),
|
||||
),
|
||||
JSONSearchField(
|
||||
Event,
|
||||
"context",
|
||||
fixed_structure=OrderedDict(
|
||||
http_request=JSONSearchField(
|
||||
Event,
|
||||
"context_http_request",
|
||||
fixed_structure=OrderedDict(
|
||||
args=JSONSearchField(Event, "context_http_request_args"),
|
||||
path=StrField(),
|
||||
method=StrField(),
|
||||
request_id=StrField(),
|
||||
user_agent=StrField(),
|
||||
),
|
||||
),
|
||||
),
|
||||
),
|
||||
DateTimeField(Event, "created", suggest_options=True),
|
||||
]
|
||||
|
||||
|
||||
@@ -7,7 +7,7 @@ from typing import Any
|
||||
from django.utils.timezone import now
|
||||
from rest_framework.fields import CharField, ChoiceField, DateTimeField, DictField
|
||||
from structlog import configure, get_config
|
||||
from structlog.stdlib import NAME_TO_LEVEL, ProcessorFormatter
|
||||
from structlog.stdlib import NAME_TO_LEVEL, ProcessorFormatter, get_logger
|
||||
from structlog.testing import LogCapture
|
||||
from structlog.types import EventDict
|
||||
|
||||
@@ -36,6 +36,9 @@ class LogEvent:
|
||||
event, log_level, item.pop("logger"), timestamp, attributes=sanitize_dict(item)
|
||||
)
|
||||
|
||||
def log(self):
|
||||
get_logger(self.logger).log(NAME_TO_LEVEL[self.log_level], self.event, **self.attributes)
|
||||
|
||||
|
||||
class LogEventSerializer(PassiveSerializer):
|
||||
"""Single log message with all context logged."""
|
||||
|
||||
@@ -8,6 +8,8 @@ from inspect import currentframe
|
||||
from typing import Any
|
||||
from uuid import uuid4
|
||||
|
||||
from asgiref.sync import async_to_sync
|
||||
from channels.layers import get_channel_layer
|
||||
from django.apps import apps
|
||||
from django.db import models
|
||||
from django.http import HttpRequest
|
||||
@@ -41,6 +43,7 @@ from authentik.lib.utils.http import get_http_session
|
||||
from authentik.lib.utils.time import timedelta_from_string
|
||||
from authentik.policies.models import PolicyBindingModel
|
||||
from authentik.root.middleware import ClientIPMiddleware
|
||||
from authentik.root.ws.consumer import build_user_group
|
||||
from authentik.stages.email.models import EmailTemplates
|
||||
from authentik.stages.email.utils import TemplateEmailMessage
|
||||
from authentik.tasks.models import TasksModel
|
||||
@@ -361,6 +364,15 @@ class NotificationTransport(TasksModel, SerializerModel):
|
||||
notification=notification,
|
||||
)
|
||||
notification.save()
|
||||
layer = get_channel_layer()
|
||||
async_to_sync(layer.group_send)(
|
||||
build_user_group(notification.user),
|
||||
{
|
||||
"type": "event.notification",
|
||||
"id": str(notification.pk),
|
||||
"data": notification.serializer(notification).data,
|
||||
},
|
||||
)
|
||||
return []
|
||||
|
||||
def send_webhook(self, notification: "Notification") -> list[str]:
|
||||
|
||||
@@ -22,6 +22,7 @@ from authentik.core.api.utils import (
|
||||
LinkSerializer,
|
||||
ModelSerializer,
|
||||
PassiveSerializer,
|
||||
ThemedUrlsSerializer,
|
||||
)
|
||||
from authentik.events.logs import LogEventSerializer
|
||||
from authentik.flows.api.flows_diagram import FlowDiagram, FlowDiagramSerializer
|
||||
@@ -47,6 +48,7 @@ class FlowSerializer(ModelSerializer):
|
||||
"""Flow Serializer"""
|
||||
|
||||
background_url = ReadOnlyField()
|
||||
background_themed_urls = ThemedUrlsSerializer(read_only=True, allow_null=True)
|
||||
|
||||
cache_count = SerializerMethodField()
|
||||
export_url = SerializerMethodField()
|
||||
@@ -70,6 +72,7 @@ class FlowSerializer(ModelSerializer):
|
||||
"designation",
|
||||
"background",
|
||||
"background_url",
|
||||
"background_themed_urls",
|
||||
"stages",
|
||||
"policies",
|
||||
"cache_count",
|
||||
|
||||
@@ -11,7 +11,7 @@ from django.http import JsonResponse
|
||||
from rest_framework.fields import BooleanField, CharField, ChoiceField, DictField
|
||||
from rest_framework.request import Request
|
||||
|
||||
from authentik.core.api.utils import PassiveSerializer
|
||||
from authentik.core.api.utils import PassiveSerializer, ThemedUrlsSerializer
|
||||
from authentik.lib.utils.errors import exception_to_string
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -44,6 +44,7 @@ class ContextualFlowInfo(PassiveSerializer):
|
||||
|
||||
title = CharField(required=False, allow_blank=True)
|
||||
background = CharField(required=False)
|
||||
background_themed_urls = ThemedUrlsSerializer(required=False, allow_null=True)
|
||||
cancel_url = CharField()
|
||||
layout = ChoiceField(choices=[(x.value, x.name) for x in FlowLayout])
|
||||
|
||||
|
||||
@@ -196,6 +196,15 @@ class Flow(SerializerModel, PolicyBindingModel):
|
||||
|
||||
return get_file_manager(FileUsage.MEDIA).file_url(self.background, request)
|
||||
|
||||
def background_themed_urls(self, request: HttpRequest | None = None) -> dict[str, str] | None:
|
||||
"""Get themed URLs for background if it contains %(theme)s"""
|
||||
if not self.background:
|
||||
if request:
|
||||
return request.brand.branding_default_flow_background_themed_urls()
|
||||
return None
|
||||
|
||||
return get_file_manager(FileUsage.MEDIA).themed_urls(self.background, request)
|
||||
|
||||
stages = models.ManyToManyField(Stage, through="FlowStageBinding", blank=True)
|
||||
|
||||
@property
|
||||
|
||||
@@ -38,7 +38,7 @@ def invalidate_flow_cache(sender, instance, **_):
|
||||
if isinstance(instance, Flow):
|
||||
total = delete_cache_prefix(f"{cache_key(instance)}*")
|
||||
LOGGER.debug("Invalidating Flow cache", flow=instance, len=total)
|
||||
if isinstance(instance, FlowStageBinding):
|
||||
if isinstance(instance, FlowStageBinding) and instance.target_id:
|
||||
total = delete_cache_prefix(f"{cache_key(instance.target)}*")
|
||||
LOGGER.debug("Invalidating Flow cache from FlowStageBinding", binding=instance, len=total)
|
||||
if isinstance(instance, Stage):
|
||||
|
||||
@@ -197,6 +197,9 @@ class ChallengeStageView(StageView):
|
||||
data={
|
||||
"title": self.format_title(),
|
||||
"background": self.executor.flow.background_url(self.request),
|
||||
"background_themed_urls": self.executor.flow.background_themed_urls(
|
||||
self.request
|
||||
),
|
||||
"cancel_url": self.cancel_url,
|
||||
"layout": self.executor.flow.layout,
|
||||
}
|
||||
|
||||
@@ -48,6 +48,14 @@ class FlowTestCase(APITestCase):
|
||||
self.assertEqual(raw_response[key], expected)
|
||||
return raw_response
|
||||
|
||||
def get_flow_plan(self) -> FlowPlan | None:
|
||||
return self.client.session.get(SESSION_KEY_PLAN)
|
||||
|
||||
def set_flow_plan(self, plan: FlowPlan):
|
||||
session = self.client.session
|
||||
session[SESSION_KEY_PLAN] = plan
|
||||
session.save()
|
||||
|
||||
def assertStageRedirects(self, response: HttpResponse, to: str) -> dict[str, Any]:
|
||||
"""Wrapper around assertStageResponse that checks for a redirect"""
|
||||
return self.assertStageResponse(response, component="xak-flow-redirect", to=to)
|
||||
|
||||
@@ -734,6 +734,7 @@ class TestFlowExecutor(FlowTestCase):
|
||||
flow,
|
||||
flow_info={
|
||||
"background": "/static/dist/assets/images/flow_background.jpg",
|
||||
"background_themed_urls": None,
|
||||
"cancel_url": "/flows/-/cancel/?next=%2Ffoo",
|
||||
"layout": "stacked",
|
||||
"title": flow.title,
|
||||
|
||||
@@ -51,6 +51,7 @@ class TestFlowInspector(APITestCase):
|
||||
"enable_remember_me": False,
|
||||
"flow_info": {
|
||||
"background": "/static/dist/assets/images/flow_background.jpg",
|
||||
"background_themed_urls": None,
|
||||
"cancel_url": reverse("authentik_flows:cancel"),
|
||||
"title": flow.title,
|
||||
"layout": "stacked",
|
||||
|
||||
@@ -22,7 +22,6 @@ if TYPE_CHECKING:
|
||||
|
||||
|
||||
class Direction(StrEnum):
|
||||
|
||||
add = "add"
|
||||
remove = "remove"
|
||||
|
||||
@@ -36,7 +35,10 @@ SAFE_METHODS = [
|
||||
|
||||
|
||||
class BaseOutgoingSyncClient[
|
||||
TModel: "Model", TConnection: "Model", TSchema: dict, TProvider: "OutgoingSyncProvider"
|
||||
TModel: "Model",
|
||||
TConnection: "Model",
|
||||
TSchema: dict,
|
||||
TProvider: "OutgoingSyncProvider",
|
||||
]:
|
||||
"""Basic Outgoing sync client Client"""
|
||||
|
||||
@@ -82,7 +84,7 @@ class BaseOutgoingSyncClient[
|
||||
connection.delete()
|
||||
return None, False
|
||||
|
||||
def delete(self, obj: TModel):
|
||||
def delete(self, identifier: str):
|
||||
"""Delete object from destination"""
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
@@ -56,6 +56,14 @@ class OutgoingSyncProvider(ScheduledModel, Model):
|
||||
def get_object_qs[T: User | Group](self, type: type[T]) -> QuerySet[T]:
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def get_object_mappings(cls, obj: User | Group) -> list[tuple[str, str]]:
|
||||
"""
|
||||
Get a list of mapping between User/Group and ProviderUser/Group:
|
||||
[("provider_pk", "obj_pk")]
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def get_paginator[T: User | Group](self, type: type[T]) -> Paginator:
|
||||
return Paginator(self.get_object_qs(type), self.sync_page_size)
|
||||
|
||||
@@ -84,7 +92,7 @@ class OutgoingSyncProvider(ScheduledModel, Model):
|
||||
raise NotImplementedError
|
||||
|
||||
def sync_dispatch(self) -> None:
|
||||
for schedule in self.schedules:
|
||||
for schedule in self.schedules.all():
|
||||
schedule.send()
|
||||
|
||||
@property
|
||||
|
||||
@@ -6,7 +6,6 @@ from django.db.models.signals import m2m_changed, post_save, pre_delete
|
||||
from dramatiq.actor import Actor
|
||||
|
||||
from authentik.core.models import Group, User
|
||||
from authentik.lib.sync.outgoing.base import Direction
|
||||
from authentik.lib.sync.outgoing.models import OutgoingSyncProvider
|
||||
from authentik.lib.utils.reflection import class_to_path
|
||||
|
||||
@@ -30,7 +29,8 @@ def sync_outgoing_inhibit_dispatch():
|
||||
|
||||
def register_signals(
|
||||
provider_type: type[OutgoingSyncProvider],
|
||||
task_sync_direct_dispatch: Actor[[str, str | int, str], None],
|
||||
task_sync_direct_dispatch: Actor[[str, str | int], None],
|
||||
task_sync_delete_dispatch: Actor[[str, list[tuple[str, str]]], None],
|
||||
task_sync_m2m_dispatch: Actor[[str, str, list[str], bool], None],
|
||||
):
|
||||
"""Register sync signals"""
|
||||
@@ -55,7 +55,6 @@ def register_signals(
|
||||
task_sync_direct_dispatch.send(
|
||||
class_to_path(instance.__class__),
|
||||
instance.pk,
|
||||
Direction.add.value,
|
||||
)
|
||||
|
||||
post_save.connect(model_post_save, User, dispatch_uid=uid, weak=False)
|
||||
@@ -65,12 +64,12 @@ def register_signals(
|
||||
"""Pre-delete handler"""
|
||||
if _CTX_INHIBIT_DISPATCH.get():
|
||||
return
|
||||
if not provider_type.objects.exists():
|
||||
mappings = provider_type.get_object_mappings(instance)
|
||||
if not mappings:
|
||||
return
|
||||
task_sync_direct_dispatch.send(
|
||||
task_sync_delete_dispatch.send(
|
||||
class_to_path(instance.__class__),
|
||||
instance.pk,
|
||||
Direction.remove.value,
|
||||
mappings,
|
||||
)
|
||||
|
||||
pre_delete.connect(model_pre_delete, User, dispatch_uid=uid, weak=False)
|
||||
|
||||
@@ -13,6 +13,7 @@ from authentik.lib.sync.outgoing.base import Direction
|
||||
from authentik.lib.sync.outgoing.exceptions import (
|
||||
BadRequestSyncException,
|
||||
DryRunRejected,
|
||||
NotFoundSyncException,
|
||||
StopSync,
|
||||
TransientSyncException,
|
||||
)
|
||||
@@ -189,17 +190,16 @@ class SyncTasks:
|
||||
|
||||
def sync_signal_direct_dispatch(
|
||||
self,
|
||||
task_sync_signal_direct: Actor[[str, str | int, int, str], None],
|
||||
task_sync_signal_direct: Actor[[str, str | int, int], None],
|
||||
model: str,
|
||||
pk: str | int,
|
||||
raw_op: str,
|
||||
):
|
||||
model_class: type[Model] = path_to_class(model)
|
||||
for provider in self._provider_model.objects.filter(
|
||||
Q(backchannel_application__isnull=False) | Q(application__isnull=False)
|
||||
):
|
||||
task_sync_signal_direct.send_with_options(
|
||||
args=(model, pk, provider.pk, raw_op),
|
||||
args=(model, pk, provider.pk),
|
||||
rel_obj=provider,
|
||||
uid=f"{provider.name}:{model_class._meta.model_name}:{pk}:direct",
|
||||
)
|
||||
@@ -209,7 +209,6 @@ class SyncTasks:
|
||||
model: str,
|
||||
pk: str | int,
|
||||
provider_pk: int,
|
||||
raw_op: str,
|
||||
):
|
||||
task = CurrentTask.get_task()
|
||||
self.logger = get_logger().bind(
|
||||
@@ -219,14 +218,13 @@ class SyncTasks:
|
||||
instance = model_class.objects.filter(pk=pk).first()
|
||||
if not instance:
|
||||
return
|
||||
provider: OutgoingSyncProvider = self._provider_model.objects.filter(
|
||||
provider: OutgoingSyncProvider | None = self._provider_model.objects.filter(
|
||||
Q(backchannel_application__isnull=False) | Q(application__isnull=False),
|
||||
pk=provider_pk,
|
||||
).first()
|
||||
if not provider:
|
||||
task.warning("No provider found. Is it assigned to an application?")
|
||||
return
|
||||
operation = Direction(raw_op)
|
||||
client = provider.client_for_model(instance.__class__)
|
||||
# Check if the object is allowed within the provider's restrictions
|
||||
queryset = provider.get_object_qs(instance.__class__)
|
||||
@@ -239,10 +237,7 @@ class SyncTasks:
|
||||
return
|
||||
|
||||
try:
|
||||
if operation == Direction.add:
|
||||
client.write(instance)
|
||||
if operation == Direction.remove:
|
||||
client.delete(instance)
|
||||
client.write(instance)
|
||||
except TransientSyncException as exc:
|
||||
raise Retry() from exc
|
||||
except SkipObjectException:
|
||||
@@ -252,6 +247,60 @@ class SyncTasks:
|
||||
except StopSync as exc:
|
||||
self.logger.warning("Stopping sync", exc=exc, provider_pk=provider.pk)
|
||||
|
||||
def sync_signal_delete_dispatch(
|
||||
self,
|
||||
task_sync_signal_delete: Actor[[str, int, str], None],
|
||||
model: str,
|
||||
mappings: list[tuple[str, str]],
|
||||
):
|
||||
model_class: type[Model] = path_to_class(model)
|
||||
for provider_pk, identifier in mappings:
|
||||
provider: OutgoingSyncProvider | None = self._provider_model.objects.filter(
|
||||
pk=provider_pk
|
||||
).first()
|
||||
if not provider:
|
||||
continue
|
||||
task_sync_signal_delete.send_with_options(
|
||||
args=(model, identifier, provider_pk),
|
||||
rel_obj=provider,
|
||||
uid=f"{provider.name}:{model_class._meta.model_name}:{identifier}:delete",
|
||||
)
|
||||
|
||||
def sync_signal_delete(
|
||||
self,
|
||||
model: str,
|
||||
identifier: str,
|
||||
provider_pk: int,
|
||||
):
|
||||
task = CurrentTask.get_task()
|
||||
self.logger = get_logger().bind(
|
||||
provider_type=class_to_path(self._provider_model),
|
||||
)
|
||||
model_class: type[Model] = path_to_class(model)
|
||||
provider: OutgoingSyncProvider | None = self._provider_model.objects.filter(
|
||||
Q(backchannel_application__isnull=False) | Q(application__isnull=False),
|
||||
pk=provider_pk,
|
||||
).first()
|
||||
if not provider:
|
||||
task.warning("No provider found. Is it assigned to an application?")
|
||||
return
|
||||
client = provider.client_for_model(model_class)
|
||||
|
||||
try:
|
||||
client.delete(identifier)
|
||||
except NotFoundSyncException as exc:
|
||||
self.logger.info(
|
||||
"Object not found in remote provider",
|
||||
model_name=model_class._meta.model_name,
|
||||
identifier=identifier,
|
||||
exc=exc,
|
||||
provider_pk=provider.pk,
|
||||
)
|
||||
except TransientSyncException as exc:
|
||||
raise Retry() from exc
|
||||
except DryRunRejected as exc:
|
||||
self.logger.info("Rejected dry-run event", exc=exc)
|
||||
|
||||
def sync_signal_m2m_dispatch(
|
||||
self,
|
||||
task_sync_signal_m2m: Actor[[str, int, str, list[int]], None],
|
||||
|
||||
@@ -1,16 +1,17 @@
|
||||
"""authentik database utilities"""
|
||||
|
||||
import gc
|
||||
from collections.abc import Generator
|
||||
|
||||
from django.db import reset_queries
|
||||
from django.db.models import QuerySet
|
||||
from django.db.models import Model, QuerySet
|
||||
|
||||
|
||||
def chunked_queryset(queryset: QuerySet, chunk_size: int = 1_000):
|
||||
def chunked_queryset[T: Model](queryset: QuerySet[T], chunk_size: int = 1_000) -> Generator[T]:
|
||||
if not queryset.exists():
|
||||
return []
|
||||
|
||||
def get_chunks(qs: QuerySet):
|
||||
def get_chunks(qs: QuerySet) -> Generator[QuerySet[T]]:
|
||||
qs = qs.order_by("pk")
|
||||
pks = qs.values_list("pk", flat=True)
|
||||
start_pk = pks[0]
|
||||
|
||||
@@ -47,7 +47,9 @@ class OutpostSerializer(ModelSerializer):
|
||||
)
|
||||
providers_obj = ProviderSerializer(source="providers", many=True, read_only=True)
|
||||
service_connection_obj = ServiceConnectionSerializer(
|
||||
source="service_connection", read_only=True
|
||||
source="service_connection",
|
||||
read_only=True,
|
||||
allow_null=True,
|
||||
)
|
||||
refresh_interval_s = SerializerMethodField()
|
||||
|
||||
|
||||
@@ -86,7 +86,7 @@ class OutpostConfig:
|
||||
class OutpostModel(Model):
|
||||
"""Base model for providers that need more objects than just themselves"""
|
||||
|
||||
def get_required_objects(self) -> Iterable[models.Model | str]:
|
||||
def get_required_objects(self) -> Iterable[models.Model | str | tuple[str, models.Model]]:
|
||||
"""Return a list of all required objects"""
|
||||
return [self]
|
||||
|
||||
@@ -332,41 +332,35 @@ class Outpost(ScheduledModel, SerializerModel, ManagedModel):
|
||||
"""Create per-object and global permissions for outpost service-account"""
|
||||
# To ensure the user only has the correct permissions, we delete all of them and re-add
|
||||
# the ones the user needs
|
||||
with transaction.atomic():
|
||||
user.remove_all_perms_from_managed_role()
|
||||
for model_or_perm in self.get_required_objects():
|
||||
if isinstance(model_or_perm, models.Model):
|
||||
model_or_perm: models.Model
|
||||
code_name = (
|
||||
f"{model_or_perm._meta.app_label}.view_{model_or_perm._meta.model_name}"
|
||||
)
|
||||
try:
|
||||
user.assign_perms_to_managed_role(code_name, model_or_perm)
|
||||
except (Permission.DoesNotExist, AttributeError) as exc:
|
||||
LOGGER.warning(
|
||||
"permission doesn't exist",
|
||||
code_name=code_name,
|
||||
user=user,
|
||||
model=model_or_perm,
|
||||
try:
|
||||
with transaction.atomic():
|
||||
user.remove_all_perms_from_managed_role()
|
||||
for model_or_perm in self.get_required_objects():
|
||||
if isinstance(model_or_perm, models.Model):
|
||||
code_name = (
|
||||
f"{model_or_perm._meta.app_label}.view_{model_or_perm._meta.model_name}"
|
||||
)
|
||||
Event.new(
|
||||
action=EventAction.SYSTEM_EXCEPTION,
|
||||
message=(
|
||||
"While setting the permissions for the service-account, a "
|
||||
"permission was not found: Check "
|
||||
"https://docs.goauthentik.io/troubleshooting/missing_permission"
|
||||
),
|
||||
).with_exception(exc).set_user(user).save()
|
||||
else:
|
||||
app_label, perm = model_or_perm.split(".")
|
||||
permission = Permission.objects.filter(
|
||||
codename=perm,
|
||||
content_type__app_label=app_label,
|
||||
)
|
||||
if not permission.exists():
|
||||
LOGGER.warning("permission doesn't exist", perm=model_or_perm)
|
||||
continue
|
||||
user.assign_perms_to_managed_role(permission.first())
|
||||
user.assign_perms_to_managed_role(code_name, model_or_perm)
|
||||
elif isinstance(model_or_perm, tuple):
|
||||
perm, obj = model_or_perm
|
||||
user.assign_perms_to_managed_role(perm, obj)
|
||||
else:
|
||||
user.assign_perms_to_managed_role(model_or_perm)
|
||||
except (Permission.DoesNotExist, AttributeError) as exc:
|
||||
LOGGER.warning(
|
||||
"permission doesn't exist",
|
||||
code_name=code_name,
|
||||
user=user,
|
||||
model=model_or_perm,
|
||||
)
|
||||
Event.new(
|
||||
action=EventAction.SYSTEM_EXCEPTION,
|
||||
message=(
|
||||
"While setting the permissions for the service-account, a "
|
||||
"permission was not found: Check "
|
||||
"https://docs.goauthentik.io/troubleshooting/missing_permission"
|
||||
),
|
||||
).with_exception(exc).set_user(user).save()
|
||||
LOGGER.debug(
|
||||
"Updated service account's permissions",
|
||||
obj_perms=user.get_all_obj_perms_on_managed_role(),
|
||||
@@ -431,7 +425,7 @@ class Outpost(ScheduledModel, SerializerModel, ManagedModel):
|
||||
Token.objects.filter(identifier=self.token_identifier).delete()
|
||||
return self.token
|
||||
|
||||
def get_required_objects(self) -> Iterable[models.Model | str]:
|
||||
def get_required_objects(self) -> Iterable[models.Model | str | tuple[str, models.Model]]:
|
||||
"""Get an iterator of all objects the user needs read access to"""
|
||||
objects: list[models.Model | str] = [
|
||||
self,
|
||||
@@ -445,7 +439,9 @@ class Outpost(ScheduledModel, SerializerModel, ManagedModel):
|
||||
if self.managed:
|
||||
for brand in Brand.objects.filter(web_certificate__isnull=False):
|
||||
objects.append(brand)
|
||||
objects.append(brand.web_certificate)
|
||||
objects.append(("view_certificatekeypair", brand.web_certificate))
|
||||
objects.append(("view_certificatekeypair_certificate", brand.web_certificate))
|
||||
objects.append(("view_certificatekeypair_key", brand.web_certificate))
|
||||
return objects
|
||||
|
||||
def __str__(self) -> str:
|
||||
|
||||
@@ -51,10 +51,12 @@ class OutpostTests(TestCase):
|
||||
permissions = outpost.user.get_all_obj_perms_on_managed_role().order_by(
|
||||
"content_type__model"
|
||||
)
|
||||
self.assertEqual(len(permissions), 3)
|
||||
self.assertEqual(len(permissions), 5)
|
||||
self.assertEqual(permissions[0].object_pk, str(keypair.pk))
|
||||
self.assertEqual(permissions[1].object_pk, str(outpost.pk))
|
||||
self.assertEqual(permissions[2].object_pk, str(provider.pk))
|
||||
self.assertEqual(permissions[1].object_pk, str(keypair.pk))
|
||||
self.assertEqual(permissions[2].object_pk, str(keypair.pk))
|
||||
self.assertEqual(permissions[3].object_pk, str(outpost.pk))
|
||||
self.assertEqual(permissions[4].object_pk, str(provider.pk))
|
||||
|
||||
# Remove provider from outpost, user should only have access to outpost
|
||||
outpost.providers.remove(provider)
|
||||
|
||||
@@ -103,24 +103,28 @@ class PolicyAccessView(AccessMixin, View):
|
||||
they try to access and redirect to the login URL. The application is saved to show
|
||||
a hint on the Identification Stage what the user should login for."""
|
||||
flow_context = {}
|
||||
authn_flow = None
|
||||
if self.application:
|
||||
flow_context[PLAN_CONTEXT_APPLICATION] = self.application
|
||||
if self.provider and self.provider.authentication_flow:
|
||||
authn_flow = self.provider.authentication_flow
|
||||
# Because this view might get hit with a POST request, we need to preserve that data
|
||||
# since later views might need it (mostly SAML)
|
||||
if self.request.method.lower() == "post":
|
||||
self.request.session[SESSION_KEY_POST] = self.request.POST
|
||||
flow_context[PLAN_CONTEXT_POST] = self.request.POST
|
||||
|
||||
flow = ToDefaultFlow.get_flow(self.request, FlowDesignation.AUTHENTICATION)
|
||||
if not flow:
|
||||
raise Http404
|
||||
planner = FlowPlanner(flow)
|
||||
if not authn_flow:
|
||||
authn_flow = ToDefaultFlow.get_flow(self.request, FlowDesignation.AUTHENTICATION)
|
||||
if not authn_flow:
|
||||
raise Http404
|
||||
planner = FlowPlanner(authn_flow)
|
||||
try:
|
||||
plan = planner.plan(self.request, self.modify_flow_context(flow, flow_context))
|
||||
plan = planner.plan(self.request, self.modify_flow_context(authn_flow, flow_context))
|
||||
except (FlowNonApplicableException, EmptyFlowException) as exc:
|
||||
LOGGER.warning("Non-applicable authentication flow", exc=exc)
|
||||
raise Http404 from None
|
||||
return plan.to_redirect(self.request, flow, next=self.request.get_full_path())
|
||||
return plan.to_redirect(self.request, authn_flow, next=self.request.get_full_path())
|
||||
|
||||
def handle_no_permission_authenticated(
|
||||
self, result: PolicyResult | None = None
|
||||
|
||||
@@ -93,11 +93,13 @@ class LDAPProvider(OutpostModel, BackchannelProvider):
|
||||
def __str__(self):
|
||||
return f"LDAP Provider {self.name}"
|
||||
|
||||
def get_required_objects(self) -> Iterable[models.Model | str]:
|
||||
required_models = [self, "authentik_core.view_user", "authentik_core.view_group"]
|
||||
def get_required_objects(self) -> Iterable[models.Model | str | tuple[str, models.Model]]:
|
||||
required = [self, "authentik_core.view_user", "authentik_core.view_group"]
|
||||
if self.certificate is not None:
|
||||
required_models.append(self.certificate)
|
||||
return required_models
|
||||
required.append(("view_certificatekeypair", self.certificate))
|
||||
required.append(("view_certificatekeypair_certificate", self.certificate))
|
||||
required.append(("view_certificatekeypair_key", self.certificate))
|
||||
return required
|
||||
|
||||
class Meta:
|
||||
verbose_name = _("LDAP Provider")
|
||||
|
||||
@@ -152,7 +152,7 @@ class IDToken:
|
||||
final = self.to_dict()
|
||||
final["azp"] = provider.client_id
|
||||
final["uid"] = generate_id()
|
||||
final["scope"] = " ".join(token.scope)
|
||||
final.setdefault("scope", " ".join(token.scope))
|
||||
return provider.encode(final)
|
||||
|
||||
def to_jwt(self, provider: "OAuth2Provider") -> str:
|
||||
|
||||
@@ -386,11 +386,18 @@ class OAuth2Provider(WebfingerProvider, Provider):
|
||||
def __str__(self):
|
||||
return f"OAuth2 Provider {self.name}"
|
||||
|
||||
def encode(self, payload: dict[str, Any]) -> str:
|
||||
"""Represent the ID Token as a JSON Web Token (JWT)."""
|
||||
def encode(self, payload: dict[str, Any], jwt_type: str | None = None) -> str:
|
||||
"""Represent the ID Token as a JSON Web Token (JWT).
|
||||
|
||||
:param payload The payload to encode into the JWT
|
||||
:param jwt_type The type of the JWT. This will be put in the JWT header using the `typ`
|
||||
parameter. See RFC7515 Section 4.1.9. If not set fallback to the default of `JWT`.
|
||||
"""
|
||||
headers = {}
|
||||
if self.signing_key:
|
||||
headers["kid"] = self.signing_key.kid
|
||||
if jwt_type is not None:
|
||||
headers["typ"] = jwt_type
|
||||
key, alg = self.jwt_key
|
||||
encoded = encode(payload, key, algorithm=alg, headers=headers)
|
||||
if self.encryption_key:
|
||||
|
||||
@@ -9,8 +9,9 @@ from django.utils.timezone import now
|
||||
|
||||
from authentik.blueprints.tests import apply_blueprint
|
||||
from authentik.core.models import Application
|
||||
from authentik.core.tests.utils import create_test_admin_user, create_test_flow
|
||||
from authentik.core.tests.utils import create_test_admin_user, create_test_brand, create_test_flow
|
||||
from authentik.events.models import Event, EventAction
|
||||
from authentik.flows.models import FlowStageBinding
|
||||
from authentik.lib.generators import generate_id
|
||||
from authentik.lib.utils.time import timedelta_from_string
|
||||
from authentik.providers.oauth2.constants import SCOPE_OFFLINE_ACCESS, SCOPE_OPENID, TOKEN_TYPE
|
||||
@@ -26,6 +27,7 @@ from authentik.providers.oauth2.models import (
|
||||
)
|
||||
from authentik.providers.oauth2.tests.utils import OAuthTestCase
|
||||
from authentik.providers.oauth2.views.authorize import OAuthAuthorizationParams
|
||||
from authentik.stages.dummy.models import DummyStage
|
||||
from authentik.stages.password.stage import PLAN_CONTEXT_METHOD
|
||||
|
||||
|
||||
@@ -725,3 +727,41 @@ class TestAuthorize(OAuthTestCase):
|
||||
)
|
||||
parsed = parse_qs(urlparse(response.url).query)
|
||||
self.assertNotIn("locale", parsed)
|
||||
|
||||
def test_authentication_flow(self):
|
||||
"""Test custom authentication flow"""
|
||||
brand = create_test_brand()
|
||||
global_auth = create_test_flow()
|
||||
FlowStageBinding.objects.create(
|
||||
target=global_auth, stage=DummyStage.objects.create(name=generate_id()), order=10
|
||||
)
|
||||
brand.flow_authentication = global_auth
|
||||
brand.save()
|
||||
|
||||
flow = create_test_flow()
|
||||
auth_flow = create_test_flow()
|
||||
FlowStageBinding.objects.create(
|
||||
target=auth_flow, stage=DummyStage.objects.create(name=generate_id()), order=10
|
||||
)
|
||||
provider = OAuth2Provider.objects.create(
|
||||
name=generate_id(),
|
||||
client_id="test",
|
||||
authorization_flow=flow,
|
||||
authentication_flow=auth_flow,
|
||||
redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "foo://localhost")],
|
||||
access_code_validity="seconds=100",
|
||||
)
|
||||
Application.objects.create(name="app", slug="app", provider=provider)
|
||||
state = generate_id()
|
||||
response = self.client.get(
|
||||
reverse("authentik_providers_oauth2:authorize"),
|
||||
data={
|
||||
"response_type": "code",
|
||||
"client_id": "test",
|
||||
"state": state,
|
||||
"redirect_uri": "foo://localhost",
|
||||
},
|
||||
)
|
||||
self.assertEqual(response.status_code, 302)
|
||||
self.assertIn(auth_flow.slug, response.url)
|
||||
self.assertNotIn(global_auth.slug, response.url)
|
||||
|
||||
@@ -113,11 +113,16 @@ class TestBackChannelLogout(OAuthTestCase):
|
||||
|
||||
def _decode_token(self, token, provider=None):
|
||||
"""Helper to decode and validate a JWT token"""
|
||||
decoded = self._decode_token_complete(token, provider)
|
||||
return decoded["payload"]
|
||||
|
||||
def _decode_token_complete(self, token, provider=None):
|
||||
"""Helper to decode and validate a JWT token into a header, and payload dict"""
|
||||
provider = provider or self.provider
|
||||
key, alg = provider.jwt_key
|
||||
if alg != "HS256":
|
||||
key = provider.signing_key.public_key
|
||||
return jwt.decode(
|
||||
return jwt.decode_complete(
|
||||
token, key, algorithms=[alg], options={"verify_exp": False, "verify_aud": False}
|
||||
)
|
||||
|
||||
@@ -155,6 +160,16 @@ class TestBackChannelLogout(OAuthTestCase):
|
||||
self.assertEqual(decoded3["sub"], sub)
|
||||
self.assertIn("events", decoded3)
|
||||
|
||||
def test_create_logout_token_header_type(self):
|
||||
"""Test creating logout tokens and checking if the token header type is correct"""
|
||||
session_id = "test-session-123"
|
||||
token1 = self._create_logout_token(session_id=session_id)
|
||||
|
||||
decoded = self._decode_token_complete(token1)
|
||||
|
||||
self.assertIsNotNone(decoded["header"])
|
||||
self.assertEqual(decoded["header"]["typ"], "logout+jwt")
|
||||
|
||||
@patch("authentik.providers.oauth2.tasks.get_http_session")
|
||||
def test_send_backchannel_logout_request_scenarios(self, mock_get_session):
|
||||
"""Test various scenarios for backchannel logout request task"""
|
||||
|
||||
@@ -84,6 +84,7 @@ class TesOAuth2DeviceInit(OAuthTestCase):
|
||||
"component": "ak-provider-oauth2-device-code",
|
||||
"flow_info": {
|
||||
"background": "/static/dist/assets/images/flow_background.jpg",
|
||||
"background_themed_urls": None,
|
||||
"cancel_url": "/flows/-/cancel/",
|
||||
"layout": "stacked",
|
||||
"title": self.device_flow.title,
|
||||
|
||||
@@ -436,3 +436,57 @@ class TestToken(OAuthTestCase):
|
||||
},
|
||||
)
|
||||
self.validate_jwt(access, provider)
|
||||
|
||||
@apply_blueprint("system/providers-oauth2.yaml")
|
||||
def test_scope_claim_override_via_property_mapping(self):
|
||||
"""Test that property mappings can override the scope claim in access tokens.
|
||||
|
||||
See: https://github.com/goauthentik/authentik/issues/19224
|
||||
"""
|
||||
# Create a custom scope mapping that returns a custom scope claim
|
||||
custom_scope_mapping = ScopeMapping.objects.create(
|
||||
name="custom-scope-override",
|
||||
scope_name="custom",
|
||||
expression='return {"scope": "custom-scope-value additional-scope"}',
|
||||
)
|
||||
|
||||
provider = OAuth2Provider.objects.create(
|
||||
name=generate_id(),
|
||||
authorization_flow=create_test_flow(),
|
||||
redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://local.invalid")],
|
||||
signing_key=self.keypair,
|
||||
include_claims_in_id_token=True,
|
||||
)
|
||||
provider.property_mappings.add(custom_scope_mapping)
|
||||
|
||||
# Needs to be assigned to an application for iss to be set
|
||||
self.app.provider = provider
|
||||
self.app.save()
|
||||
|
||||
header = b64encode(f"{provider.client_id}:{provider.client_secret}".encode()).decode()
|
||||
user = create_test_admin_user()
|
||||
code = AuthorizationCode.objects.create(
|
||||
code="foobar",
|
||||
provider=provider,
|
||||
user=user,
|
||||
auth_time=timezone.now(),
|
||||
_scope="openid custom", # Request the custom scope
|
||||
)
|
||||
|
||||
response = self.client.post(
|
||||
reverse("authentik_providers_oauth2:token"),
|
||||
data={
|
||||
"grant_type": GRANT_TYPE_AUTHORIZATION_CODE,
|
||||
"code": code.code,
|
||||
"redirect_uri": "http://local.invalid",
|
||||
},
|
||||
HTTP_AUTHORIZATION=f"Basic {header}",
|
||||
)
|
||||
self.assertEqual(response.status_code, 200)
|
||||
|
||||
access: AccessToken = AccessToken.objects.filter(user=user, provider=provider).first()
|
||||
jwt_data = self.validate_jwt(access, provider)
|
||||
|
||||
# The scope should be the custom value from the property mapping,
|
||||
# not the default "openid custom"
|
||||
self.assertEqual(jwt_data["scope"], "custom-scope-value additional-scope")
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user