mirror of
https://github.com/goauthentik/authentik
synced 2026-05-05 22:52:42 +02:00
Compare commits
130 Commits
version/20
...
version/20
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
dc2332a316 | ||
|
|
c39414f558 | ||
|
|
aac1acfebd | ||
|
|
4d881bb3d2 | ||
|
|
852d392158 | ||
|
|
76b26ea288 | ||
|
|
a1f1378814 | ||
|
|
afc2be6b68 | ||
|
|
c45985e9d0 | ||
|
|
7221ed1ce6 | ||
|
|
123fd3dfb8 | ||
|
|
59c292ca21 | ||
|
|
2b247b60cf | ||
|
|
359a3b9768 | ||
|
|
2c84d73353 | ||
|
|
56ba055857 | ||
|
|
4b9775d9fe | ||
|
|
d06091e226 | ||
|
|
f715e7a537 | ||
|
|
1068dfcc28 | ||
|
|
9a6f66b23c | ||
|
|
853a367325 | ||
|
|
09cdcd1892 | ||
|
|
bed6407b52 | ||
|
|
3936a4e09a | ||
|
|
ad818a2880 | ||
|
|
f8f049f080 | ||
|
|
434e8203de | ||
|
|
7715ce1a90 | ||
|
|
c735dd67a2 | ||
|
|
1b5962be60 | ||
|
|
796d130ea4 | ||
|
|
6c8b502a5b | ||
|
|
674d681f98 | ||
|
|
8c6d3e131d | ||
|
|
b689debfed | ||
|
|
03e4297824 | ||
|
|
c4e0a02837 | ||
|
|
4586ed0735 | ||
|
|
59ef6bb6ea | ||
|
|
6ce812b01f | ||
|
|
87d08dc164 | ||
|
|
c41883b8ea | ||
|
|
6e9d510c9e | ||
|
|
d09ed8e8f0 | ||
|
|
8fe8b1e803 | ||
|
|
66438f3780 | ||
|
|
46f446fd0e | ||
|
|
f83d3a19d0 | ||
|
|
ef59ff1856 | ||
|
|
4966225282 | ||
|
|
2b8765d0aa | ||
|
|
d60d06f958 | ||
|
|
1a3f268476 | ||
|
|
515a855c40 | ||
|
|
16d65b8d12 | ||
|
|
bfe928df18 | ||
|
|
c447bbe6c8 | ||
|
|
1c0a3f95df | ||
|
|
8a6116ab79 | ||
|
|
430010fbea | ||
|
|
079b575a45 | ||
|
|
b2ca887d59 | ||
|
|
d7b30ad0d7 | ||
|
|
b084ace1dd | ||
|
|
b3e45cdf1a | ||
|
|
8132e1f7d9 | ||
|
|
149dccf244 | ||
|
|
b5e4797761 | ||
|
|
be670d6253 | ||
|
|
71060ea4e7 | ||
|
|
f60f38280c | ||
|
|
418deeb332 | ||
|
|
619c77c27e | ||
|
|
ddfddb49da | ||
|
|
dbbb1870b7 | ||
|
|
5b43301206 | ||
|
|
d915d1a94a | ||
|
|
786497790a | ||
|
|
56c899cf21 | ||
|
|
943f22e5a9 | ||
|
|
11b45689f4 | ||
|
|
87f443532f | ||
|
|
0c672a0c37 | ||
|
|
dfd11ceb57 | ||
|
|
d865b7fd87 | ||
|
|
aa8a6b9c43 | ||
|
|
fe5313f42e | ||
|
|
499f739e2b | ||
|
|
4e0e738823 | ||
|
|
24360bf306 | ||
|
|
6fad3c2bbd | ||
|
|
2cf20de7ec | ||
|
|
3d8d3bb8ce | ||
|
|
80bcbe4885 | ||
|
|
32e4782ed8 | ||
|
|
613a51bdbb | ||
|
|
1c6de43701 | ||
|
|
6771530025 | ||
|
|
5876f367bc | ||
|
|
e263af2dd9 | ||
|
|
3a59911a2b | ||
|
|
bbf31e99c3 | ||
|
|
9d5bd42f3e | ||
|
|
e721dae6da | ||
|
|
af3106b144 | ||
|
|
5b55103575 | ||
|
|
ee4ecf929f | ||
|
|
8336556a6f | ||
|
|
709aad1d3b | ||
|
|
fb7ab4937c | ||
|
|
5df1726d80 | ||
|
|
9fdb568843 | ||
|
|
8e76f56f89 | ||
|
|
05d3791577 | ||
|
|
d00dd7eb90 | ||
|
|
8d2e404017 | ||
|
|
95eb2af25e | ||
|
|
cbc00a501b | ||
|
|
480645d897 | ||
|
|
997c767c95 | ||
|
|
5a54e1dc9a | ||
|
|
49b1952566 | ||
|
|
e73edc2fce | ||
|
|
409652e874 | ||
|
|
1d3fb6431f | ||
|
|
76cfada60f | ||
|
|
ac45f80551 | ||
|
|
5ea85f086a | ||
|
|
e3f657746c |
3
.github/actions/setup/action.yml
vendored
3
.github/actions/setup/action.yml
vendored
@@ -12,13 +12,14 @@ inputs:
|
||||
runs:
|
||||
using: "composite"
|
||||
steps:
|
||||
- name: Install apt deps
|
||||
- name: Install apt deps & cleanup
|
||||
if: ${{ contains(inputs.dependencies, 'system') || contains(inputs.dependencies, 'python') }}
|
||||
shell: bash
|
||||
run: |
|
||||
sudo apt-get remove --purge man-db
|
||||
sudo apt-get update
|
||||
sudo apt-get install --no-install-recommends -y libpq-dev openssl libxmlsec1-dev pkg-config gettext libkrb5-dev krb5-kdc krb5-user krb5-admin-server
|
||||
sudo rm -rf /usr/local/lib/android
|
||||
- name: Install uv
|
||||
if: ${{ contains(inputs.dependencies, 'python') }}
|
||||
uses: astral-sh/setup-uv@ed21f2f24f8dd64503750218de024bcf64c7250a # v5
|
||||
|
||||
2
.github/actions/test-results/action.yml
vendored
2
.github/actions/test-results/action.yml
vendored
@@ -20,7 +20,7 @@ runs:
|
||||
- name: PostgreSQL Logs
|
||||
shell: bash
|
||||
run: |
|
||||
if [[ $ACTIONS_RUNNER_DEBUG == 'true' || $ACTIONS_STEP_DEBUG == 'true' ]]; then
|
||||
if [[ $RUNNER_DEBUG == '1' ]]; then
|
||||
docker stop setup-postgresql-1
|
||||
echo "::group::PostgreSQL Logs"
|
||||
docker logs setup-postgresql-1
|
||||
|
||||
2
.github/workflows/ci-main.yml
vendored
2
.github/workflows/ci-main.yml
vendored
@@ -84,7 +84,7 @@ jobs:
|
||||
# Current version family based on
|
||||
current_version_family=$(cat internal/constants/VERSION | grep -vE -- 'rc[0-9]+$' || true)
|
||||
if [[ -n $current_version_family ]]; then
|
||||
prev_stable=$current_version_family
|
||||
prev_stable="version/${current_version_family}"
|
||||
fi
|
||||
echo "::notice::Checking out ${prev_stable} as stable version..."
|
||||
git checkout ${prev_stable}
|
||||
|
||||
5
.github/workflows/release-publish.yml
vendored
5
.github/workflows/release-publish.yml
vendored
@@ -101,7 +101,7 @@ jobs:
|
||||
- name: make empty clients
|
||||
run: |
|
||||
mkdir -p ./gen-ts-api
|
||||
mkdir -p ./gen-go-api
|
||||
make gen-client-go
|
||||
- name: Docker Login Registry
|
||||
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # v3
|
||||
with:
|
||||
@@ -160,6 +160,9 @@ jobs:
|
||||
run: |
|
||||
npm ci
|
||||
npm run build-proxy
|
||||
- name: Build API client
|
||||
run: |
|
||||
make gen-client-go
|
||||
- name: Build outpost
|
||||
run: |
|
||||
set -x
|
||||
|
||||
4
.github/workflows/release-tag.yml
vendored
4
.github/workflows/release-tag.yml
vendored
@@ -49,8 +49,12 @@ jobs:
|
||||
test:
|
||||
name: Pre-release test
|
||||
runs-on: ubuntu-latest
|
||||
needs:
|
||||
- check-inputs
|
||||
steps:
|
||||
- uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v5
|
||||
with:
|
||||
ref: "version-${{ needs.check-inputs.outputs.major_version }}"
|
||||
- run: make test-docker
|
||||
bump-authentik:
|
||||
name: Bump authentik version
|
||||
|
||||
@@ -40,7 +40,7 @@ packages/tsconfig @goauthentik/frontend
|
||||
# Web
|
||||
web/ @goauthentik/frontend
|
||||
# Locale
|
||||
locale/ @goauthentik/backend @goauthentik/frontend
|
||||
/locale/ @goauthentik/backend @goauthentik/frontend
|
||||
web/xliff/ @goauthentik/backend @goauthentik/frontend
|
||||
# Docs
|
||||
website/ @goauthentik/docs
|
||||
|
||||
2
Makefile
2
Makefile
@@ -327,6 +327,6 @@ ci-pending-migrations: ci--meta-debug
|
||||
uv run ak makemigrations --check
|
||||
|
||||
ci-test: ci--meta-debug
|
||||
uv run coverage run manage.py test --keepdb --randomly-seed ${CI_TEST_SEED} authentik
|
||||
uv run coverage run manage.py test --keepdb authentik
|
||||
uv run coverage report
|
||||
uv run coverage xml
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
from functools import lru_cache
|
||||
from os import environ
|
||||
|
||||
VERSION = "2025.12.0-rc1"
|
||||
VERSION = "2025.12.1"
|
||||
ENV_GIT_HASH_KEY = "GIT_BUILD_HASH"
|
||||
|
||||
|
||||
|
||||
@@ -37,7 +37,7 @@ class VersionSerializer(PassiveSerializer):
|
||||
|
||||
def get_version_latest(self, _) -> str:
|
||||
"""Get latest version from cache"""
|
||||
if get_current_tenant().schema_name == get_public_schema_name():
|
||||
if get_current_tenant().schema_name != get_public_schema_name():
|
||||
return authentik_version()
|
||||
version_in_cache = cache.get(VERSION_CACHE_KEY)
|
||||
if not version_in_cache: # pragma: no cover
|
||||
|
||||
@@ -62,10 +62,10 @@ class TestSanitizeFilePath(TestCase):
|
||||
"test@file.png", # @
|
||||
"test#file.png", # #
|
||||
"test$file.png", # $
|
||||
"test%file.png", # %
|
||||
"test%file.png", # % (but %(theme)s is allowed)
|
||||
"test&file.png", # &
|
||||
"test*file.png", # *
|
||||
"test(file).png", # parentheses
|
||||
"test(file).png", # parentheses (but %(theme)s is allowed)
|
||||
"test[file].png", # brackets
|
||||
"test{file}.png", # braces
|
||||
]
|
||||
@@ -108,3 +108,30 @@ class TestSanitizeFilePath(TestCase):
|
||||
|
||||
with self.assertRaises(ValidationError):
|
||||
validate_file_name(path)
|
||||
|
||||
def test_sanitize_theme_variable_valid(self):
|
||||
"""Test sanitizing filename with %(theme)s variable"""
|
||||
# These should all be valid
|
||||
validate_file_name("logo-%(theme)s.png")
|
||||
validate_file_name("brand/logo-%(theme)s.svg")
|
||||
validate_file_name("images/icon-%(theme)s.png")
|
||||
validate_file_name("%(theme)s/logo.png")
|
||||
validate_file_name("brand/%(theme)s/logo.png")
|
||||
|
||||
def test_sanitize_theme_variable_multiple(self):
|
||||
"""Test sanitizing filename with multiple %(theme)s variables"""
|
||||
validate_file_name("%(theme)s/logo-%(theme)s.png")
|
||||
|
||||
def test_sanitize_theme_variable_invalid_format(self):
|
||||
"""Test that partial or malformed theme variables are rejected"""
|
||||
invalid_paths = [
|
||||
"test%(theme.png", # missing )s
|
||||
"test%theme)s.png", # missing (
|
||||
"test%(themes).png", # wrong variable name
|
||||
"test%(THEME)s.png", # wrong case
|
||||
"test%()s.png", # empty variable name
|
||||
]
|
||||
|
||||
for path in invalid_paths:
|
||||
with self.assertRaises(ValidationError):
|
||||
validate_file_name(path)
|
||||
|
||||
@@ -12,6 +12,10 @@ from authentik.admin.files.usage import FileUsage
|
||||
MAX_FILE_NAME_LENGTH = 1024
|
||||
MAX_PATH_COMPONENT_LENGTH = 255
|
||||
|
||||
# Theme variable placeholder that can be used in file paths
|
||||
# This allows for theme-specific files like logo-%(theme)s.png
|
||||
THEME_VARIABLE = "%(theme)s"
|
||||
|
||||
|
||||
def validate_file_name(name: str) -> None:
|
||||
if PassthroughBackend(FileUsage.MEDIA).supports_file(name) or StaticBackend(
|
||||
@@ -39,12 +43,17 @@ def validate_upload_file_name(
|
||||
if not name:
|
||||
raise ValidationError(_("File name cannot be empty"))
|
||||
|
||||
# Same regex is used in the frontend as well
|
||||
if not re.match(r"^[a-zA-Z0-9._/-]+$", name):
|
||||
# Allow %(theme)s placeholder for theme-specific files
|
||||
# We temporarily replace it for validation, then check the result
|
||||
name_for_validation = name.replace(THEME_VARIABLE, "theme")
|
||||
|
||||
# Same regex is used in the frontend as well (without %(theme)s handling there)
|
||||
if not re.match(r"^[a-zA-Z0-9._/-]+$", name_for_validation):
|
||||
raise ValidationError(
|
||||
_(
|
||||
"File name can only contain letters (a-z, A-Z), numbers (0-9), "
|
||||
"dots (.), hyphens (-), underscores (_), and forward slashes (/)"
|
||||
"dots (.), hyphens (-), underscores (_), forward slashes (/), "
|
||||
"and the special placeholder %(theme)s for theme-specific files"
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@@ -15,7 +15,9 @@ class Pagination(pagination.PageNumberPagination):
|
||||
|
||||
def get_page_size(self, request):
|
||||
if self.page_size_query_param in request.query_params:
|
||||
return min(super().get_page_size(request), request.tenant.pagination_max_page_size)
|
||||
page_size = super().get_page_size(request)
|
||||
if page_size is not None:
|
||||
return min(super().get_page_size(request), request.tenant.pagination_max_page_size)
|
||||
return request.tenant.pagination_default_page_size
|
||||
|
||||
def get_paginated_response(self, data):
|
||||
|
||||
@@ -31,6 +31,7 @@ class Capabilities(models.TextChoices):
|
||||
"""Define capabilities which influence which APIs can/should be used"""
|
||||
|
||||
CAN_SAVE_MEDIA = "can_save_media"
|
||||
CAN_SAVE_REPORTS = "can_save_reports"
|
||||
CAN_GEO_IP = "can_geo_ip"
|
||||
CAN_ASN = "can_asn"
|
||||
CAN_IMPERSONATE = "can_impersonate"
|
||||
@@ -70,6 +71,8 @@ class ConfigView(APIView):
|
||||
caps = []
|
||||
if get_file_manager(FileUsage.MEDIA).manageable:
|
||||
caps.append(Capabilities.CAN_SAVE_MEDIA)
|
||||
if get_file_manager(FileUsage.REPORTS).manageable:
|
||||
caps.append(Capabilities.CAN_SAVE_REPORTS)
|
||||
for processor in get_context_processors():
|
||||
if cap := processor.capability():
|
||||
caps.append(cap)
|
||||
|
||||
@@ -8,45 +8,62 @@ metadata:
|
||||
- Application (icon)
|
||||
- Source (icon)
|
||||
- Flow (background)
|
||||
- Endpoint Enrollment token (key)
|
||||
entries:
|
||||
- model: authentik_core.token
|
||||
identifiers:
|
||||
identifier: "%(uid)s-token"
|
||||
attrs:
|
||||
key: "%(uid)s"
|
||||
user: "%(user)s"
|
||||
intent: api
|
||||
- model: authentik_core.application
|
||||
identifiers:
|
||||
slug: "%(uid)s-app"
|
||||
attrs:
|
||||
name: "%(uid)s-app"
|
||||
icon: https://goauthentik.io/img/icon.png
|
||||
- model: authentik_sources_oauth.oauthsource
|
||||
identifiers:
|
||||
slug: "%(uid)s-source"
|
||||
attrs:
|
||||
name: "%(uid)s-source"
|
||||
provider_type: azuread
|
||||
consumer_key: "%(uid)s"
|
||||
consumer_secret: "%(uid)s"
|
||||
icon: https://goauthentik.io/img/icon.png
|
||||
- model: authentik_flows.flow
|
||||
identifiers:
|
||||
slug: "%(uid)s-flow"
|
||||
attrs:
|
||||
name: "%(uid)s-flow"
|
||||
title: "%(uid)s-flow"
|
||||
designation: authentication
|
||||
background: https://goauthentik.io/img/icon.png
|
||||
- model: authentik_core.user
|
||||
identifiers:
|
||||
username: "%(uid)s"
|
||||
attrs:
|
||||
name: "%(uid)s"
|
||||
password: "%(uid)s"
|
||||
- model: authentik_core.user
|
||||
identifiers:
|
||||
username: "%(uid)s-no-password"
|
||||
attrs:
|
||||
name: "%(uid)s"
|
||||
token:
|
||||
- model: authentik_core.token
|
||||
identifiers:
|
||||
identifier: "%(uid)s-token"
|
||||
attrs:
|
||||
key: "%(uid)s"
|
||||
user: "%(user)s"
|
||||
intent: api
|
||||
app:
|
||||
- model: authentik_core.application
|
||||
identifiers:
|
||||
slug: "%(uid)s-app"
|
||||
attrs:
|
||||
name: "%(uid)s-app"
|
||||
icon: https://goauthentik.io/img/icon.png
|
||||
source:
|
||||
- model: authentik_sources_oauth.oauthsource
|
||||
identifiers:
|
||||
slug: "%(uid)s-source"
|
||||
attrs:
|
||||
name: "%(uid)s-source"
|
||||
provider_type: azuread
|
||||
consumer_key: "%(uid)s"
|
||||
consumer_secret: "%(uid)s"
|
||||
icon: https://goauthentik.io/img/icon.png
|
||||
flow:
|
||||
- model: authentik_flows.flow
|
||||
identifiers:
|
||||
slug: "%(uid)s-flow"
|
||||
attrs:
|
||||
name: "%(uid)s-flow"
|
||||
title: "%(uid)s-flow"
|
||||
designation: authentication
|
||||
background: https://goauthentik.io/img/icon.png
|
||||
user:
|
||||
- model: authentik_core.user
|
||||
identifiers:
|
||||
username: "%(uid)s"
|
||||
attrs:
|
||||
name: "%(uid)s"
|
||||
password: "%(uid)s"
|
||||
- model: authentik_core.user
|
||||
identifiers:
|
||||
username: "%(uid)s-no-password"
|
||||
attrs:
|
||||
name: "%(uid)s"
|
||||
endpoint:
|
||||
- model: authentik_endpoints_connectors_agent.agentconnector
|
||||
id: connector
|
||||
identifiers:
|
||||
name: "%(uid)s"
|
||||
- model: authentik_endpoints_connectors_agent.enrollmenttoken
|
||||
identifiers:
|
||||
name: "%(uid)s"
|
||||
attrs:
|
||||
key: "%(uid)s"
|
||||
connector: !KeyOf connector
|
||||
|
||||
@@ -5,6 +5,7 @@ from django.test import TransactionTestCase
|
||||
from authentik.blueprints.v1.importer import Importer
|
||||
from authentik.core.models import Token, User
|
||||
from authentik.core.tests.utils import create_test_admin_user
|
||||
from authentik.endpoints.connectors.agent.models import EnrollmentToken
|
||||
from authentik.lib.generators import generate_id
|
||||
from authentik.lib.tests.utils import load_fixture
|
||||
|
||||
@@ -29,12 +30,18 @@ class TestBlueprintsV1ConditionalFields(TransactionTestCase):
|
||||
|
||||
def test_user(self):
|
||||
"""Test user"""
|
||||
user: User = User.objects.filter(username=self.uid).first()
|
||||
user = User.objects.filter(username=self.uid).first()
|
||||
self.assertIsNotNone(user)
|
||||
self.assertTrue(user.check_password(self.uid))
|
||||
|
||||
def test_user_null(self):
|
||||
"""Test user"""
|
||||
user: User = User.objects.filter(username=f"{self.uid}-no-password").first()
|
||||
user = User.objects.filter(username=f"{self.uid}-no-password").first()
|
||||
self.assertIsNotNone(user)
|
||||
self.assertFalse(user.has_usable_password())
|
||||
|
||||
def test_enrollment_token(self):
|
||||
"""Test endpoint enrollment token"""
|
||||
token = EnrollmentToken.objects.filter(name=self.uid).first()
|
||||
self.assertIsNotNone(token)
|
||||
self.assertEqual(token.key, self.uid)
|
||||
|
||||
@@ -149,7 +149,7 @@ class TestBlueprintsV1Tasks(TransactionTestCase):
|
||||
instance.status,
|
||||
BlueprintInstanceStatus.UNKNOWN,
|
||||
)
|
||||
apply_blueprint(instance.pk)
|
||||
apply_blueprint.send(instance.pk).get_result(block=True)
|
||||
instance.refresh_from_db()
|
||||
self.assertEqual(instance.last_applied_hash, "")
|
||||
self.assertEqual(
|
||||
|
||||
@@ -37,14 +37,21 @@ class ApplyBlueprintMetaSerializer(PassiveSerializer):
|
||||
return super().validate(attrs)
|
||||
|
||||
def create(self, validated_data: dict) -> MetaResult:
|
||||
from authentik.blueprints.v1.tasks import apply_blueprint
|
||||
from authentik.blueprints.v1.importer import Importer
|
||||
|
||||
if not self.blueprint_instance:
|
||||
LOGGER.info("Blueprint does not exist, but not required")
|
||||
return MetaResult()
|
||||
LOGGER.debug("Applying blueprint from meta model", blueprint=self.blueprint_instance)
|
||||
|
||||
apply_blueprint(self.blueprint_instance.pk)
|
||||
# Apply blueprint directly using Importer to avoid task context requirements
|
||||
# and prevent deadlocks when called from within another blueprint task
|
||||
blueprint_content = self.blueprint_instance.retrieve()
|
||||
importer = Importer.from_string(blueprint_content, self.blueprint_instance.context)
|
||||
valid, logs = importer.validate()
|
||||
[log.log() for log in logs]
|
||||
if valid:
|
||||
importer.apply()
|
||||
return MetaResult()
|
||||
|
||||
|
||||
|
||||
@@ -12,7 +12,6 @@ from django.db import DatabaseError, InternalError, ProgrammingError
|
||||
from django.utils.text import slugify
|
||||
from django.utils.timezone import now
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
from django_dramatiq_postgres.middleware import CurrentTaskNotFound
|
||||
from dramatiq.actor import actor
|
||||
from dramatiq.middleware import Middleware
|
||||
from structlog.stdlib import get_logger
|
||||
@@ -40,7 +39,6 @@ from authentik.events.utils import sanitize_dict
|
||||
from authentik.lib.config import CONFIG
|
||||
from authentik.tasks.apps import PRIORITY_HIGH
|
||||
from authentik.tasks.middleware import CurrentTask
|
||||
from authentik.tasks.models import Task
|
||||
from authentik.tasks.schedules.models import Schedule
|
||||
from authentik.tenants.models import Tenant
|
||||
|
||||
@@ -191,10 +189,7 @@ def check_blueprint_v1_file(blueprint: BlueprintFile):
|
||||
|
||||
@actor(description=_("Apply single blueprint."))
|
||||
def apply_blueprint(instance_pk: UUID):
|
||||
try:
|
||||
self = CurrentTask.get_task()
|
||||
except CurrentTaskNotFound:
|
||||
self = Task()
|
||||
self = CurrentTask.get_task()
|
||||
self.set_uid(str(instance_pk))
|
||||
instance: BlueprintInstance | None = None
|
||||
try:
|
||||
|
||||
@@ -33,6 +33,16 @@ from authentik.endpoints.connectors.agent.auth import AgentAuth
|
||||
from authentik.rbac.api.roles import RoleSerializer
|
||||
from authentik.rbac.decorators import permission_required
|
||||
|
||||
PARTIAL_USER_SERIALIZER_MODEL_FIELDS = [
|
||||
"pk",
|
||||
"username",
|
||||
"name",
|
||||
"is_active",
|
||||
"last_login",
|
||||
"email",
|
||||
"attributes",
|
||||
]
|
||||
|
||||
|
||||
class PartialUserSerializer(ModelSerializer):
|
||||
"""Partial User Serializer, does not include child relations."""
|
||||
@@ -42,16 +52,7 @@ class PartialUserSerializer(ModelSerializer):
|
||||
|
||||
class Meta:
|
||||
model = User
|
||||
fields = [
|
||||
"pk",
|
||||
"username",
|
||||
"name",
|
||||
"is_active",
|
||||
"last_login",
|
||||
"email",
|
||||
"attributes",
|
||||
"uid",
|
||||
]
|
||||
fields = PARTIAL_USER_SERIALIZER_MODEL_FIELDS + ["uid"]
|
||||
|
||||
|
||||
class RelatedGroupSerializer(ModelSerializer):
|
||||
@@ -84,6 +85,7 @@ class GroupSerializer(ModelSerializer):
|
||||
source="roles",
|
||||
required=False,
|
||||
)
|
||||
inherited_roles_obj = SerializerMethodField(allow_null=True)
|
||||
num_pk = IntegerField(read_only=True)
|
||||
|
||||
@property
|
||||
@@ -107,6 +109,13 @@ class GroupSerializer(ModelSerializer):
|
||||
return True
|
||||
return str(request.query_params.get("include_parents", "false")).lower() == "true"
|
||||
|
||||
@property
|
||||
def _should_include_inherited_roles(self) -> bool:
|
||||
request: Request = self.context.get("request", None)
|
||||
if not request:
|
||||
return True
|
||||
return str(request.query_params.get("include_inherited_roles", "false")).lower() == "true"
|
||||
|
||||
@extend_schema_field(PartialUserSerializer(many=True))
|
||||
def get_users_obj(self, instance: Group) -> list[PartialUserSerializer] | None:
|
||||
if not self._should_include_users:
|
||||
@@ -125,6 +134,15 @@ class GroupSerializer(ModelSerializer):
|
||||
return None
|
||||
return RelatedGroupSerializer(instance.parents, many=True).data
|
||||
|
||||
@extend_schema_field(RoleSerializer(many=True))
|
||||
def get_inherited_roles_obj(self, instance: Group) -> list | None:
|
||||
"""Return only inherited roles from ancestor groups (excludes direct roles)"""
|
||||
if not self._should_include_inherited_roles:
|
||||
return None
|
||||
direct_role_pks = instance.roles.values_list("pk", flat=True)
|
||||
inherited_roles = instance.all_roles().exclude(pk__in=direct_role_pks)
|
||||
return RoleSerializer(inherited_roles, many=True).data
|
||||
|
||||
def validate_is_superuser(self, superuser: bool):
|
||||
"""Ensure that the user creating this group has permissions to set the superuser flag"""
|
||||
request: Request = self.context.get("request", None)
|
||||
@@ -166,6 +184,7 @@ class GroupSerializer(ModelSerializer):
|
||||
"attributes",
|
||||
"roles",
|
||||
"roles_obj",
|
||||
"inherited_roles_obj",
|
||||
"children",
|
||||
"children_obj",
|
||||
]
|
||||
@@ -255,14 +274,21 @@ class GroupViewSet(UsedByMixin, ModelViewSet):
|
||||
return [
|
||||
StrField(Group, "name"),
|
||||
BoolField(Group, "is_superuser", nullable=True),
|
||||
JSONSearchField(Group, "attributes", suggest_nested=False),
|
||||
JSONSearchField(Group, "attributes"),
|
||||
]
|
||||
|
||||
def get_queryset(self):
|
||||
base_qs = Group.objects.all().prefetch_related("roles")
|
||||
|
||||
if self.serializer_class(context={"request": self.request})._should_include_users:
|
||||
base_qs = base_qs.prefetch_related("users")
|
||||
# Only fetch fields needed by PartialUserSerializer to reduce DB load and instantiation
|
||||
# time
|
||||
base_qs = base_qs.prefetch_related(
|
||||
Prefetch(
|
||||
"users",
|
||||
queryset=User.objects.all().only(*PARTIAL_USER_SERIALIZER_MODEL_FIELDS),
|
||||
)
|
||||
)
|
||||
else:
|
||||
base_qs = base_qs.prefetch_related(
|
||||
Prefetch("users", queryset=User.objects.all().only("id"))
|
||||
@@ -281,6 +307,7 @@ class GroupViewSet(UsedByMixin, ModelViewSet):
|
||||
OpenApiParameter("include_users", bool, default=True),
|
||||
OpenApiParameter("include_children", bool, default=False),
|
||||
OpenApiParameter("include_parents", bool, default=False),
|
||||
OpenApiParameter("include_inherited_roles", bool, default=False),
|
||||
]
|
||||
)
|
||||
def list(self, request, *args, **kwargs):
|
||||
@@ -291,6 +318,7 @@ class GroupViewSet(UsedByMixin, ModelViewSet):
|
||||
OpenApiParameter("include_users", bool, default=True),
|
||||
OpenApiParameter("include_children", bool, default=False),
|
||||
OpenApiParameter("include_parents", bool, default=False),
|
||||
OpenApiParameter("include_inherited_roles", bool, default=False),
|
||||
]
|
||||
)
|
||||
def retrieve(self, request, *args, **kwargs):
|
||||
|
||||
@@ -518,7 +518,7 @@ class UserViewSet(
|
||||
StrField(User, "path"),
|
||||
BoolField(User, "is_active", nullable=True),
|
||||
ChoiceSearchField(User, "type"),
|
||||
JSONSearchField(User, "attributes", suggest_nested=False),
|
||||
JSONSearchField(User, "attributes"),
|
||||
]
|
||||
|
||||
def get_queryset(self):
|
||||
|
||||
@@ -18,10 +18,9 @@ def migrate_object_permissions(apps: Apps, schema_editor: BaseDatabaseSchemaEdit
|
||||
RoleModelPermission = apps.get_model("guardian", "RoleModelPermission")
|
||||
|
||||
def get_role_for_user_id(user_id: int) -> Role:
|
||||
name = f"ak-managed-role--user-{user_id}"
|
||||
name = f"ak-migrated-role--user-{user_id}"
|
||||
role, created = Role.objects.using(db_alias).get_or_create(
|
||||
name=name,
|
||||
managed=name,
|
||||
)
|
||||
if created:
|
||||
role.users.add(user_id)
|
||||
@@ -32,11 +31,10 @@ def migrate_object_permissions(apps: Apps, schema_editor: BaseDatabaseSchemaEdit
|
||||
if not role:
|
||||
# Every django group should already have a role, so this should never happen.
|
||||
# But let's be nice.
|
||||
name = f"ak-managed-role--group-{group_id}"
|
||||
name = f"ak-migrated-role--group-{group_id}"
|
||||
role, created = Role.objects.using(db_alias).get_or_create(
|
||||
group_id=group_id,
|
||||
name=name,
|
||||
managed=name,
|
||||
)
|
||||
if created:
|
||||
role.group_id = group_id
|
||||
|
||||
@@ -66,9 +66,12 @@ class SessionStore(SessionBase):
|
||||
def decode(self, session_data):
|
||||
try:
|
||||
return pickle.loads(session_data) # nosec
|
||||
except pickle.PickleError:
|
||||
# ValueError, unpickling exceptions. If any of these happen, just return an empty
|
||||
# dictionary (an empty session)
|
||||
except (pickle.PickleError, AttributeError, TypeError):
|
||||
# PickleError, ValueError - unpickling exceptions
|
||||
# AttributeError - can happen when Django model fields (e.g., FileField) are unpickled
|
||||
# and their descriptors fail to initialize (e.g., missing storage)
|
||||
# TypeError - can happen with incompatible pickled objects
|
||||
# If any of these happen, just return an empty dictionary (an empty session)
|
||||
pass
|
||||
return {}
|
||||
|
||||
|
||||
@@ -35,8 +35,13 @@ def clean_expired_models():
|
||||
LOGGER.debug("Expired models", model=cls, amount=amount)
|
||||
self.info(f"Expired {amount} {cls._meta.verbose_name_plural}")
|
||||
clear_expired_cache()
|
||||
Message.delete_expired()
|
||||
GroupChannel.delete_expired()
|
||||
for cls in [Message, GroupChannel]:
|
||||
objects = cls.objects.all().filter(expires__lt=now())
|
||||
amount = objects.count()
|
||||
for obj in chunked_queryset(objects):
|
||||
obj.delete()
|
||||
LOGGER.debug("Expired models", model=cls, amount=amount)
|
||||
self.info(f"Expired {amount} {cls._meta.verbose_name_plural}")
|
||||
|
||||
|
||||
@actor(description=_("Remove temporary users created by SAML Sources."))
|
||||
|
||||
@@ -10,15 +10,23 @@
|
||||
{% elif ui_theme == "light" %}
|
||||
<meta name="color-scheme" content="light" />
|
||||
<meta name="theme-color" content="#ffffff">
|
||||
{% else %}
|
||||
{% else %}
|
||||
<script data-id="theme-script">
|
||||
"use strict";
|
||||
|
||||
(function () {
|
||||
try {
|
||||
/* Ignore older theme names */
|
||||
let locallyStoredTheme = window.localStorage?.getItem("theme") || null;
|
||||
if (typeof locallyStoredTheme === "string") {
|
||||
locallyStoredTheme = locallyStoredTheme.trim();
|
||||
}
|
||||
if (!(["auto", "light", "dark"].includes(locallyStoredTheme))) {
|
||||
locallyStoredTheme = null;
|
||||
}
|
||||
|
||||
const initialThemeChoice =
|
||||
new URLSearchParams(window.location.search).get("theme") ||
|
||||
window.localStorage?.getItem("theme");
|
||||
new URLSearchParams(window.location.search).get("theme") || locallyStoredTheme;
|
||||
|
||||
const themeChoice =
|
||||
initialThemeChoice || document.documentElement.dataset.themeChoice || "auto";
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
from django.db import models
|
||||
from rest_framework.fields import (
|
||||
BooleanField,
|
||||
CharField,
|
||||
@@ -14,6 +15,12 @@ from authentik.endpoints.models import Device
|
||||
from authentik.lib.utils.time import timedelta_from_string
|
||||
from authentik.providers.oauth2.views.jwks import JWKSView
|
||||
|
||||
try:
|
||||
from authentik.enterprise.models import LicenseUsageStatus
|
||||
except ImportError:
|
||||
|
||||
class LicenseUsageStatus(models.TextChoices): ...
|
||||
|
||||
|
||||
class AgentConfigSerializer(PassiveSerializer):
|
||||
|
||||
@@ -29,6 +36,7 @@ class AgentConfigSerializer(PassiveSerializer):
|
||||
auth_terminate_session_on_expiry = BooleanField()
|
||||
|
||||
system_config = SerializerMethodField()
|
||||
license_status = SerializerMethodField(required=False, allow_null=True)
|
||||
|
||||
def get_device_id(self, instance: AgentConnector) -> str:
|
||||
device: Device = self.context["device"]
|
||||
@@ -54,6 +62,14 @@ class AgentConfigSerializer(PassiveSerializer):
|
||||
def get_system_config(self, instance: AgentConnector) -> ConfigSerializer:
|
||||
return ConfigView.get_config(self.context["request"]).data
|
||||
|
||||
def get_license_status(self, instance: AgentConnector) -> "LicenseUsageStatus":
|
||||
try:
|
||||
from authentik.enterprise.license import LicenseKey
|
||||
|
||||
return LicenseKey.cached_summary().status
|
||||
except ModuleNotFoundError:
|
||||
return None
|
||||
|
||||
|
||||
class EnrollSerializer(PassiveSerializer):
|
||||
|
||||
|
||||
@@ -1,11 +1,13 @@
|
||||
from typing import cast
|
||||
|
||||
from django.utils.timezone import now
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
from drf_spectacular.types import OpenApiTypes
|
||||
from drf_spectacular.utils import OpenApiResponse, extend_schema
|
||||
from drf_spectacular.utils import OpenApiParameter, OpenApiResponse, extend_schema
|
||||
from rest_framework.decorators import action
|
||||
from rest_framework.exceptions import PermissionDenied, ValidationError
|
||||
from rest_framework.fields import ChoiceField
|
||||
from rest_framework.permissions import IsAuthenticated
|
||||
from rest_framework.relations import PrimaryKeyRelatedField
|
||||
from rest_framework.request import Request
|
||||
from rest_framework.response import Response
|
||||
@@ -22,6 +24,9 @@ from authentik.endpoints.connectors.agent.api.agent import (
|
||||
from authentik.endpoints.connectors.agent.auth import (
|
||||
AgentAuth,
|
||||
AgentEnrollmentAuth,
|
||||
DeviceAuthFedAuthentication,
|
||||
agent_auth_issue_token,
|
||||
check_device_policies,
|
||||
)
|
||||
from authentik.endpoints.connectors.agent.controller import MDMConfigResponseSerializer
|
||||
from authentik.endpoints.connectors.agent.models import (
|
||||
@@ -32,7 +37,10 @@ from authentik.endpoints.connectors.agent.models import (
|
||||
)
|
||||
from authentik.endpoints.facts import DeviceFacts, OSFamily
|
||||
from authentik.endpoints.models import Device
|
||||
from authentik.events.models import Event, EventAction
|
||||
from authentik.flows.planner import PLAN_CONTEXT_DEVICE
|
||||
from authentik.lib.utils.reflection import ConditionalInheritance
|
||||
from authentik.stages.password.stage import PLAN_CONTEXT_METHOD, PLAN_CONTEXT_METHOD_ARGS
|
||||
|
||||
|
||||
class AgentConnectorSerializer(ConnectorSerializer):
|
||||
@@ -163,3 +171,43 @@ class AgentConnectorViewSet(
|
||||
connection: AgentDeviceConnection = token.device
|
||||
connection.create_snapshot(data.validated_data)
|
||||
return Response(status=204)
|
||||
|
||||
@extend_schema(
|
||||
request=OpenApiTypes.NONE,
|
||||
parameters=[OpenApiParameter("device", OpenApiTypes.STR, location="query", required=True)],
|
||||
responses={
|
||||
200: AgentTokenResponseSerializer(),
|
||||
404: OpenApiResponse(description="Device not found"),
|
||||
},
|
||||
)
|
||||
@action(
|
||||
methods=["POST"],
|
||||
detail=False,
|
||||
pagination_class=None,
|
||||
filter_backends=[],
|
||||
permission_classes=[IsAuthenticated],
|
||||
authentication_classes=[DeviceAuthFedAuthentication],
|
||||
)
|
||||
def auth_fed(self, request: Request) -> Response:
|
||||
federated_token, device, connector = request.auth
|
||||
|
||||
policy_result = check_device_policies(device, federated_token.user, request._request)
|
||||
if not policy_result.passing:
|
||||
raise ValidationError(
|
||||
{"policy_result": "Policy denied access", "policy_messages": policy_result.messages}
|
||||
)
|
||||
|
||||
token, exp = agent_auth_issue_token(device, connector, federated_token.user)
|
||||
rel_exp = int((exp - now()).total_seconds())
|
||||
Event.new(
|
||||
EventAction.LOGIN,
|
||||
**{
|
||||
PLAN_CONTEXT_METHOD: "jwt",
|
||||
PLAN_CONTEXT_METHOD_ARGS: {
|
||||
"jwt": federated_token,
|
||||
"provider": federated_token.provider,
|
||||
},
|
||||
PLAN_CONTEXT_DEVICE: device,
|
||||
},
|
||||
).from_http(request, user=federated_token.user)
|
||||
return Response({"token": token, "expires_in": rel_exp})
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
from drf_spectacular.utils import OpenApiResponse, extend_schema
|
||||
from rest_framework.decorators import action
|
||||
from rest_framework.fields import CharField
|
||||
from rest_framework.request import Request
|
||||
from rest_framework.response import Response
|
||||
from rest_framework.viewsets import ModelViewSet
|
||||
|
||||
from authentik.blueprints.v1.importer import SERIALIZER_CONTEXT_BLUEPRINT
|
||||
from authentik.core.api.tokens import TokenViewSerializer
|
||||
from authentik.core.api.used_by import UsedByMixin
|
||||
from authentik.core.api.utils import ModelSerializer
|
||||
@@ -19,6 +21,11 @@ class EnrollmentTokenSerializer(ModelSerializer):
|
||||
source="device_group", read_only=True, required=False
|
||||
)
|
||||
|
||||
def __init__(self, *args, **kwargs) -> None:
|
||||
super().__init__(*args, **kwargs)
|
||||
if SERIALIZER_CONTEXT_BLUEPRINT in self.context:
|
||||
self.fields["key"] = CharField(required=False)
|
||||
|
||||
class Meta:
|
||||
model = EnrollmentToken
|
||||
fields = [
|
||||
|
||||
@@ -1,13 +1,28 @@
|
||||
from typing import Any
|
||||
|
||||
from django.http import HttpRequest
|
||||
from django.utils.timezone import now
|
||||
from drf_spectacular.extensions import OpenApiAuthenticationExtension
|
||||
from jwt import PyJWTError, decode, encode
|
||||
from rest_framework.authentication import BaseAuthentication, get_authorization_header
|
||||
from rest_framework.exceptions import PermissionDenied
|
||||
from rest_framework.request import Request
|
||||
from structlog.stdlib import get_logger
|
||||
|
||||
from authentik.api.authentication import IPCUser, validate_auth
|
||||
from authentik.core.middleware import CTX_AUTH_VIA
|
||||
from authentik.core.models import User
|
||||
from authentik.endpoints.connectors.agent.models import DeviceToken, EnrollmentToken
|
||||
from authentik.crypto.apps import MANAGED_KEY
|
||||
from authentik.crypto.models import CertificateKeyPair
|
||||
from authentik.endpoints.connectors.agent.models import AgentConnector, DeviceToken, EnrollmentToken
|
||||
from authentik.endpoints.models import Device
|
||||
from authentik.lib.utils.time import timedelta_from_string
|
||||
from authentik.policies.engine import PolicyEngine
|
||||
from authentik.policies.models import PolicyBindingModel
|
||||
from authentik.providers.oauth2.models import AccessToken, JWTAlgorithms, OAuth2Provider
|
||||
|
||||
LOGGER = get_logger()
|
||||
PLATFORM_ISSUER = "goauthentik.io/platform"
|
||||
|
||||
|
||||
class DeviceUser(IPCUser):
|
||||
@@ -40,3 +55,96 @@ class AgentAuth(BaseAuthentication):
|
||||
raise PermissionDenied()
|
||||
CTX_AUTH_VIA.set("endpoint_token")
|
||||
return (DeviceUser(), device_token)
|
||||
|
||||
|
||||
def agent_auth_issue_token(device: Device, connector: AgentConnector, user: User, **kwargs):
|
||||
kp = CertificateKeyPair.objects.filter(managed=MANAGED_KEY).first()
|
||||
if not kp:
|
||||
return None, None
|
||||
exp = now() + timedelta_from_string(connector.auth_session_duration)
|
||||
token = encode(
|
||||
{
|
||||
"iss": PLATFORM_ISSUER,
|
||||
"aud": str(device.pk),
|
||||
"iat": int(now().timestamp()),
|
||||
"exp": int(exp.timestamp()),
|
||||
"preferred_username": user.username,
|
||||
**kwargs,
|
||||
},
|
||||
kp.private_key,
|
||||
headers={
|
||||
"kid": kp.kid,
|
||||
},
|
||||
algorithm=JWTAlgorithms.from_private_key(kp.private_key),
|
||||
)
|
||||
return token, exp
|
||||
|
||||
|
||||
class DeviceAuthFedAuthentication(BaseAuthentication):
|
||||
|
||||
def authenticate(self, request):
|
||||
raw_token = validate_auth(get_authorization_header(request))
|
||||
if not raw_token:
|
||||
LOGGER.warning("Missing token")
|
||||
return None
|
||||
device = Device.filter_not_expired(name=request.query_params.get("device")).first()
|
||||
if not device:
|
||||
LOGGER.warning("Couldn't find device")
|
||||
return None
|
||||
connectors_for_device = AgentConnector.objects.filter(device__in=[device])
|
||||
connector = connectors_for_device.first()
|
||||
providers = OAuth2Provider.objects.filter(agentconnector__in=connectors_for_device)
|
||||
federated_token = AccessToken.objects.filter(
|
||||
token=raw_token, provider__in=providers
|
||||
).first()
|
||||
if not federated_token:
|
||||
LOGGER.warning("Couldn't lookup provider")
|
||||
return None
|
||||
_key, _alg = federated_token.provider.jwt_key
|
||||
try:
|
||||
decode(
|
||||
raw_token,
|
||||
_key.public_key(),
|
||||
algorithms=[_alg],
|
||||
options={
|
||||
"verify_aud": False,
|
||||
},
|
||||
)
|
||||
LOGGER.info(
|
||||
"successfully verified JWT with provider", provider=federated_token.provider.name
|
||||
)
|
||||
return (federated_token.user, (federated_token, device, connector))
|
||||
except (PyJWTError, ValueError, TypeError, AttributeError) as exc:
|
||||
LOGGER.warning("failed to verify JWT", exc=exc, provider=federated_token.provider.name)
|
||||
return None
|
||||
|
||||
|
||||
class DeviceFederationAuthSchema(OpenApiAuthenticationExtension):
|
||||
"""Auth schema"""
|
||||
|
||||
target_class = DeviceAuthFedAuthentication
|
||||
name = "device_federation"
|
||||
|
||||
def get_security_definition(self, auto_schema):
|
||||
"""Auth schema"""
|
||||
return {"type": "http", "scheme": "bearer"}
|
||||
|
||||
|
||||
def check_device_policies(device: Device, user: User, request: HttpRequest):
|
||||
"""Check policies bound to device group and device"""
|
||||
if device.access_group:
|
||||
result = check_pbm_policies(device.access_group, user, request)
|
||||
if result.passing:
|
||||
return result
|
||||
return check_pbm_policies(device, user, request)
|
||||
|
||||
|
||||
def check_pbm_policies(pbm: PolicyBindingModel, user: User, request: HttpRequest):
|
||||
policy_engine = PolicyEngine(pbm, user, request)
|
||||
policy_engine.use_cache = False
|
||||
policy_engine.empty_result = False
|
||||
policy_engine.mode = pbm.policy_engine_mode
|
||||
policy_engine.build()
|
||||
result = policy_engine.result
|
||||
LOGGER.debug("PolicyAccessView user_has_access", user=user.username, result=result, pbm=pbm.pk)
|
||||
return result
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from datetime import timedelta
|
||||
from hashlib import sha256
|
||||
from hmac import compare_digest
|
||||
|
||||
from django.http import HttpResponse
|
||||
@@ -8,7 +9,7 @@ from rest_framework.exceptions import ValidationError
|
||||
from rest_framework.fields import CharField, IntegerField
|
||||
|
||||
from authentik.crypto.models import CertificateKeyPair
|
||||
from authentik.endpoints.connectors.agent.models import DeviceToken
|
||||
from authentik.endpoints.connectors.agent.models import DeviceAuthenticationToken, DeviceToken
|
||||
from authentik.endpoints.models import Device, EndpointStage, StageMode
|
||||
from authentik.flows.challenge import (
|
||||
Challenge,
|
||||
@@ -20,6 +21,7 @@ from authentik.lib.generators import generate_id
|
||||
from authentik.lib.utils.time import timedelta_from_string
|
||||
from authentik.providers.oauth2.models import JWTAlgorithms
|
||||
|
||||
PLAN_CONTEXT_DEVICE_AUTH_TOKEN = "goauthentik.io/endpoints/device_auth_token" # nosec
|
||||
PLAN_CONTEXT_AGENT_ENDPOINT_CHALLENGE = "goauthentik.io/endpoints/connectors/agent/challenge"
|
||||
QS_CHALLENGE = "challenge"
|
||||
QS_CHALLENGE_RESPONSE = "response"
|
||||
@@ -85,12 +87,36 @@ class AuthenticatorEndpointStageView(ChallengeStageView):
|
||||
response_class = EndpointAgentChallengeResponse
|
||||
|
||||
def get(self, request, *args, **kwargs):
|
||||
# Check if we're in a device interactive auth flow, in which case we use that
|
||||
# to prove which device is being used
|
||||
if response := self.check_device_ia():
|
||||
return response
|
||||
stage: EndpointStage = self.executor.current_stage
|
||||
keypair = CertificateKeyPair.objects.filter(pk=stage.connector.challenge_key_id).first()
|
||||
if not keypair:
|
||||
return self.executor.stage_ok()
|
||||
return super().get(request, *args, **kwargs)
|
||||
|
||||
def check_device_ia(self):
|
||||
"""Check if we're in a device interactive authentication flow, and if so,
|
||||
there won't be a browser extension to talk to. However we can authenticate
|
||||
on the DTH header"""
|
||||
if PLAN_CONTEXT_DEVICE_AUTH_TOKEN not in self.executor.plan.context:
|
||||
return None
|
||||
auth_token: DeviceAuthenticationToken = self.executor.plan.context.get(
|
||||
PLAN_CONTEXT_DEVICE_AUTH_TOKEN
|
||||
)
|
||||
device_token_hash = self.request.headers.get("X-Authentik-Platform-Auth-DTH")
|
||||
if not device_token_hash:
|
||||
return None
|
||||
if not compare_digest(
|
||||
device_token_hash, sha256(auth_token.device_token.key.encode()).hexdigest()
|
||||
):
|
||||
return self.executor.stage_invalid("Invalid device token")
|
||||
self.logger.debug("Setting device based on DTH header")
|
||||
self.executor.plan.context[PLAN_CONTEXT_DEVICE] = auth_token.device
|
||||
return self.executor.stage_ok()
|
||||
|
||||
def get_challenge(self, *args, **kwargs) -> Challenge:
|
||||
stage: EndpointStage = self.executor.current_stage
|
||||
keypair = CertificateKeyPair.objects.get(pk=stage.connector.challenge_key_id)
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
from hashlib import sha256
|
||||
from json import loads
|
||||
|
||||
from django.urls import reverse
|
||||
@@ -7,10 +8,14 @@ from authentik.core.tests.utils import create_test_cert, create_test_flow
|
||||
from authentik.endpoints.connectors.agent.models import (
|
||||
AgentConnector,
|
||||
AgentDeviceConnection,
|
||||
DeviceAuthenticationToken,
|
||||
DeviceToken,
|
||||
EnrollmentToken,
|
||||
)
|
||||
from authentik.endpoints.connectors.agent.stage import PLAN_CONTEXT_AGENT_ENDPOINT_CHALLENGE
|
||||
from authentik.endpoints.connectors.agent.stage import (
|
||||
PLAN_CONTEXT_AGENT_ENDPOINT_CHALLENGE,
|
||||
PLAN_CONTEXT_DEVICE_AUTH_TOKEN,
|
||||
)
|
||||
from authentik.endpoints.models import Device, EndpointStage, StageMode
|
||||
from authentik.flows.models import FlowStageBinding
|
||||
from authentik.flows.planner import PLAN_CONTEXT_DEVICE
|
||||
@@ -35,6 +40,11 @@ class TestEndpointStage(FlowTestCase):
|
||||
device=self.connection,
|
||||
key=generate_id(),
|
||||
)
|
||||
self.device_auth_token = DeviceAuthenticationToken.objects.create(
|
||||
device=self.device,
|
||||
device_token=self.device_token,
|
||||
connector=self.connector,
|
||||
)
|
||||
|
||||
def test_endpoint_stage(self):
|
||||
flow = create_test_flow()
|
||||
@@ -194,3 +204,31 @@ class TestEndpointStage(FlowTestCase):
|
||||
"response": [{"string": "Invalid challenge response", "code": "invalid"}]
|
||||
},
|
||||
)
|
||||
|
||||
def test_endpoint_stage_ia_dth(self):
|
||||
"""Test with DTH"""
|
||||
flow = create_test_flow()
|
||||
stage = EndpointStage.objects.create(connector=self.connector)
|
||||
FlowStageBinding.objects.create(stage=stage, target=flow, order=0)
|
||||
|
||||
# Send an "invalid" request first, to populate the flow plan
|
||||
res = self.client.get(
|
||||
reverse("authentik_api:flow-executor", kwargs={"flow_slug": flow.slug}),
|
||||
)
|
||||
plan = self.get_flow_plan()
|
||||
plan.context[PLAN_CONTEXT_DEVICE_AUTH_TOKEN] = DeviceAuthenticationToken.objects.get(
|
||||
pk=self.device_auth_token.pk
|
||||
)
|
||||
self.set_flow_plan(plan)
|
||||
|
||||
with self.assertFlowFinishes() as plan:
|
||||
res = self.client.get(
|
||||
reverse("authentik_api:flow-executor", kwargs={"flow_slug": flow.slug}),
|
||||
HTTP_X_AUTHENTIK_PLATFORM_AUTH_DTH=sha256(
|
||||
self.device_token.key.encode()
|
||||
).hexdigest(),
|
||||
)
|
||||
self.assertStageRedirects(res, reverse("authentik_core:root-redirect"))
|
||||
plan = plan()
|
||||
self.assertNotIn(PLAN_CONTEXT_AGENT_ENDPOINT_CHALLENGE, plan.context)
|
||||
self.assertEqual(plan.context[PLAN_CONTEXT_DEVICE], self.device)
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
"""Enterprise API Views"""
|
||||
|
||||
from collections.abc import Callable
|
||||
from datetime import timedelta
|
||||
from functools import wraps
|
||||
|
||||
from django.utils.timezone import now
|
||||
from django.utils.translation import gettext as _
|
||||
@@ -35,6 +37,18 @@ class EnterpriseRequiredMixin:
|
||||
return super().validate(attrs)
|
||||
|
||||
|
||||
def enterprise_action(func: Callable):
|
||||
"""Check permissions for a single custom action"""
|
||||
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs) -> Response:
|
||||
if not LicenseKey.cached_summary().status.is_valid:
|
||||
raise ValidationError(_("Enterprise is required to use this endpoint."))
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
class LicenseSerializer(ModelSerializer):
|
||||
"""License Serializer"""
|
||||
|
||||
|
||||
@@ -1,31 +1,20 @@
|
||||
from django.urls import reverse
|
||||
from django.utils.timezone import now
|
||||
from drf_spectacular.types import OpenApiTypes
|
||||
from drf_spectacular.utils import OpenApiParameter, OpenApiResponse, extend_schema
|
||||
from drf_spectacular.utils import extend_schema
|
||||
from rest_framework.decorators import action
|
||||
from rest_framework.exceptions import ValidationError
|
||||
from rest_framework.permissions import IsAuthenticated
|
||||
from rest_framework.request import Request
|
||||
from rest_framework.response import Response
|
||||
from structlog.stdlib import get_logger
|
||||
|
||||
from authentik.endpoints.connectors.agent.api.agent import (
|
||||
AgentAuthenticationResponse,
|
||||
AgentTokenResponseSerializer,
|
||||
)
|
||||
from authentik.endpoints.connectors.agent.auth import AgentAuth
|
||||
from authentik.endpoints.connectors.agent.models import (
|
||||
DeviceAuthenticationToken,
|
||||
DeviceToken,
|
||||
)
|
||||
from authentik.enterprise.endpoints.connectors.agent.auth import (
|
||||
DeviceAuthFedAuthentication,
|
||||
agent_auth_issue_token,
|
||||
check_device_policies,
|
||||
)
|
||||
from authentik.events.models import Event, EventAction
|
||||
from authentik.flows.planner import PLAN_CONTEXT_DEVICE
|
||||
from authentik.stages.password.stage import PLAN_CONTEXT_METHOD, PLAN_CONTEXT_METHOD_ARGS
|
||||
from authentik.enterprise.api import enterprise_action
|
||||
|
||||
LOGGER = get_logger()
|
||||
|
||||
@@ -37,6 +26,7 @@ class AgentConnectorViewSetMixin:
|
||||
responses=AgentAuthenticationResponse(),
|
||||
)
|
||||
@action(methods=["POST"], detail=False, authentication_classes=[AgentAuth])
|
||||
@enterprise_action
|
||||
def auth_ia(self, request: Request) -> Response:
|
||||
token: DeviceToken = request.auth
|
||||
auth_token = DeviceAuthenticationToken.objects.create(
|
||||
@@ -54,43 +44,3 @@ class AgentConnectorViewSetMixin:
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
@extend_schema(
|
||||
request=OpenApiTypes.NONE,
|
||||
parameters=[OpenApiParameter("device", OpenApiTypes.STR, location="query", required=True)],
|
||||
responses={
|
||||
200: AgentTokenResponseSerializer(),
|
||||
404: OpenApiResponse(description="Device not found"),
|
||||
},
|
||||
)
|
||||
@action(
|
||||
methods=["POST"],
|
||||
detail=False,
|
||||
pagination_class=None,
|
||||
filter_backends=[],
|
||||
permission_classes=[IsAuthenticated],
|
||||
authentication_classes=[DeviceAuthFedAuthentication],
|
||||
)
|
||||
def auth_fed(self, request: Request) -> Response:
|
||||
federated_token, device, connector = request.auth
|
||||
|
||||
policy_result = check_device_policies(device, federated_token.user, request._request)
|
||||
if not policy_result.passing:
|
||||
raise ValidationError(
|
||||
{"policy_result": "Policy denied access", "policy_messages": policy_result.messages}
|
||||
)
|
||||
|
||||
token, exp = agent_auth_issue_token(device, connector, federated_token.user)
|
||||
rel_exp = int((exp - now()).total_seconds())
|
||||
Event.new(
|
||||
EventAction.LOGIN,
|
||||
**{
|
||||
PLAN_CONTEXT_METHOD: "jwt",
|
||||
PLAN_CONTEXT_METHOD_ARGS: {
|
||||
"jwt": federated_token,
|
||||
"provider": federated_token.provider,
|
||||
},
|
||||
PLAN_CONTEXT_DEVICE: device,
|
||||
},
|
||||
).from_http(request, user=federated_token.user)
|
||||
return Response({"token": token, "expires_in": rel_exp})
|
||||
|
||||
@@ -1,113 +0,0 @@
|
||||
from django.http import HttpRequest
|
||||
from django.utils.timezone import now
|
||||
from drf_spectacular.extensions import OpenApiAuthenticationExtension
|
||||
from jwt import PyJWTError, decode, encode
|
||||
from rest_framework.authentication import BaseAuthentication
|
||||
from structlog.stdlib import get_logger
|
||||
|
||||
from authentik.api.authentication import get_authorization_header, validate_auth
|
||||
from authentik.core.models import User
|
||||
from authentik.crypto.apps import MANAGED_KEY
|
||||
from authentik.crypto.models import CertificateKeyPair
|
||||
from authentik.endpoints.connectors.agent.models import AgentConnector
|
||||
from authentik.endpoints.models import Device
|
||||
from authentik.lib.utils.time import timedelta_from_string
|
||||
from authentik.policies.engine import PolicyEngine
|
||||
from authentik.policies.models import PolicyBindingModel
|
||||
from authentik.providers.oauth2.models import AccessToken, JWTAlgorithms, OAuth2Provider
|
||||
|
||||
LOGGER = get_logger()
|
||||
PLATFORM_ISSUER = "goauthentik.io/platform"
|
||||
|
||||
|
||||
def agent_auth_issue_token(device: Device, connector: AgentConnector, user: User, **kwargs):
|
||||
kp = CertificateKeyPair.objects.filter(managed=MANAGED_KEY).first()
|
||||
if not kp:
|
||||
return None, None
|
||||
exp = now() + timedelta_from_string(connector.auth_session_duration)
|
||||
token = encode(
|
||||
{
|
||||
"iss": PLATFORM_ISSUER,
|
||||
"aud": str(device.pk),
|
||||
"iat": int(now().timestamp()),
|
||||
"exp": int(exp.timestamp()),
|
||||
"preferred_username": user.username,
|
||||
**kwargs,
|
||||
},
|
||||
kp.private_key,
|
||||
headers={
|
||||
"kid": kp.kid,
|
||||
},
|
||||
algorithm=JWTAlgorithms.from_private_key(kp.private_key),
|
||||
)
|
||||
return token, exp
|
||||
|
||||
|
||||
class DeviceAuthFedAuthentication(BaseAuthentication):
|
||||
|
||||
def authenticate(self, request):
|
||||
raw_token = validate_auth(get_authorization_header(request))
|
||||
if not raw_token:
|
||||
LOGGER.warning("Missing token")
|
||||
return None
|
||||
device = Device.filter_not_expired(name=request.query_params.get("device")).first()
|
||||
if not device:
|
||||
LOGGER.warning("Couldn't find device")
|
||||
return None
|
||||
connectors_for_device = AgentConnector.objects.filter(device__in=[device])
|
||||
connector = connectors_for_device.first()
|
||||
providers = OAuth2Provider.objects.filter(agentconnector__in=connectors_for_device)
|
||||
federated_token = AccessToken.objects.filter(
|
||||
token=raw_token, provider__in=providers
|
||||
).first()
|
||||
if not federated_token:
|
||||
LOGGER.warning("Couldn't lookup provider")
|
||||
return None
|
||||
_key, _alg = federated_token.provider.jwt_key
|
||||
try:
|
||||
decode(
|
||||
raw_token,
|
||||
_key.public_key(),
|
||||
algorithms=[_alg],
|
||||
options={
|
||||
"verify_aud": False,
|
||||
},
|
||||
)
|
||||
LOGGER.info(
|
||||
"successfully verified JWT with provider", provider=federated_token.provider.name
|
||||
)
|
||||
return (federated_token.user, (federated_token, device, connector))
|
||||
except (PyJWTError, ValueError, TypeError, AttributeError) as exc:
|
||||
LOGGER.warning("failed to verify JWT", exc=exc, provider=federated_token.provider.name)
|
||||
return None
|
||||
|
||||
|
||||
class DeviceFederationAuthSchema(OpenApiAuthenticationExtension):
|
||||
"""Auth schema"""
|
||||
|
||||
target_class = DeviceAuthFedAuthentication
|
||||
name = "device_federation"
|
||||
|
||||
def get_security_definition(self, auto_schema):
|
||||
"""Auth schema"""
|
||||
return {"type": "http", "scheme": "bearer"}
|
||||
|
||||
|
||||
def check_device_policies(device: Device, user: User, request: HttpRequest):
|
||||
"""Check policies bound to device group and device"""
|
||||
if device.access_group:
|
||||
result = check_pbm_policies(device.access_group, user, request)
|
||||
if result.passing:
|
||||
return result
|
||||
return check_pbm_policies(device, user, request)
|
||||
|
||||
|
||||
def check_pbm_policies(pbm: PolicyBindingModel, user: User, request: HttpRequest):
|
||||
policy_engine = PolicyEngine(pbm, user, request)
|
||||
policy_engine.use_cache = False
|
||||
policy_engine.empty_result = False
|
||||
policy_engine.mode = pbm.policy_engine_mode
|
||||
policy_engine.build()
|
||||
result = policy_engine.result
|
||||
LOGGER.debug("PolicyAccessView user_has_access", user=user.username, result=result, pbm=pbm.pk)
|
||||
return result
|
||||
@@ -63,8 +63,21 @@ class TestConnectorAuthIA(FlowTestCase):
|
||||
)
|
||||
self.assertEqual(response.status_code, 200)
|
||||
|
||||
@patch(
|
||||
"authentik.enterprise.license.LicenseKey.validate",
|
||||
MagicMock(
|
||||
return_value=LicenseKey(
|
||||
aud="",
|
||||
exp=expiry_valid,
|
||||
name=generate_id(),
|
||||
internal_users=100,
|
||||
external_users=100,
|
||||
)
|
||||
),
|
||||
)
|
||||
@reconcile_app("authentik_crypto")
|
||||
def test_auth_ia_fulfill(self):
|
||||
License.objects.create(key=generate_id())
|
||||
self.client.force_login(self.user)
|
||||
response = self.client.post(
|
||||
reverse("authentik_api:agentconnector-auth-ia"),
|
||||
|
||||
@@ -3,12 +3,13 @@ from hmac import compare_digest
|
||||
|
||||
from django.http import Http404, HttpRequest, HttpResponse, HttpResponseBadRequest, QueryDict
|
||||
|
||||
from authentik.endpoints.connectors.agent.models import AgentConnector, DeviceAuthenticationToken
|
||||
from authentik.endpoints.models import Device
|
||||
from authentik.enterprise.endpoints.connectors.agent.auth import (
|
||||
from authentik.endpoints.connectors.agent.auth import (
|
||||
agent_auth_issue_token,
|
||||
check_device_policies,
|
||||
)
|
||||
from authentik.endpoints.connectors.agent.models import AgentConnector, DeviceAuthenticationToken
|
||||
from authentik.endpoints.connectors.agent.stage import PLAN_CONTEXT_DEVICE_AUTH_TOKEN
|
||||
from authentik.endpoints.models import Device
|
||||
from authentik.enterprise.policy import EnterprisePolicyAccessView
|
||||
from authentik.flows.exceptions import FlowNonApplicableException
|
||||
from authentik.flows.models import in_memory_stage
|
||||
@@ -16,8 +17,6 @@ from authentik.flows.planner import PLAN_CONTEXT_DEVICE, FlowPlanner
|
||||
from authentik.flows.stage import StageView
|
||||
from authentik.providers.oauth2.utils import HttpResponseRedirectScheme
|
||||
|
||||
PLAN_CONTEXT_DEVICE_AUTH_TOKEN = "goauthentik.io/endpoints/device_auth_token" # nosec
|
||||
|
||||
QS_AGENT_IA_TOKEN = "ak-auth-ia-token" # nosec
|
||||
|
||||
|
||||
|
||||
@@ -4,37 +4,35 @@ from django.urls import reverse
|
||||
from drf_spectacular.utils import extend_schema
|
||||
from rest_framework import mixins
|
||||
from rest_framework.decorators import action
|
||||
from rest_framework.fields import CharField
|
||||
from rest_framework.fields import CharField, SerializerMethodField
|
||||
from rest_framework.permissions import BasePermission
|
||||
from rest_framework.request import Request
|
||||
from rest_framework.response import Response
|
||||
from rest_framework.viewsets import GenericViewSet
|
||||
|
||||
from authentik.core.api.groups import PartialUserSerializer
|
||||
from authentik.core.api.utils import ModelSerializer
|
||||
from authentik.core.models import User
|
||||
from authentik.enterprise.api import EnterpriseRequiredMixin
|
||||
from authentik.enterprise.reports.models import DataExport
|
||||
from authentik.enterprise.reports.tasks import generate_export
|
||||
from authentik.rbac.permissions import HasPermission
|
||||
|
||||
|
||||
class RequestedBySerializer(ModelSerializer):
|
||||
class Meta:
|
||||
model = User
|
||||
fields = ("pk", "username")
|
||||
|
||||
|
||||
class ContentTypeSerializer(ModelSerializer):
|
||||
app_label = CharField(read_only=True)
|
||||
model = CharField(read_only=True)
|
||||
verbose_name_plural = SerializerMethodField()
|
||||
|
||||
def get_verbose_name_plural(self, ct: ContentType) -> str:
|
||||
return ct.model_class()._meta.verbose_name_plural
|
||||
|
||||
class Meta:
|
||||
model = ContentType
|
||||
fields = ("id", "app_label", "model")
|
||||
fields = ("id", "app_label", "model", "verbose_name_plural")
|
||||
|
||||
|
||||
class DataExportSerializer(EnterpriseRequiredMixin, ModelSerializer):
|
||||
requested_by = RequestedBySerializer(read_only=True)
|
||||
requested_by = PartialUserSerializer(read_only=True)
|
||||
content_type = ContentTypeSerializer(read_only=True)
|
||||
|
||||
class Meta:
|
||||
|
||||
@@ -7,6 +7,7 @@ from django.db import connection
|
||||
from django.db.models import Model, Q
|
||||
from djangoql.compat import text_type
|
||||
from djangoql.schema import StrField
|
||||
from djangoql.serializers import DjangoQLSchemaSerializer
|
||||
|
||||
|
||||
class JSONSearchField(StrField):
|
||||
@@ -14,10 +15,18 @@ class JSONSearchField(StrField):
|
||||
|
||||
model: Model
|
||||
|
||||
def __init__(self, model=None, name=None, nullable=None, suggest_nested=True):
|
||||
def __init__(
|
||||
self,
|
||||
model=None,
|
||||
name=None,
|
||||
nullable=None,
|
||||
suggest_nested=False,
|
||||
fixed_structure: OrderedDict | None = None,
|
||||
):
|
||||
# Set this in the constructor to not clobber the type variable
|
||||
self.type = "relation"
|
||||
self.suggest_nested = suggest_nested
|
||||
self.fixed_structure = fixed_structure
|
||||
super().__init__(model, name, nullable)
|
||||
|
||||
def get_lookup(self, path, operator, value):
|
||||
@@ -57,11 +66,23 @@ class JSONSearchField(StrField):
|
||||
)
|
||||
return (x[0] for x in cursor.fetchall())
|
||||
|
||||
def get_nested_options(self) -> OrderedDict:
|
||||
def get_fixed_structure(self, serializer: DjangoQLSchemaSerializer) -> OrderedDict:
|
||||
new_dict = OrderedDict()
|
||||
if not self.fixed_structure:
|
||||
return new_dict
|
||||
new_dict.setdefault(self.relation(), {})
|
||||
for key, value in self.fixed_structure.items():
|
||||
new_dict[self.relation()][key] = serializer.serialize_field(value)
|
||||
if isinstance(value, JSONSearchField):
|
||||
new_dict.update(value.get_nested_options(serializer))
|
||||
return new_dict
|
||||
|
||||
def get_nested_options(self, serializer: DjangoQLSchemaSerializer) -> OrderedDict:
|
||||
"""Get keys of all nested objects to show autocomplete"""
|
||||
if not self.suggest_nested:
|
||||
if self.fixed_structure:
|
||||
return self.get_fixed_structure(serializer)
|
||||
return OrderedDict()
|
||||
base_model_name = f"{self.model._meta.app_label}.{self.model._meta.model_name}_{self.name}"
|
||||
|
||||
def recursive_function(parts: list[str], parent_parts: list[str] | None = None):
|
||||
if not parent_parts:
|
||||
@@ -87,7 +108,7 @@ class JSONSearchField(StrField):
|
||||
relation_structure = defaultdict(dict)
|
||||
|
||||
for relations in self.json_field_keys():
|
||||
result = recursive_function([base_model_name] + relations)
|
||||
result = recursive_function([self.relation()] + relations)
|
||||
for relation_key, value in result.items():
|
||||
for sub_relation_key, sub_value in value.items():
|
||||
if not relation_structure[relation_key].get(sub_relation_key, None):
|
||||
|
||||
@@ -12,7 +12,7 @@ class AKQLSchemaSerializer(DjangoQLSchemaSerializer):
|
||||
for _, field in fields.items():
|
||||
if not isinstance(field, JSONSearchField):
|
||||
continue
|
||||
serialization["models"].update(field.get_nested_options())
|
||||
serialization["models"].update(field.get_nested_options(self))
|
||||
return serialization
|
||||
|
||||
def serialize_field(self, field):
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
"""Events API Views"""
|
||||
|
||||
from collections import OrderedDict
|
||||
from datetime import timedelta
|
||||
|
||||
import django_filters
|
||||
@@ -136,7 +137,7 @@ class EventViewSet(
|
||||
filterset_class = EventsFilter
|
||||
|
||||
def get_ql_fields(self):
|
||||
from djangoql.schema import DateTimeField, StrField
|
||||
from djangoql.schema import DateTimeField, IntField, StrField
|
||||
|
||||
from authentik.enterprise.search.fields import ChoiceSearchField, JSONSearchField
|
||||
|
||||
@@ -145,9 +146,42 @@ class EventViewSet(
|
||||
StrField(Event, "event_uuid"),
|
||||
StrField(Event, "app", suggest_options=True),
|
||||
StrField(Event, "client_ip"),
|
||||
JSONSearchField(Event, "user", suggest_nested=False),
|
||||
JSONSearchField(Event, "brand", suggest_nested=False),
|
||||
JSONSearchField(Event, "context", suggest_nested=False),
|
||||
JSONSearchField(
|
||||
Event,
|
||||
"user",
|
||||
fixed_structure=OrderedDict(
|
||||
pk=IntField(),
|
||||
username=StrField(),
|
||||
email=StrField(),
|
||||
),
|
||||
),
|
||||
JSONSearchField(
|
||||
Event,
|
||||
"brand",
|
||||
fixed_structure=OrderedDict(
|
||||
pk=StrField(),
|
||||
app=StrField(),
|
||||
name=StrField(),
|
||||
model_name=StrField(),
|
||||
),
|
||||
),
|
||||
JSONSearchField(
|
||||
Event,
|
||||
"context",
|
||||
fixed_structure=OrderedDict(
|
||||
http_request=JSONSearchField(
|
||||
Event,
|
||||
"context_http_request",
|
||||
fixed_structure=OrderedDict(
|
||||
args=JSONSearchField(Event, "context_http_request_args"),
|
||||
path=StrField(),
|
||||
method=StrField(),
|
||||
request_id=StrField(),
|
||||
user_agent=StrField(),
|
||||
),
|
||||
),
|
||||
),
|
||||
),
|
||||
DateTimeField(Event, "created", suggest_options=True),
|
||||
]
|
||||
|
||||
|
||||
@@ -7,7 +7,7 @@ from typing import Any
|
||||
from django.utils.timezone import now
|
||||
from rest_framework.fields import CharField, ChoiceField, DateTimeField, DictField
|
||||
from structlog import configure, get_config
|
||||
from structlog.stdlib import NAME_TO_LEVEL, ProcessorFormatter
|
||||
from structlog.stdlib import NAME_TO_LEVEL, ProcessorFormatter, get_logger
|
||||
from structlog.testing import LogCapture
|
||||
from structlog.types import EventDict
|
||||
|
||||
@@ -36,6 +36,9 @@ class LogEvent:
|
||||
event, log_level, item.pop("logger"), timestamp, attributes=sanitize_dict(item)
|
||||
)
|
||||
|
||||
def log(self):
|
||||
get_logger(self.logger).log(NAME_TO_LEVEL[self.log_level], self.event, **self.attributes)
|
||||
|
||||
|
||||
class LogEventSerializer(PassiveSerializer):
|
||||
"""Single log message with all context logged."""
|
||||
|
||||
@@ -8,6 +8,8 @@ from inspect import currentframe
|
||||
from typing import Any
|
||||
from uuid import uuid4
|
||||
|
||||
from asgiref.sync import async_to_sync
|
||||
from channels.layers import get_channel_layer
|
||||
from django.apps import apps
|
||||
from django.db import models
|
||||
from django.http import HttpRequest
|
||||
@@ -41,6 +43,7 @@ from authentik.lib.utils.http import get_http_session
|
||||
from authentik.lib.utils.time import timedelta_from_string
|
||||
from authentik.policies.models import PolicyBindingModel
|
||||
from authentik.root.middleware import ClientIPMiddleware
|
||||
from authentik.root.ws.consumer import build_user_group
|
||||
from authentik.stages.email.models import EmailTemplates
|
||||
from authentik.stages.email.utils import TemplateEmailMessage
|
||||
from authentik.tasks.models import TasksModel
|
||||
@@ -361,6 +364,15 @@ class NotificationTransport(TasksModel, SerializerModel):
|
||||
notification=notification,
|
||||
)
|
||||
notification.save()
|
||||
layer = get_channel_layer()
|
||||
async_to_sync(layer.group_send)(
|
||||
build_user_group(notification.user),
|
||||
{
|
||||
"type": "event.notification",
|
||||
"id": str(notification.pk),
|
||||
"data": notification.serializer(notification).data,
|
||||
},
|
||||
)
|
||||
return []
|
||||
|
||||
def send_webhook(self, notification: "Notification") -> list[str]:
|
||||
|
||||
@@ -48,6 +48,14 @@ class FlowTestCase(APITestCase):
|
||||
self.assertEqual(raw_response[key], expected)
|
||||
return raw_response
|
||||
|
||||
def get_flow_plan(self) -> FlowPlan | None:
|
||||
return self.client.session.get(SESSION_KEY_PLAN)
|
||||
|
||||
def set_flow_plan(self, plan: FlowPlan):
|
||||
session = self.client.session
|
||||
session[SESSION_KEY_PLAN] = plan
|
||||
session.save()
|
||||
|
||||
def assertStageRedirects(self, response: HttpResponse, to: str) -> dict[str, Any]:
|
||||
"""Wrapper around assertStageResponse that checks for a redirect"""
|
||||
return self.assertStageResponse(response, component="xak-flow-redirect", to=to)
|
||||
|
||||
@@ -84,7 +84,7 @@ class OutgoingSyncProvider(ScheduledModel, Model):
|
||||
raise NotImplementedError
|
||||
|
||||
def sync_dispatch(self) -> None:
|
||||
for schedule in self.schedules:
|
||||
for schedule in self.schedules.all():
|
||||
schedule.send()
|
||||
|
||||
@property
|
||||
|
||||
@@ -1,16 +1,17 @@
|
||||
"""authentik database utilities"""
|
||||
|
||||
import gc
|
||||
from collections.abc import Generator
|
||||
|
||||
from django.db import reset_queries
|
||||
from django.db.models import QuerySet
|
||||
from django.db.models import Model, QuerySet
|
||||
|
||||
|
||||
def chunked_queryset(queryset: QuerySet, chunk_size: int = 1_000):
|
||||
def chunked_queryset[T: Model](queryset: QuerySet[T], chunk_size: int = 1_000) -> Generator[T]:
|
||||
if not queryset.exists():
|
||||
return []
|
||||
|
||||
def get_chunks(qs: QuerySet):
|
||||
def get_chunks(qs: QuerySet) -> Generator[QuerySet[T]]:
|
||||
qs = qs.order_by("pk")
|
||||
pks = qs.values_list("pk", flat=True)
|
||||
start_pk = pks[0]
|
||||
|
||||
@@ -86,7 +86,7 @@ class OutpostConfig:
|
||||
class OutpostModel(Model):
|
||||
"""Base model for providers that need more objects than just themselves"""
|
||||
|
||||
def get_required_objects(self) -> Iterable[models.Model | str]:
|
||||
def get_required_objects(self) -> Iterable[models.Model | str | tuple[str, models.Model]]:
|
||||
"""Return a list of all required objects"""
|
||||
return [self]
|
||||
|
||||
@@ -332,41 +332,35 @@ class Outpost(ScheduledModel, SerializerModel, ManagedModel):
|
||||
"""Create per-object and global permissions for outpost service-account"""
|
||||
# To ensure the user only has the correct permissions, we delete all of them and re-add
|
||||
# the ones the user needs
|
||||
with transaction.atomic():
|
||||
user.remove_all_perms_from_managed_role()
|
||||
for model_or_perm in self.get_required_objects():
|
||||
if isinstance(model_or_perm, models.Model):
|
||||
model_or_perm: models.Model
|
||||
code_name = (
|
||||
f"{model_or_perm._meta.app_label}.view_{model_or_perm._meta.model_name}"
|
||||
)
|
||||
try:
|
||||
user.assign_perms_to_managed_role(code_name, model_or_perm)
|
||||
except (Permission.DoesNotExist, AttributeError) as exc:
|
||||
LOGGER.warning(
|
||||
"permission doesn't exist",
|
||||
code_name=code_name,
|
||||
user=user,
|
||||
model=model_or_perm,
|
||||
try:
|
||||
with transaction.atomic():
|
||||
user.remove_all_perms_from_managed_role()
|
||||
for model_or_perm in self.get_required_objects():
|
||||
if isinstance(model_or_perm, models.Model):
|
||||
code_name = (
|
||||
f"{model_or_perm._meta.app_label}.view_{model_or_perm._meta.model_name}"
|
||||
)
|
||||
Event.new(
|
||||
action=EventAction.SYSTEM_EXCEPTION,
|
||||
message=(
|
||||
"While setting the permissions for the service-account, a "
|
||||
"permission was not found: Check "
|
||||
"https://docs.goauthentik.io/troubleshooting/missing_permission"
|
||||
),
|
||||
).with_exception(exc).set_user(user).save()
|
||||
else:
|
||||
app_label, perm = model_or_perm.split(".")
|
||||
permission = Permission.objects.filter(
|
||||
codename=perm,
|
||||
content_type__app_label=app_label,
|
||||
)
|
||||
if not permission.exists():
|
||||
LOGGER.warning("permission doesn't exist", perm=model_or_perm)
|
||||
continue
|
||||
user.assign_perms_to_managed_role(permission.first())
|
||||
user.assign_perms_to_managed_role(code_name, model_or_perm)
|
||||
elif isinstance(model_or_perm, tuple):
|
||||
perm, obj = model_or_perm
|
||||
user.assign_perms_to_managed_role(perm, obj)
|
||||
else:
|
||||
user.assign_perms_to_managed_role(model_or_perm)
|
||||
except (Permission.DoesNotExist, AttributeError) as exc:
|
||||
LOGGER.warning(
|
||||
"permission doesn't exist",
|
||||
code_name=code_name,
|
||||
user=user,
|
||||
model=model_or_perm,
|
||||
)
|
||||
Event.new(
|
||||
action=EventAction.SYSTEM_EXCEPTION,
|
||||
message=(
|
||||
"While setting the permissions for the service-account, a "
|
||||
"permission was not found: Check "
|
||||
"https://docs.goauthentik.io/troubleshooting/missing_permission"
|
||||
),
|
||||
).with_exception(exc).set_user(user).save()
|
||||
LOGGER.debug(
|
||||
"Updated service account's permissions",
|
||||
obj_perms=user.get_all_obj_perms_on_managed_role(),
|
||||
@@ -431,7 +425,7 @@ class Outpost(ScheduledModel, SerializerModel, ManagedModel):
|
||||
Token.objects.filter(identifier=self.token_identifier).delete()
|
||||
return self.token
|
||||
|
||||
def get_required_objects(self) -> Iterable[models.Model | str]:
|
||||
def get_required_objects(self) -> Iterable[models.Model | str | tuple[str, models.Model]]:
|
||||
"""Get an iterator of all objects the user needs read access to"""
|
||||
objects: list[models.Model | str] = [
|
||||
self,
|
||||
@@ -445,7 +439,9 @@ class Outpost(ScheduledModel, SerializerModel, ManagedModel):
|
||||
if self.managed:
|
||||
for brand in Brand.objects.filter(web_certificate__isnull=False):
|
||||
objects.append(brand)
|
||||
objects.append(brand.web_certificate)
|
||||
objects.append(("view_certificatekeypair", brand.web_certificate))
|
||||
objects.append(("view_certificatekeypair_certificate", brand.web_certificate))
|
||||
objects.append(("view_certificatekeypair_key", brand.web_certificate))
|
||||
return objects
|
||||
|
||||
def __str__(self) -> str:
|
||||
|
||||
@@ -51,10 +51,12 @@ class OutpostTests(TestCase):
|
||||
permissions = outpost.user.get_all_obj_perms_on_managed_role().order_by(
|
||||
"content_type__model"
|
||||
)
|
||||
self.assertEqual(len(permissions), 3)
|
||||
self.assertEqual(len(permissions), 5)
|
||||
self.assertEqual(permissions[0].object_pk, str(keypair.pk))
|
||||
self.assertEqual(permissions[1].object_pk, str(outpost.pk))
|
||||
self.assertEqual(permissions[2].object_pk, str(provider.pk))
|
||||
self.assertEqual(permissions[1].object_pk, str(keypair.pk))
|
||||
self.assertEqual(permissions[2].object_pk, str(keypair.pk))
|
||||
self.assertEqual(permissions[3].object_pk, str(outpost.pk))
|
||||
self.assertEqual(permissions[4].object_pk, str(provider.pk))
|
||||
|
||||
# Remove provider from outpost, user should only have access to outpost
|
||||
outpost.providers.remove(provider)
|
||||
|
||||
@@ -93,11 +93,13 @@ class LDAPProvider(OutpostModel, BackchannelProvider):
|
||||
def __str__(self):
|
||||
return f"LDAP Provider {self.name}"
|
||||
|
||||
def get_required_objects(self) -> Iterable[models.Model | str]:
|
||||
required_models = [self, "authentik_core.view_user", "authentik_core.view_group"]
|
||||
def get_required_objects(self) -> Iterable[models.Model | str | tuple[str, models.Model]]:
|
||||
required = [self, "authentik_core.view_user", "authentik_core.view_group"]
|
||||
if self.certificate is not None:
|
||||
required_models.append(self.certificate)
|
||||
return required_models
|
||||
required.append(("view_certificatekeypair", self.certificate))
|
||||
required.append(("view_certificatekeypair_certificate", self.certificate))
|
||||
required.append(("view_certificatekeypair_key", self.certificate))
|
||||
return required
|
||||
|
||||
class Meta:
|
||||
verbose_name = _("LDAP Provider")
|
||||
|
||||
@@ -152,7 +152,7 @@ class IDToken:
|
||||
final = self.to_dict()
|
||||
final["azp"] = provider.client_id
|
||||
final["uid"] = generate_id()
|
||||
final["scope"] = " ".join(token.scope)
|
||||
final.setdefault("scope", " ".join(token.scope))
|
||||
return provider.encode(final)
|
||||
|
||||
def to_jwt(self, provider: "OAuth2Provider") -> str:
|
||||
|
||||
@@ -436,3 +436,57 @@ class TestToken(OAuthTestCase):
|
||||
},
|
||||
)
|
||||
self.validate_jwt(access, provider)
|
||||
|
||||
@apply_blueprint("system/providers-oauth2.yaml")
|
||||
def test_scope_claim_override_via_property_mapping(self):
|
||||
"""Test that property mappings can override the scope claim in access tokens.
|
||||
|
||||
See: https://github.com/goauthentik/authentik/issues/19224
|
||||
"""
|
||||
# Create a custom scope mapping that returns a custom scope claim
|
||||
custom_scope_mapping = ScopeMapping.objects.create(
|
||||
name="custom-scope-override",
|
||||
scope_name="custom",
|
||||
expression='return {"scope": "custom-scope-value additional-scope"}',
|
||||
)
|
||||
|
||||
provider = OAuth2Provider.objects.create(
|
||||
name=generate_id(),
|
||||
authorization_flow=create_test_flow(),
|
||||
redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://local.invalid")],
|
||||
signing_key=self.keypair,
|
||||
include_claims_in_id_token=True,
|
||||
)
|
||||
provider.property_mappings.add(custom_scope_mapping)
|
||||
|
||||
# Needs to be assigned to an application for iss to be set
|
||||
self.app.provider = provider
|
||||
self.app.save()
|
||||
|
||||
header = b64encode(f"{provider.client_id}:{provider.client_secret}".encode()).decode()
|
||||
user = create_test_admin_user()
|
||||
code = AuthorizationCode.objects.create(
|
||||
code="foobar",
|
||||
provider=provider,
|
||||
user=user,
|
||||
auth_time=timezone.now(),
|
||||
_scope="openid custom", # Request the custom scope
|
||||
)
|
||||
|
||||
response = self.client.post(
|
||||
reverse("authentik_providers_oauth2:token"),
|
||||
data={
|
||||
"grant_type": GRANT_TYPE_AUTHORIZATION_CODE,
|
||||
"code": code.code,
|
||||
"redirect_uri": "http://local.invalid",
|
||||
},
|
||||
HTTP_AUTHORIZATION=f"Basic {header}",
|
||||
)
|
||||
self.assertEqual(response.status_code, 200)
|
||||
|
||||
access: AccessToken = AccessToken.objects.filter(user=user, provider=provider).first()
|
||||
jwt_data = self.validate_jwt(access, provider)
|
||||
|
||||
# The scope should be the custom value from the property mapping,
|
||||
# not the default "openid custom"
|
||||
self.assertEqual(jwt_data["scope"], "custom-scope-value additional-scope")
|
||||
|
||||
@@ -179,11 +179,13 @@ class ProxyProvider(OutpostModel, OAuth2Provider):
|
||||
def __str__(self):
|
||||
return f"Proxy Provider {self.name}"
|
||||
|
||||
def get_required_objects(self) -> Iterable[models.Model | str]:
|
||||
required_models = [self]
|
||||
def get_required_objects(self) -> Iterable[models.Model | str | tuple[str, models.Model]]:
|
||||
required = [self]
|
||||
if self.certificate is not None:
|
||||
required_models.append(self.certificate)
|
||||
return required_models
|
||||
required.append(("view_certificatekeypair", self.certificate))
|
||||
required.append(("view_certificatekeypair_certificate", self.certificate))
|
||||
required.append(("view_certificatekeypair_key", self.certificate))
|
||||
return required
|
||||
|
||||
class Meta:
|
||||
verbose_name = _("Proxy Provider")
|
||||
|
||||
@@ -1,10 +1,14 @@
|
||||
"""proxy provider tests"""
|
||||
|
||||
from json import loads
|
||||
|
||||
from django.urls import reverse
|
||||
from rest_framework.test import APITestCase
|
||||
|
||||
from authentik.core.tests.utils import create_test_admin_user, create_test_flow
|
||||
from authentik.core.models import Application
|
||||
from authentik.core.tests.utils import create_test_admin_user, create_test_cert, create_test_flow
|
||||
from authentik.lib.generators import generate_id
|
||||
from authentik.outposts.models import Outpost, OutpostType
|
||||
from authentik.providers.oauth2.models import ClientTypes
|
||||
from authentik.providers.proxy.models import ProxyMode, ProxyProvider
|
||||
|
||||
@@ -127,3 +131,55 @@ class ProxyProviderTests(APITestCase):
|
||||
self.assertEqual(response.status_code, 200)
|
||||
provider: ProxyProvider = ProxyProvider.objects.get(name=name)
|
||||
self.assertEqual(provider.client_type, ClientTypes.CONFIDENTIAL)
|
||||
|
||||
def test_sa_fetch(self):
|
||||
"""Test fetching the outpost config as the service account"""
|
||||
outpost = Outpost.objects.create(name=generate_id(), type=OutpostType.PROXY)
|
||||
provider = ProxyProvider.objects.create(name=generate_id())
|
||||
Application.objects.create(name=generate_id(), slug=generate_id(), provider=provider)
|
||||
outpost.providers.add(provider)
|
||||
|
||||
res = self.client.get(
|
||||
reverse("authentik_api:proxyprovideroutpost-list"),
|
||||
HTTP_AUTHORIZATION=f"Bearer {outpost.token.key}",
|
||||
)
|
||||
body = loads(res.content)
|
||||
self.assertEqual(body["pagination"]["count"], 1)
|
||||
|
||||
def test_sa_perms_cert(self):
|
||||
"""Test permissions to access a configured certificate"""
|
||||
cert = create_test_cert()
|
||||
outpost = Outpost.objects.create(name=generate_id(), type=OutpostType.PROXY)
|
||||
provider = ProxyProvider.objects.create(name=generate_id(), certificate=cert)
|
||||
Application.objects.create(name=generate_id(), slug=generate_id(), provider=provider)
|
||||
outpost.providers.add(provider)
|
||||
|
||||
res = self.client.get(
|
||||
reverse("authentik_api:proxyprovideroutpost-list"),
|
||||
HTTP_AUTHORIZATION=f"Bearer {outpost.token.key}",
|
||||
)
|
||||
body = loads(res.content)
|
||||
self.assertEqual(body["pagination"]["count"], 1)
|
||||
cert_id = body["results"][0]["certificate"]
|
||||
self.assertEqual(cert_id, str(cert.pk))
|
||||
|
||||
res = self.client.get(
|
||||
reverse(
|
||||
"authentik_api:certificatekeypair-view-certificate",
|
||||
kwargs={
|
||||
"pk": cert_id,
|
||||
},
|
||||
),
|
||||
HTTP_AUTHORIZATION=f"Bearer {outpost.token.key}",
|
||||
)
|
||||
self.assertEqual(res.status_code, 200)
|
||||
# res = self.client.get(
|
||||
# reverse(
|
||||
# "authentik_api:certificatekeypair-view-private-key",
|
||||
# kwargs={
|
||||
# "pk": cert_id,
|
||||
# },
|
||||
# ),
|
||||
# HTTP_AUTHORIZATION=f"Bearer {outpost.token.key}",
|
||||
# )
|
||||
# self.assertEqual(res.status_code, 200)
|
||||
|
||||
@@ -64,10 +64,12 @@ class RadiusProvider(OutpostModel, Provider):
|
||||
|
||||
return RadiusProviderSerializer
|
||||
|
||||
def get_required_objects(self) -> Iterable[models.Model | str]:
|
||||
def get_required_objects(self) -> Iterable[models.Model | str | tuple[str, models.Model]]:
|
||||
required = [self, "authentik_stages_mtls.pass_outpost_certificate"]
|
||||
if self.certificate is not None:
|
||||
required.append(self.certificate)
|
||||
required.append(("view_certificatekeypair", self.certificate))
|
||||
required.append(("view_certificatekeypair_certificate", self.certificate))
|
||||
required.append(("view_certificatekeypair_key", self.certificate))
|
||||
return required
|
||||
|
||||
def __str__(self):
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
from django.contrib.auth.models import Permission
|
||||
from django.http import Http404
|
||||
from django_filters.filters import AllValuesMultipleFilter, BooleanFilter
|
||||
from django_filters.filters import AllValuesMultipleFilter, BooleanFilter, CharFilter, NumberFilter
|
||||
from django_filters.filterset import FilterSet
|
||||
from drf_spectacular.types import OpenApiTypes
|
||||
from drf_spectacular.utils import OpenApiResponse, extend_schema, extend_schema_field
|
||||
@@ -22,7 +22,7 @@ from authentik.blueprints.api import ManagedSerializer
|
||||
from authentik.blueprints.v1.importer import SERIALIZER_CONTEXT_BLUEPRINT
|
||||
from authentik.core.api.used_by import UsedByMixin
|
||||
from authentik.core.api.utils import ModelSerializer, PassiveSerializer
|
||||
from authentik.core.models import User
|
||||
from authentik.core.models import Group, User
|
||||
from authentik.rbac.decorators import permission_required
|
||||
from authentik.rbac.models import Role, get_permission_choices
|
||||
|
||||
@@ -65,15 +65,63 @@ class RoleSerializer(ManagedSerializer, ModelSerializer):
|
||||
|
||||
|
||||
class RoleFilterSet(FilterSet):
|
||||
"""Filter for PropertyMapping"""
|
||||
"""Filter for Role"""
|
||||
|
||||
managed = extend_schema_field(OpenApiTypes.STR)(AllValuesMultipleFilter(field_name="managed"))
|
||||
|
||||
managed__isnull = BooleanFilter(field_name="managed", lookup_expr="isnull")
|
||||
|
||||
inherited = BooleanFilter(
|
||||
method="filter_inherited",
|
||||
label="Include inherited roles (requires users or ak_groups filter)",
|
||||
)
|
||||
|
||||
users = extend_schema_field(OpenApiTypes.INT)(
|
||||
NumberFilter(
|
||||
method="filter_users",
|
||||
label="Filter by user (use with inherited=true for all roles)",
|
||||
)
|
||||
)
|
||||
|
||||
ak_groups = extend_schema_field(OpenApiTypes.UUID)(
|
||||
CharFilter(
|
||||
method="filter_ak_groups",
|
||||
label="Filter by group (use with inherited=true for all roles)",
|
||||
)
|
||||
)
|
||||
|
||||
def filter_inherited(self, queryset, name, value):
|
||||
"""This filter is handled by filter_users and filter_ak_groups"""
|
||||
return queryset
|
||||
|
||||
def filter_users(self, queryset, name, value):
|
||||
"""Filter roles by user, optionally including inherited roles"""
|
||||
user = User.objects.filter(pk=value).first()
|
||||
if not user:
|
||||
return queryset.none()
|
||||
|
||||
include_inherited = self.data.get("inherited", "").lower() == "true"
|
||||
if include_inherited:
|
||||
return user.all_roles()
|
||||
return queryset.filter(users=user)
|
||||
|
||||
def filter_ak_groups(self, queryset, name, value):
|
||||
"""Filter roles by group, optionally including inherited roles"""
|
||||
group = Group.objects.filter(pk=value).first()
|
||||
if not group:
|
||||
return queryset.none()
|
||||
|
||||
include_inherited = self.data.get("inherited", "").lower() == "true"
|
||||
if include_inherited:
|
||||
return group.all_roles()
|
||||
return queryset.filter(ak_groups=group)
|
||||
|
||||
class Meta:
|
||||
model = Role
|
||||
fields = ["name", "users", "managed"]
|
||||
fields = [
|
||||
"name",
|
||||
"managed",
|
||||
]
|
||||
|
||||
|
||||
class RoleViewSet(UsedByMixin, ModelViewSet):
|
||||
|
||||
@@ -50,7 +50,7 @@ def get_user(scope):
|
||||
"Cannot find session in scope. You should wrap your consumer in SessionMiddleware."
|
||||
)
|
||||
user = None
|
||||
if (authenticated_session := scope["session"].get("authenticated_session", None)) is not None:
|
||||
if (authenticated_session := scope["session"].get("authenticatedsession", None)) is not None:
|
||||
user = authenticated_session.user
|
||||
return user or AnonymousUser()
|
||||
|
||||
|
||||
@@ -190,6 +190,7 @@ SPECTACULAR_SETTINGS = {
|
||||
"PKCEMethodEnum": "authentik.sources.oauth.models.PKCEMethod",
|
||||
"DeviceFactsOSFamily": "authentik.endpoints.facts.OSFamily",
|
||||
"StageModeEnum": "authentik.endpoints.models.StageMode",
|
||||
"LicenseSummaryStatusEnum": "authentik.enterprise.models.LicenseUsageStatus",
|
||||
},
|
||||
"ENUM_ADD_EXPLICIT_BLANK_NULL_CHOICE": False,
|
||||
"ENUM_GENERATE_CHOICE_DESCRIPTION": False,
|
||||
|
||||
@@ -96,6 +96,9 @@ class PytestTestRunner(DiscoverRunner): # pragma: no cover
|
||||
def add_arguments(cls, parser: ArgumentParser):
|
||||
"""Add more pytest-specific arguments"""
|
||||
DiscoverRunner.add_arguments(parser)
|
||||
default_seed = None
|
||||
if seed := os.getenv("CI_TEST_SEED"):
|
||||
default_seed = int(seed)
|
||||
parser.add_argument(
|
||||
"--randomly-seed",
|
||||
type=int,
|
||||
@@ -103,6 +106,7 @@ class PytestTestRunner(DiscoverRunner): # pragma: no cover
|
||||
"to reuse the seed from the previous run."
|
||||
"Default behaviour: use random.Random().getrandbits(32), so the seed is"
|
||||
"different on each run.",
|
||||
default=default_seed,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--no-capture",
|
||||
|
||||
0
authentik/root/tests/__init__.py
Normal file
0
authentik/root/tests/__init__.py
Normal file
115
authentik/root/tests/test_ws_client.py
Normal file
115
authentik/root/tests/test_ws_client.py
Normal file
@@ -0,0 +1,115 @@
|
||||
from unittest.mock import patch
|
||||
|
||||
from asgiref.sync import sync_to_async
|
||||
from channels.routing import URLRouter
|
||||
from channels.testing import WebsocketCommunicator
|
||||
from django.http import HttpRequest
|
||||
from django.test import TransactionTestCase
|
||||
|
||||
from authentik.core.tests.utils import create_test_user
|
||||
from authentik.events.models import (
|
||||
Event,
|
||||
EventAction,
|
||||
Notification,
|
||||
NotificationTransport,
|
||||
TransportMode,
|
||||
)
|
||||
from authentik.flows.apps import RefreshOtherFlowsAfterAuthentication
|
||||
from authentik.lib.generators import generate_id
|
||||
from authentik.root import websocket
|
||||
from authentik.stages.password import BACKEND_INBUILT
|
||||
from authentik.stages.user_login.stage import COOKIE_NAME_KNOWN_DEVICE
|
||||
from authentik.tenants.utils import get_current_tenant
|
||||
|
||||
|
||||
class TestClientWS(TransactionTestCase):
|
||||
|
||||
def setUp(self):
|
||||
tenant = get_current_tenant()
|
||||
tenant.flags[RefreshOtherFlowsAfterAuthentication().key] = True
|
||||
tenant.save()
|
||||
self.user = create_test_user()
|
||||
|
||||
async def _alogin_cookie(self, user, **kwargs):
|
||||
"""Similar to `client.aforce_login` but allow setting of cookies"""
|
||||
from django.contrib.auth import alogin
|
||||
|
||||
# Create a fake request to store login details.
|
||||
request = HttpRequest()
|
||||
session = await self.client.asession()
|
||||
request.session = session
|
||||
request.COOKIES.update(kwargs)
|
||||
|
||||
await alogin(request, user, BACKEND_INBUILT)
|
||||
# Save the session values.
|
||||
await request.session.asave()
|
||||
self.client._set_login_cookies(request)
|
||||
|
||||
async def test_auth_blank(self):
|
||||
dev_id = generate_id()
|
||||
communicator = WebsocketCommunicator(
|
||||
URLRouter(websocket.websocket_urlpatterns),
|
||||
"/ws/client/",
|
||||
headers=[(b"cookie", f"{COOKIE_NAME_KNOWN_DEVICE}={dev_id}".encode())],
|
||||
)
|
||||
connected, _ = await communicator.connect()
|
||||
self.assertTrue(connected)
|
||||
|
||||
await self._alogin_cookie(self.user, **{COOKIE_NAME_KNOWN_DEVICE: dev_id})
|
||||
|
||||
await communicator.receive_nothing()
|
||||
await communicator.receive_json_from()
|
||||
await communicator.disconnect()
|
||||
|
||||
async def test_tab_refresh(self):
|
||||
dev_id = generate_id()
|
||||
communicator = WebsocketCommunicator(
|
||||
URLRouter(websocket.websocket_urlpatterns),
|
||||
"/ws/client/",
|
||||
headers=[(b"cookie", f"{COOKIE_NAME_KNOWN_DEVICE}={dev_id}".encode())],
|
||||
)
|
||||
connected, _ = await communicator.connect()
|
||||
self.assertTrue(connected)
|
||||
|
||||
with patch("authentik.flows.apps.RefreshOtherFlowsAfterAuthentication.get") as flag:
|
||||
flag.return_value = True
|
||||
await self._alogin_cookie(self.user, **{COOKIE_NAME_KNOWN_DEVICE: dev_id})
|
||||
|
||||
evt = await communicator.receive_json_from()
|
||||
self.assertEqual(
|
||||
evt, {"message_type": "session.authenticated", "type": "event.session.authenticated"}
|
||||
)
|
||||
|
||||
await communicator.disconnect()
|
||||
|
||||
async def test_notification(self):
|
||||
communicator = WebsocketCommunicator(
|
||||
URLRouter(websocket.websocket_urlpatterns), "/ws/client/"
|
||||
)
|
||||
communicator.scope["user"] = self.user
|
||||
connected, _ = await communicator.connect()
|
||||
self.assertTrue(connected)
|
||||
|
||||
transport = await NotificationTransport.objects.acreate(
|
||||
name=generate_id(), mode=TransportMode.LOCAL
|
||||
)
|
||||
event = await sync_to_async(Event.new)(EventAction.LOGIN)
|
||||
event.set_user(self.user)
|
||||
await event.asave()
|
||||
notification = Notification(
|
||||
user=self.user,
|
||||
body="foo",
|
||||
event=event,
|
||||
hyperlink="goauthentik.io",
|
||||
hyperlink_label="a link",
|
||||
)
|
||||
await sync_to_async(transport.send_local)(notification)
|
||||
|
||||
evt = await communicator.receive_json_from(timeout=5)
|
||||
self.assertEqual(evt["message_type"], "notification.new")
|
||||
self.assertEqual(evt["id"], str(notification.pk))
|
||||
self.assertEqual(evt["data"]["pk"], str(notification.pk))
|
||||
self.assertEqual(evt["data"]["body"], "foo")
|
||||
self.assertEqual(evt["data"]["event"]["pk"], str(event.pk))
|
||||
|
||||
await communicator.disconnect()
|
||||
@@ -7,6 +7,7 @@ from channels.generic.websocket import JsonWebsocketConsumer
|
||||
from django.core.cache import cache
|
||||
from django.db import connection
|
||||
|
||||
from authentik.core.models import User
|
||||
from authentik.root.ws.storage import CACHE_PREFIX
|
||||
|
||||
|
||||
@@ -16,24 +17,34 @@ def build_session_group(session_key: str):
|
||||
).hexdigest()
|
||||
|
||||
|
||||
def build_device_group(session_key: str):
|
||||
def build_device_group(device_id: str):
|
||||
return sha256(
|
||||
f"{connection.schema_name}/group_client_device_{str(session_key)}".encode()
|
||||
f"{connection.schema_name}/group_client_device_{str(device_id)}".encode()
|
||||
).hexdigest()
|
||||
|
||||
|
||||
def build_user_group(user: User):
|
||||
return sha256(f"{connection.schema_name}/group_client_user_{user.uuid}".encode()).hexdigest()
|
||||
|
||||
|
||||
class MessageConsumer(JsonWebsocketConsumer):
|
||||
"""Consumer which sends django.contrib.messages Messages over WS.
|
||||
channel_name is saved into cache with user_id, and when a add_message is called"""
|
||||
|
||||
session_key: str
|
||||
device_cookie: str | None = None
|
||||
user: User | None = None
|
||||
|
||||
def connect(self):
|
||||
self.accept()
|
||||
self.session_key = self.scope["session"].session_key
|
||||
if self.session_key:
|
||||
cache.set(f"{CACHE_PREFIX}{self.session_key}_messages_{self.channel_name}", True, None)
|
||||
if user := self.scope.get("user"):
|
||||
if user.is_authenticated:
|
||||
async_to_sync(self.channel_layer.group_add)(
|
||||
build_user_group(user), self.channel_name
|
||||
)
|
||||
if device_cookie := self.scope["cookies"].get("authentik_device", None):
|
||||
self.device_cookie = device_cookie
|
||||
async_to_sync(self.channel_layer.group_add)(
|
||||
@@ -47,6 +58,10 @@ class MessageConsumer(JsonWebsocketConsumer):
|
||||
async_to_sync(self.channel_layer.group_discard)(
|
||||
build_device_group(self.device_cookie), self.channel_name
|
||||
)
|
||||
if self.user:
|
||||
async_to_sync(self.channel_layer.group_discard)(
|
||||
build_user_group(self.user), self.channel_name
|
||||
)
|
||||
|
||||
def event_message(self, event: dict):
|
||||
"""Event handler which is called by Messages Storage backend"""
|
||||
@@ -54,4 +69,8 @@ class MessageConsumer(JsonWebsocketConsumer):
|
||||
|
||||
def event_session_authenticated(self, event: dict):
|
||||
"""Event handler post user authentication"""
|
||||
self.send_json({"message_type": "session.authenticated"})
|
||||
self.send_json({"message_type": "session.authenticated", **event})
|
||||
|
||||
def event_notification(self, event: dict):
|
||||
"""Event handler for new notifications"""
|
||||
self.send_json({"message_type": "notification.new", **event})
|
||||
|
||||
@@ -6,7 +6,7 @@ from django.http.request import QueryDict
|
||||
from django.template.exceptions import TemplateSyntaxError
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
from rest_framework.exceptions import ValidationError
|
||||
from rest_framework.fields import BooleanField, CharField, IntegerField
|
||||
from rest_framework.fields import BooleanField, CharField
|
||||
|
||||
from authentik.events.models import Event, EventAction
|
||||
from authentik.flows.challenge import (
|
||||
@@ -47,7 +47,7 @@ class AuthenticatorEmailChallengeResponse(ChallengeResponse):
|
||||
|
||||
device: EmailDevice
|
||||
|
||||
code = IntegerField(required=False)
|
||||
code = CharField(required=False)
|
||||
email = CharField(required=False)
|
||||
|
||||
component = CharField(default="ak-stage-authenticator-email")
|
||||
|
||||
@@ -5,7 +5,7 @@ from django.http import HttpRequest, HttpResponse
|
||||
from django.http.request import QueryDict
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
from rest_framework.exceptions import ValidationError
|
||||
from rest_framework.fields import BooleanField, CharField, IntegerField
|
||||
from rest_framework.fields import BooleanField, CharField
|
||||
|
||||
from authentik.flows.challenge import (
|
||||
Challenge,
|
||||
@@ -38,7 +38,7 @@ class AuthenticatorSMSChallengeResponse(ChallengeResponse):
|
||||
|
||||
device: SMSDevice
|
||||
|
||||
code = IntegerField(required=False)
|
||||
code = CharField(required=False)
|
||||
phone_number = CharField(required=False)
|
||||
|
||||
component = CharField(default="ak-stage-authenticator-sms")
|
||||
|
||||
@@ -0,0 +1,29 @@
|
||||
# Generated by Django 5.2.9 on 2026-01-06 23:52
|
||||
|
||||
import django.core.validators
|
||||
from django.db import migrations, models
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
|
||||
dependencies = [
|
||||
(
|
||||
"authentik_stages_authenticator_static",
|
||||
"0011_alter_authenticatorstaticstage_friendly_name",
|
||||
),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.AlterField(
|
||||
model_name="authenticatorstaticstage",
|
||||
name="token_length",
|
||||
field=models.PositiveIntegerField(
|
||||
default=12, validators=[django.core.validators.MaxValueValidator(100)]
|
||||
),
|
||||
),
|
||||
migrations.AlterField(
|
||||
model_name="statictoken",
|
||||
name="token",
|
||||
field=models.CharField(db_index=True, max_length=100),
|
||||
),
|
||||
]
|
||||
@@ -4,6 +4,7 @@ from base64 import b32encode
|
||||
from os import urandom
|
||||
|
||||
from django.conf import settings
|
||||
from django.core.validators import MaxValueValidator
|
||||
from django.db import models
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
from django.views import View
|
||||
@@ -19,7 +20,7 @@ class AuthenticatorStaticStage(ConfigurableStage, FriendlyNamedStage, Stage):
|
||||
"""Setup static token based authentication for the user."""
|
||||
|
||||
token_count = models.PositiveIntegerField(default=6)
|
||||
token_length = models.PositiveIntegerField(default=12)
|
||||
token_length = models.PositiveIntegerField(default=12, validators=[MaxValueValidator(100)])
|
||||
|
||||
@property
|
||||
def serializer(self) -> type[BaseSerializer]:
|
||||
@@ -109,11 +110,11 @@ class StaticToken(models.Model):
|
||||
|
||||
.. attribute:: token
|
||||
|
||||
*CharField*: A random string up to 16 characters.
|
||||
*CharField*: A random string up to 100 characters.
|
||||
"""
|
||||
|
||||
device = models.ForeignKey(StaticDevice, related_name="token_set", on_delete=models.CASCADE)
|
||||
token = models.CharField(max_length=16, db_index=True)
|
||||
token = models.CharField(max_length=100, db_index=True)
|
||||
|
||||
class Meta:
|
||||
verbose_name = _("Static Token")
|
||||
|
||||
@@ -5,7 +5,7 @@ from urllib.parse import quote
|
||||
from django.http import HttpRequest, HttpResponse
|
||||
from django.http.request import QueryDict
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
from rest_framework.fields import CharField, IntegerField
|
||||
from rest_framework.fields import CharField
|
||||
from rest_framework.serializers import ValidationError
|
||||
|
||||
from authentik.flows.challenge import (
|
||||
@@ -32,10 +32,10 @@ class AuthenticatorTOTPChallengeResponse(ChallengeResponse):
|
||||
|
||||
device: TOTPDevice
|
||||
|
||||
code = IntegerField()
|
||||
code = CharField()
|
||||
component = CharField(default="ak-stage-authenticator-totp")
|
||||
|
||||
def validate_code(self, code: int) -> int:
|
||||
def validate_code(self, code: str) -> str:
|
||||
"""Validate totp code"""
|
||||
if not self.device:
|
||||
raise ValidationError(_("Code does not match"))
|
||||
|
||||
@@ -21,6 +21,7 @@ from authentik.flows.models import FlowDesignation, NotConfiguredAction, Stage
|
||||
from authentik.flows.planner import PLAN_CONTEXT_PENDING_USER
|
||||
from authentik.flows.stage import ChallengeStageView
|
||||
from authentik.lib.utils.time import timedelta_from_string
|
||||
from authentik.policies.reputation.signals import update_score
|
||||
from authentik.stages.authenticator import devices_for_user
|
||||
from authentik.stages.authenticator.models import Device
|
||||
from authentik.stages.authenticator_email.models import EmailDevice
|
||||
@@ -418,6 +419,10 @@ class AuthenticatorValidateStageView(ChallengeStageView):
|
||||
)
|
||||
return response
|
||||
|
||||
def challenge_invalid(self, response: AuthenticatorValidationChallengeResponse) -> HttpResponse:
|
||||
update_score(self.request, self.get_pending_user().username, -1)
|
||||
return super().challenge_invalid(response)
|
||||
|
||||
def challenge_valid(self, response: AuthenticatorValidationChallengeResponse) -> HttpResponse:
|
||||
# All validation is done by the serializer
|
||||
user = self.executor.plan.context.get(PLAN_CONTEXT_PENDING_USER)
|
||||
|
||||
@@ -1,10 +1,9 @@
|
||||
"""Identification stage logic"""
|
||||
|
||||
from dataclasses import asdict
|
||||
from random import SystemRandom
|
||||
from time import sleep
|
||||
from typing import Any
|
||||
|
||||
from django.contrib.auth.hashers import make_password
|
||||
from django.core.exceptions import PermissionDenied
|
||||
from django.db.models import Q
|
||||
from django.http import HttpResponse
|
||||
@@ -18,6 +17,9 @@ from sentry_sdk import start_span
|
||||
from authentik.core.api.utils import JSONDictField, PassiveSerializer
|
||||
from authentik.core.models import Application, Source, User
|
||||
from authentik.endpoints.models import Device
|
||||
from authentik.enterprise.endpoints.connectors.agent.views.auth_interactive import (
|
||||
PLAN_CONTEXT_DEVICE_AUTH_TOKEN,
|
||||
)
|
||||
from authentik.events.middleware import audit_ignore
|
||||
from authentik.events.utils import sanitize_item
|
||||
from authentik.flows.challenge import (
|
||||
@@ -161,8 +163,8 @@ class IdentificationChallengeResponse(ChallengeResponse):
|
||||
op="authentik.stages.identification.validate_invalid_wait",
|
||||
name="Sleep random time on invalid user identifier",
|
||||
):
|
||||
# Sleep a random time (between 90 and 210ms) to "prevent" user enumeration attacks
|
||||
sleep(0.030 * SystemRandom().randint(3, 7))
|
||||
# hash a random password on invalid identifier, same as with a valid identifier
|
||||
make_password(make_password(None))
|
||||
# Log in a similar format to Event.new(), but we don't want to create an event here
|
||||
# as this stage is mostly used by unauthenticated users with very high rate limits
|
||||
self.stage.logger.info(
|
||||
@@ -316,7 +318,10 @@ class IdentificationStageView(ChallengeStageView):
|
||||
challenge.initial_data["application_pre"] = self.executor.plan.context.get(
|
||||
PLAN_CONTEXT_APPLICATION, Application()
|
||||
).name
|
||||
if PLAN_CONTEXT_DEVICE in self.executor.plan.context:
|
||||
if (
|
||||
PLAN_CONTEXT_DEVICE in self.executor.plan.context
|
||||
and PLAN_CONTEXT_DEVICE_AUTH_TOKEN in self.executor.plan.context
|
||||
):
|
||||
challenge.initial_data["application_pre"] = self.executor.plan.context.get(
|
||||
PLAN_CONTEXT_DEVICE, Device()
|
||||
).name
|
||||
|
||||
@@ -5,6 +5,7 @@ from typing import Any
|
||||
from django.contrib.auth import _clean_credentials
|
||||
from django.contrib.auth.backends import BaseBackend
|
||||
from django.core.exceptions import PermissionDenied
|
||||
from django.db.models import Sum
|
||||
from django.http import HttpRequest, HttpResponse
|
||||
from django.urls import reverse
|
||||
from django.utils.translation import gettext as _
|
||||
@@ -25,13 +26,14 @@ from authentik.flows.models import Flow, Stage
|
||||
from authentik.flows.planner import PLAN_CONTEXT_PENDING_USER
|
||||
from authentik.flows.stage import ChallengeStageView
|
||||
from authentik.lib.utils.reflection import path_to_class
|
||||
from authentik.policies.reputation.models import Reputation
|
||||
from authentik.stages.password.models import PasswordStage
|
||||
|
||||
LOGGER = get_logger()
|
||||
PLAN_CONTEXT_AUTHENTICATION_BACKEND = "user_backend"
|
||||
PLAN_CONTEXT_METHOD = "auth_method"
|
||||
PLAN_CONTEXT_METHOD_ARGS = "auth_method_args"
|
||||
SESSION_KEY_INVALID_TRIES = "authentik/stages/password/user_invalid_tries"
|
||||
PLAN_CONTEXT_INITIAL_SCORE = "goauthentik.io/stages/password/initial_score"
|
||||
|
||||
|
||||
def authenticate(
|
||||
@@ -148,19 +150,27 @@ class PasswordStageView(ChallengeStageView):
|
||||
kwargs={"flow_slug": recovery_flow.slug},
|
||||
)
|
||||
challenge.initial_data["recovery_url"] = self.request.build_absolute_uri(recover_url)
|
||||
if PLAN_CONTEXT_INITIAL_SCORE not in self.executor.plan.context:
|
||||
self.executor.plan.context[PLAN_CONTEXT_INITIAL_SCORE] = self.get_reputation_score()
|
||||
return challenge
|
||||
|
||||
def get_reputation_score(self) -> int:
|
||||
return (
|
||||
Reputation.objects.filter(identifier=self.get_pending_user().username).aggregate(
|
||||
total_score=Sum("score")
|
||||
)["total_score"]
|
||||
or 0
|
||||
)
|
||||
|
||||
def challenge_invalid(self, response: PasswordChallengeResponse) -> HttpResponse:
|
||||
if SESSION_KEY_INVALID_TRIES not in self.request.session:
|
||||
self.request.session[SESSION_KEY_INVALID_TRIES] = 0
|
||||
self.request.session[SESSION_KEY_INVALID_TRIES] += 1
|
||||
current_stage: PasswordStage = self.executor.current_stage
|
||||
if (
|
||||
self.request.session[SESSION_KEY_INVALID_TRIES]
|
||||
>= current_stage.failed_attempts_before_cancel
|
||||
):
|
||||
initial_score = self.executor.plan.context.get(PLAN_CONTEXT_INITIAL_SCORE)
|
||||
if initial_score is None:
|
||||
initial_score = self.get_reputation_score()
|
||||
self.executor.plan.context[PLAN_CONTEXT_INITIAL_SCORE] = initial_score
|
||||
new_score = self.get_reputation_score()
|
||||
if (initial_score - new_score) >= current_stage.failed_attempts_before_cancel:
|
||||
self.logger.debug("User has exceeded maximum tries")
|
||||
del self.request.session[SESSION_KEY_INVALID_TRIES]
|
||||
return self.executor.stage_invalid(_("Invalid password"))
|
||||
return super().challenge_invalid(response)
|
||||
|
||||
|
||||
@@ -135,6 +135,13 @@ class TestPasswordStage(FlowTestCase):
|
||||
session[SESSION_KEY_PLAN] = plan
|
||||
session.save()
|
||||
|
||||
res = self.client.get(
|
||||
reverse(
|
||||
"authentik_api:flow-executor",
|
||||
kwargs={"flow_slug": self.flow.slug},
|
||||
),
|
||||
)
|
||||
self.assertEqual(res.status_code, 200)
|
||||
for _ in range(self.stage.failed_attempts_before_cancel - 1):
|
||||
response = self.client.post(
|
||||
reverse(
|
||||
|
||||
@@ -37,7 +37,11 @@ class PromptStageSerializer(StageSerializer):
|
||||
class PromptStageViewSet(UsedByMixin, ModelViewSet):
|
||||
"""PromptStage Viewset"""
|
||||
|
||||
queryset = PromptStage.objects.all()
|
||||
queryset = PromptStage.objects.prefetch_related(
|
||||
"flow_set",
|
||||
"fields",
|
||||
"validation_policies",
|
||||
).all()
|
||||
serializer_class = PromptStageSerializer
|
||||
filterset_fields = "__all__"
|
||||
ordering = ["name"]
|
||||
@@ -73,7 +77,12 @@ class PromptSerializer(ModelSerializer):
|
||||
class PromptViewSet(UsedByMixin, ModelViewSet):
|
||||
"""Prompt Viewset"""
|
||||
|
||||
queryset = Prompt.objects.all().prefetch_related("promptstage_set")
|
||||
queryset = Prompt.objects.all().prefetch_related(
|
||||
"promptstage_set",
|
||||
"promptstage_set__flow_set",
|
||||
"promptstage_set__fields",
|
||||
"promptstage_set__validation_policies",
|
||||
)
|
||||
serializer_class = PromptSerializer
|
||||
ordering = ["field_key"]
|
||||
filterset_fields = ["field_key", "name", "label", "type", "placeholder"]
|
||||
|
||||
@@ -245,7 +245,10 @@ class WorkerStatusMiddleware(Middleware):
|
||||
WorkerStatusMiddleware.keep(status)
|
||||
except DB_ERRORS: # pragma: no cover
|
||||
sleep(10)
|
||||
pass
|
||||
try:
|
||||
connections.close_all()
|
||||
except DB_ERRORS:
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def keep(status: WorkerStatus):
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from random import choice
|
||||
|
||||
from django.conf import settings
|
||||
from django.db import DEFAULT_DB_ALIAS, connections
|
||||
|
||||
|
||||
class FailoverRouter:
|
||||
@@ -10,16 +11,22 @@ class FailoverRouter:
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.database_aliases = set(settings.DATABASES.keys())
|
||||
self.read_replica_aliases = list(self.database_aliases - {"default"})
|
||||
self.read_replica_aliases = list(self.database_aliases - {DEFAULT_DB_ALIAS})
|
||||
self.replica_enabled = len(self.read_replica_aliases) > 0
|
||||
|
||||
def db_for_read(self, model, **hints):
|
||||
if not self.replica_enabled:
|
||||
return "default"
|
||||
return DEFAULT_DB_ALIAS
|
||||
# Stay on primary for the entire transaction to maintain consistency.
|
||||
# Reading from a replica mid-transaction would give a different snapshot,
|
||||
# breaking transactional semantics (not just read-your-writes, but the
|
||||
# entire consistent point-in-time view that a transaction provides).
|
||||
if connections[DEFAULT_DB_ALIAS].in_atomic_block:
|
||||
return DEFAULT_DB_ALIAS
|
||||
return choice(self.read_replica_aliases) # nosec
|
||||
|
||||
def db_for_write(self, model, **hints):
|
||||
return "default"
|
||||
return DEFAULT_DB_ALIAS
|
||||
|
||||
def allow_relation(self, obj1, obj2, **hints):
|
||||
"""Relations between objects are allowed if both objects are
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
"$schema": "http://json-schema.org/draft-07/schema",
|
||||
"$id": "https://goauthentik.io/blueprints/schema.json",
|
||||
"type": "object",
|
||||
"title": "authentik 2025.12.0-rc1 Blueprint schema",
|
||||
"title": "authentik 2025.12.1 Blueprint schema",
|
||||
"required": [
|
||||
"version",
|
||||
"entries"
|
||||
@@ -6276,6 +6276,11 @@
|
||||
],
|
||||
"format": "date-time",
|
||||
"title": "Expires"
|
||||
},
|
||||
"key": {
|
||||
"type": "string",
|
||||
"minLength": 1,
|
||||
"title": "Key"
|
||||
}
|
||||
},
|
||||
"required": []
|
||||
@@ -13707,7 +13712,7 @@
|
||||
"token_length": {
|
||||
"type": "integer",
|
||||
"minimum": 0,
|
||||
"maximum": 2147483647,
|
||||
"maximum": 100,
|
||||
"title": "Token length"
|
||||
}
|
||||
},
|
||||
|
||||
@@ -31,13 +31,13 @@ services:
|
||||
AUTHENTIK_POSTGRESQL__PASSWORD: ${PG_PASS}
|
||||
AUTHENTIK_POSTGRESQL__USER: ${PG_USER:-authentik}
|
||||
AUTHENTIK_SECRET_KEY: ${AUTHENTIK_SECRET_KEY:?secret key required}
|
||||
image: ${AUTHENTIK_IMAGE:-ghcr.io/goauthentik/server}:${AUTHENTIK_TAG:-2025.12.0-rc1}
|
||||
image: ${AUTHENTIK_IMAGE:-ghcr.io/goauthentik/server}:${AUTHENTIK_TAG:-2025.12.1}
|
||||
ports:
|
||||
- ${COMPOSE_PORT_HTTP:-9000}:9000
|
||||
- ${COMPOSE_PORT_HTTPS:-9443}:9443
|
||||
restart: unless-stopped
|
||||
volumes:
|
||||
- ./media:/data/media
|
||||
- ./data:/data
|
||||
- ./custom-templates:/templates
|
||||
worker:
|
||||
command: worker
|
||||
@@ -52,12 +52,12 @@ services:
|
||||
AUTHENTIK_POSTGRESQL__PASSWORD: ${PG_PASS}
|
||||
AUTHENTIK_POSTGRESQL__USER: ${PG_USER:-authentik}
|
||||
AUTHENTIK_SECRET_KEY: ${AUTHENTIK_SECRET_KEY:?secret key required}
|
||||
image: ${AUTHENTIK_IMAGE:-ghcr.io/goauthentik/server}:${AUTHENTIK_TAG:-2025.12.0-rc1}
|
||||
image: ${AUTHENTIK_IMAGE:-ghcr.io/goauthentik/server}:${AUTHENTIK_TAG:-2025.12.1}
|
||||
restart: unless-stopped
|
||||
user: root
|
||||
volumes:
|
||||
- /var/run/docker.sock:/var/run/docker.sock
|
||||
- ./media:/data/media
|
||||
- ./data:/data
|
||||
- ./certs:/certs
|
||||
- ./custom-templates:/templates
|
||||
volumes:
|
||||
|
||||
6
go.mod
6
go.mod
@@ -1,8 +1,6 @@
|
||||
module goauthentik.io
|
||||
|
||||
go 1.24.3
|
||||
|
||||
toolchain go1.24.6
|
||||
go 1.25.5
|
||||
|
||||
require (
|
||||
beryju.io/ldap v0.1.0
|
||||
@@ -32,7 +30,7 @@ require (
|
||||
github.com/spf13/cobra v1.10.2
|
||||
github.com/stretchr/testify v1.11.1
|
||||
github.com/wwt/guac v1.3.2
|
||||
goauthentik.io/api/v3 v3.2025120.26
|
||||
goauthentik.io/api/v3 v3.2026020.7
|
||||
golang.org/x/exp v0.0.0-20230210204819-062eb4c674ab
|
||||
golang.org/x/oauth2 v0.34.0
|
||||
golang.org/x/sync v0.19.0
|
||||
|
||||
4
go.sum
4
go.sum
@@ -214,8 +214,8 @@ go.yaml.in/yaml/v2 v2.4.2 h1:DzmwEr2rDGHl7lsFgAHxmNz/1NlQ7xLIrlN2h5d1eGI=
|
||||
go.yaml.in/yaml/v2 v2.4.2/go.mod h1:081UH+NErpNdqlCXm3TtEran0rJZGxAYx9hb/ELlsPU=
|
||||
go.yaml.in/yaml/v3 v3.0.4 h1:tfq32ie2Jv2UxXFdLJdh3jXuOzWiL1fo0bu/FbuKpbc=
|
||||
go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg=
|
||||
goauthentik.io/api/v3 v3.2025120.26 h1:2lTMtjCWtdOeQe7kwjpGUx39qUEpcxcxTirIqMvn0Os=
|
||||
goauthentik.io/api/v3 v3.2025120.26/go.mod h1:82lqAz4jxzl6Cg0YDbhNtvvTG2rm6605ZhdJFnbbsl8=
|
||||
goauthentik.io/api/v3 v3.2026020.7 h1:/Op0pV6liiv+dJT3BhZdypIrqMimlABqAap/sMjphyo=
|
||||
goauthentik.io/api/v3 v3.2026020.7/go.mod h1:82lqAz4jxzl6Cg0YDbhNtvvTG2rm6605ZhdJFnbbsl8=
|
||||
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
|
||||
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
|
||||
golang.org/x/crypto v0.13.0/go.mod h1:y6Z2r+Rw4iayiXXAIxJIDAJ1zMW4yaTpebo8fPOliYc=
|
||||
|
||||
@@ -204,6 +204,7 @@ func (c *Config) RefreshPostgreSQLConfig() PostgreSQLConfig {
|
||||
// Map of environment variable suffix to config field pointer
|
||||
envVars := map[string]*string{
|
||||
"HOST": &refreshed.Host,
|
||||
"PORT": &refreshed.Port,
|
||||
"USER": &refreshed.User,
|
||||
"PASSWORD": &refreshed.Password,
|
||||
"NAME": &refreshed.Name,
|
||||
|
||||
@@ -27,7 +27,7 @@ type Config struct {
|
||||
|
||||
type PostgreSQLConfig struct {
|
||||
Host string `yaml:"host" env:"HOST, overwrite"`
|
||||
Port int `yaml:"port" env:"PORT, overwrite"`
|
||||
Port string `yaml:"port" env:"PORT, overwrite"`
|
||||
User string `yaml:"user" env:"USER, overwrite"`
|
||||
Password string `yaml:"password" env:"PASSWORD, overwrite"`
|
||||
Name string `yaml:"name" env:"NAME, overwrite"`
|
||||
|
||||
@@ -1 +1 @@
|
||||
2025.12.0-rc1
|
||||
2025.12.1
|
||||
@@ -165,7 +165,7 @@ func (ms *MemorySearcher) Search(req *search.Request) (ldap.ServerSearchResult,
|
||||
for _, u := range g.UsersObj {
|
||||
if flag.UserPk == u.Pk {
|
||||
// TODO: Is there a better way to clone this object?
|
||||
fg := api.NewGroup(g.Pk, g.NumPk, g.Name, []api.RelatedGroup{}, []api.PartialUser{u}, []api.Role{}, []string{}, []api.RelatedGroup{})
|
||||
fg := api.NewGroup(g.Pk, g.NumPk, g.Name, []api.RelatedGroup{}, []api.PartialUser{u}, []api.Role{}, nil, []string{}, []api.RelatedGroup{})
|
||||
fg.SetUsers([]int32{flag.UserPk})
|
||||
fg.SetAttributes(g.Attributes)
|
||||
fg.SetIsSuperuser(*g.IsSuperuser)
|
||||
|
||||
@@ -71,7 +71,15 @@ func (a *Application) checkRedirectParam(r *http.Request) (string, bool) {
|
||||
func (a *Application) createState(r *http.Request, w http.ResponseWriter, fwd string) (string, error) {
|
||||
s, err := a.sessions.Get(r, a.SessionName())
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to get session: %w", err)
|
||||
// Session file may not exist (e.g., after outpost restart or logout)
|
||||
// Delete the stale session cookie and continue with the new empty session
|
||||
a.log.WithError(err).Debug("failed to get session, clearing stale cookie")
|
||||
s.Options.MaxAge = -1
|
||||
if saveErr := s.Save(r, w); saveErr != nil {
|
||||
a.log.WithError(saveErr).Warning("failed to delete stale session cookie")
|
||||
}
|
||||
// Get a fresh session after clearing the stale cookie
|
||||
s, _ = a.sessions.Get(r, a.SessionName())
|
||||
}
|
||||
if s.ID == "" {
|
||||
// Ensure session has an ID
|
||||
|
||||
@@ -154,3 +154,39 @@ func TestStateFromRequestDeletesStaleCookie(t *testing.T) {
|
||||
}
|
||||
assert.True(t, foundDeleteCookie, "Expected stale session cookie to be deleted")
|
||||
}
|
||||
|
||||
func TestCreateStateWithStaleCookie(t *testing.T) {
|
||||
a := newTestApplication()
|
||||
_ = a.configureProxy()
|
||||
|
||||
// Create a request with a stale session cookie (simulates outpost restart or user change)
|
||||
req, _ := http.NewRequest("GET", "https://ext.t.goauthentik.io/outpost.goauthentik.io/start", nil)
|
||||
|
||||
// Add a cookie for a non-existent session
|
||||
nonExistentSessionID := uuid.New().String()
|
||||
req.AddCookie(&http.Cookie{
|
||||
Name: a.SessionName(),
|
||||
Value: "encoded_session_data_" + nonExistentSessionID,
|
||||
Path: "/",
|
||||
})
|
||||
|
||||
rr := httptest.NewRecorder()
|
||||
|
||||
// Call createState which should succeed despite the stale cookie
|
||||
state, err := a.createState(req, rr, "/redirect")
|
||||
|
||||
// Verify createState succeeded
|
||||
assert.NoError(t, err)
|
||||
assert.NotEmpty(t, state)
|
||||
|
||||
// Verify the response includes a Set-Cookie header to delete the stale cookie
|
||||
cookies := rr.Result().Cookies()
|
||||
var foundDeleteCookie bool
|
||||
for _, cookie := range cookies {
|
||||
if cookie.Name == a.SessionName() && cookie.MaxAge < 0 {
|
||||
foundDeleteCookie = true
|
||||
break
|
||||
}
|
||||
}
|
||||
assert.True(t, foundDeleteCookie, "Expected stale session cookie to be deleted")
|
||||
}
|
||||
|
||||
@@ -4,21 +4,23 @@ import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/gorilla/sessions"
|
||||
"github.com/jackc/pgx/v5"
|
||||
"github.com/jackc/pgx/v5/pgconn"
|
||||
"github.com/jackc/pgx/v5/stdlib"
|
||||
"github.com/mitchellh/mapstructure"
|
||||
log "github.com/sirupsen/logrus"
|
||||
_ "gorm.io/driver/postgres"
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/clause"
|
||||
|
||||
@@ -65,8 +67,8 @@ func BuildConnConfig(cfg config.PostgreSQLConfig) (*pgx.ConnConfig, error) {
|
||||
if cfg.Name == "" {
|
||||
return nil, fmt.Errorf("PostgreSQL database name is required")
|
||||
}
|
||||
if cfg.Port <= 0 {
|
||||
return nil, fmt.Errorf("PostgreSQL port must be positive")
|
||||
if cfg.Port == "" {
|
||||
return nil, fmt.Errorf("PostgreSQL port is required")
|
||||
}
|
||||
|
||||
// Start with a default config
|
||||
@@ -75,9 +77,38 @@ func BuildConnConfig(cfg config.PostgreSQLConfig) (*pgx.ConnConfig, error) {
|
||||
return nil, fmt.Errorf("failed to create default config: %w", err)
|
||||
}
|
||||
|
||||
// Set connection parameters
|
||||
connConfig.Host = cfg.Host
|
||||
connConfig.Port = uint16(cfg.Port)
|
||||
// Parse comma-separated hosts and create fallbacks
|
||||
// cfg.Host can be a comma-separated list like "host1,host2,host3"
|
||||
hosts := strings.Split(cfg.Host, ",")
|
||||
for i, host := range hosts {
|
||||
hosts[i] = strings.TrimSpace(host)
|
||||
}
|
||||
|
||||
// Parse and validate comma-separated ports
|
||||
portStrs := strings.Split(cfg.Port, ",")
|
||||
ports := make([]uint16, len(portStrs))
|
||||
for i, portStr := range portStrs {
|
||||
portStr = strings.TrimSpace(portStr)
|
||||
port, err := strconv.Atoi(portStr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid port value %q: %w", portStr, err)
|
||||
}
|
||||
if port <= 0 {
|
||||
return nil, fmt.Errorf("PostgreSQL port %d must be positive", port)
|
||||
}
|
||||
if port > 65535 {
|
||||
return nil, fmt.Errorf("PostgreSQL port %d is out of valid range", port)
|
||||
}
|
||||
ports[i] = uint16(port)
|
||||
}
|
||||
|
||||
// Get port for primary host
|
||||
primaryHost := hosts[0]
|
||||
primaryPort := ports[0]
|
||||
|
||||
// Set connection parameters for primary host
|
||||
connConfig.Host = primaryHost
|
||||
connConfig.Port = primaryPort
|
||||
connConfig.User = cfg.User
|
||||
connConfig.Password = cfg.Password
|
||||
connConfig.Database = cfg.Name
|
||||
@@ -123,13 +154,35 @@ func BuildConnConfig(cfg config.PostgreSQLConfig) (*pgx.ConnConfig, error) {
|
||||
case "verify-full":
|
||||
// Verify the certificate and hostname
|
||||
tlsConfig.InsecureSkipVerify = false
|
||||
tlsConfig.ServerName = cfg.Host
|
||||
tlsConfig.ServerName = primaryHost
|
||||
}
|
||||
|
||||
connConfig.TLSConfig = tlsConfig
|
||||
}
|
||||
}
|
||||
|
||||
// Create fallback configurations for additional hosts
|
||||
if len(hosts) > 1 {
|
||||
connConfig.Fallbacks = make([]*pgconn.FallbackConfig, 0, len(hosts)-1)
|
||||
for i, host := range hosts[1:] {
|
||||
port := getPortForIndex(ports, i+1)
|
||||
fallback := &pgconn.FallbackConfig{
|
||||
Host: host,
|
||||
Port: port,
|
||||
}
|
||||
// Copy TLS config to fallback if present
|
||||
if connConfig.TLSConfig != nil {
|
||||
fallbackTLS := connConfig.TLSConfig.Clone()
|
||||
// Update ServerName for verify-full mode
|
||||
if cfg.SSLMode == "verify-full" {
|
||||
fallbackTLS.ServerName = host
|
||||
}
|
||||
fallback.TLSConfig = fallbackTLS
|
||||
}
|
||||
connConfig.Fallbacks = append(connConfig.Fallbacks, fallback)
|
||||
}
|
||||
}
|
||||
|
||||
// Set runtime params
|
||||
if connConfig.RuntimeParams == nil {
|
||||
connConfig.RuntimeParams = make(map[string]string)
|
||||
@@ -141,23 +194,106 @@ func BuildConnConfig(cfg config.PostgreSQLConfig) (*pgx.ConnConfig, error) {
|
||||
|
||||
// Parse and apply connection options if specified
|
||||
if cfg.ConnOptions != "" {
|
||||
// Parse key=value pairs from ConnOptions
|
||||
// Format: "key1=value1 key2=value2"
|
||||
pairs := strings.Split(cfg.ConnOptions, " ")
|
||||
for _, pair := range pairs {
|
||||
if pair == "" {
|
||||
continue
|
||||
}
|
||||
kv := strings.SplitN(pair, "=", 2)
|
||||
if len(kv) == 2 {
|
||||
connConfig.RuntimeParams[kv[0]] = kv[1]
|
||||
}
|
||||
connOpts, err := parseConnOptions(cfg.ConnOptions)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse connection options: %w", err)
|
||||
}
|
||||
|
||||
if err := applyConnOptions(connConfig, connOpts); err != nil {
|
||||
return nil, fmt.Errorf("failed to apply connection options: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return connConfig, nil
|
||||
}
|
||||
|
||||
// getPortForIndex returns the port for the given host index.
|
||||
// If there are fewer ports than needed, returns the last port (libpq behavior).
|
||||
func getPortForIndex(ports []uint16, i int) uint16 {
|
||||
if i >= len(ports) {
|
||||
return ports[len(ports)-1]
|
||||
}
|
||||
return ports[i]
|
||||
}
|
||||
|
||||
// parseConnOptions decodes a base64-encoded JSON string into a map of connection options.
|
||||
// This matches the Python behavior in authentik/lib/config.py:get_dict_from_b64_json
|
||||
func parseConnOptions(encoded string) (map[string]string, error) {
|
||||
// Base64 decode
|
||||
decoded, err := base64.StdEncoding.DecodeString(encoded)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid base64 encoding: %w", err)
|
||||
}
|
||||
|
||||
// Parse JSON
|
||||
var opts map[string]interface{}
|
||||
if err := json.Unmarshal(decoded, &opts); err != nil {
|
||||
return nil, fmt.Errorf("invalid JSON: %w", err)
|
||||
}
|
||||
|
||||
// Convert all values to strings
|
||||
result := make(map[string]string)
|
||||
for k, v := range opts {
|
||||
switch val := v.(type) {
|
||||
case string:
|
||||
result[k] = val
|
||||
case float64:
|
||||
// JSON numbers are float64
|
||||
if val == float64(int(val)) {
|
||||
result[k] = strconv.Itoa(int(val))
|
||||
} else {
|
||||
result[k] = strconv.FormatFloat(val, 'f', -1, 64)
|
||||
}
|
||||
case bool:
|
||||
result[k] = strconv.FormatBool(val)
|
||||
default:
|
||||
result[k] = fmt.Sprintf("%v", v)
|
||||
}
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// applyConnOptions applies parsed connection options to the pgx.ConnConfig.
|
||||
func applyConnOptions(connConfig *pgx.ConnConfig, opts map[string]string) error {
|
||||
for key, value := range opts {
|
||||
// connect_timeout needs special handling as it's a connection-level timeout
|
||||
if key == "connect_timeout" {
|
||||
timeout, err := strconv.Atoi(value)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid connect_timeout value: %w", err)
|
||||
}
|
||||
connConfig.ConnectTimeout = time.Duration(timeout) * time.Second
|
||||
continue
|
||||
}
|
||||
// target_session_attrs needs special handling to set ValidateConnect function
|
||||
if key == "target_session_attrs" {
|
||||
switch value {
|
||||
case "read-write":
|
||||
connConfig.ValidateConnect = pgconn.ValidateConnectTargetSessionAttrsReadWrite
|
||||
case "read-only":
|
||||
connConfig.ValidateConnect = pgconn.ValidateConnectTargetSessionAttrsReadOnly
|
||||
case "primary":
|
||||
connConfig.ValidateConnect = pgconn.ValidateConnectTargetSessionAttrsPrimary
|
||||
case "standby":
|
||||
connConfig.ValidateConnect = pgconn.ValidateConnectTargetSessionAttrsStandby
|
||||
case "prefer-standby":
|
||||
connConfig.ValidateConnect = pgconn.ValidateConnectTargetSessionAttrsPreferStandby
|
||||
case "any":
|
||||
// "any" is the default (no validation needed)
|
||||
connConfig.ValidateConnect = nil
|
||||
default:
|
||||
return fmt.Errorf("unknown target_session_attrs value: %s", value)
|
||||
}
|
||||
// Do not add target_session_attrs to RuntimeParams
|
||||
continue
|
||||
}
|
||||
// All other options go to RuntimeParams
|
||||
connConfig.RuntimeParams[key] = value
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// BuildDSN constructs a PostgreSQL connection string from a ConnConfig.
|
||||
func BuildDSN(cfg config.PostgreSQLConfig) (string, error) {
|
||||
connConfig, err := BuildConnConfig(cfg)
|
||||
@@ -234,8 +370,8 @@ func NewPostgresStore(log *log.Entry) (*PostgresStore, error) {
|
||||
}
|
||||
|
||||
// Determine connection pool settings
|
||||
maxIdleConns := 10
|
||||
maxOpenConns := 100
|
||||
maxIdleConns := 4
|
||||
maxOpenConns := 4
|
||||
var connMaxLifetime time.Duration
|
||||
if cfg.ConnMaxAge > 0 {
|
||||
connMaxLifetime = time.Duration(cfg.ConnMaxAge) * time.Second
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"crypto/rsa"
|
||||
"crypto/x509"
|
||||
"crypto/x509/pkix"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"encoding/pem"
|
||||
"fmt"
|
||||
@@ -13,12 +14,15 @@ import (
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"reflect"
|
||||
"runtime"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/gorilla/sessions"
|
||||
"github.com/jackc/pgx/v5"
|
||||
"github.com/jackc/pgx/v5/pgconn"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"gorm.io/gorm"
|
||||
@@ -33,7 +37,7 @@ import (
|
||||
func SetupTestDB(t *testing.T) (*gorm.DB, *RefreshableConnPool) {
|
||||
cfg := config.Get().PostgreSQL
|
||||
|
||||
t.Logf("PostgreSQL config: Host=%s Port=%d User=%s DBName=%s SSLMode=%s",
|
||||
t.Logf("PostgreSQL config: Host=%s Port=%s User=%s DBName=%s SSLMode=%s",
|
||||
cfg.Host, cfg.Port, cfg.User, cfg.Name, cfg.SSLMode)
|
||||
t.Logf("Password length: %d", len(cfg.Password))
|
||||
if cfg.Password == "" {
|
||||
@@ -485,7 +489,7 @@ func TestBuildDSN_Validation(t *testing.T) {
|
||||
{
|
||||
name: "missing host",
|
||||
cfg: config.PostgreSQLConfig{
|
||||
Port: 5432,
|
||||
Port: "5432",
|
||||
User: "testuser",
|
||||
Name: "testdb",
|
||||
},
|
||||
@@ -496,7 +500,7 @@ func TestBuildDSN_Validation(t *testing.T) {
|
||||
name: "missing user",
|
||||
cfg: config.PostgreSQLConfig{
|
||||
Host: "localhost",
|
||||
Port: 5432,
|
||||
Port: "5432",
|
||||
Name: "testdb",
|
||||
},
|
||||
expectError: true,
|
||||
@@ -506,7 +510,7 @@ func TestBuildDSN_Validation(t *testing.T) {
|
||||
name: "missing database name",
|
||||
cfg: config.PostgreSQLConfig{
|
||||
Host: "localhost",
|
||||
Port: 5432,
|
||||
Port: "5432",
|
||||
User: "testuser",
|
||||
},
|
||||
expectError: true,
|
||||
@@ -516,23 +520,23 @@ func TestBuildDSN_Validation(t *testing.T) {
|
||||
name: "invalid port (zero)",
|
||||
cfg: config.PostgreSQLConfig{
|
||||
Host: "localhost",
|
||||
Port: 0,
|
||||
Port: "0",
|
||||
User: "testuser",
|
||||
Name: "testdb",
|
||||
},
|
||||
expectError: true,
|
||||
errorMsg: "PostgreSQL port must be positive",
|
||||
errorMsg: "PostgreSQL port 0 must be positive",
|
||||
},
|
||||
{
|
||||
name: "invalid port (negative)",
|
||||
cfg: config.PostgreSQLConfig{
|
||||
Host: "localhost",
|
||||
Port: -1,
|
||||
Port: "-1",
|
||||
User: "testuser",
|
||||
Name: "testdb",
|
||||
},
|
||||
expectError: true,
|
||||
errorMsg: "PostgreSQL port must be positive",
|
||||
errorMsg: "PostgreSQL port -1 must be positive",
|
||||
},
|
||||
}
|
||||
|
||||
@@ -560,7 +564,7 @@ func TestBuildConnConfig(t *testing.T) {
|
||||
name: "basic configuration",
|
||||
cfg: config.PostgreSQLConfig{
|
||||
Host: "localhost",
|
||||
Port: 5432,
|
||||
Port: "5432",
|
||||
User: "testuser",
|
||||
Name: "testdb",
|
||||
},
|
||||
@@ -576,7 +580,7 @@ func TestBuildConnConfig(t *testing.T) {
|
||||
name: "with simple password",
|
||||
cfg: config.PostgreSQLConfig{
|
||||
Host: "localhost",
|
||||
Port: 5432,
|
||||
Port: "5432",
|
||||
User: "testuser",
|
||||
Password: "testpass",
|
||||
Name: "testdb",
|
||||
@@ -589,7 +593,7 @@ func TestBuildConnConfig(t *testing.T) {
|
||||
name: "with password containing spaces",
|
||||
cfg: config.PostgreSQLConfig{
|
||||
Host: "localhost",
|
||||
Port: 5432,
|
||||
Port: "5432",
|
||||
User: "testuser",
|
||||
Password: "my secure password",
|
||||
Name: "testdb",
|
||||
@@ -602,7 +606,7 @@ func TestBuildConnConfig(t *testing.T) {
|
||||
name: "with password containing single quotes",
|
||||
cfg: config.PostgreSQLConfig{
|
||||
Host: "localhost",
|
||||
Port: 5432,
|
||||
Port: "5432",
|
||||
User: "testuser",
|
||||
Password: "pass'word",
|
||||
Name: "testdb",
|
||||
@@ -615,7 +619,7 @@ func TestBuildConnConfig(t *testing.T) {
|
||||
name: "with password containing backslashes",
|
||||
cfg: config.PostgreSQLConfig{
|
||||
Host: "localhost",
|
||||
Port: 5432,
|
||||
Port: "5432",
|
||||
User: "testuser",
|
||||
Password: `pass\word`,
|
||||
Name: "testdb",
|
||||
@@ -628,7 +632,7 @@ func TestBuildConnConfig(t *testing.T) {
|
||||
name: "with password containing special characters",
|
||||
cfg: config.PostgreSQLConfig{
|
||||
Host: "localhost",
|
||||
Port: 5432,
|
||||
Port: "5432",
|
||||
User: "testuser",
|
||||
Password: `p@ss w0rd!#$%^&*()`,
|
||||
Name: "testdb",
|
||||
@@ -641,7 +645,7 @@ func TestBuildConnConfig(t *testing.T) {
|
||||
name: "with password containing quotes and backslashes",
|
||||
cfg: config.PostgreSQLConfig{
|
||||
Host: "localhost",
|
||||
Port: 5432,
|
||||
Port: "5432",
|
||||
User: "testuser",
|
||||
Password: `my'pass\word"here`,
|
||||
Name: "testdb",
|
||||
@@ -654,7 +658,7 @@ func TestBuildConnConfig(t *testing.T) {
|
||||
name: "with passphrase (multiple spaces)",
|
||||
cfg: config.PostgreSQLConfig{
|
||||
Host: "localhost",
|
||||
Port: 5432,
|
||||
Port: "5432",
|
||||
User: "testuser",
|
||||
Password: "the quick brown fox jumps over",
|
||||
Name: "testdb",
|
||||
@@ -667,7 +671,7 @@ func TestBuildConnConfig(t *testing.T) {
|
||||
name: "with sslmode=disable",
|
||||
cfg: config.PostgreSQLConfig{
|
||||
Host: "localhost",
|
||||
Port: 5432,
|
||||
Port: "5432",
|
||||
User: "testuser",
|
||||
Name: "testdb",
|
||||
SSLMode: "disable",
|
||||
@@ -680,7 +684,7 @@ func TestBuildConnConfig(t *testing.T) {
|
||||
name: "with sslmode=require (no certs)",
|
||||
cfg: config.PostgreSQLConfig{
|
||||
Host: "localhost",
|
||||
Port: 5432,
|
||||
Port: "5432",
|
||||
User: "testuser",
|
||||
Name: "testdb",
|
||||
SSLMode: "require",
|
||||
@@ -694,7 +698,7 @@ func TestBuildConnConfig(t *testing.T) {
|
||||
name: "with custom schema",
|
||||
cfg: config.PostgreSQLConfig{
|
||||
Host: "localhost",
|
||||
Port: 5432,
|
||||
Port: "5432",
|
||||
User: "testuser",
|
||||
Name: "testdb",
|
||||
DefaultSchema: "custom_schema",
|
||||
@@ -707,27 +711,48 @@ func TestBuildConnConfig(t *testing.T) {
|
||||
name: "with connection options",
|
||||
cfg: config.PostgreSQLConfig{
|
||||
Host: "localhost",
|
||||
Port: 5432,
|
||||
Port: "5432",
|
||||
User: "testuser",
|
||||
Name: "testdb",
|
||||
ConnOptions: "connect_timeout=10 application_name=authentik",
|
||||
ConnOptions: base64.StdEncoding.EncodeToString([]byte(`{"connect_timeout":"10","application_name":"authentik"}`)),
|
||||
},
|
||||
validate: func(t *testing.T, cc *pgx.ConnConfig) {
|
||||
assert.Equal(t, "10", cc.RuntimeParams["connect_timeout"])
|
||||
assert.Equal(t, 10*time.Second, cc.ConnectTimeout)
|
||||
assert.Equal(t, "authentik", cc.RuntimeParams["application_name"])
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "with target_session_attrs",
|
||||
cfg: config.PostgreSQLConfig{
|
||||
Host: "localhost",
|
||||
Port: "5432",
|
||||
User: "testuser",
|
||||
Name: "testdb",
|
||||
ConnOptions: base64.StdEncoding.EncodeToString([]byte(`{"target_session_attrs":"read-write"}`)),
|
||||
},
|
||||
validate: func(t *testing.T, cc *pgx.ConnConfig) {
|
||||
// target_session_attrs should NOT be in RuntimeParams
|
||||
_, hasTargetSessionAttrs := cc.RuntimeParams["target_session_attrs"]
|
||||
assert.False(t, hasTargetSessionAttrs, "target_session_attrs should not appear in RuntimeParams")
|
||||
// It should set ValidateConnect instead
|
||||
assert.NotNil(t, cc.ValidateConnect, "ValidateConnect should be set for target_session_attrs")
|
||||
// Verify it's the correct validator function
|
||||
expectedValidator := pgconn.ValidateConnectTargetSessionAttrsReadWrite
|
||||
assert.Equal(t, runtime.FuncForPC(reflect.ValueOf(expectedValidator).Pointer()).Name(),
|
||||
runtime.FuncForPC(reflect.ValueOf(cc.ValidateConnect).Pointer()).Name())
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "full configuration with special password",
|
||||
cfg: config.PostgreSQLConfig{
|
||||
Host: "db.example.com",
|
||||
Port: 5433,
|
||||
Port: "5433",
|
||||
User: "admin",
|
||||
Password: "my super secret password!@#",
|
||||
Name: "production",
|
||||
SSLMode: "require",
|
||||
DefaultSchema: "app_schema",
|
||||
ConnOptions: "application_name=authentik",
|
||||
ConnOptions: base64.StdEncoding.EncodeToString([]byte(`{"application_name":"authentik"}`)),
|
||||
},
|
||||
validate: func(t *testing.T, cc *pgx.ConnConfig) {
|
||||
assert.Equal(t, "db.example.com", cc.Host)
|
||||
@@ -765,7 +790,7 @@ func TestBuildConnConfig_WithSSLCertificates(t *testing.T) {
|
||||
name: "verify-full with all certificates",
|
||||
cfg: config.PostgreSQLConfig{
|
||||
Host: "db.example.com",
|
||||
Port: 5432,
|
||||
Port: "5432",
|
||||
User: "testuser",
|
||||
Password: "my secure password",
|
||||
Name: "testdb",
|
||||
@@ -786,7 +811,7 @@ func TestBuildConnConfig_WithSSLCertificates(t *testing.T) {
|
||||
name: "verify-ca with root cert only",
|
||||
cfg: config.PostgreSQLConfig{
|
||||
Host: "localhost",
|
||||
Port: 5432,
|
||||
Port: "5432",
|
||||
User: "testuser",
|
||||
Name: "testdb",
|
||||
SSLMode: "verify-ca",
|
||||
@@ -803,7 +828,7 @@ func TestBuildConnConfig_WithSSLCertificates(t *testing.T) {
|
||||
name: "require with client cert",
|
||||
cfg: config.PostgreSQLConfig{
|
||||
Host: "localhost",
|
||||
Port: 5432,
|
||||
Port: "5432",
|
||||
User: "testuser",
|
||||
Name: "testdb",
|
||||
SSLMode: "require",
|
||||
@@ -820,7 +845,7 @@ func TestBuildConnConfig_WithSSLCertificates(t *testing.T) {
|
||||
name: "full configuration with SSL and special password",
|
||||
cfg: config.PostgreSQLConfig{
|
||||
Host: "db.example.com",
|
||||
Port: 5433,
|
||||
Port: "5433",
|
||||
User: "admin",
|
||||
Password: "my super secret password!@#",
|
||||
Name: "production",
|
||||
@@ -829,7 +854,7 @@ func TestBuildConnConfig_WithSSLCertificates(t *testing.T) {
|
||||
SSLCert: clientCertPath,
|
||||
SSLKey: clientKeyPath,
|
||||
DefaultSchema: "app_schema",
|
||||
ConnOptions: "application_name=authentik",
|
||||
ConnOptions: base64.StdEncoding.EncodeToString([]byte(`{"application_name":"authentik"}`)),
|
||||
},
|
||||
validate: func(t *testing.T, cc *pgx.ConnConfig) {
|
||||
assert.Equal(t, "db.example.com", cc.Host)
|
||||
@@ -881,7 +906,7 @@ func TestBuildDSN_WithSpecialPasswords(t *testing.T) {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
cfg := config.PostgreSQLConfig{
|
||||
Host: "localhost",
|
||||
Port: 5432,
|
||||
Port: "5432",
|
||||
User: "testuser",
|
||||
Password: tt.password,
|
||||
Name: "testdb",
|
||||
@@ -941,6 +966,221 @@ func TestPostgresStore_ConnectionPoolSettings(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// TestParseConnOptions tests the base64 JSON parsing of connection options
|
||||
func TestParseConnOptions(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expected map[string]string
|
||||
expectError bool
|
||||
errorMsg string
|
||||
}{
|
||||
{
|
||||
name: "simple key-value",
|
||||
input: base64.StdEncoding.EncodeToString([]byte(`{"target_session_attrs":"read-write"}`)),
|
||||
expected: map[string]string{"target_session_attrs": "read-write"},
|
||||
},
|
||||
{
|
||||
name: "multiple options",
|
||||
input: base64.StdEncoding.EncodeToString([]byte(`{"connect_timeout":"10","application_name":"authentik"}`)),
|
||||
expected: map[string]string{"connect_timeout": "10", "application_name": "authentik"},
|
||||
},
|
||||
{
|
||||
name: "numeric value as number",
|
||||
input: base64.StdEncoding.EncodeToString([]byte(`{"connect_timeout":10}`)),
|
||||
expected: map[string]string{"connect_timeout": "10"},
|
||||
},
|
||||
{
|
||||
name: "boolean value",
|
||||
input: base64.StdEncoding.EncodeToString([]byte(`{"default_transaction_read_only":true}`)),
|
||||
expected: map[string]string{"default_transaction_read_only": "true"},
|
||||
},
|
||||
{
|
||||
name: "empty object",
|
||||
input: base64.StdEncoding.EncodeToString([]byte(`{}`)),
|
||||
expected: map[string]string{},
|
||||
},
|
||||
{
|
||||
name: "invalid base64",
|
||||
input: "not-valid-base64!!!",
|
||||
expectError: true,
|
||||
errorMsg: "invalid base64 encoding",
|
||||
},
|
||||
{
|
||||
name: "invalid JSON",
|
||||
input: base64.StdEncoding.EncodeToString([]byte(`not json`)),
|
||||
expectError: true,
|
||||
errorMsg: "invalid JSON",
|
||||
},
|
||||
{
|
||||
name: "JSON array instead of object",
|
||||
input: base64.StdEncoding.EncodeToString([]byte(`["value1", "value2"]`)),
|
||||
expectError: true,
|
||||
errorMsg: "invalid JSON",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result, err := parseConnOptions(tt.input)
|
||||
if tt.expectError {
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), tt.errorMsg)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestApplyConnOptions tests that connection options are applied correctly to pgx.ConnConfig
|
||||
func TestApplyConnOptions(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
opts map[string]string
|
||||
validate func(*testing.T, *pgx.ConnConfig)
|
||||
expectError bool
|
||||
errorMsg string
|
||||
}{
|
||||
{
|
||||
name: "connect_timeout sets ConnectTimeout",
|
||||
opts: map[string]string{"connect_timeout": "30"},
|
||||
validate: func(t *testing.T, cc *pgx.ConnConfig) {
|
||||
assert.Equal(t, 30*time.Second, cc.ConnectTimeout)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "target_session_attrs sets ValidateConnect",
|
||||
opts: map[string]string{"target_session_attrs": "read-write"},
|
||||
validate: func(t *testing.T, cc *pgx.ConnConfig) {
|
||||
// target_session_attrs should NOT be in RuntimeParams
|
||||
_, hasTargetSessionAttrs := cc.RuntimeParams["target_session_attrs"]
|
||||
assert.False(t, hasTargetSessionAttrs, "target_session_attrs should not be in RuntimeParams")
|
||||
// It should set ValidateConnect instead
|
||||
assert.NotNil(t, cc.ValidateConnect, "ValidateConnect should be set")
|
||||
expectedValidator := pgconn.ValidateConnectTargetSessionAttrsReadWrite
|
||||
assert.Equal(t, runtime.FuncForPC(reflect.ValueOf(expectedValidator).Pointer()).Name(),
|
||||
runtime.FuncForPC(reflect.ValueOf(cc.ValidateConnect).Pointer()).Name())
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "application_name goes to RuntimeParams",
|
||||
opts: map[string]string{"application_name": "my-app"},
|
||||
validate: func(t *testing.T, cc *pgx.ConnConfig) {
|
||||
assert.Equal(t, "my-app", cc.RuntimeParams["application_name"])
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "statement_timeout goes to RuntimeParams",
|
||||
opts: map[string]string{"statement_timeout": "5000"},
|
||||
validate: func(t *testing.T, cc *pgx.ConnConfig) {
|
||||
assert.Equal(t, "5000", cc.RuntimeParams["statement_timeout"])
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "unknown options go to RuntimeParams",
|
||||
opts: map[string]string{"custom_param": "custom_value"},
|
||||
validate: func(t *testing.T, cc *pgx.ConnConfig) {
|
||||
assert.Equal(t, "custom_value", cc.RuntimeParams["custom_param"])
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "multiple options",
|
||||
opts: map[string]string{
|
||||
"connect_timeout": "10",
|
||||
"target_session_attrs": "read-write",
|
||||
"application_name": "authentik",
|
||||
},
|
||||
validate: func(t *testing.T, cc *pgx.ConnConfig) {
|
||||
assert.Equal(t, 10*time.Second, cc.ConnectTimeout)
|
||||
// target_session_attrs should NOT be in RuntimeParams
|
||||
_, hasTargetSessionAttrs := cc.RuntimeParams["target_session_attrs"]
|
||||
assert.False(t, hasTargetSessionAttrs, "target_session_attrs should not be in RuntimeParams")
|
||||
// It should set ValidateConnect instead
|
||||
assert.NotNil(t, cc.ValidateConnect, "ValidateConnect should be set")
|
||||
assert.Equal(t, "authentik", cc.RuntimeParams["application_name"])
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "invalid connect_timeout",
|
||||
opts: map[string]string{"connect_timeout": "not-a-number"},
|
||||
expectError: true,
|
||||
errorMsg: "invalid connect_timeout value",
|
||||
},
|
||||
{
|
||||
name: "invalid target_session_attrs",
|
||||
opts: map[string]string{"target_session_attrs": "invalid-mode"},
|
||||
expectError: true,
|
||||
errorMsg: "unknown target_session_attrs value",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Create a base config
|
||||
connConfig, err := pgx.ParseConfig("")
|
||||
require.NoError(t, err)
|
||||
connConfig.RuntimeParams = make(map[string]string)
|
||||
|
||||
err = applyConnOptions(connConfig, tt.opts)
|
||||
if tt.expectError {
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), tt.errorMsg)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
tt.validate(t, connConfig)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestBuildConnConfig_Base64JSONConnOptions tests the full integration of base64 JSON connection options
|
||||
func TestBuildConnConfig_Base64JSONConnOptions(t *testing.T) {
|
||||
t.Run("bug report scenario - target_session_attrs", func(t *testing.T) {
|
||||
cfg := config.PostgreSQLConfig{
|
||||
Host: "localhost",
|
||||
Port: "5432",
|
||||
User: "authentik",
|
||||
Name: "authentik",
|
||||
ConnOptions: "eyJ0YXJnZXRfc2Vzc2lvbl9hdHRycyI6InJlYWQtd3JpdGUifQ==",
|
||||
}
|
||||
|
||||
connConfig, err := BuildConnConfig(cfg)
|
||||
require.NoError(t, err)
|
||||
// target_session_attrs should NOT be in RuntimeParams
|
||||
_, hasTargetSessionAttrs := connConfig.RuntimeParams["target_session_attrs"]
|
||||
assert.False(t, hasTargetSessionAttrs, "target_session_attrs should not appear in RuntimeParams")
|
||||
// It should set ValidateConnect instead
|
||||
assert.NotNil(t, connConfig.ValidateConnect, "ValidateConnect should be set")
|
||||
expectedValidator := pgconn.ValidateConnectTargetSessionAttrsReadWrite
|
||||
assert.Equal(t, runtime.FuncForPC(reflect.ValueOf(expectedValidator).Pointer()).Name(),
|
||||
runtime.FuncForPC(reflect.ValueOf(connConfig.ValidateConnect).Pointer()).Name())
|
||||
})
|
||||
|
||||
t.Run("complex connection options", func(t *testing.T) {
|
||||
// {"connect_timeout":10,"target_session_attrs":"read-write","application_name":"authentik-proxy"}
|
||||
connOpts := base64.StdEncoding.EncodeToString([]byte(`{"connect_timeout":10,"target_session_attrs":"read-write","application_name":"authentik-proxy"}`))
|
||||
cfg := config.PostgreSQLConfig{
|
||||
Host: "localhost",
|
||||
Port: "5432",
|
||||
User: "authentik",
|
||||
Name: "authentik",
|
||||
ConnOptions: connOpts,
|
||||
}
|
||||
|
||||
connConfig, err := BuildConnConfig(cfg)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 10*time.Second, connConfig.ConnectTimeout)
|
||||
// target_session_attrs should NOT be in RuntimeParams
|
||||
_, hasTargetSessionAttrs := connConfig.RuntimeParams["target_session_attrs"]
|
||||
assert.False(t, hasTargetSessionAttrs, "target_session_attrs should not appear in RuntimeParams")
|
||||
// It should set ValidateConnect instead
|
||||
assert.NotNil(t, connConfig.ValidateConnect, "ValidateConnect should be set")
|
||||
assert.Equal(t, "authentik-proxy", connConfig.RuntimeParams["application_name"])
|
||||
})
|
||||
}
|
||||
|
||||
// Helper function to create session data JSON
|
||||
func createSessionData(t *testing.T, claims map[string]interface{}) string {
|
||||
sessionData := map[string]interface{}{
|
||||
@@ -1036,3 +1276,495 @@ func generateTestCerts(t *testing.T) (rootCertPath, clientCertPath, clientKeyPat
|
||||
|
||||
return rootCertPath, clientCertPath, clientKeyPath, cleanup
|
||||
}
|
||||
|
||||
// TestBuildConnConfig_WithBase64EncodedConnOptions demonstrates that ConnOptions
|
||||
// should be base64-encoded JSON but is currently being parsed as key=value pairs
|
||||
func TestBuildConnConfig_WithBase64EncodedConnOptions(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
connOptions string
|
||||
expected map[string]string
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "base64 encoded JSON with single parameter",
|
||||
connOptions: base64.StdEncoding.EncodeToString([]byte(`{"connect_timeout":"10"}`)),
|
||||
expected: map[string]string{
|
||||
// connect_timeout is handled specially and NOT added to RuntimeParams
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "base64 encoded JSON with multiple parameters",
|
||||
connOptions: base64.StdEncoding.EncodeToString([]byte(`{"connect_timeout":"10","application_name":"authentik","statement_timeout":"30000"}`)),
|
||||
expected: map[string]string{
|
||||
// connect_timeout is handled specially and NOT added to RuntimeParams
|
||||
"application_name": "authentik",
|
||||
"statement_timeout": "30000",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "base64 encoded JSON with special characters in values",
|
||||
connOptions: base64.StdEncoding.EncodeToString([]byte(`{"application_name":"authentik proxy v2"}`)),
|
||||
expected: map[string]string{
|
||||
"application_name": "authentik proxy v2",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "base64 encoded JSON with target_session_attrs",
|
||||
connOptions: base64.StdEncoding.EncodeToString([]byte(`{"target_session_attrs":"read-write","application_name":"authentik"}`)),
|
||||
expected: map[string]string{
|
||||
"application_name": "authentik",
|
||||
// target_session_attrs should NOT appear in RuntimeParams
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
cfg := config.PostgreSQLConfig{
|
||||
Host: "localhost",
|
||||
Port: "5432",
|
||||
User: "testuser",
|
||||
Name: "testdb",
|
||||
ConnOptions: tt.connOptions,
|
||||
}
|
||||
|
||||
result, err := BuildConnConfig(cfg)
|
||||
|
||||
if tt.expectError {
|
||||
assert.Error(t, err)
|
||||
return
|
||||
}
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
|
||||
// Verify that all expected parameters are present in RuntimeParams
|
||||
for key, expectedValue := range tt.expected {
|
||||
actualValue, exists := result.RuntimeParams[key]
|
||||
assert.True(t, exists, "Expected runtime parameter %s to exist", key)
|
||||
assert.Equal(t, expectedValue, actualValue, "Runtime parameter %s should have value %s", key, expectedValue)
|
||||
}
|
||||
|
||||
// Verify that connect_timeout is handled specially (sets ConnectTimeout field, not RuntimeParams)
|
||||
if tt.name == "base64 encoded JSON with single parameter" || tt.name == "base64 encoded JSON with multiple parameters" {
|
||||
_, hasConnectTimeout := result.RuntimeParams["connect_timeout"]
|
||||
assert.False(t, hasConnectTimeout, "connect_timeout should not appear in RuntimeParams")
|
||||
assert.Equal(t, 10*time.Second, result.ConnectTimeout, "connect_timeout should be set as ConnectTimeout duration")
|
||||
}
|
||||
|
||||
// Verify that target_session_attrs is NOT in RuntimeParams
|
||||
// (it affects connection behavior, not a runtime param)
|
||||
_, hasTargetSessionAttrs := result.RuntimeParams["target_session_attrs"]
|
||||
assert.False(t, hasTargetSessionAttrs, "target_session_attrs should not appear in RuntimeParams")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestBuildConnConfig_TargetSessionAttrs demonstrates how target_session_attrs
|
||||
// should be properly handled using pgx's ValidateConnect callback
|
||||
func TestBuildConnConfig_TargetSessionAttrs(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
connOptions string
|
||||
targetSessionAttrs string
|
||||
expectedValidator pgconn.ValidateConnectFunc
|
||||
validatorDescription string
|
||||
}{
|
||||
{
|
||||
name: "target_session_attrs=read-write",
|
||||
connOptions: base64.StdEncoding.EncodeToString([]byte(`{"target_session_attrs":"read-write"}`)),
|
||||
targetSessionAttrs: "read-write",
|
||||
expectedValidator: pgconn.ValidateConnectTargetSessionAttrsReadWrite,
|
||||
validatorDescription: "should validate connection is read-write by checking transaction_read_only=off",
|
||||
},
|
||||
{
|
||||
name: "target_session_attrs=read-only",
|
||||
connOptions: base64.StdEncoding.EncodeToString([]byte(`{"target_session_attrs":"read-only"}`)),
|
||||
targetSessionAttrs: "read-only",
|
||||
expectedValidator: pgconn.ValidateConnectTargetSessionAttrsReadOnly,
|
||||
validatorDescription: "should validate connection is read-only by checking transaction_read_only=on",
|
||||
},
|
||||
{
|
||||
name: "target_session_attrs=primary",
|
||||
connOptions: base64.StdEncoding.EncodeToString([]byte(`{"target_session_attrs":"primary"}`)),
|
||||
targetSessionAttrs: "primary",
|
||||
expectedValidator: pgconn.ValidateConnectTargetSessionAttrsPrimary,
|
||||
validatorDescription: "should validate connection is to primary by checking in_hot_standby=off",
|
||||
},
|
||||
{
|
||||
name: "target_session_attrs=standby",
|
||||
connOptions: base64.StdEncoding.EncodeToString([]byte(`{"target_session_attrs":"standby"}`)),
|
||||
targetSessionAttrs: "standby",
|
||||
expectedValidator: pgconn.ValidateConnectTargetSessionAttrsStandby,
|
||||
validatorDescription: "should validate connection is to standby by checking in_hot_standby=on",
|
||||
},
|
||||
{
|
||||
name: "target_session_attrs=prefer-standby",
|
||||
connOptions: base64.StdEncoding.EncodeToString([]byte(`{"target_session_attrs":"prefer-standby"}`)),
|
||||
targetSessionAttrs: "prefer-standby",
|
||||
expectedValidator: pgconn.ValidateConnectTargetSessionAttrsPreferStandby,
|
||||
validatorDescription: "should prefer standby connections (affects fallback logic)",
|
||||
},
|
||||
{
|
||||
name: "target_session_attrs=any (default)",
|
||||
connOptions: base64.StdEncoding.EncodeToString([]byte(`{"target_session_attrs":"any"}`)),
|
||||
targetSessionAttrs: "any",
|
||||
expectedValidator: nil,
|
||||
validatorDescription: "should not set validator as any connection is acceptable",
|
||||
},
|
||||
{
|
||||
name: "no target_session_attrs",
|
||||
connOptions: base64.StdEncoding.EncodeToString([]byte(`{"application_name":"authentik"}`)),
|
||||
targetSessionAttrs: "",
|
||||
expectedValidator: nil,
|
||||
validatorDescription: "should not set validator when target_session_attrs is not specified",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
cfg := config.PostgreSQLConfig{
|
||||
Host: "localhost",
|
||||
Port: "5432",
|
||||
User: "testuser",
|
||||
Name: "testdb",
|
||||
ConnOptions: tt.connOptions,
|
||||
}
|
||||
|
||||
result, err := BuildConnConfig(cfg)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
|
||||
// Verify target_session_attrs is NOT in RuntimeParams
|
||||
_, hasTargetSessionAttrs := result.RuntimeParams["target_session_attrs"]
|
||||
assert.False(t, hasTargetSessionAttrs,
|
||||
"target_session_attrs should not appear in RuntimeParams")
|
||||
|
||||
// Verify ValidateConnect callback is set to the correct standard pgx function
|
||||
if tt.expectedValidator != nil {
|
||||
require.NotNil(t, result.ValidateConnect,
|
||||
"ValidateConnect should be set for target_session_attrs=%s: %s",
|
||||
tt.targetSessionAttrs, tt.validatorDescription)
|
||||
|
||||
// Compare function pointers using reflect to check if it's the same function
|
||||
actualFuncPtr := runtime.FuncForPC(reflect.ValueOf(result.ValidateConnect).Pointer())
|
||||
expectedFuncPtr := runtime.FuncForPC(reflect.ValueOf(tt.expectedValidator).Pointer())
|
||||
|
||||
assert.Equal(t, expectedFuncPtr.Name(), actualFuncPtr.Name(),
|
||||
"ValidateConnect should be set to %s for target_session_attrs=%s",
|
||||
expectedFuncPtr.Name(), tt.targetSessionAttrs)
|
||||
|
||||
t.Logf("Expected validator: %s", expectedFuncPtr.Name())
|
||||
t.Logf("Actual validator: %s", actualFuncPtr.Name())
|
||||
} else {
|
||||
assert.Nil(t, result.ValidateConnect,
|
||||
"ValidateConnect should not be set: %s", tt.validatorDescription)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestBuildConnConfig_TargetSessionAttrs_WithMultipleHosts tests that when multiple
|
||||
// hosts are specified, fallbacks are properly configured along with the validator
|
||||
func TestBuildConnConfig_TargetSessionAttrs_WithMultipleHosts(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
host string
|
||||
port string
|
||||
sslMode string
|
||||
connOptions string
|
||||
targetSessionAttrs string
|
||||
expectedValidator pgconn.ValidateConnectFunc
|
||||
expectedPrimaryHost string
|
||||
expectedPrimaryPort uint16
|
||||
expectedFallbacks []*pgconn.FallbackConfig
|
||||
expectTLS bool
|
||||
validatorDescription string
|
||||
}{
|
||||
{
|
||||
name: "multiple hosts with read-write",
|
||||
host: "db1.local,db2.local,db3.local",
|
||||
port: "5432",
|
||||
connOptions: base64.StdEncoding.EncodeToString([]byte(`{"target_session_attrs":"read-write"}`)),
|
||||
targetSessionAttrs: "read-write",
|
||||
expectedValidator: pgconn.ValidateConnectTargetSessionAttrsReadWrite,
|
||||
expectedPrimaryHost: "db1.local",
|
||||
expectedPrimaryPort: 5432,
|
||||
expectedFallbacks: []*pgconn.FallbackConfig{
|
||||
{Host: "db2.local", Port: 5432, TLSConfig: nil},
|
||||
{Host: "db3.local", Port: 5432, TLSConfig: nil},
|
||||
},
|
||||
expectTLS: false,
|
||||
validatorDescription: "should set validator and create fallbacks for additional hosts",
|
||||
},
|
||||
{
|
||||
name: "multiple hosts with ports specified",
|
||||
host: "db1.local,db2.local,db3.local",
|
||||
port: "5432,5433,5434",
|
||||
connOptions: base64.StdEncoding.EncodeToString([]byte(`{"target_session_attrs":"read-write"}`)),
|
||||
targetSessionAttrs: "read-write",
|
||||
expectedValidator: pgconn.ValidateConnectTargetSessionAttrsReadWrite,
|
||||
expectedPrimaryHost: "db1.local",
|
||||
expectedPrimaryPort: 5432,
|
||||
expectedFallbacks: []*pgconn.FallbackConfig{
|
||||
{Host: "db2.local", Port: 5433, TLSConfig: nil},
|
||||
{Host: "db3.local", Port: 5434, TLSConfig: nil},
|
||||
},
|
||||
expectTLS: false,
|
||||
validatorDescription: "should handle hosts with explicit ports",
|
||||
},
|
||||
{
|
||||
name: "multiple hosts with TLS required",
|
||||
host: "db1.local,db2.local,db3.local",
|
||||
port: "5432",
|
||||
sslMode: "require",
|
||||
connOptions: base64.StdEncoding.EncodeToString([]byte(`{"target_session_attrs":"read-write", "sslmode":"require"}`)),
|
||||
targetSessionAttrs: "read-write",
|
||||
expectedValidator: pgconn.ValidateConnectTargetSessionAttrsReadWrite,
|
||||
expectedPrimaryHost: "db1.local",
|
||||
expectedPrimaryPort: 5432,
|
||||
expectedFallbacks: []*pgconn.FallbackConfig{
|
||||
{Host: "db2.local", Port: 5432}, // TLSConfig should be set (non-nil)
|
||||
{Host: "db3.local", Port: 5432}, // TLSConfig should be set (non-nil)
|
||||
},
|
||||
expectTLS: true,
|
||||
validatorDescription: "should set TLS config for all hosts when sslmode=require",
|
||||
},
|
||||
{
|
||||
name: "multiple hosts with TLS verify-full",
|
||||
host: "db1.local,db2.local,db3.local",
|
||||
port: "5432",
|
||||
sslMode: "require",
|
||||
connOptions: base64.StdEncoding.EncodeToString([]byte(`{"target_session_attrs":"read-write", "sslmode":"verify-full"}`)),
|
||||
targetSessionAttrs: "read-write",
|
||||
expectedValidator: pgconn.ValidateConnectTargetSessionAttrsReadWrite,
|
||||
expectedPrimaryHost: "db1.local",
|
||||
expectedPrimaryPort: 5432,
|
||||
expectedFallbacks: []*pgconn.FallbackConfig{
|
||||
{Host: "db2.local", Port: 5432}, // TLSConfig should be set (non-nil)
|
||||
{Host: "db3.local", Port: 5432}, // TLSConfig should be set (non-nil)
|
||||
},
|
||||
expectTLS: true,
|
||||
validatorDescription: "should set TLS config host name for all hosts when sslmode=verify-full",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
cfg := config.PostgreSQLConfig{
|
||||
Host: tt.host,
|
||||
Port: tt.port,
|
||||
User: "testuser",
|
||||
Name: "testdb",
|
||||
SSLMode: tt.sslMode,
|
||||
ConnOptions: tt.connOptions,
|
||||
}
|
||||
|
||||
result, err := BuildConnConfig(cfg)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
|
||||
// Verify target_session_attrs is NOT in RuntimeParams
|
||||
_, hasTargetSessionAttrs := result.RuntimeParams["target_session_attrs"]
|
||||
assert.False(t, hasTargetSessionAttrs,
|
||||
"target_session_attrs should not appear in RuntimeParams")
|
||||
|
||||
// Verify ValidateConnect is set to the correct function
|
||||
require.NotNil(t, result.ValidateConnect,
|
||||
"ValidateConnect should be set for target_session_attrs=%s with multiple hosts",
|
||||
tt.targetSessionAttrs)
|
||||
|
||||
actualFuncPtr := runtime.FuncForPC(reflect.ValueOf(result.ValidateConnect).Pointer())
|
||||
expectedFuncPtr := runtime.FuncForPC(reflect.ValueOf(tt.expectedValidator).Pointer())
|
||||
|
||||
assert.Equal(t, expectedFuncPtr.Name(), actualFuncPtr.Name(),
|
||||
"ValidateConnect should be %s for target_session_attrs=%s",
|
||||
expectedFuncPtr.Name(), tt.targetSessionAttrs)
|
||||
|
||||
// Verify the primary host and port
|
||||
assert.Equal(t, tt.expectedPrimaryHost, result.Host,
|
||||
"Primary host should be %s", tt.expectedPrimaryHost)
|
||||
assert.Equal(t, tt.expectedPrimaryPort, result.Port,
|
||||
"Primary port should be %d", tt.expectedPrimaryPort)
|
||||
|
||||
// Verify primary TLSConfig based on sslmode
|
||||
if tt.expectTLS {
|
||||
assert.NotNil(t, result.TLSConfig,
|
||||
"Primary connection should have TLSConfig set when sslmode=%s", tt.sslMode)
|
||||
} else {
|
||||
assert.Nil(t, result.TLSConfig,
|
||||
"Primary connection should not have TLSConfig when sslmode is not set")
|
||||
}
|
||||
|
||||
// Verify Fallbacks are configured for the additional hosts
|
||||
require.Len(t, result.Fallbacks, len(tt.expectedFallbacks),
|
||||
"Should have %d fallback configs for the additional hosts", len(tt.expectedFallbacks))
|
||||
|
||||
// Verify each fallback configuration
|
||||
for i, expectedFb := range tt.expectedFallbacks {
|
||||
actualFb := result.Fallbacks[i]
|
||||
|
||||
assert.Equal(t, expectedFb.Host, actualFb.Host,
|
||||
"Fallback %d host should be %s", i+1, expectedFb.Host)
|
||||
assert.Equal(t, expectedFb.Port, actualFb.Port,
|
||||
"Fallback %d port should be %d", i+1, expectedFb.Port)
|
||||
|
||||
// Verify TLSConfig is set appropriately for fallbacks
|
||||
if tt.expectTLS {
|
||||
assert.NotNil(t, actualFb.TLSConfig,
|
||||
"Fallback %d should have TLSConfig set when sslmode=%s", i+1, tt.sslMode)
|
||||
// Verify InsecureSkipVerify for sslmode=require
|
||||
switch tt.sslMode {
|
||||
case "require":
|
||||
assert.True(t, actualFb.TLSConfig.InsecureSkipVerify,
|
||||
"Fallback %d TLSConfig should have InsecureSkipVerify=true for sslmode=require", i+1)
|
||||
case "verify-full":
|
||||
assert.False(t, actualFb.TLSConfig.InsecureSkipVerify,
|
||||
"Fallback %d TLSConfig should have InsecureSkipVerify=false for sslmode=verify-full", i+1)
|
||||
assert.Equal(t, actualFb.Host, actualFb.TLSConfig.ServerName,
|
||||
"Fallback %d TLSConfig ServerName should match host for sslmode=verify-full", i+1)
|
||||
}
|
||||
} else {
|
||||
assert.Nil(t, actualFb.TLSConfig,
|
||||
"Fallback %d should not have TLSConfig when sslmode is not set", i+1)
|
||||
}
|
||||
}
|
||||
|
||||
// Log the configuration for debugging
|
||||
t.Logf("Primary host: %s:%d", result.Host, result.Port)
|
||||
t.Logf("Validator: %s", actualFuncPtr.Name())
|
||||
for i, fb := range result.Fallbacks {
|
||||
t.Logf("Fallback %d: %s:%d", i+1, fb.Host, fb.Port)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestBuildConnConfig_MultipleHosts_WithoutTargetSessionAttrs tests that multiple hosts
|
||||
// create fallbacks even without target_session_attrs
|
||||
func TestBuildConnConfig_MultipleHosts_WithoutTargetSessionAttrs(t *testing.T) {
|
||||
cfg := config.PostgreSQLConfig{
|
||||
Host: "db1.local,db2.local,db3.local",
|
||||
Port: "5432",
|
||||
User: "testuser",
|
||||
Name: "testdb",
|
||||
}
|
||||
|
||||
result, err := BuildConnConfig(cfg)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
|
||||
// Verify primary host
|
||||
assert.Equal(t, "db1.local", result.Host)
|
||||
assert.Equal(t, uint16(5432), result.Port)
|
||||
|
||||
// Verify fallbacks are created
|
||||
require.Len(t, result.Fallbacks, 2, "Should have 2 fallback configs")
|
||||
assert.Equal(t, "db2.local", result.Fallbacks[0].Host)
|
||||
assert.Equal(t, uint16(5432), result.Fallbacks[0].Port)
|
||||
assert.Equal(t, "db3.local", result.Fallbacks[1].Host)
|
||||
assert.Equal(t, uint16(5432), result.Fallbacks[1].Port)
|
||||
|
||||
// Verify no ValidateConnect is set (no target_session_attrs)
|
||||
assert.Nil(t, result.ValidateConnect)
|
||||
}
|
||||
|
||||
// TestBuildConnConfig_CommaSeparatedPorts_EdgeCases tests edge cases and error scenarios for comma-separated ports
|
||||
func TestBuildConnConfig_CommaSeparatedPorts_EdgeCases(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
host string
|
||||
port string
|
||||
expectError bool
|
||||
errorContains string
|
||||
expectedHost string
|
||||
expectedPort uint16
|
||||
expectedFallbacks []*pgconn.FallbackConfig
|
||||
}{
|
||||
{
|
||||
name: "invalid port in comma-separated list",
|
||||
host: "db1.local,db2.local",
|
||||
port: "5432,abc",
|
||||
expectError: true,
|
||||
errorContains: "invalid port value",
|
||||
},
|
||||
{
|
||||
name: "port out of range (too high)",
|
||||
host: "db1.local,db2.local",
|
||||
port: "5432,99999",
|
||||
expectError: true,
|
||||
errorContains: "PostgreSQL port 99999 is out of valid range",
|
||||
},
|
||||
{
|
||||
name: "port out of range (zero)",
|
||||
host: "db1.local,db2.local",
|
||||
port: "5432,0",
|
||||
expectError: true,
|
||||
errorContains: "PostgreSQL port 0 must be positive",
|
||||
},
|
||||
{
|
||||
name: "empty port string",
|
||||
host: "db1.local",
|
||||
port: "",
|
||||
expectError: true,
|
||||
errorContains: "PostgreSQL port is required",
|
||||
},
|
||||
{
|
||||
name: "port with only whitespace",
|
||||
host: "db1.local",
|
||||
port: " ",
|
||||
expectError: true,
|
||||
errorContains: "invalid port value",
|
||||
},
|
||||
{
|
||||
name: "mismatched number of hosts and ports",
|
||||
host: "db1.local,db2.local",
|
||||
port: "5432",
|
||||
expectError: false,
|
||||
expectedHost: "db1.local",
|
||||
expectedPort: 5432,
|
||||
expectedFallbacks: []*pgconn.FallbackConfig{
|
||||
{Host: "db2.local", Port: 5432},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "extra ports than hosts",
|
||||
host: "db1.local",
|
||||
port: "5432,5433",
|
||||
expectError: false,
|
||||
expectedHost: "db1.local",
|
||||
expectedPort: 5432,
|
||||
expectedFallbacks: []*pgconn.FallbackConfig{},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
cfg := config.PostgreSQLConfig{
|
||||
Host: tt.host,
|
||||
Port: tt.port,
|
||||
User: "testuser",
|
||||
Name: "testdb",
|
||||
}
|
||||
|
||||
c, err := BuildConnConfig(cfg)
|
||||
if tt.expectError {
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), tt.errorContains)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, c)
|
||||
|
||||
assert.Equal(t, tt.expectedHost, c.Host)
|
||||
assert.Equal(t, tt.expectedPort, c.Port)
|
||||
require.Len(t, c.Fallbacks, len(tt.expectedFallbacks))
|
||||
for i, expectedFb := range tt.expectedFallbacks {
|
||||
actualFb := c.Fallbacks[i]
|
||||
assert.Equal(t, expectedFb.Host, actualFb.Host)
|
||||
assert.Equal(t, expectedFb.Port, actualFb.Port)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,25 +1,39 @@
|
||||
package utils
|
||||
|
||||
import "crypto/tls"
|
||||
import (
|
||||
"crypto/tls"
|
||||
"slices"
|
||||
)
|
||||
|
||||
func GetTLSConfig() *tls.Config {
|
||||
// Based on
|
||||
// https://ssl-config.mozilla.org/#server=go&version=1.25&config=intermediate&guideline=5.7
|
||||
tlsConfig := &tls.Config{
|
||||
MinVersion: tls.VersionTLS12,
|
||||
MaxVersion: tls.VersionTLS12,
|
||||
CurvePreferences: []tls.CurveID{
|
||||
tls.X25519,
|
||||
tls.CurveP256,
|
||||
tls.CurveP384,
|
||||
},
|
||||
PreferServerCipherSuites: true,
|
||||
CipherSuites: []uint16{},
|
||||
}
|
||||
|
||||
// Insecure SWEET32 attack ciphers, TLS config uses a fallback
|
||||
insecureCiphersIds := []uint16{
|
||||
excludedCiphers := []uint16{
|
||||
// ChaCha20 is not FIPS validated
|
||||
tls.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305,
|
||||
tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305,
|
||||
// Insecure SWEET32 attack ciphers, TLS config uses a fallback
|
||||
tls.TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA,
|
||||
tls.TLS_RSA_WITH_3DES_EDE_CBC_SHA,
|
||||
}
|
||||
|
||||
defaultSecureCiphers := []uint16{}
|
||||
for _, cs := range tls.CipherSuites() {
|
||||
for _, icsId := range insecureCiphersIds {
|
||||
if cs.ID != icsId {
|
||||
defaultSecureCiphers = append(defaultSecureCiphers, cs.ID)
|
||||
}
|
||||
if slices.Contains(excludedCiphers, cs.ID) {
|
||||
continue
|
||||
}
|
||||
defaultSecureCiphers = append(defaultSecureCiphers, cs.ID)
|
||||
}
|
||||
tlsConfig.CipherSuites = defaultSecureCiphers
|
||||
return tlsConfig
|
||||
|
||||
@@ -63,7 +63,11 @@ func (ws *WebServer) configureProxy() {
|
||||
rp.ErrorHandler = ws.proxyErrorHandler
|
||||
rp.ModifyResponse = ws.proxyModifyResponse
|
||||
ws.mainRouter.PathPrefix(config.Get().Web.Path).Path("/-/health/live/").HandlerFunc(sentry.SentryNoSample(func(rw http.ResponseWriter, r *http.Request) {
|
||||
rw.WriteHeader(200)
|
||||
if ws.upstreamHealthcheck() {
|
||||
rw.WriteHeader(200)
|
||||
} else {
|
||||
rw.WriteHeader(502)
|
||||
}
|
||||
}))
|
||||
ws.mainRouter.PathPrefix(config.Get().Web.Path).HandlerFunc(sentry.SentryNoSample(func(rw http.ResponseWriter, r *http.Request) {
|
||||
if !ws.g.IsRunning() {
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/go-http-utils/etag"
|
||||
@@ -17,11 +18,44 @@ import (
|
||||
staticWeb "goauthentik.io/web"
|
||||
)
|
||||
|
||||
// Theme variable placeholder that can be used in file paths
|
||||
// This allows for theme-specific files like logo-%(theme)s.png
|
||||
const themeVariable = "%(theme)s"
|
||||
|
||||
// Valid themes that can be substituted for %(theme)s
|
||||
var validThemes = []string{"light", "dark"}
|
||||
|
||||
type StorageClaims struct {
|
||||
jwt.RegisteredClaims
|
||||
Path string `json:"path,omitempty"`
|
||||
}
|
||||
|
||||
// pathMatchesWithTheme checks if the requested path matches the JWT path,
|
||||
// accounting for theme variable substitution.
|
||||
// If the JWT path contains %(theme)s, it will match the requested path
|
||||
// if substituting %(theme)s with any valid theme produces the requested path.
|
||||
func pathMatchesWithTheme(jwtPath, requestedPath string) bool {
|
||||
// Direct match (no theme variable)
|
||||
if jwtPath == requestedPath {
|
||||
return true
|
||||
}
|
||||
|
||||
// Check if JWT path contains theme variable
|
||||
if !strings.Contains(jwtPath, themeVariable) {
|
||||
return false
|
||||
}
|
||||
|
||||
// Try substituting each valid theme and check for a match
|
||||
for _, theme := range validThemes {
|
||||
substituted := strings.ReplaceAll(jwtPath, themeVariable, theme)
|
||||
if substituted == requestedPath {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func storageTokenIsValid(usage string, r *http.Request) bool {
|
||||
tokenString := r.URL.Query().Get("token")
|
||||
if tokenString == "" {
|
||||
@@ -51,11 +85,8 @@ func storageTokenIsValid(usage string, r *http.Request) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
if claims.Path != fmt.Sprintf("%s/%s", usage, r.URL.Path) {
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
requestedPath := fmt.Sprintf("%s/%s", usage, r.URL.Path)
|
||||
return pathMatchesWithTheme(claims.Path, requestedPath)
|
||||
}
|
||||
|
||||
func (ws *WebServer) configureStatic() {
|
||||
|
||||
95
internal/web/static_test.go
Normal file
95
internal/web/static_test.go
Normal file
@@ -0,0 +1,95 @@
|
||||
package web
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestPathMatchesWithTheme(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
jwtPath string
|
||||
requestedPath string
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "exact match without theme variable",
|
||||
jwtPath: "media/public/logo.png",
|
||||
requestedPath: "media/public/logo.png",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "no match without theme variable",
|
||||
jwtPath: "media/public/logo.png",
|
||||
requestedPath: "media/public/other.png",
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "theme variable matches light theme",
|
||||
jwtPath: "media/public/logo-%(theme)s.png",
|
||||
requestedPath: "media/public/logo-light.png",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "theme variable matches dark theme",
|
||||
jwtPath: "media/public/logo-%(theme)s.png",
|
||||
requestedPath: "media/public/logo-dark.png",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "theme variable does not match invalid theme",
|
||||
jwtPath: "media/public/logo-%(theme)s.png",
|
||||
requestedPath: "media/public/logo-blue.png",
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "theme variable in directory path",
|
||||
jwtPath: "media/%(theme)s/logo.png",
|
||||
requestedPath: "media/light/logo.png",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "multiple theme variables",
|
||||
jwtPath: "media/%(theme)s/logo-%(theme)s.png",
|
||||
requestedPath: "media/light/logo-light.png",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "multiple theme variables with dark",
|
||||
jwtPath: "media/%(theme)s/logo-%(theme)s.png",
|
||||
requestedPath: "media/dark/logo-dark.png",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "multiple theme variables mixed themes should not match",
|
||||
jwtPath: "media/%(theme)s/logo-%(theme)s.png",
|
||||
requestedPath: "media/light/logo-dark.png",
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "theme variable with nested path",
|
||||
jwtPath: "media/public/brand/logo-%(theme)s.svg",
|
||||
requestedPath: "media/public/brand/logo-dark.svg",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "empty paths",
|
||||
jwtPath: "",
|
||||
requestedPath: "",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "theme variable only",
|
||||
jwtPath: "%(theme)s",
|
||||
requestedPath: "light",
|
||||
want: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := pathMatchesWithTheme(tt.jwtPath, tt.requestedPath)
|
||||
if got != tt.want {
|
||||
t.Errorf("pathMatchesWithTheme(%q, %q) = %v, want %v",
|
||||
tt.jwtPath, tt.requestedPath, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -97,23 +97,27 @@ func NewWebServer() *WebServer {
|
||||
if sp := config.Get().Web.Path; sp != "/" {
|
||||
ws.mainRouter.Path("/").Handler(http.RedirectHandler(sp, http.StatusFound))
|
||||
}
|
||||
hcUrl := fmt.Sprintf("%s%s-/health/live/", ws.upstreamURL.String(), config.Get().Web.Path)
|
||||
ws.g = gounicorn.New(func() bool {
|
||||
req, err := http.NewRequest(http.MethodGet, hcUrl, nil)
|
||||
if err != nil {
|
||||
ws.log.WithError(err).Warning("failed to create request for healthcheck")
|
||||
return false
|
||||
}
|
||||
req.Header.Set("User-Agent", "goauthentik.io/router/healthcheck")
|
||||
res, err := ws.upstreamHttpClient().Do(req)
|
||||
if err == nil && res.StatusCode >= 200 && res.StatusCode < 300 {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
return ws.upstreamHealthcheck()
|
||||
})
|
||||
return ws
|
||||
}
|
||||
|
||||
func (ws *WebServer) upstreamHealthcheck() bool {
|
||||
hcUrl := fmt.Sprintf("%s%s-/health/live/", ws.upstreamURL.String(), config.Get().Web.Path)
|
||||
req, err := http.NewRequest(http.MethodGet, hcUrl, nil)
|
||||
if err != nil {
|
||||
ws.log.WithError(err).Warning("failed to create request for healthcheck")
|
||||
return false
|
||||
}
|
||||
req.Header.Set("User-Agent", "goauthentik.io/router/healthcheck")
|
||||
res, err := ws.upstreamHttpClient().Do(req)
|
||||
if err == nil && res.StatusCode >= 200 && res.StatusCode < 300 {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (ws *WebServer) prepareKeys() {
|
||||
tmp := os.TempDir()
|
||||
key := base64.StdEncoding.EncodeToString(securecookie.GenerateRandomKey(64))
|
||||
|
||||
@@ -18,7 +18,7 @@ Parameters:
|
||||
Description: authentik Docker image
|
||||
AuthentikVersion:
|
||||
Type: String
|
||||
Default: 2025.12.0-rc1
|
||||
Default: 2025.12.1
|
||||
Description: authentik Docker image tag
|
||||
AuthentikServerCPU:
|
||||
Type: Number
|
||||
|
||||
@@ -30,10 +30,11 @@ class BaseMigration:
|
||||
def __init__(self, cur: Any, con: Any):
|
||||
self.cur = cur
|
||||
self.con = con
|
||||
self.log = get_logger().bind()
|
||||
|
||||
def system_crit(self, command: str):
|
||||
"""Run system command"""
|
||||
LOGGER.debug("Running system_crit command", command=command)
|
||||
self.log.debug("Running system_crit command", command=command)
|
||||
retval = system(command) # nosec
|
||||
if retval != 0:
|
||||
raise CommandError("Migration error")
|
||||
@@ -73,6 +74,7 @@ def release_lock(conn: Connection, cursor: Cursor):
|
||||
|
||||
|
||||
def run_migrations():
|
||||
conn_opts = CONFIG.get_dict_from_b64_json("postgresql.conn_options", default={})
|
||||
conn = connect(
|
||||
dbname=CONFIG.get("postgresql.name"),
|
||||
user=CONFIG.get("postgresql.user"),
|
||||
@@ -83,6 +85,7 @@ def run_migrations():
|
||||
sslrootcert=CONFIG.get("postgresql.sslrootcert"),
|
||||
sslcert=CONFIG.get("postgresql.sslcert"),
|
||||
sslkey=CONFIG.get("postgresql.sslkey"),
|
||||
**conn_opts,
|
||||
)
|
||||
curr = conn.cursor()
|
||||
try:
|
||||
|
||||
41
lifecycle/system_migrations/to_2025_12_group_duplicate.py
Normal file
41
lifecycle/system_migrations/to_2025_12_group_duplicate.py
Normal file
@@ -0,0 +1,41 @@
|
||||
# flake8: noqa
|
||||
from lifecycle.migrate import BaseMigration
|
||||
|
||||
SQL_STATEMENT = """
|
||||
SELECT "authentik_core_group"."name" AS "name",
|
||||
Count("authentik_core_group"."name") AS "name__count"
|
||||
FROM "authentik_core_group" GROUP BY 1
|
||||
HAVING Count("authentik_core_group"."name") > 1
|
||||
ORDER BY 2 DESC,
|
||||
1 ASC
|
||||
"""
|
||||
|
||||
|
||||
class DuplicateNameError(RuntimeError):
|
||||
pass
|
||||
|
||||
|
||||
class Migration(BaseMigration):
|
||||
def needs_migration(self) -> bool:
|
||||
self.cur.execute(
|
||||
"select 1 from information_schema.tables where table_name = 'django_migrations';"
|
||||
)
|
||||
if not bool(self.cur.rowcount):
|
||||
# No django_migrations table, no data to check
|
||||
return False
|
||||
# migration that introduces the uniqueness
|
||||
self.cur.execute(
|
||||
"select 1 from django_migrations where app = 'authentik_core' and name = '0056_user_roles';"
|
||||
)
|
||||
return not bool(self.cur.rowcount)
|
||||
|
||||
def run(self):
|
||||
rows = self.cur.execute(SQL_STATEMENT).fetchall()
|
||||
if len(rows):
|
||||
for row in rows:
|
||||
self.log.error(
|
||||
"Group with duplicate name detected", group_name=row[0], count=row[1]
|
||||
)
|
||||
raise DuplicateNameError(
|
||||
f"authentik 2025.12 forbids duplicate group names. For a list of duplicate groups, see logging output above. Please rename the offending groups and re-run the migration. For more information, see: https://version-2025-12.goauthentik.io/releases/2025.12/#group-name-uniqueness"
|
||||
)
|
||||
@@ -18,6 +18,7 @@ def check_postgres():
|
||||
if attempt >= CHECK_THRESHOLD:
|
||||
sysexit(1)
|
||||
try:
|
||||
conn_opts = CONFIG.get_dict_from_b64_json("postgresql.conn_options", default={})
|
||||
conn = connect(
|
||||
dbname=CONFIG.refresh("postgresql.name"),
|
||||
user=CONFIG.refresh("postgresql.user"),
|
||||
@@ -28,6 +29,7 @@ def check_postgres():
|
||||
sslrootcert=CONFIG.get("postgresql.sslrootcert"),
|
||||
sslcert=CONFIG.get("postgresql.sslcert"),
|
||||
sslkey=CONFIG.get("postgresql.sslkey"),
|
||||
**conn_opts,
|
||||
)
|
||||
conn.cursor()
|
||||
break
|
||||
|
||||
4
package-lock.json
generated
4
package-lock.json
generated
@@ -1,12 +1,12 @@
|
||||
{
|
||||
"name": "@goauthentik/authentik",
|
||||
"version": "2025.12.0-rc1",
|
||||
"version": "2025.12.1",
|
||||
"lockfileVersion": 3,
|
||||
"requires": true,
|
||||
"packages": {
|
||||
"": {
|
||||
"name": "@goauthentik/authentik",
|
||||
"version": "2025.12.0-rc1",
|
||||
"version": "2025.12.1",
|
||||
"dependencies": {
|
||||
"@eslint/js": "^9.39.1",
|
||||
"@goauthentik/eslint-config": "./packages/eslint-config",
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "@goauthentik/authentik",
|
||||
"version": "2025.12.0-rc1",
|
||||
"version": "2025.12.1",
|
||||
"private": true,
|
||||
"type": "module",
|
||||
"dependencies": {
|
||||
|
||||
@@ -1,27 +1,18 @@
|
||||
"""Convenient shortcuts to manage or check object permissions."""
|
||||
|
||||
from functools import lru_cache, partial
|
||||
from functools import lru_cache
|
||||
from typing import Any, TypeVar
|
||||
|
||||
from django.contrib.auth.models import Permission
|
||||
from django.contrib.contenttypes.models import ContentType
|
||||
from django.db import connection
|
||||
from django.db.models import (
|
||||
AutoField,
|
||||
BigIntegerField,
|
||||
CharField,
|
||||
Count,
|
||||
ForeignKey,
|
||||
IntegerField,
|
||||
Model,
|
||||
PositiveIntegerField,
|
||||
PositiveSmallIntegerField,
|
||||
QuerySet,
|
||||
SmallIntegerField,
|
||||
UUIDField,
|
||||
)
|
||||
from django.db.models.expressions import Value
|
||||
from django.db.models.functions import Cast, Replace
|
||||
from django.db.models.expressions import RawSQL
|
||||
|
||||
from guardian.core import ObjectPermissionChecker
|
||||
from guardian.ctypes import get_content_type
|
||||
@@ -295,42 +286,33 @@ def get_objects_for_user( # noqa: PLR0912 PLR0915
|
||||
.filter(object_pk_count__gte=len(codenames))
|
||||
)
|
||||
|
||||
# object_pk is a varchar, while the queryset's pk is probably an integer or a uuid, so we cast
|
||||
handle_pk_field = _handle_pk_field(queryset)
|
||||
if handle_pk_field is not None:
|
||||
perms_queryset = perms_queryset.annotate(obj_pk=handle_pk_field(expression=pk_field))
|
||||
pk_field = "obj_pk"
|
||||
|
||||
return queryset.filter(pk__in=perms_queryset.values_list(pk_field, flat=True))
|
||||
|
||||
|
||||
def _handle_pk_field(queryset):
|
||||
# pk is either UUID or an integer type, while object_pk is a varchar
|
||||
pk = queryset.model._meta.pk
|
||||
|
||||
if isinstance(pk, ForeignKey):
|
||||
return _handle_pk_field(pk.target_field)
|
||||
def _cast_type(pk):
|
||||
if isinstance(pk, ForeignKey):
|
||||
return _cast_type(pk.target_field)
|
||||
if isinstance(pk, UUIDField):
|
||||
return "uuid"
|
||||
return "bigint"
|
||||
|
||||
if isinstance( # noqa: UP038
|
||||
pk,
|
||||
(
|
||||
IntegerField,
|
||||
AutoField,
|
||||
BigIntegerField,
|
||||
PositiveIntegerField,
|
||||
PositiveSmallIntegerField,
|
||||
SmallIntegerField,
|
||||
),
|
||||
):
|
||||
return partial(Cast, output_field=BigIntegerField())
|
||||
cast_type = _cast_type(pk)
|
||||
|
||||
if isinstance(pk, UUIDField):
|
||||
if connection.features.has_native_uuid_field:
|
||||
return partial(Cast, output_field=UUIDField())
|
||||
return partial(
|
||||
Replace,
|
||||
text=Value("-"),
|
||||
replacement=Value(""),
|
||||
output_field=CharField(),
|
||||
)
|
||||
|
||||
return None
|
||||
perms_queryset = perms_queryset.values_list(pk_field, flat=True)
|
||||
# The raw subquery is done to ensure that casting only takes place after the WHERE clause of
|
||||
# `perms_queryset` is ran. Otherwise, the query planner may decide to cast every `object_pk`,
|
||||
# which breaks (for example) if it tries to cast an integer to a UUID. In such a case, the WHERE
|
||||
# of `perms_queryset` will remove any integer.
|
||||
# However, the subquery might get optimized out by the query planner, which would cause the same
|
||||
# cast issue as before. To prevent the subquery from being collapsed in the query below, we add
|
||||
# OFFSET 0.
|
||||
perms_subquery_sql, perms_subquery_params = perms_queryset.query.sql_with_params()
|
||||
subquery = RawSQL(
|
||||
f"""
|
||||
SELECT ("permission_subquery"."{pk_field}")::{cast_type} as "object_pk"
|
||||
FROM ({perms_subquery_sql}) "permission_subquery"
|
||||
OFFSET 0
|
||||
""", # nosec
|
||||
perms_subquery_params,
|
||||
)
|
||||
return queryset.filter(pk__in=subquery)
|
||||
|
||||
@@ -449,6 +449,7 @@ class _PostgresConsumer(Consumer):
|
||||
pass
|
||||
self.to_unlock.add(str(message.message_id))
|
||||
task = message.options.pop("task", None)
|
||||
m = b"" if state == TaskState.DONE else message.encode()
|
||||
self.query_set.filter(
|
||||
message_id=message.message_id,
|
||||
queue_name=message.queue_name,
|
||||
@@ -456,7 +457,7 @@ class _PostgresConsumer(Consumer):
|
||||
state=TaskState.QUEUED,
|
||||
).update(
|
||||
state=state,
|
||||
message=message.encode(),
|
||||
message=m,
|
||||
mtime=timezone.now(),
|
||||
eta=None,
|
||||
)
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
[project]
|
||||
name = "authentik"
|
||||
version = "2025.12.0-rc1"
|
||||
version = "2025.12.1"
|
||||
description = ""
|
||||
authors = [{ name = "authentik Team", email = "hello@goauthentik.io" }]
|
||||
requires-python = "==3.13.*"
|
||||
@@ -12,7 +12,7 @@ dependencies = [
|
||||
"dacite==1.9.2",
|
||||
"deepmerge==2.0",
|
||||
"defusedxml==0.7.1",
|
||||
"django==5.2.9",
|
||||
"django==5.2.10",
|
||||
"django-channels-postgres",
|
||||
"django-countries==7.6.1",
|
||||
"django-cte==2.0.0",
|
||||
|
||||
91
schema.yml
91
schema.yml
@@ -1,7 +1,7 @@
|
||||
openapi: 3.0.3
|
||||
info:
|
||||
title: authentik
|
||||
version: 2025.12.0-rc1
|
||||
version: 2025.12.1
|
||||
description: Making authentication simple.
|
||||
contact:
|
||||
email: hello@goauthentik.io
|
||||
@@ -3385,6 +3385,11 @@ paths:
|
||||
schema:
|
||||
type: boolean
|
||||
default: false
|
||||
- in: query
|
||||
name: include_inherited_roles
|
||||
schema:
|
||||
type: boolean
|
||||
default: false
|
||||
- in: query
|
||||
name: include_parents
|
||||
schema:
|
||||
@@ -3478,6 +3483,11 @@ paths:
|
||||
schema:
|
||||
type: boolean
|
||||
default: false
|
||||
- in: query
|
||||
name: include_inherited_roles
|
||||
schema:
|
||||
type: boolean
|
||||
default: false
|
||||
- in: query
|
||||
name: include_parents
|
||||
schema:
|
||||
@@ -20052,6 +20062,16 @@ paths:
|
||||
operationId: rbac_roles_list
|
||||
description: Role viewset
|
||||
parameters:
|
||||
- in: query
|
||||
name: ak_groups
|
||||
schema:
|
||||
type: string
|
||||
format: uuid
|
||||
- in: query
|
||||
name: inherited
|
||||
schema:
|
||||
type: boolean
|
||||
description: Include inherited roles (requires users or ak_groups filter)
|
||||
- in: query
|
||||
name: managed
|
||||
schema:
|
||||
@@ -20072,11 +20092,7 @@ paths:
|
||||
- in: query
|
||||
name: users
|
||||
schema:
|
||||
type: array
|
||||
items:
|
||||
type: integer
|
||||
explode: true
|
||||
style: form
|
||||
type: integer
|
||||
tags:
|
||||
- rbac
|
||||
security:
|
||||
@@ -32762,12 +32778,18 @@ components:
|
||||
allOf:
|
||||
- $ref: '#/components/schemas/Config'
|
||||
readOnly: true
|
||||
license_status:
|
||||
allOf:
|
||||
- $ref: '#/components/schemas/LicenseStatusEnum'
|
||||
readOnly: true
|
||||
nullable: true
|
||||
required:
|
||||
- auth_terminate_session_on_expiry
|
||||
- authorization_flow
|
||||
- device_id
|
||||
- jwks_auth
|
||||
- jwks_challenge
|
||||
- license_status
|
||||
- nss_gid_offset
|
||||
- nss_uid_offset
|
||||
- refresh_interval
|
||||
@@ -33586,7 +33608,8 @@ components:
|
||||
minLength: 1
|
||||
default: ak-stage-authenticator-email
|
||||
code:
|
||||
type: integer
|
||||
type: string
|
||||
minLength: 1
|
||||
email:
|
||||
type: string
|
||||
minLength: 1
|
||||
@@ -33833,7 +33856,8 @@ components:
|
||||
minLength: 1
|
||||
default: ak-stage-authenticator-sms
|
||||
code:
|
||||
type: integer
|
||||
type: string
|
||||
minLength: 1
|
||||
phone_number:
|
||||
type: string
|
||||
minLength: 1
|
||||
@@ -34038,7 +34062,7 @@ components:
|
||||
minimum: 0
|
||||
token_length:
|
||||
type: integer
|
||||
maximum: 2147483647
|
||||
maximum: 100
|
||||
minimum: 0
|
||||
required:
|
||||
- component
|
||||
@@ -34069,7 +34093,7 @@ components:
|
||||
minimum: 0
|
||||
token_length:
|
||||
type: integer
|
||||
maximum: 2147483647
|
||||
maximum: 100
|
||||
minimum: 0
|
||||
required:
|
||||
- name
|
||||
@@ -34107,7 +34131,8 @@ components:
|
||||
minLength: 1
|
||||
default: ak-stage-authenticator-totp
|
||||
code:
|
||||
type: integer
|
||||
type: string
|
||||
minLength: 1
|
||||
required:
|
||||
- code
|
||||
AuthenticatorTOTPStage:
|
||||
@@ -34794,6 +34819,7 @@ components:
|
||||
CapabilitiesEnum:
|
||||
enum:
|
||||
- can_save_media
|
||||
- can_save_reports
|
||||
- can_geo_ip
|
||||
- can_asn
|
||||
- can_impersonate
|
||||
@@ -35388,10 +35414,14 @@ components:
|
||||
model:
|
||||
type: string
|
||||
readOnly: true
|
||||
verbose_name_plural:
|
||||
type: string
|
||||
readOnly: true
|
||||
required:
|
||||
- app_label
|
||||
- id
|
||||
- model
|
||||
- verbose_name_plural
|
||||
ContextualFlowInfo:
|
||||
type: object
|
||||
description: Contextual flow information for a challenge
|
||||
@@ -35739,7 +35769,7 @@ components:
|
||||
readOnly: true
|
||||
requested_by:
|
||||
allOf:
|
||||
- $ref: '#/components/schemas/RequestedBy'
|
||||
- $ref: '#/components/schemas/PartialUser'
|
||||
readOnly: true
|
||||
requested_on:
|
||||
type: string
|
||||
@@ -38766,6 +38796,12 @@ components:
|
||||
items:
|
||||
$ref: '#/components/schemas/Role'
|
||||
readOnly: true
|
||||
inherited_roles_obj:
|
||||
type: array
|
||||
items:
|
||||
$ref: '#/components/schemas/Role'
|
||||
readOnly: true
|
||||
nullable: true
|
||||
children:
|
||||
type: array
|
||||
items:
|
||||
@@ -38781,6 +38817,7 @@ components:
|
||||
required:
|
||||
- children
|
||||
- children_obj
|
||||
- inherited_roles_obj
|
||||
- name
|
||||
- num_pk
|
||||
- parents_obj
|
||||
@@ -40876,6 +40913,16 @@ components:
|
||||
minLength: 1
|
||||
required:
|
||||
- key
|
||||
LicenseStatusEnum:
|
||||
enum:
|
||||
- unlicensed
|
||||
- valid
|
||||
- expired
|
||||
- expiry_soon
|
||||
- limit_exceeded_admin
|
||||
- limit_exceeded_user
|
||||
- read_only
|
||||
type: string
|
||||
LicenseSummary:
|
||||
type: object
|
||||
description: Serializer for license status
|
||||
@@ -45767,7 +45814,7 @@ components:
|
||||
minimum: 0
|
||||
token_length:
|
||||
type: integer
|
||||
maximum: 2147483647
|
||||
maximum: 100
|
||||
minimum: 0
|
||||
PatchedAuthenticatorTOTPStageRequest:
|
||||
type: object
|
||||
@@ -51253,22 +51300,6 @@ components:
|
||||
minimum: -2147483648
|
||||
required:
|
||||
- name
|
||||
RequestedBy:
|
||||
type: object
|
||||
properties:
|
||||
pk:
|
||||
type: integer
|
||||
readOnly: true
|
||||
title: ID
|
||||
username:
|
||||
type: string
|
||||
description: Required. 150 characters or fewer. Letters, digits and @/./+/-/_
|
||||
only.
|
||||
pattern: ^[\w.@+-]+$
|
||||
maxLength: 150
|
||||
required:
|
||||
- pk
|
||||
- username
|
||||
ResidentKeyRequirementEnum:
|
||||
enum:
|
||||
- discouraged
|
||||
@@ -53775,7 +53806,7 @@ components:
|
||||
properties:
|
||||
token:
|
||||
type: string
|
||||
maxLength: 16
|
||||
maxLength: 100
|
||||
required:
|
||||
- token
|
||||
SubModeEnum:
|
||||
|
||||
6
scripts/generate_docker_compose.py
Normal file → Executable file
6
scripts/generate_docker_compose.py
Normal file → Executable file
@@ -1,3 +1,5 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
from yaml import safe_dump
|
||||
|
||||
from authentik import authentik_version
|
||||
@@ -42,7 +44,7 @@ base = {
|
||||
"image": authentik_image,
|
||||
"ports": ["${COMPOSE_PORT_HTTP:-9000}:9000", "${COMPOSE_PORT_HTTPS:-9443}:9443"],
|
||||
"restart": "unless-stopped",
|
||||
"volumes": ["./media:/data/media", "./custom-templates:/templates"],
|
||||
"volumes": ["./data:/data", "./custom-templates:/templates"],
|
||||
},
|
||||
"worker": {
|
||||
"command": "worker",
|
||||
@@ -62,7 +64,7 @@ base = {
|
||||
"user": "root",
|
||||
"volumes": [
|
||||
"/var/run/docker.sock:/var/run/docker.sock",
|
||||
"./media:/data/media",
|
||||
"./data:/data",
|
||||
"./certs:/certs",
|
||||
"./custom-templates:/templates",
|
||||
],
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user