Compare commits

..

1 Commits

Author SHA1 Message Date
Teffen Ellis
fbd9460720 web: Tidy timestamps. 2025-09-29 21:58:22 +02:00
461 changed files with 4232 additions and 12746 deletions

View File

@@ -3,7 +3,6 @@ services:
image: docker.io/library/postgres:${PSQL_TAG:-16}
volumes:
- db-data:/var/lib/postgresql/data
command: "-c log_statement=all"
environment:
POSTGRES_USER: authentik
POSTGRES_PASSWORD: "EK-5jnKfjrGRm<77"

View File

@@ -1,28 +0,0 @@
name: "Process test results"
description: Convert test results to JUnit, add them to GitHub Actions and codecov
inputs:
flags:
description: Codecov flags
runs:
using: "composite"
steps:
- uses: codecov/codecov-action@v5
with:
flags: ${{ inputs.flags }}
use_oidc: true
- uses: codecov/test-results-action@v1
with:
flags: ${{ inputs.flags }}
file: unittest.xml
use_oidc: true
- name: PostgreSQL Logs
shell: bash
run: |
if [[ $ACTIONS_RUNNER_DEBUG == 'true' || $ACTIONS_STEP_DEBUG == 'true' ]]; then
docker stop setup-postgresql-1
echo "::group::PostgreSQL Logs"
docker logs setup-postgresql-1
echo "::endgroup::"
fi

View File

@@ -74,7 +74,7 @@ jobs:
mkdir -p ./gen-go-api
- name: Setup node
if: ${{ !inputs.release }}
uses: actions/setup-node@v5
uses: actions/setup-node@v4
with:
node-version-file: web/package.json
cache: "npm"

View File

@@ -113,10 +113,6 @@ jobs:
CI_TOTAL_RUNS: "5"
run: |
uv run make ci-test
- uses: ./.github/actions/test-results
if: ${{ always() }}
with:
flags: unit-migrate
test-unittest:
name: test-unittest - PostgreSQL ${{ matrix.psql }} - Run ${{ matrix.run_id }}/5
runs-on: ubuntu-latest
@@ -143,10 +139,17 @@ jobs:
CI_TOTAL_RUNS: "5"
run: |
uv run make ci-test
- uses: ./.github/actions/test-results
if: ${{ always() }}
- if: ${{ always() }}
uses: codecov/codecov-action@v5
with:
flags: unit
use_oidc: true
- if: ${{ !cancelled() }}
uses: codecov/test-results-action@v1
with:
flags: unit
file: unittest.xml
use_oidc: true
test-integration:
runs-on: ubuntu-latest
timeout-minutes: 30
@@ -160,10 +163,17 @@ jobs:
run: |
uv run coverage run manage.py test tests/integration
uv run coverage xml
- uses: ./.github/actions/test-results
if: ${{ always() }}
- if: ${{ always() }}
uses: codecov/codecov-action@v5
with:
flags: integration
use_oidc: true
- if: ${{ !cancelled() }}
uses: codecov/test-results-action@v1
with:
flags: integration
file: unittest.xml
use_oidc: true
test-e2e:
name: test-e2e (${{ matrix.job.name }})
runs-on: ubuntu-latest
@@ -212,10 +222,17 @@ jobs:
run: |
uv run coverage run manage.py test ${{ matrix.job.glob }}
uv run coverage xml
- uses: ./.github/actions/test-results
if: ${{ always() }}
- if: ${{ always() }}
uses: codecov/codecov-action@v5
with:
flags: e2e
use_oidc: true
- if: ${{ !cancelled() }}
uses: codecov/test-results-action@v1
with:
flags: e2e
file: unittest.xml
use_oidc: true
ci-core-mark:
if: always()
needs:

View File

@@ -20,14 +20,14 @@ jobs:
runs-on: ubuntu-latest
steps:
- name: Find Comment
uses: peter-evans/find-comment@v4
uses: peter-evans/find-comment@v3
id: fc
with:
issue-number: ${{ github.event.pull_request.number }}
comment-author: "github-actions[bot]"
body-includes: authentik translations instructions
- name: Create or update comment
uses: peter-evans/create-or-update-comment@v5
uses: peter-evans/create-or-update-comment@v4
with:
comment-id: ${{ steps.fc.outputs.comment-id }}
issue-number: ${{ github.event.pull_request.number }}

View File

@@ -24,7 +24,6 @@ Makefile @goauthentik/infrastructure
.editorconfig @goauthentik/infrastructure
CODEOWNERS @goauthentik/infrastructure
# Backend packages
packages/django-postgres-cache @goauthentik/backend
packages/django-dramatiq-postgres @goauthentik/backend
# Web packages
packages/docusaurus-config @goauthentik/frontend

View File

@@ -26,7 +26,7 @@ RUN npm run build && \
npm run build:sfe
# Stage 2: Build go proxy
FROM --platform=${BUILDPLATFORM} docker.io/library/golang:1.25.1-bookworm AS go-builder
FROM --platform=${BUILDPLATFORM} docker.io/library/golang:1.25-bookworm AS go-builder
ARG TARGETOS
ARG TARGETARCH
@@ -119,11 +119,7 @@ RUN --mount=type=cache,id=apt-$TARGETARCH$TARGETVARIANT,sharing=locked,target=/v
libltdl-dev && \
curl https://sh.rustup.rs -sSf | sh -s -- -y
ENV UV_NO_BINARY_PACKAGE="cryptography lxml python-kadmin-rs xmlsec" \
# https://github.com/rust-lang/rustup/issues/2949
# Fixes issues where the rust version in the build cache is older than latest
# and rustup tries to update it, which fails
RUSTUP_PERMIT_COPY_RENAME="true"
ENV UV_NO_BINARY_PACKAGE="cryptography lxml python-kadmin-rs xmlsec"
RUN --mount=type=bind,target=pyproject.toml,src=pyproject.toml \
--mount=type=bind,target=uv.lock,src=uv.lock \

View File

@@ -1,9 +0,0 @@
from django.dispatch import receiver
from authentik.admin.tasks import _set_prom_info
from authentik.root.signals import post_startup
@receiver(post_startup)
def post_startup_admin_metrics(sender, **_):
_set_prom_info()

View File

@@ -2,6 +2,7 @@
from django.core.cache import cache
from django.utils.translation import gettext_lazy as _
from django_dramatiq_postgres.middleware import CurrentTask
from dramatiq import actor
from packaging.version import parse
from requests import RequestException
@@ -12,7 +13,7 @@ from authentik.admin.apps import PROM_INFO
from authentik.events.models import Event, EventAction
from authentik.lib.config import CONFIG
from authentik.lib.utils.http import get_http_session
from authentik.tasks.middleware import CurrentTask
from authentik.tasks.models import Task
LOGGER = get_logger()
VERSION_NULL = "0.0.0"
@@ -34,7 +35,7 @@ def _set_prom_info():
@actor(description=_("Update latest version info."))
def update_latest_version():
self = CurrentTask.get_task()
self: Task = CurrentTask.get_task()
if CONFIG.get_bool("disable_update_check"):
cache.set(VERSION_CACHE_KEY, VERSION_NULL, VERSION_CACHE_TIMEOUT)
self.info("Version check disabled.")
@@ -71,3 +72,6 @@ def update_latest_version():
except (RequestException, IndexError) as exc:
cache.set(VERSION_CACHE_KEY, VERSION_NULL, VERSION_CACHE_TIMEOUT)
raise exc
_set_prom_info()

View File

@@ -12,7 +12,7 @@ from django.db import DatabaseError, InternalError, ProgrammingError
from django.utils.text import slugify
from django.utils.timezone import now
from django.utils.translation import gettext_lazy as _
from django_dramatiq_postgres.middleware import CurrentTaskNotFound
from django_dramatiq_postgres.middleware import CurrentTask, CurrentTaskNotFound
from dramatiq.actor import actor
from dramatiq.middleware import Middleware
from structlog.stdlib import get_logger
@@ -39,7 +39,6 @@ from authentik.events.logs import capture_logs
from authentik.events.utils import sanitize_dict
from authentik.lib.config import CONFIG
from authentik.tasks.apps import PRIORITY_HIGH
from authentik.tasks.middleware import CurrentTask
from authentik.tasks.models import Task
from authentik.tasks.schedules.models import Schedule
from authentik.tenants.models import Tenant
@@ -156,7 +155,7 @@ def blueprints_find() -> list[BlueprintFile]:
throws=(DatabaseError, ProgrammingError, InternalError),
)
def blueprints_discovery(path: str | None = None):
self = CurrentTask.get_task()
self: Task = CurrentTask.get_task()
count = 0
for blueprint in blueprints_find():
if path and blueprint.path != path:
@@ -196,7 +195,7 @@ def check_blueprint_v1_file(blueprint: BlueprintFile):
@actor(description=_("Apply single blueprint."))
def apply_blueprint(instance_pk: UUID):
try:
self = CurrentTask.get_task()
self: Task = CurrentTask.get_task()
except CurrentTaskNotFound:
self = Task()
self.set_uid(str(instance_pk))

View File

@@ -334,21 +334,6 @@ class UserPasswordSetSerializer(PassiveSerializer):
password = CharField(required=True)
class UserServiceAccountSerializer(PassiveSerializer):
"""Payload to create a service account"""
name = CharField(
required=True,
validators=[UniqueValidator(queryset=User.objects.all().order_by("username"))],
)
create_group = BooleanField(default=False)
expiring = BooleanField(default=True)
expires = DateTimeField(
required=False,
help_text="If not provided, valid for 360 days",
)
class UsersFilter(FilterSet):
"""Filter for users"""
@@ -509,7 +494,18 @@ class UserViewSet(UsedByMixin, ModelViewSet):
@permission_required(None, ["authentik_core.add_user", "authentik_core.add_token"])
@extend_schema(
request=UserServiceAccountSerializer,
request=inline_serializer(
"UserServiceAccountSerializer",
{
"name": CharField(required=True),
"create_group": BooleanField(default=False),
"expiring": BooleanField(default=True),
"expires": DateTimeField(
required=False,
help_text="If not provided, valid for 360 days",
),
},
),
responses={
200: inline_serializer(
"UserServiceAccountResponse",
@@ -531,12 +527,11 @@ class UserViewSet(UsedByMixin, ModelViewSet):
)
def service_account(self, request: Request) -> Response:
"""Create a new user account that is marked as a service account"""
data = UserServiceAccountSerializer(data=request.data)
data.is_valid(raise_exception=True)
expires = data.validated_data.get("expires", now() + timedelta(days=360))
username = request.data.get("name")
create_group = request.data.get("create_group", False)
expiring = request.data.get("expiring", True)
expires = request.data.get("expires", now() + timedelta(days=360))
username = data.validated_data["name"]
expiring = data.validated_data["expiring"]
with atomic():
try:
user: User = User.objects.create(
@@ -554,10 +549,10 @@ class UserViewSet(UsedByMixin, ModelViewSet):
"user_uid": user.uid,
"user_pk": user.pk,
}
if data.validated_data["create_group"] and self.request.user.has_perm(
"authentik_core.add_group"
):
group = Group.objects.create(name=username)
if create_group and self.request.user.has_perm("authentik_core.add_group"):
group = Group.objects.create(
name=username,
)
group.users.add(user)
response["group_pk"] = str(group.pk)
token = Token.objects.create(
@@ -570,29 +565,7 @@ class UserViewSet(UsedByMixin, ModelViewSet):
response["token"] = token.key
return Response(response)
except IntegrityError as exc:
error_msg = str(exc).lower()
if "unique" in error_msg:
return Response(
data={
"non_field_errors": [
_("A user/group with these details already exists")
]
},
status=400,
)
else:
LOGGER.warning("Service account creation failed", exc=exc)
return Response(
data={"non_field_errors": [_("Unable to create user")]},
status=400,
)
except (ValueError, TypeError) as exc:
LOGGER.error("Unexpected error during service account creation", exc=exc)
return Response(
data={"non_field_errors": [_("Unknown error occurred")]},
status=500,
)
return Response(data={"non_field_errors": [str(exc)]}, status=400)
@extend_schema(responses={200: SessionUserSerializer(many=False)})
@action(

View File

@@ -13,6 +13,14 @@ import authentik.core.models
import authentik.lib.models
def migrate_sessions(apps: Apps, schema_editor: BaseDatabaseSchemaEditor):
from django.contrib.sessions.backends.cache import KEY_PREFIX
from django.core.cache import cache
session_keys = cache.keys(KEY_PREFIX + "*")
cache.delete_many(session_keys)
def fix_duplicates(apps: Apps, schema_editor: BaseDatabaseSchemaEditor):
db_alias = schema_editor.connection.alias
Token = apps.get_model("authentik_core", "token")
@@ -143,6 +151,9 @@ class Migration(migrations.Migration):
"abstract": False,
},
),
migrations.RunPython(
code=migrate_sessions,
),
migrations.AlterField(
model_name="application",
name="meta_launch_url",

View File

@@ -7,10 +7,15 @@ from django.contrib.auth import BACKEND_SESSION_KEY, HASH_SESSION_KEY, SESSION_K
from django.db import migrations, models
import django.db.models.deletion
from django.conf import settings
from django.contrib.sessions.backends.cache import KEY_PREFIX
from django.utils.timezone import now, timedelta
from authentik.lib.migrations import progress_bar
from authentik.root.middleware import ClientIPMiddleware
SESSION_CACHE_ALIAS = "default"
class PickleSerializer:
"""
Simple wrapper around pickle to be used in signing.dumps()/loads() and
@@ -78,6 +83,27 @@ def _migrate_session(
)
def migrate_redis_sessions(apps, schema_editor):
from django.core.cache import caches
db_alias = schema_editor.connection.alias
cache = caches[SESSION_CACHE_ALIAS]
# Not a redis cache, skipping
if not hasattr(cache, "keys"):
return
print("\nMigrating Redis sessions to database, this might take a couple of minutes...")
for key, session_data in progress_bar(cache.get_many(cache.keys(f"{KEY_PREFIX}*")).items()):
_migrate_session(
apps=apps,
db_alias=db_alias,
session_key=key.removeprefix(KEY_PREFIX),
session_data=session_data,
expires=now() + timedelta(seconds=cache.ttl(key)),
)
def migrate_database_sessions(apps, schema_editor):
DjangoSession = apps.get_model("sessions", "Session")
db_alias = schema_editor.connection.alias
@@ -205,6 +231,10 @@ class Migration(migrations.Migration):
"verbose_name_plural": "Authenticated Sessions",
},
),
migrations.RunPython(
code=migrate_redis_sessions,
reverse_code=migrations.RunPython.noop,
),
migrations.RunPython(
code=migrate_database_sessions,
reverse_code=migrations.RunPython.noop,

View File

@@ -406,8 +406,6 @@ class User(SerializerModel, GuardianUserMixin, AttributesMixin, AbstractUser):
def locale(self, request: HttpRequest | None = None) -> str:
"""Get the locale the user has configured"""
if request and hasattr(request, "LANGUAGE_CODE"):
return request.LANGUAGE_CODE
try:
return self.attributes.get("settings", {}).get("locale", "")

View File

@@ -4,7 +4,7 @@ from datetime import datetime, timedelta
from django.utils.timezone import now
from django.utils.translation import gettext_lazy as _
from django_postgres_cache.tasks import clear_expired_cache
from django_dramatiq_postgres.middleware import CurrentTask
from dramatiq.actor import actor
from structlog.stdlib import get_logger
@@ -15,14 +15,14 @@ from authentik.core.models import (
User,
)
from authentik.lib.utils.db import chunked_queryset
from authentik.tasks.middleware import CurrentTask
from authentik.tasks.models import Task
LOGGER = get_logger()
@actor(description=_("Remove expired objects."))
def clean_expired_models():
self = CurrentTask.get_task()
self: Task = CurrentTask.get_task()
for cls in ExpiringModel.__subclasses__():
cls: ExpiringModel
objects = (
@@ -33,12 +33,11 @@ def clean_expired_models():
obj.expire_action()
LOGGER.debug("Expired models", model=cls, amount=amount)
self.info(f"Expired {amount} {cls._meta.verbose_name_plural}")
clear_expired_cache()
@actor(description=_("Remove temporary users created by SAML Sources."))
def clean_temporary_users():
self = CurrentTask.get_task()
self: Task = CurrentTask.get_task()
_now = datetime.now()
deleted_users = 0
for user in User.objects.filter(**{f"attributes__{USER_ATTRIBUTE_GENERATED}": True}):

View File

@@ -469,274 +469,3 @@ class TestUsersAPI(APITestCase):
body = loads(response.content)
self.assertEqual(len(body["results"]), 2)
self.assertEqual(body["results"][0]["pk"], user.pk)
def test_service_account_validation_empty_username(self):
"""Test service account creation with empty/blank username validation"""
self.client.force_login(self.admin)
# Test with empty string
response = self.client.post(
reverse("authentik_api:user-service-account"),
data={
"name": "",
"create_group": True,
},
)
self.assertEqual(response.status_code, 400)
self.assertJSONEqual(
response.content,
{"name": ["This field may not be blank."]},
)
# Test with only whitespace
response = self.client.post(
reverse("authentik_api:user-service-account"),
data={
"name": " ",
"create_group": True,
},
)
self.assertEqual(response.status_code, 400)
self.assertJSONEqual(
response.content,
{"name": ["This field may not be blank."]},
)
# Test with tab and newline characters
response = self.client.post(
reverse("authentik_api:user-service-account"),
data={
"name": "\t\n",
"create_group": True,
},
)
self.assertEqual(response.status_code, 400)
self.assertJSONEqual(
response.content,
{"name": ["This field may not be blank."]},
)
def test_service_account_validation_valid_username(self):
"""Test service account creation with valid username"""
self.client.force_login(self.admin)
# Test with valid username
response = self.client.post(
reverse("authentik_api:user-service-account"),
data={
"name": "valid-service-account",
"create_group": True,
},
)
self.assertEqual(response.status_code, 200)
# Verify response structure
body = loads(response.content)
self.assertIn("username", body)
self.assertIn("user_uid", body)
self.assertIn("user_pk", body)
self.assertIn("group_pk", body) # Should exist since create_group=True
self.assertIn("token", body)
# Verify field types
self.assertEqual(body["username"], "valid-service-account")
self.assertIsInstance(body["user_pk"], int)
self.assertIsInstance(body["user_uid"], str)
self.assertIsInstance(body["token"], str)
self.assertIsInstance(body["group_pk"], str)
def test_service_account_validation_without_group(self):
"""Test service account creation without creating a group"""
self.client.force_login(self.admin)
response = self.client.post(
reverse("authentik_api:user-service-account"),
data={
"name": "no-group-service-account",
"create_group": False,
},
)
self.assertEqual(response.status_code, 200)
body = loads(response.content)
self.assertIn("username", body)
self.assertIn("user_uid", body)
self.assertIn("user_pk", body)
self.assertIn("token", body)
# Should NOT have group_pk when create_group=False
self.assertNotIn("group_pk", body)
def test_service_account_validation_duplicate_username(self):
"""Test service account creation with duplicate username"""
self.client.force_login(self.admin)
# Create first service account
response = self.client.post(
reverse("authentik_api:user-service-account"),
data={
"name": "duplicate-test",
"create_group": True,
},
)
self.assertEqual(response.status_code, 200)
# Attempt to create second with same username
response = self.client.post(
reverse("authentik_api:user-service-account"),
data={
"name": "duplicate-test",
"create_group": True,
},
)
self.assertEqual(response.status_code, 400)
self.assertJSONEqual(
response.content,
{"name": ["This field must be unique."]},
)
def test_service_account_validation_invalid_create_group(self):
"""Test service account creation with invalid create_group field"""
self.client.force_login(self.admin)
# Test with string instead of boolean
response = self.client.post(
reverse("authentik_api:user-service-account"),
data={
"name": "test-sa",
"create_group": "invalid",
},
)
self.assertEqual(response.status_code, 400)
self.assertJSONEqual(
response.content,
{"create_group": ["Must be a valid boolean."]},
)
# Test with number instead of boolean
response = self.client.post(
reverse("authentik_api:user-service-account"),
data={
"name": "test-sa",
"create_group": 123,
},
)
self.assertEqual(response.status_code, 400)
self.assertJSONEqual(
response.content,
{"create_group": ["Must be a valid boolean."]},
)
def test_service_account_validation_invalid_expiring(self):
"""Test service account creation with invalid expiring field"""
self.client.force_login(self.admin)
# Test with string instead of boolean
response = self.client.post(
reverse("authentik_api:user-service-account"),
data={
"name": "test-sa",
"expiring": "invalid",
},
)
self.assertEqual(response.status_code, 400)
self.assertJSONEqual(
response.content,
{"expiring": ["Must be a valid boolean."]},
)
def test_service_account_validation_invalid_expires(self):
"""Test service account creation with invalid expires field"""
self.client.force_login(self.admin)
# Test with invalid datetime string
response = self.client.post(
reverse("authentik_api:user-service-account"),
data={
"name": "test-sa",
"expires": "invalid-datetime",
},
)
self.assertEqual(response.status_code, 400)
self.assertJSONEqual(
response.content,
{
"expires": [
"Datetime has wrong format. Use one of these formats instead: "
"YYYY-MM-DDThh:mm[:ss[.uuuuuu]][+HH:MM|-HH:MM|Z]."
]
},
)
# Test with invalid format
response = self.client.post(
reverse("authentik_api:user-service-account"),
data={
"name": "test-sa",
"expires": "2024-13-45", # Invalid month/day
},
)
self.assertEqual(response.status_code, 400)
self.assertJSONEqual(
response.content,
{
"expires": [
"Datetime has wrong format. Use one of these formats instead: "
"YYYY-MM-DDThh:mm[:ss[.uuuuuu]][+HH:MM|-HH:MM|Z]."
]
},
)
def test_service_account_validation_multiple_errors(self):
"""Test service account creation with multiple validation errors"""
self.client.force_login(self.admin)
response = self.client.post(
reverse("authentik_api:user-service-account"),
data={
"name": "", # Empty username
"create_group": "invalid", # Invalid boolean
"expiring": 123, # Invalid boolean
"expires": "not-a-date", # Invalid datetime
},
)
self.assertEqual(response.status_code, 400)
self.assertJSONEqual(
response.content,
{
"name": ["This field may not be blank."],
"create_group": ["Must be a valid boolean."],
"expiring": ["Must be a valid boolean."],
"expires": [
"Datetime has wrong format. Use one of these formats instead: "
"YYYY-MM-DDThh:mm[:ss[.uuuuuu]][+HH:MM|-HH:MM|Z]."
],
},
)
def test_service_account_validation_user_friendly_duplicate_error(self):
"""Test that duplicate username returns user-friendly error, not database error"""
self.client.force_login(self.admin)
# Create first service account
response = self.client.post(
reverse("authentik_api:user-service-account"),
data={
"name": "duplicate-username-test",
"create_group": True,
},
)
self.assertEqual(response.status_code, 200)
# Attempt to create second with same username
response = self.client.post(
reverse("authentik_api:user-service-account"),
data={
"name": "duplicate-username-test",
"create_group": True,
},
)
self.assertEqual(response.status_code, 400)
self.assertJSONEqual(
response.content,
{"name": ["This field must be unique."]},
)

View File

@@ -30,7 +30,6 @@ from authentik.flows.views.interface import FlowInterfaceView
from authentik.root.asgi_middleware import AuthMiddlewareStack
from authentik.root.messages.consumer import MessageConsumer
from authentik.root.middleware import ChannelsLoggingMiddleware
from authentik.tenants.channels import TenantsAwareMiddleware
urlpatterns = [
path(
@@ -98,9 +97,7 @@ api_urlpatterns = [
websocket_urlpatterns = [
path(
"ws/client/",
ChannelsLoggingMiddleware(
TenantsAwareMiddleware(AuthMiddlewareStack(MessageConsumer.as_asgi()))
),
ChannelsLoggingMiddleware(AuthMiddlewareStack(MessageConsumer.as_asgi())),
),
]

View File

@@ -7,12 +7,13 @@ from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives.serialization import load_pem_private_key
from cryptography.x509.base import load_pem_x509_certificate
from django.utils.translation import gettext_lazy as _
from django_dramatiq_postgres.middleware import CurrentTask
from dramatiq.actor import actor
from structlog.stdlib import get_logger
from authentik.crypto.models import CertificateKeyPair
from authentik.lib.config import CONFIG
from authentik.tasks.middleware import CurrentTask
from authentik.tasks.models import Task
LOGGER = get_logger()
@@ -37,7 +38,7 @@ def ensure_certificate_valid(body: str):
@actor(description=_("Discover, import and update certificates from the filesystem."))
def certificate_discovery():
self = CurrentTask.get_task()
self: Task = CurrentTask.get_task()
certs = {}
private_keys = {}
discovered = 0

View File

@@ -1,5 +1,6 @@
from django.db.models.aggregates import Count
from django.utils.translation import gettext_lazy as _
from django_dramatiq_postgres.middleware import CurrentTask
from dramatiq.actor import actor
from structlog import get_logger
@@ -7,7 +8,7 @@ from authentik.enterprise.policies.unique_password.models import (
UniquePasswordPolicy,
UserPasswordHistory,
)
from authentik.tasks.middleware import CurrentTask
from authentik.tasks.models import Task
LOGGER = get_logger()
@@ -18,7 +19,7 @@ LOGGER = get_logger()
)
)
def check_and_purge_password_history():
self = CurrentTask.get_task()
self: Task = CurrentTask.get_task()
if not UniquePasswordPolicy.objects.exists():
UserPasswordHistory.objects.all().delete()
@@ -38,7 +39,7 @@ def trim_password_histories():
UniquePasswordPolicy policies.
"""
self = CurrentTask.get_task()
self: Task = CurrentTask.get_task()
# No policy, we'll let the cleanup above do its thing
if not UniquePasswordPolicy.objects.exists():

View File

@@ -4,6 +4,7 @@ from uuid import UUID
from django.http import HttpRequest
from django.utils.timezone import now
from django.utils.translation import gettext_lazy as _
from django_dramatiq_postgres.middleware import CurrentTask
from dramatiq.actor import actor
from requests.exceptions import RequestException
from structlog.stdlib import get_logger
@@ -19,7 +20,7 @@ from authentik.enterprise.providers.ssf.models import (
from authentik.lib.utils.http import get_http_session
from authentik.lib.utils.time import timedelta_from_string
from authentik.policies.engine import PolicyEngine
from authentik.tasks.middleware import CurrentTask
from authentik.tasks.models import Task
session = get_http_session()
LOGGER = get_logger()
@@ -73,7 +74,7 @@ def _check_app_access(stream: Stream, event_data: dict) -> bool:
@actor(description=_("Send an SSF event."))
def send_ssf_event(stream_uuid: UUID, event_data: dict[str, Any]):
self = CurrentTask.get_task()
self: Task = CurrentTask.get_task()
stream = Stream.objects.filter(pk=stream_uuid).first()
if not stream:

View File

@@ -4,6 +4,7 @@ from uuid import UUID
from django.db.models.query_utils import Q
from django.utils.translation import gettext_lazy as _
from django_dramatiq_postgres.middleware import CurrentTask
from dramatiq.actor import actor
from guardian.shortcuts import get_anonymous_user
from structlog.stdlib import get_logger
@@ -18,7 +19,7 @@ from authentik.events.models import (
from authentik.lib.utils.db import chunked_queryset
from authentik.policies.engine import PolicyEngine
from authentik.policies.models import PolicyBinding, PolicyEngineMode
from authentik.tasks.middleware import CurrentTask
from authentik.tasks.models import Task
LOGGER = get_logger()
@@ -37,7 +38,7 @@ def event_trigger_dispatch(event_uuid: UUID):
)
def event_trigger_handler(event_uuid: UUID, trigger_name: str):
"""Check if policies attached to NotificationRule match event"""
self = CurrentTask.get_task()
self: Task = CurrentTask.get_task()
event: Event = Event.objects.filter(event_uuid=event_uuid).first()
if not event:
@@ -130,7 +131,7 @@ def gdpr_cleanup(user_pk: int):
@actor(description=_("Cleanup seen notifications and notifications whose event expired."))
def notification_cleanup():
"""Cleanup seen notifications and notifications whose event expired."""
self = CurrentTask.get_task()
self: Task = CurrentTask.get_task()
notifications = Notification.objects.filter(Q(event=None) | Q(seen=True))
amount = notifications.count()
notifications.delete()

View File

@@ -54,7 +54,6 @@ class Challenge(PassiveSerializer):
flow_info = ContextualFlowInfo(required=False)
component = CharField(default="")
xid = CharField(required=False)
response_errors = DictField(
child=ErrorDetailSerializer(many=True), allow_empty=True, required=False

View File

@@ -143,12 +143,10 @@ class FlowPlan:
request: HttpRequest,
flow: Flow,
allowed_silent_types: list["StageView"] | None = None,
**get_params,
) -> HttpResponse:
"""Redirect to the flow executor for this flow plan"""
from authentik.flows.views.executor import (
SESSION_KEY_PLAN,
FlowContainer,
FlowExecutorView,
)
@@ -159,7 +157,6 @@ class FlowPlan:
# No unskippable stages found, so we can directly return the response of the last stage
final_stage: type[StageView] = self.bindings[-1].stage.view
temp_exec = FlowExecutorView(flow=flow, request=request, plan=self)
temp_exec.container = FlowContainer(request)
temp_exec.current_stage = self.bindings[-1].stage
temp_exec.current_stage_view = final_stage
temp_exec.setup(request, flow.slug)
@@ -177,9 +174,6 @@ class FlowPlan:
):
get_qs["inspector"] = "available"
for key, value in get_params:
get_qs[key] = value
return redirect_with_qs(
"authentik_core:if-flow",
get_qs,

View File

@@ -192,7 +192,6 @@ class ChallengeStageView(StageView):
)
flow_info.is_valid()
challenge.initial_data["flow_info"] = flow_info.data
challenge.initial_data["xid"] = self.executor.container.exec_id
if isinstance(challenge, WithUserInfoChallenge):
# If there's a pending user, update the `username` field
# this field is only used by password managers.

View File

@@ -29,7 +29,7 @@ window.authentik.flow = {
{% block body %}
<ak-skip-to-content></ak-skip-to-content>
<ak-message-container></ak-message-container>
<ak-flow-executor flowSlug="{{ flow.slug }}" xid="{{ xid }}">
<ak-flow-executor flowSlug="{{ flow.slug }}">
<ak-loading></ak-loading>
</ak-flow-executor>
{% endblock %}

View File

@@ -1,7 +1,6 @@
"""authentik multi-stage authentication engine"""
from copy import deepcopy
from uuid import uuid4
from django.conf import settings
from django.contrib.auth.mixins import LoginRequiredMixin
@@ -64,7 +63,6 @@ from authentik.policies.engine import PolicyEngine
LOGGER = get_logger()
# Argument used to redirect user after login
NEXT_ARG_NAME = "next"
SESSION_KEY_PLAN_CONTAINER = "authentik/flows/plan_container/%s"
SESSION_KEY_PLAN = "authentik/flows/plan"
SESSION_KEY_APPLICATION_PRE = "authentik/flows/application_pre"
SESSION_KEY_GET = "authentik/flows/get"
@@ -72,7 +70,6 @@ SESSION_KEY_POST = "authentik/flows/post"
SESSION_KEY_HISTORY = "authentik/flows/history"
QS_KEY_TOKEN = "flow_token" # nosec
QS_QUERY = "query"
QS_EXEC_ID = "xid"
def challenge_types():
@@ -99,88 +96,6 @@ class InvalidStageError(SentryIgnoredException):
"""Error raised when a challenge from a stage is not valid"""
class FlowContainer:
"""Allow for multiple concurrent flow executions in the same session"""
def __init__(self, request: HttpRequest, exec_id: str | None = None) -> None:
self.request = request
self.exec_id = exec_id
@staticmethod
def new(request: HttpRequest):
exec_id = str(uuid4())
request.session[SESSION_KEY_PLAN_CONTAINER % exec_id] = {}
return FlowContainer(request, exec_id)
def exists(self) -> bool:
"""Check if flow exists in container/session"""
return SESSION_KEY_PLAN in self.session
def save(self):
self.request.session.modified = True
@property
def session(self):
# Backwards compatibility: store session plan/etc directly in session
if not self.exec_id:
return self.request.session
self.request.session.setdefault(SESSION_KEY_PLAN_CONTAINER % self.exec_id, {})
return self.request.session.get(SESSION_KEY_PLAN_CONTAINER % self.exec_id, {})
@property
def plan(self) -> FlowPlan:
return self.session.get(SESSION_KEY_PLAN)
def to_redirect(
self,
request: HttpRequest,
flow: Flow,
allowed_silent_types: list[StageView] | None = None,
**get_params,
) -> HttpResponse:
get_params[QS_EXEC_ID] = self.exec_id
return self.plan.to_redirect(
request, flow, allowed_silent_types=allowed_silent_types, **get_params
)
@plan.setter
def plan(self, value: FlowPlan):
self.session[SESSION_KEY_PLAN] = value
self.request.session.modified = True
self.save()
@property
def application_pre(self):
return self.session.get(SESSION_KEY_APPLICATION_PRE)
@property
def get(self) -> QueryDict:
return self.session.get(SESSION_KEY_GET)
@get.setter
def get(self, value: QueryDict):
self.session[SESSION_KEY_GET] = value
self.save()
@property
def post(self) -> QueryDict:
return self.session.get(SESSION_KEY_POST)
@post.setter
def post(self, value: QueryDict):
self.session[SESSION_KEY_POST] = value
self.save()
@property
def history(self) -> list[FlowPlan]:
return self.session.get(SESSION_KEY_HISTORY)
@history.setter
def history(self, value: list[FlowPlan]):
self.session[SESSION_KEY_HISTORY] = value
self.save()
@method_decorator(xframe_options_sameorigin, name="dispatch")
class FlowExecutorView(APIView):
"""Flow executor, passing requests to Stage Views"""
@@ -188,9 +103,8 @@ class FlowExecutorView(APIView):
permission_classes = [AllowAny]
flow: Flow = None
plan: FlowPlan | None = None
container: FlowContainer
plan: FlowPlan | None = None
current_binding: FlowStageBinding | None = None
current_stage: Stage
current_stage_view: View
@@ -246,12 +160,10 @@ class FlowExecutorView(APIView):
if QS_KEY_TOKEN in get_params:
plan = self._check_flow_token(get_params[QS_KEY_TOKEN])
if plan:
container = FlowContainer.new(request)
container.plan = plan
self.request.session[SESSION_KEY_PLAN] = plan
# Early check if there's an active Plan for the current session
self.container = FlowContainer(request, request.GET.get(QS_EXEC_ID))
if self.container.exists():
self.plan: FlowPlan = self.container.plan
if SESSION_KEY_PLAN in self.request.session:
self.plan: FlowPlan = self.request.session[SESSION_KEY_PLAN]
if self.plan.flow_pk != self.flow.pk.hex:
self._logger.warning(
"f(exec): Found existing plan for other flow, deleting plan",
@@ -264,14 +176,13 @@ class FlowExecutorView(APIView):
self._logger.debug("f(exec): Continuing existing plan")
# Initial flow request, check if we have an upstream query string passed in
self.container.get = get_params
request.session[SESSION_KEY_GET] = get_params
# Don't check session again as we've either already loaded the plan or we need to plan
if not self.plan:
self.container.history = []
request.session[SESSION_KEY_HISTORY] = []
self._logger.debug("f(exec): No active Plan found, initiating planner")
try:
self.plan = self._initiate_plan()
self.container.plan = self.plan
except FlowNonApplicableException as exc:
self._logger.warning("f(exec): Flow not applicable to current user", exc=exc)
return self.handle_invalid_flow(exc)
@@ -344,19 +255,12 @@ class FlowExecutorView(APIView):
request=OpenApiTypes.NONE,
parameters=[
OpenApiParameter(
name=QS_QUERY,
name="query",
location=OpenApiParameter.QUERY,
required=True,
description="Querystring as received",
type=OpenApiTypes.STR,
),
OpenApiParameter(
name=QS_EXEC_ID,
location=OpenApiParameter.QUERY,
required=False,
description="Flow execution ID",
type=OpenApiTypes.STR,
),
)
],
operation_id="flows_executor_get",
)
@@ -383,8 +287,8 @@ class FlowExecutorView(APIView):
span.set_data("authentik Stage", self.current_stage_view)
span.set_data("authentik Flow", self.flow.slug)
stage_response = self.current_stage_view.dispatch(request)
return to_stage_response(request, stage_response, self.container.exec_id)
except Exception as exc:
return to_stage_response(request, stage_response)
except Exception as exc: # noqa
return self.handle_exception(exc)
@extend_schema(
@@ -402,19 +306,12 @@ class FlowExecutorView(APIView):
),
parameters=[
OpenApiParameter(
name=QS_QUERY,
name="query",
location=OpenApiParameter.QUERY,
required=True,
description="Querystring as received",
type=OpenApiTypes.STR,
),
OpenApiParameter(
name=QS_EXEC_ID,
location=OpenApiParameter.QUERY,
required=True,
description="Flow execution ID",
type=OpenApiTypes.STR,
),
)
],
operation_id="flows_executor_solve",
)
@@ -441,15 +338,14 @@ class FlowExecutorView(APIView):
span.set_data("authentik Stage", self.current_stage_view)
span.set_data("authentik Flow", self.flow.slug)
stage_response = self.current_stage_view.dispatch(request)
return to_stage_response(request, stage_response, self.container.exec_id)
return to_stage_response(request, stage_response)
except Exception as exc: # noqa
return self.handle_exception(exc)
def _initiate_plan(self) -> FlowPlan:
planner = FlowPlanner(self.flow)
plan = planner.plan(self.request)
container = FlowContainer.new(self.request)
container.plan = plan
self.request.session[SESSION_KEY_PLAN] = plan
try:
# Call the has_stages getter to check that
# there are no issues with the class we might've gotten
@@ -473,7 +369,7 @@ class FlowExecutorView(APIView):
except FlowNonApplicableException as exc:
self._logger.warning("f(exec): Flow restart not applicable to current user", exc=exc)
return self.handle_invalid_flow(exc)
self.container.plan = plan
self.request.session[SESSION_KEY_PLAN] = plan
kwargs = self.kwargs
kwargs.update({"flow_slug": self.flow.slug})
return redirect_with_qs("authentik_api:flow-executor", self.request.GET, **kwargs)
@@ -495,13 +391,9 @@ class FlowExecutorView(APIView):
)
self.cancel()
if next_param and not is_url_absolute(next_param):
return to_stage_response(
self.request, redirect_with_qs(next_param), self.container.exec_id
)
return to_stage_response(self.request, redirect_with_qs(next_param))
return to_stage_response(
self.request,
self.stage_invalid(error_message=_("Invalid next URL")),
self.container.exec_id,
self.request, self.stage_invalid(error_message=_("Invalid next URL"))
)
def stage_ok(self) -> HttpResponse:
@@ -515,7 +407,7 @@ class FlowExecutorView(APIView):
self.current_stage_view.cleanup()
self.request.session.get(SESSION_KEY_HISTORY, []).append(deepcopy(self.plan))
self.plan.pop()
self.container.plan = self.plan
self.request.session[SESSION_KEY_PLAN] = self.plan
if self.plan.bindings:
self._logger.debug(
"f(exec): Continuing with next stage",
@@ -558,7 +450,6 @@ class FlowExecutorView(APIView):
def cancel(self):
"""Cancel current flow execution"""
# TODO: Clean up container
keys_to_delete = [
SESSION_KEY_APPLICATION_PRE,
SESSION_KEY_PLAN,
@@ -581,8 +472,8 @@ class CancelView(View):
def get(self, request: HttpRequest) -> HttpResponse:
"""View which canels the currently active plan"""
if FlowContainer(request, request.GET.get(QS_EXEC_ID)).exists():
del request.session[SESSION_KEY_PLAN_CONTAINER % request.GET.get(QS_EXEC_ID)]
if SESSION_KEY_PLAN in request.session:
del request.session[SESSION_KEY_PLAN]
LOGGER.debug("Canceled current plan")
return redirect("authentik_flows:default-invalidation")
@@ -630,12 +521,19 @@ class ToDefaultFlow(View):
def dispatch(self, request: HttpRequest) -> HttpResponse:
flow = self.get_flow()
get_qs = request.GET.copy()
get_qs[QS_EXEC_ID] = str(uuid4())
return redirect_with_qs("authentik_core:if-flow", get_qs, flow_slug=flow.slug)
# If user already has a pending plan, clear it so we don't have to later.
if SESSION_KEY_PLAN in self.request.session:
plan: FlowPlan = self.request.session[SESSION_KEY_PLAN]
if plan.flow_pk != flow.pk.hex:
LOGGER.warning(
"f(def): Found existing plan for other flow, deleting plan",
flow_slug=flow.slug,
)
del self.request.session[SESSION_KEY_PLAN]
return redirect_with_qs("authentik_core:if-flow", request.GET, flow_slug=flow.slug)
def to_stage_response(request: HttpRequest, source: HttpResponse, xid: str) -> HttpResponse:
def to_stage_response(request: HttpRequest, source: HttpResponse) -> HttpResponse:
"""Convert normal HttpResponse into JSON Response"""
if (
isinstance(source, HttpResponseRedirect)
@@ -654,7 +552,6 @@ def to_stage_response(request: HttpRequest, source: HttpResponse, xid: str) -> H
RedirectChallenge(
{
"to": str(redirect_url),
"xid": xid,
}
)
)
@@ -663,7 +560,6 @@ def to_stage_response(request: HttpRequest, source: HttpResponse, xid: str) -> H
ShellChallenge(
{
"body": source.render().content.decode("utf-8"),
"xid": xid,
}
)
)
@@ -673,7 +569,6 @@ def to_stage_response(request: HttpRequest, source: HttpResponse, xid: str) -> H
ShellChallenge(
{
"body": source.content.decode("utf-8"),
"xid": xid,
}
)
)
@@ -705,6 +600,4 @@ class ConfigureFlowInitView(LoginRequiredMixin, View):
except FlowNonApplicableException:
LOGGER.warning("Flow not applicable to user")
raise Http404 from None
container = FlowContainer.new(request)
container.plan = plan
return container.to_redirect(request, stage.configure_flow)
return plan.to_redirect(request, stage.configure_flow)

View File

@@ -7,7 +7,6 @@ from ua_parser.user_agent_parser import Parse
from authentik.core.views.interface import InterfaceView
from authentik.flows.models import Flow
from authentik.flows.views.executor import QS_EXEC_ID
class FlowInterfaceView(InterfaceView):
@@ -18,7 +17,6 @@ class FlowInterfaceView(InterfaceView):
kwargs["flow"] = flow
kwargs["flow_background_url"] = flow.background_url(self.request)
kwargs["inspector"] = "inspector" in self.request.GET
kwargs["xid"] = self.request.GET.get(QS_EXEC_ID)
return super().get_context_data(**kwargs)
def compat_needs_sfe(self) -> bool:

View File

@@ -3,10 +3,9 @@
import re
import socket
from ipaddress import ip_address, ip_network
from smtplib import SMTPException
from textwrap import indent
from types import CodeType
from typing import TYPE_CHECKING, Any
from typing import Any
from cachetools import TLRUCache, cached
from django.core.exceptions import FieldError
@@ -30,10 +29,6 @@ from authentik.policies.types import PolicyRequest, PolicyResult
from authentik.providers.oauth2.id_token import IDToken
from authentik.providers.oauth2.models import AccessToken, OAuth2Provider
from authentik.stages.authenticator import devices_for_user
from authentik.stages.email.utils import TemplateEmailMessage
if TYPE_CHECKING:
from authentik.stages.email.models import EmailStage
LOGGER = get_logger()
@@ -62,12 +57,11 @@ class BaseEvaluator:
self._globals = {
"ak_call_policy": self.expr_func_call_policy,
"ak_create_event": self.expr_event_create,
"ak_create_jwt": self.expr_create_jwt,
"ak_is_group_member": BaseEvaluator.expr_is_group_member,
"ak_logger": get_logger(self._filename).bind(),
"ak_send_email": self.expr_send_email,
"ak_user_by": BaseEvaluator.expr_user_by,
"ak_user_has_authenticator": BaseEvaluator.expr_func_user_has_authenticator,
"ak_create_jwt": self.expr_create_jwt,
"ip_address": ip_address,
"ip_network": ip_network,
"list_flatten": BaseEvaluator.expr_flatten,
@@ -222,81 +216,6 @@ class BaseEvaluator:
access_token.save()
return access_token.token
def expr_send_email(
self,
address: str | list[str],
subject: str,
body: str | None = None,
stage: "EmailStage | None" = None,
template: str | None = None,
context: dict | None = None,
) -> bool:
"""Send an email using authentik's email system
Args:
address: Email address(es) to send to. Can be:
- Single email: "user@example.com"
- List of emails: ["user1@example.com", "user2@example.com"]
subject: Email subject
body: Email body (plain text/HTML). Mutually exclusive with template.
stage: EmailStage instance to use for settings. If None, uses global settings.
template: Template name to render. Mutually exclusive with body.
context: Additional context variables for template rendering.
Returns:
bool: True if email was queued successfully, False otherwise
"""
# Deferred imports to avoid circular import issues
from authentik.stages.email.tasks import send_mails
if body and template:
raise ValueError("body and template parameters are mutually exclusive")
if not body and not template:
raise ValueError("Either body or template parameter must be provided")
# Normalize address parameter to list of (name, email) tuples
if isinstance(address, str):
# Single email address
to_addresses = [("", address)]
elif isinstance(address, list):
if not address:
raise ValueError("Address list cannot be empty")
# List of email strings
to_addresses = [("", email) for email in address]
else:
raise ValueError("Address must be a string or list of strings")
try:
if template is not None:
# Use all available context from the evaluator for template rendering
template_context = self._context.copy()
# Add any custom context passed to the function
if context:
template_context.update(context)
# Use template rendering
message = TemplateEmailMessage(
subject=subject,
to=to_addresses,
template_name=template,
template_context=template_context,
)
else:
# Use plain body
message = TemplateEmailMessage(
subject=subject,
to=to_addresses,
body=body,
)
send_mails(stage, message)
return True
except (SMTPException, ConnectionError, ValidationError, ValueError) as exc:
LOGGER.warning("Failed to send email", exc=exc, addresses=to_addresses, subject=subject)
return False
def wrap_expression(self, expression: str) -> str:
"""Wrap expression in a function, call it, and save the result as `result`"""
handler_signature = ",".join(sanitize_arg(x) for x in self._context.keys())

View File

@@ -112,7 +112,6 @@ def get_logger_config():
"hpack": "WARNING",
"httpx": "WARNING",
"azure": "WARNING",
"channels_postgres": "WARNING",
}
for handler_name, level in handler_level_map.items():
base_config["loggers"][handler_name] = {

View File

@@ -3,15 +3,19 @@
from asyncio.exceptions import CancelledError
from typing import Any
from channels_redis.core import ChannelFull
from django.conf import settings
from django.core.exceptions import ImproperlyConfigured, SuspiciousOperation, ValidationError
from django.db import DatabaseError, InternalError, OperationalError, ProgrammingError
from django.http.response import Http404
from django_redis.exceptions import ConnectionInterrupted
from docker.errors import DockerException
from dramatiq.errors import Retry
from h11 import LocalProtocolError
from ldap3.core.exceptions import LDAPException
from psycopg.errors import Error
from redis.exceptions import ConnectionError as RedisConnectionError
from redis.exceptions import RedisError, ResponseError
from rest_framework.exceptions import APIException
from sentry_sdk import HttpTransport, get_current_scope
from sentry_sdk import init as sentry_sdk_init
@@ -19,6 +23,7 @@ from sentry_sdk.api import set_tag
from sentry_sdk.integrations.argv import ArgvIntegration
from sentry_sdk.integrations.django import DjangoIntegration
from sentry_sdk.integrations.dramatiq import DramatiqIntegration
from sentry_sdk.integrations.redis import RedisIntegration
from sentry_sdk.integrations.socket import SocketIntegration
from sentry_sdk.integrations.stdlib import StdlibIntegration
from sentry_sdk.integrations.threading import ThreadingIntegration
@@ -54,7 +59,13 @@ ignored_classes = (
ProgrammingError,
SuspiciousOperation,
ValidationError,
# Redis errors
RedisConnectionError,
ConnectionInterrupted,
RedisError,
ResponseError,
# websocket errors
ChannelFull,
WebSocketException,
LocalProtocolError,
# rest_framework error
@@ -101,6 +112,7 @@ def sentry_init(**sentry_init_kwargs):
ArgvIntegration(),
DjangoIntegration(transaction_style="function_name", cache_spans=True),
DramatiqIntegration(),
RedisIntegration(),
SocketIntegration(),
StdlibIntegration(),
ThreadingIntegration(propagate_hub=True),
@@ -147,7 +159,9 @@ def before_send(event: dict, hint: dict) -> dict | None:
if event["logger"] in [
"asyncio",
"multiprocessing",
"django_redis",
"django.security.DisallowedHost",
"django_redis.cache",
"paramiko.transport",
]:
return None

View File

@@ -2,6 +2,7 @@ from django.core.paginator import Paginator
from django.db.models import Model, QuerySet
from django.db.models.query import Q
from django.utils.text import slugify
from django_dramatiq_postgres.middleware import CurrentTask
from dramatiq.actor import Actor
from dramatiq.composition import group
from dramatiq.errors import Retry
@@ -21,7 +22,6 @@ from authentik.lib.sync.outgoing.exceptions import (
from authentik.lib.sync.outgoing.models import OutgoingSyncProvider
from authentik.lib.utils.errors import exception_to_dict
from authentik.lib.utils.reflection import class_to_path, path_to_class
from authentik.tasks.middleware import CurrentTask
from authentik.tasks.models import Task
@@ -61,7 +61,7 @@ class SyncTasks:
provider_pk: int,
sync_objects: Actor[[str, int, int, bool], None],
):
task = CurrentTask.get_task()
task: Task = CurrentTask.get_task()
self.logger = get_logger().bind(
provider_type=class_to_path(self._provider_model),
provider_pk=provider_pk,
@@ -118,7 +118,7 @@ class SyncTasks:
override_dry_run=False,
**filter,
):
task = CurrentTask.get_task()
task: Task = CurrentTask.get_task()
_object_type: type[Model] = path_to_class(object_type)
self.logger = get_logger().bind(
provider_type=class_to_path(self._provider_model),
@@ -173,7 +173,7 @@ class SyncTasks:
except TransientSyncException as exc:
self.logger.warning("failed to sync object", exc=exc, user=obj)
task.warning(
f"Failed to sync {str(obj)} due to transient error: {str(exc)}",
f"Failed to sync {str(obj)} due to " f"transient error: {str(exc)}",
obj=sanitize_item(obj),
exception=exception_to_dict(exc),
)
@@ -207,7 +207,7 @@ class SyncTasks:
provider_pk: int,
raw_op: str,
):
task = CurrentTask.get_task()
task: Task = CurrentTask.get_task()
self.logger = get_logger().bind(
provider_type=class_to_path(self._provider_model),
)
@@ -281,7 +281,7 @@ class SyncTasks:
action: str,
pk_set: list[int],
):
task = CurrentTask.get_task()
task: Task = CurrentTask.get_task()
self.logger = get_logger().bind(
provider_type=class_to_path(self._provider_model),
)

View File

@@ -1,7 +1,5 @@
"""Test Evaluator base functions"""
from unittest.mock import patch
from django.test import RequestFactory, TestCase
from django.urls import reverse
from jwt import decode
@@ -79,163 +77,3 @@ class TestEvaluator(TestCase):
jwt, provider.client_secret, algorithms=["HS256"], audience=provider.client_id
)
self.assertEqual(decoded["preferred_username"], user.username)
@patch("authentik.stages.email.tasks.send_mails")
def test_expr_send_email_with_body(self, mock_send_mails):
"""Test ak_send_email with body parameter"""
user = create_test_user()
evaluator = BaseEvaluator(generate_id())
evaluator._context = {"user": user}
# Test sending email with body
result = evaluator.evaluate(
"return ak_send_email('test@example.com', 'Test Subject', body='Test Body')"
)
self.assertTrue(result)
mock_send_mails.assert_called_once()
# Verify the call arguments - send_mails is called with (stage, message)
args, kwargs = mock_send_mails.call_args
stage, message = args
# Check that global settings are used (stage is None)
self.assertIsNone(stage)
# Check message properties
self.assertEqual(message.subject, "Test Subject")
self.assertEqual(message.to, ["test@example.com"])
self.assertEqual(message.body, "Test Body")
@patch("authentik.stages.email.tasks.send_mails")
def test_expr_send_email_with_template(self, mock_send_mails):
"""Test ak_send_email with template parameter"""
user = create_test_user()
evaluator = BaseEvaluator(generate_id())
evaluator._context = {"user": user}
# Test sending email with template
result = evaluator.evaluate(
"return ak_send_email('test@example.com', 'Test Subject', "
"template='email/password_reset.html')"
)
self.assertTrue(result)
mock_send_mails.assert_called_once()
def test_expr_send_email_validation_errors(self):
"""Test ak_send_email validation errors"""
evaluator = BaseEvaluator(generate_id())
# Test error when both body and template are provided
with self.assertRaises(ValueError) as cm:
evaluator.evaluate(
"return ak_send_email('test@example.com', 'Test', "
"body='Body', template='template.html')"
)
self.assertIn("mutually exclusive", str(cm.exception))
# Test error when neither body nor template are provided
with self.assertRaises(ValueError) as cm:
evaluator.evaluate("return ak_send_email('test@example.com', 'Test')")
self.assertIn("Either body or template parameter must be provided", str(cm.exception))
@patch("authentik.stages.email.tasks.send_mails")
def test_expr_send_email_with_custom_stage(self, mock_send_mails):
"""Test ak_send_email with custom EmailStage"""
from authentik.stages.email.models import EmailStage
user = create_test_user()
custom_stage = EmailStage(
name="custom-stage", use_global_settings=False, from_address="custom@example.com"
)
evaluator = BaseEvaluator(generate_id())
evaluator._context = {"user": user, "custom_stage": custom_stage}
# Test sending email with custom stage
result = evaluator.evaluate(
"return ak_send_email('test@example.com', 'Test Subject', "
"body='Test Body', stage=custom_stage)"
)
self.assertTrue(result)
mock_send_mails.assert_called_once()
# Verify the custom stage was used
args, kwargs = mock_send_mails.call_args
stage, message = args
self.assertEqual(stage, custom_stage)
self.assertFalse(stage.use_global_settings)
@patch("authentik.stages.email.tasks.send_mails")
def test_expr_send_email_with_context(self, mock_send_mails):
"""Test ak_send_email with custom context parameter"""
user = create_test_user()
evaluator = BaseEvaluator(generate_id())
evaluator._context = {"user": user, "request_id": "123"}
# Test sending email with template and custom context
result = evaluator.evaluate(
"return ak_send_email('test@example.com', 'Test Subject', "
"template='email/password_reset.html', "
"context={'url': 'http://localhost', 'expires': '2026-01-01'})"
)
self.assertTrue(result)
mock_send_mails.assert_called_once()
# Verify the call arguments - send_mails is called with (stage, message)
args, kwargs = mock_send_mails.call_args
stage, message = args
# Check that global settings are used (stage is None)
self.assertIsNone(stage)
self.assertEqual(message.subject, "Test Subject")
self.assertEqual(message.to, ["test@example.com"])
self.assertIn("2026-01-01", message.body)
self.assertIn("http://localhost", message.body)
@patch("authentik.stages.email.tasks.send_mails")
def test_expr_send_email_multiple_addresses(self, mock_send_mails):
"""Test ak_send_email with multiple email addresses"""
user = create_test_user()
evaluator = BaseEvaluator(generate_id())
evaluator._context = {"user": user}
# Test sending email to multiple addresses
result = evaluator.evaluate(
"return ak_send_email(['user1@example.com', 'user2@example.com'], "
"'Test Subject', body='Test Body')"
)
self.assertTrue(result)
mock_send_mails.assert_called_once()
# Verify the call arguments - send_mails is called with (stage, message)
args, kwargs = mock_send_mails.call_args
stage, message = args
# Check that global settings are used (stage is None)
self.assertIsNone(stage)
# Check message properties - should have multiple recipients
self.assertEqual(message.subject, "Test Subject")
self.assertEqual(message.to, ["user1@example.com", "user2@example.com"])
self.assertEqual(message.body, "Test Body")
def test_expr_send_email_multiple_addresses_validation(self):
"""Test ak_send_email validation with multiple addresses"""
evaluator = BaseEvaluator(generate_id())
# Test error when empty list is provided
with self.assertRaises(ValueError) as cm:
evaluator.evaluate("return ak_send_email([], 'Test', body='Body')")
self.assertIn("Address list cannot be empty", str(cm.exception))
# Test error when invalid type is provided
with self.assertRaises(ValueError) as cm:
evaluator.evaluate("return ak_send_email(123, 'Test', body='Body')")
self.assertIn("Address must be a string or list of strings", str(cm.exception))

View File

@@ -3,9 +3,7 @@
from dataclasses import asdict, dataclass, field
from datetime import datetime
from enum import IntEnum
from hashlib import sha256
from typing import Any
from uuid import UUID
from asgiref.sync import async_to_sync
from channels.exceptions import DenyConnection
@@ -20,15 +18,8 @@ from structlog.stdlib import BoundLogger, get_logger
from authentik.outposts.apps import GAUGE_OUTPOSTS_CONNECTED, GAUGE_OUTPOSTS_LAST_UPDATE
from authentik.outposts.models import OUTPOST_HELLO_INTERVAL, Outpost, OutpostState
def build_outpost_group(outpost_pk: str | UUID) -> str:
return sha256(f"{connection.schema_name}/group_outpost_{str(outpost_pk)}".encode()).hexdigest()
def build_outpost_group_instance(outpost_pk: str | UUID, instance: str) -> str:
return sha256(
f"{connection.schema_name}/group_outpost_{str(outpost_pk)}_{instance}".encode()
).hexdigest()
OUTPOST_GROUP = "group_outpost_%(outpost_pk)s"
OUTPOST_GROUP_INSTANCE = "group_outpost_%(outpost_pk)s_%(instance)s"
class WebsocketMessageInstruction(IntEnum):
@@ -73,24 +64,26 @@ class OutpostConsumer(JsonWebsocketConsumer):
def connect(self):
uuid = self.scope["url_route"]["kwargs"]["pk"]
user = self.scope["user"]
self.outpost: Outpost | None = (
outpost = (
get_objects_for_user(user, "authentik_outposts.view_outpost").filter(pk=uuid).first()
)
if self.outpost is None:
if not outpost:
raise DenyConnection()
self.logger = self.logger.bind(outpost=self.outpost)
self.logger = self.logger.bind(outpost=outpost)
try:
self.accept()
except RuntimeError as exc:
self.logger.warning("runtime error during accept", exc=exc)
raise DenyConnection() from None
self.outpost = outpost
query = QueryDict(self.scope["query_string"].decode())
self.instance_uid = query.get("instance_uuid", self.channel_name)
async_to_sync(self.channel_layer.group_add)(
build_outpost_group(self.outpost.pk), self.channel_name
OUTPOST_GROUP % {"outpost_pk": str(self.outpost.pk)}, self.channel_name
)
async_to_sync(self.channel_layer.group_add)(
build_outpost_group_instance(self.outpost.pk, self.instance_uid),
OUTPOST_GROUP_INSTANCE
% {"outpost_pk": str(self.outpost.pk), "instance": self.instance_uid},
self.channel_name,
)
GAUGE_OUTPOSTS_CONNECTED.labels(
@@ -103,11 +96,12 @@ class OutpostConsumer(JsonWebsocketConsumer):
def disconnect(self, code):
if self.outpost:
async_to_sync(self.channel_layer.group_discard)(
build_outpost_group(self.outpost.pk), self.channel_name
OUTPOST_GROUP % {"outpost_pk": str(self.outpost.pk)}, self.channel_name
)
if self.instance_uid:
async_to_sync(self.channel_layer.group_discard)(
build_outpost_group_instance(self.outpost.pk, self.instance_uid),
OUTPOST_GROUP_INSTANCE
% {"outpost_pk": str(self.outpost.pk), "instance": self.instance_uid},
self.channel_name,
)
if self.outpost and self.instance_uid:

View File

@@ -12,6 +12,7 @@ from channels.layers import get_channel_layer
from django.core.cache import cache
from django.utils.text import slugify
from django.utils.translation import gettext_lazy as _
from django_dramatiq_postgres.middleware import CurrentTask
from docker.constants import DEFAULT_UNIX_SOCKET
from dramatiq.actor import actor
from kubernetes.config.incluster_config import SERVICE_TOKEN_FILENAME
@@ -20,7 +21,7 @@ from structlog.stdlib import get_logger
from yaml import safe_load
from authentik.lib.config import CONFIG
from authentik.outposts.consumer import build_outpost_group
from authentik.outposts.consumer import OUTPOST_GROUP
from authentik.outposts.controllers.base import BaseController, ControllerException
from authentik.outposts.controllers.docker import DockerClient
from authentik.outposts.controllers.kubernetes import KubernetesClient
@@ -40,7 +41,7 @@ from authentik.providers.rac.controllers.docker import RACDockerController
from authentik.providers.rac.controllers.kubernetes import RACKubernetesController
from authentik.providers.radius.controllers.docker import RadiusDockerController
from authentik.providers.radius.controllers.kubernetes import RadiusKubernetesController
from authentik.tasks.middleware import CurrentTask
from authentik.tasks.models import Task
LOGGER = get_logger()
CACHE_KEY_OUTPOST_DOWN = "goauthentik.io/outposts/teardown/%s"
@@ -107,7 +108,7 @@ def outpost_service_connection_monitor(connection_pk: Any):
@actor(description=_("Create/update/monitor/delete the deployment of an Outpost."))
def outpost_controller(outpost_pk: str, action: str = "up", from_cache: bool = False):
"""Create/update/monitor/delete the deployment of an Outpost"""
self = CurrentTask.get_task()
self: Task = CurrentTask.get_task()
self.set_uid(outpost_pk)
logs = []
if from_cache:
@@ -141,7 +142,7 @@ def outpost_token_ensurer():
"""
Periodically ensure that all Outposts have valid Service Accounts and Tokens
"""
self = CurrentTask.get_task()
self: Task = CurrentTask.get_task()
all_outposts = Outpost.objects.all()
for outpost in all_outposts:
_ = outpost.token
@@ -160,7 +161,7 @@ def outpost_send_update(pk: Any):
_ = outpost.token
outpost.build_user_permissions(outpost.user)
layer = get_channel_layer()
group = build_outpost_group(outpost.pk)
group = OUTPOST_GROUP % {"outpost_pk": str(outpost.pk)}
LOGGER.debug("sending update", channel=group, outpost=outpost)
async_to_sync(layer.group_send)(group, {"type": "event.update"})
@@ -168,7 +169,7 @@ def outpost_send_update(pk: Any):
@actor(description=_("Checks the local environment and create Service connections."))
def outpost_connection_discovery():
"""Checks the local environment and create Service connections."""
self = CurrentTask.get_task()
self: Task = CurrentTask.get_task()
if not CONFIG.get_bool("outposts.discover"):
self.info("Outpost integration discovery is disabled")
return
@@ -212,7 +213,7 @@ def outpost_session_end(session_id: str):
hashed_session_id = hash_session_key(session_id)
for outpost in Outpost.objects.all():
LOGGER.info("Sending session end signal to outpost", outpost=outpost)
group = build_outpost_group(outpost.pk)
group = OUTPOST_GROUP % {"outpost_pk": str(outpost.pk)}
async_to_sync(layer.group_send)(
group,
{

View File

@@ -45,7 +45,7 @@ class TestOutpostWS(TransactionTestCase):
communicator = WebsocketCommunicator(
URLRouter(websocket.websocket_urlpatterns),
f"/ws/outpost/{self.outpost.pk}/",
[(b"authorization", f"Bearer {self.token}".encode())],
{b"authorization": f"Bearer {self.token}".encode()},
)
connected, _ = await communicator.connect()
self.assertTrue(connected)
@@ -56,7 +56,7 @@ class TestOutpostWS(TransactionTestCase):
communicator = WebsocketCommunicator(
URLRouter(websocket.websocket_urlpatterns),
f"/ws/outpost/{self.outpost.pk}/",
[(b"authorization", f"Bearer {self.token}".encode())],
{b"authorization": f"Bearer {self.token}".encode()},
)
connected, _ = await communicator.connect()
self.assertTrue(connected)
@@ -83,7 +83,7 @@ class TestOutpostWS(TransactionTestCase):
communicator = WebsocketCommunicator(
URLRouter(websocket.websocket_urlpatterns),
f"/ws/outpost/{self.outpost.pk}/",
[(b"authorization", f"Bearer {self.token}".encode())],
{b"authorization": f"Bearer {self.token}".encode()},
)
connected, _ = await communicator.connect()
self.assertTrue(connected)

View File

@@ -11,14 +11,11 @@ from authentik.outposts.api.service_connections import (
from authentik.outposts.channels import TokenOutpostMiddleware
from authentik.outposts.consumer import OutpostConsumer
from authentik.root.middleware import ChannelsLoggingMiddleware
from authentik.tenants.channels import TenantsAwareMiddleware
websocket_urlpatterns = [
path(
"ws/outpost/<uuid:pk>/",
ChannelsLoggingMiddleware(
TenantsAwareMiddleware(TokenOutpostMiddleware(OutpostConsumer.as_asgi()))
),
ChannelsLoggingMiddleware(TokenOutpostMiddleware(OutpostConsumer.as_asgi())),
),
]

View File

@@ -23,8 +23,6 @@ SCOPE_OPENID_PROFILE = "profile"
SCOPE_OPENID_EMAIL = "email"
SCOPE_OFFLINE_ACCESS = "offline_access"
UI_LOCALES = "ui_locales"
# https://www.iana.org/assignments/oauth-parameters/auth-parameters.xhtml#pkce-code-challenge-method
PKCE_METHOD_PLAIN = "plain"
PKCE_METHOD_S256 = "S256"

View File

@@ -1,13 +1,14 @@
"""OAuth2 Provider Tasks"""
from django.utils.translation import gettext_lazy as _
from django_dramatiq_postgres.middleware import CurrentTask
from dramatiq.actor import actor
from structlog.stdlib import get_logger
from authentik.lib.utils.http import get_http_session
from authentik.providers.oauth2.models import OAuth2Provider
from authentik.providers.oauth2.utils import create_logout_token
from authentik.tasks.middleware import CurrentTask
from authentik.tasks.models import Task
LOGGER = get_logger()
@@ -30,7 +31,7 @@ def send_backchannel_logout_request(
Returns:
bool: True if the request was sent successfully, False otherwise
"""
self = CurrentTask.get_task()
self: Task = CurrentTask.get_task()
LOGGER.debug("Sending back-channel logout request", provider_pk=provider_pk, sub=sub)
provider = OAuth2Provider.objects.filter(pk=provider_pk).first()

View File

@@ -1,7 +1,6 @@
"""Test authorize view"""
from unittest.mock import MagicMock, patch
from urllib.parse import parse_qs, urlparse
from django.test import RequestFactory
from django.urls import reverse
@@ -671,55 +670,3 @@ class TestAuthorize(OAuthTestCase):
)
parsed = OAuthAuthorizationParams.from_request(request)
self.assertNotIn(SCOPE_OFFLINE_ACCESS, parsed.scope)
def test_ui_locales(self):
"""Test OIDC ui_locales authorization"""
flow = create_test_flow()
provider = OAuth2Provider.objects.create(
name=generate_id(),
client_id="test",
authorization_flow=flow,
redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "foo://localhost")],
access_code_validity="seconds=100",
)
Application.objects.create(name="app", slug="app", provider=provider)
state = generate_id()
self.client.logout()
response = self.client.get(
reverse("authentik_providers_oauth2:authorize"),
data={
"response_type": "code",
"client_id": "test",
"state": state,
"redirect_uri": "foo://localhost",
"ui_locales": "invalid fr",
},
)
parsed = parse_qs(urlparse(response.url).query)
self.assertEqual(parsed["locale"], ["fr"])
def test_ui_locales_invalid(self):
"""Test OIDC ui_locales authorization"""
flow = create_test_flow()
provider = OAuth2Provider.objects.create(
name=generate_id(),
client_id="test",
authorization_flow=flow,
redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "foo://localhost")],
access_code_validity="seconds=100",
)
Application.objects.create(name="app", slug="app", provider=provider)
state = generate_id()
self.client.logout()
response = self.client.get(
reverse("authentik_providers_oauth2:authorize"),
data={
"response_type": "code",
"client_id": "test",
"state": state,
"redirect_uri": "foo://localhost",
"ui_locales": "invalid",
},
)
parsed = parse_qs(urlparse(response.url).query)
self.assertNotIn("locale", parsed)

View File

@@ -5,14 +5,13 @@ from datetime import timedelta
from json import dumps
from re import error as RegexError
from re import fullmatch
from urllib.parse import parse_qs, quote, urlencode, urlparse, urlsplit, urlunparse, urlunsplit
from urllib.parse import parse_qs, urlencode, urlparse, urlsplit, urlunsplit
from uuid import uuid4
from django.conf import settings
from django.http import HttpRequest, HttpResponse, HttpResponseRedirect
from django.http import HttpRequest, HttpResponse
from django.http.response import Http404, HttpResponseBadRequest
from django.shortcuts import get_object_or_404
from django.utils import timezone, translation
from django.utils import timezone
from django.utils.translation import gettext as _
from structlog.stdlib import get_logger
@@ -42,7 +41,6 @@ from authentik.providers.oauth2.constants import (
SCOPE_OFFLINE_ACCESS,
SCOPE_OPENID,
TOKEN_TYPE,
UI_LOCALES,
)
from authentik.providers.oauth2.errors import (
AuthorizeError,
@@ -389,45 +387,6 @@ class AuthorizationFlowInitView(BufferedPolicyAccessView):
request.context["oauth_response_type"] = self.params.response_type
return request
def dispatch_with_language(self, request: HttpRequest, *args, **kwargs) -> HttpResponse:
"""Activate language from OIDC specific ui_locales parameter, picking the earliest one
available"""
selected_language = None
if UI_LOCALES in self.request.GET:
languages = str(self.request.GET[UI_LOCALES]).split(" ")
for language in languages:
if translation.check_for_language(language):
selected_language = translation.get_supported_language_variant(language)
LOGGER.debug(
"Activating language from oidc ui_locales", locale=selected_language
)
break
translation.activate(selected_language)
response = super().dispatch(request, *args, **kwargs)
if selected_language:
response.set_cookie(
settings.LANGUAGE_COOKIE_NAME,
selected_language,
max_age=settings.LANGUAGE_COOKIE_AGE,
path=settings.LANGUAGE_COOKIE_PATH,
domain=settings.LANGUAGE_COOKIE_DOMAIN,
secure=settings.LANGUAGE_COOKIE_SECURE,
httponly=settings.LANGUAGE_COOKIE_HTTPONLY,
samesite=settings.LANGUAGE_COOKIE_SAMESITE,
)
if isinstance(response, HttpResponseRedirect):
parsed_url = urlparse(response.url)
args = parse_qs(parsed_url.query)
args["locale"] = selected_language
response["Location"] = urlunparse(
parsed_url._replace(query=urlencode(args, quote_via=quote, doseq=True))
)
return response
def dispatch(self, request: HttpRequest, *args, **kwargs):
# Activate language before parsing params (error messages should be localised)
return self.dispatch_with_language(request, *args, **kwargs)
def get(self, request: HttpRequest, *args, **kwargs) -> HttpResponse:
"""Start FlowPLanner, return to flow executor shell"""
# Require a login event to be set, otherwise make the user re-login

View File

@@ -5,7 +5,7 @@ from channels.layers import get_channel_layer
from django.utils.translation import gettext_lazy as _
from dramatiq.actor import actor
from authentik.outposts.consumer import build_outpost_group
from authentik.outposts.consumer import OUTPOST_GROUP
from authentik.outposts.models import Outpost, OutpostType
from authentik.providers.oauth2.id_token import hash_session_key
@@ -15,7 +15,7 @@ def proxy_on_logout(session_id: str):
layer = get_channel_layer()
hashed_session_id = hash_session_key(session_id)
for outpost in Outpost.objects.filter(type=OutpostType.PROXY):
group = build_outpost_group(outpost.pk)
group = OUTPOST_GROUP % {"outpost_pk": str(outpost.pk)}
async_to_sync(layer.group_send)(
group,
{

View File

@@ -1,46 +1,28 @@
"""RAC Client consumer"""
from hashlib import sha256
from asgiref.sync import async_to_sync
from channels.db import database_sync_to_async
from channels.exceptions import ChannelFull, DenyConnection
from channels.generic.websocket import AsyncWebsocketConsumer
from django.db import connection
from django.http.request import QueryDict
from structlog.stdlib import BoundLogger, get_logger
from authentik.outposts.consumer import build_outpost_group_instance
from authentik.outposts.consumer import OUTPOST_GROUP_INSTANCE
from authentik.outposts.models import Outpost, OutpostState, OutpostType
from authentik.providers.rac.models import ConnectionToken, RACProvider
def build_rac_client_group() -> str:
"""
Global broadcast group, which messages are sent to when the outpost connects back
to authentik for a specific connection
The `RACClientConsumer` consumer adds itself to this group on connection,
and removes itself once it has been assigned a specific outpost channel
"""
return sha256(f"{connection.schema_name}/group_rac_client".encode()).hexdigest()
def build_rac_client_group_session(session_key: str) -> str:
"""
A group for all connections in a given authentik session ID
A disconnect message is sent to this group when the session expires/is deleted
"""
return sha256(f"{connection.schema_name}/group_rac_client_{session_key}".encode()).hexdigest()
def build_rac_client_group_token(token: str) -> str:
"""
A group for all connections with a specific token, which in almost all cases
is just one connection, however this is used to disconnect the connection
when the token is deleted
"""
return sha256(f"{connection.schema_name}/group_rac_token_{token}".encode()).hexdigest()
# Global broadcast group, which messages are sent to when the outpost connects back
# to authentik for a specific connection
# The `RACClientConsumer` consumer adds itself to this group on connection,
# and removes itself once it has been assigned a specific outpost channel
RAC_CLIENT_GROUP = "group_rac_client"
# A group for all connections in a given authentik session ID
# A disconnect message is sent to this group when the session expires/is deleted
RAC_CLIENT_GROUP_SESSION = "group_rac_client_%(session)s"
# A group for all connections with a specific token, which in almost all cases
# is just one connection, however this is used to disconnect the connection
# when the token is deleted
RAC_CLIENT_GROUP_TOKEN = "group_rac_token_%(token)s" # nosec
# Step 1: Client connects to this websocket endpoint
# Step 2: We prepare all the connection args for Guac
@@ -63,23 +45,22 @@ class RACClientConsumer(AsyncWebsocketConsumer):
async def connect(self):
self.logger = get_logger()
await self.accept("guacamole")
await self.channel_layer.group_add(build_rac_client_group(), self.channel_name)
await self.channel_layer.group_add(RAC_CLIENT_GROUP, self.channel_name)
await self.channel_layer.group_add(
build_rac_client_group_session(self.scope["session"].session_key),
RAC_CLIENT_GROUP_SESSION % {"session": self.scope["session"].session_key},
self.channel_name,
)
await self.init_outpost_connection()
async def disconnect(self, code):
self.logger.debug("Disconnecting")
if self.dest_channel_id:
# Tell the outpost we're disconnecting
await self.channel_layer.send(
self.dest_channel_id,
{
"type": "event.disconnect",
},
)
# Tell the outpost we're disconnecting
await self.channel_layer.send(
self.dest_channel_id,
{
"type": "event.disconnect",
},
)
@database_sync_to_async
def init_outpost_connection(self):
@@ -128,8 +109,10 @@ class RACClientConsumer(AsyncWebsocketConsumer):
if len(states) < 1:
continue
self.logger.debug("Sending out connection broadcast")
group = build_outpost_group_instance(outpost.pk, states[0].uid)
async_to_sync(self.channel_layer.group_send)(group, msg)
async_to_sync(self.channel_layer.group_send)(
OUTPOST_GROUP_INSTANCE % {"outpost_pk": str(outpost.pk), "instance": states[0].uid},
msg,
)
if self.provider and self.provider.delete_token_on_disconnect:
self.logger.info("Deleting connection token to prevent reconnect", token=self.token)
self.token.delete()
@@ -174,7 +157,7 @@ class RACClientConsumer(AsyncWebsocketConsumer):
self.dest_channel_id = outpost_channel
# Since we have a specific outpost channel now, we can remove
# ourselves from the global broadcast group
await self.channel_layer.group_discard(build_rac_client_group(), self.channel_name)
await self.channel_layer.group_discard(RAC_CLIENT_GROUP, self.channel_name)
async def event_send(self, event: dict):
"""Handler called by outpost websocket that sends data to this specific

View File

@@ -3,7 +3,7 @@
from channels.exceptions import ChannelFull
from channels.generic.websocket import AsyncWebsocketConsumer
from authentik.providers.rac.consumer_client import build_rac_client_group
from authentik.providers.rac.consumer_client import RAC_CLIENT_GROUP
class RACOutpostConsumer(AsyncWebsocketConsumer):
@@ -15,7 +15,7 @@ class RACOutpostConsumer(AsyncWebsocketConsumer):
self.dest_channel_id = self.scope["url_route"]["kwargs"]["channel"]
await self.accept()
await self.channel_layer.group_send(
build_rac_client_group(),
RAC_CLIENT_GROUP,
{
"type": "event.outpost.connected",
"outpost_channel": self.channel_name,

View File

@@ -9,8 +9,8 @@ from django.dispatch import receiver
from authentik.core.models import AuthenticatedSession
from authentik.providers.rac.api.endpoints import user_endpoint_cache_key
from authentik.providers.rac.consumer_client import (
build_rac_client_group_session,
build_rac_client_group_token,
RAC_CLIENT_GROUP_SESSION,
RAC_CLIENT_GROUP_TOKEN,
)
from authentik.providers.rac.models import ConnectionToken, Endpoint
@@ -19,7 +19,10 @@ from authentik.providers.rac.models import ConnectionToken, Endpoint
def user_session_deleted(sender, instance: AuthenticatedSession, **_):
layer = get_channel_layer()
async_to_sync(layer.group_send)(
build_rac_client_group_session(instance.session.session_key),
RAC_CLIENT_GROUP_SESSION
% {
"session": instance.session.session_key,
},
{"type": "event.disconnect", "reason": "session_logout"},
)
@@ -29,7 +32,10 @@ def pre_delete_connection_token_disconnect(sender, instance: ConnectionToken, **
"""Disconnect session when connection token is deleted"""
layer = get_channel_layer()
async_to_sync(layer.group_send)(
build_rac_client_group_token(instance.token),
RAC_CLIENT_GROUP_TOKEN
% {
"token": instance.token,
},
{"type": "event.disconnect", "reason": "token_delete"},
)

View File

@@ -12,7 +12,6 @@ from authentik.providers.rac.consumer_outpost import RACOutpostConsumer
from authentik.providers.rac.views import RACInterface, RACStartView
from authentik.root.asgi_middleware import AuthMiddlewareStack
from authentik.root.middleware import ChannelsLoggingMiddleware
from authentik.tenants.channels import TenantsAwareMiddleware
urlpatterns = [
path(
@@ -30,15 +29,11 @@ urlpatterns = [
websocket_urlpatterns = [
path(
"ws/rac/<str:token>/",
ChannelsLoggingMiddleware(
TenantsAwareMiddleware(AuthMiddlewareStack(RACClientConsumer.as_asgi()))
),
ChannelsLoggingMiddleware(AuthMiddlewareStack(RACClientConsumer.as_asgi())),
),
path(
"ws/outpost_rac/<str:channel>/",
ChannelsLoggingMiddleware(
TenantsAwareMiddleware(TokenOutpostMiddleware(RACOutpostConsumer.as_asgi()))
),
ChannelsLoggingMiddleware(TokenOutpostMiddleware(RACOutpostConsumer.as_asgi())),
),
]

View File

@@ -1,35 +0,0 @@
from typing import Any
from channels_postgres.core import PostgresChannelLayer as BasePostgresChannelLayer
from channels_postgres.db import DatabaseLayer as BaseDatabaseLayer
from django.conf import settings
from psycopg_pool import AsyncConnectionPool
from authentik.root.db.base import DatabaseWrapper
class DatabaseLayer(BaseDatabaseLayer):
async def get_db_pool(self, db_params: dict[str, Any]) -> AsyncConnectionPool:
db_wrapper = DatabaseWrapper(settings.CHANNEL_LAYERS["default"]["CONFIG"])
db_params = db_wrapper.get_connection_params()
db_params.pop("cursor_factory")
db_params.pop("context")
return await super().get_db_pool(db_params)
class PostgresChannelLayer(BasePostgresChannelLayer):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.django_db = DatabaseLayer(self.django_db.psycopg_options, self.db_params)
@property
def db_params(self):
db_wrapper = DatabaseWrapper(settings.CHANNEL_LAYERS["default"]["CONFIG"])
db_params = db_wrapper.get_connection_params()
db_params.pop("cursor_factory")
db_params.pop("context")
return db_params
@db_params.setter
def db_params(self, value):
pass

View File

@@ -292,7 +292,7 @@ class ChannelsLoggingMiddleware:
except DenyConnection:
return await send({"type": "websocket.close"})
except Exception as exc:
if settings.DEBUG or settings.TEST:
if settings.DEBUG:
raise exc
LOGGER.warning("Exception in ASGI application", exc=exc)
return await send({"type": "websocket.close"})

View File

@@ -11,6 +11,8 @@ from django.dispatch import Signal
from django.http import HttpRequest, HttpResponse
from django.views import View
from django_prometheus.exports import ExportToDjangoView
from django_redis import get_redis_connection
from redis.exceptions import RedisError
monitoring_set = Signal()
@@ -42,17 +44,19 @@ class LiveView(View):
class ReadyView(View):
"""View for readiness probe, always returns Http 200, unless sql is down"""
def check_db(self):
for db_conn in connections.all():
# Force connection reload
db_conn.connect()
_ = db_conn.cursor()
"""View for readiness probe, always returns Http 200, unless sql or redis is down"""
def dispatch(self, request: HttpRequest) -> HttpResponse:
try:
self.check_db()
for db_conn in connections.all():
# Force connection reload
db_conn.connect()
_ = db_conn.cursor()
except OperationalError: # pragma: no cover
return HttpResponse(status=503)
try:
redis_conn = get_redis_connection()
redis_conn.ping()
except RedisError: # pragma: no cover
return HttpResponse(status=503)
return HttpResponse(status=200)

View File

@@ -10,7 +10,7 @@ from sentry_sdk import set_tag
from xmlsec import enable_debug_trace
from authentik import authentik_version
from authentik.lib.config import CONFIG, django_db_config
from authentik.lib.config import CONFIG, django_db_config, redis_url
from authentik.lib.logging import get_logger_config, structlog_configure
from authentik.lib.sentry import sentry_init
from authentik.lib.utils.reflection import get_env
@@ -64,7 +64,6 @@ SHARED_APPS = [
"pgactivity",
"pglock",
"channels",
"channels_postgres",
"django_dramatiq_postgres",
"authentik.tasks",
]
@@ -73,7 +72,6 @@ TENANT_APPS = [
"django.contrib.contenttypes",
"django.contrib.sessions",
"pgtrigger",
"django_postgres_cache",
"authentik.admin",
"authentik.api",
"authentik.core",
@@ -105,7 +103,6 @@ TENANT_APPS = [
"authentik.sources.plex",
"authentik.sources.saml",
"authentik.sources.scim",
"authentik.sources.telegram",
"authentik.stages.authenticator",
"authentik.stages.authenticator_duo",
"authentik.stages.authenticator_email",
@@ -228,11 +225,20 @@ REST_FRAMEWORK = {
CACHES = {
"default": {
"BACKEND": "django_postgres_cache.backend.DatabaseCache",
"BACKEND": "django_redis.cache.RedisCache",
"LOCATION": CONFIG.get("cache.url") or redis_url(CONFIG.get("redis.db")),
"TIMEOUT": CONFIG.get_int("cache.timeout", 300),
"OPTIONS": {
"CLIENT_CLASS": "django_redis.client.DefaultClient",
},
"KEY_PREFIX": "authentik_cache",
"KEY_FUNCTION": "django_tenants.cache.make_key",
"REVERSE_KEY_FUNCTION": "django_tenants.cache.reverse_key",
}
}
DJANGO_REDIS_SCAN_ITERSIZE = 1000
DJANGO_REDIS_IGNORE_EXCEPTIONS = True
DJANGO_REDIS_LOG_IGNORED_EXCEPTIONS = True
SESSION_ENGINE = "authentik.core.sessions"
# Configured via custom SessionMiddleware
# SESSION_COOKIE_SAMESITE = "None"
@@ -290,6 +296,16 @@ TEMPLATES = [
ASGI_APPLICATION = "authentik.root.asgi.application"
CHANNEL_LAYERS = {
"default": {
"BACKEND": "channels_redis.pubsub.RedisPubSubChannelLayer",
"CONFIG": {
"hosts": [CONFIG.get("channel.url") or redis_url(CONFIG.get("redis.db"))],
"prefix": "authentik_channels_",
},
},
}
# Database
# https://docs.djangoproject.com/en/2.1/ref/settings/#databases
@@ -302,16 +318,6 @@ DATABASE_ROUTERS = (
"django_tenants.routers.TenantSyncRouter",
)
CHANNEL_LAYERS = {
"default": {
"BACKEND": "authentik.root.channels.PostgresChannelLayer",
"CONFIG": {
**DATABASES["default"],
"TIME_ZONE": None,
},
},
}
# Email
# These values should never actually be used, emails are only sent from email stages, which
# loads the config directly from CONFIG
@@ -393,6 +399,8 @@ DRAMATIQ = {
).total_seconds(),
"middlewares": (
("django_dramatiq_postgres.middleware.FullyQualifiedActorName", {}),
# TODO: fixme
# ("dramatiq.middleware.prometheus.Prometheus", {}),
("django_dramatiq_postgres.middleware.DbConnectionMiddleware", {}),
("dramatiq.middleware.age_limit.AgeLimit", {}),
(
@@ -415,7 +423,7 @@ DRAMATIQ = {
},
),
("dramatiq.results.middleware.Results", {"store_results": True}),
("authentik.tasks.middleware.CurrentTask", {}),
("django_dramatiq_postgres.middleware.CurrentTask", {}),
("authentik.tasks.middleware.TenantMiddleware", {}),
("authentik.tasks.middleware.RelObjMiddleware", {}),
("authentik.tasks.middleware.MessagesMiddleware", {}),

View File

@@ -62,11 +62,6 @@ class PytestTestRunner(DiscoverRunner): # pragma: no cover
"""Configure test environment settings"""
settings.TEST = True
settings.DRAMATIQ["test"] = True
settings.CHANNEL_LAYERS["default"]["CONFIG"] = {
**settings.DATABASES["default"],
**settings.DATABASES["default"]["TEST"],
"TIME_ZONE": None,
}
# Test-specific configuration
test_config = {

View File

@@ -2,6 +2,7 @@
from django.core.cache import cache
from django.utils.translation import gettext_lazy as _
from django_dramatiq_postgres.middleware import CurrentTask
from dramatiq.actor import actor
from structlog.stdlib import get_logger
@@ -9,7 +10,7 @@ from authentik.lib.config import CONFIG
from authentik.lib.sync.outgoing.exceptions import StopSync
from authentik.sources.kerberos.models import KerberosSource
from authentik.sources.kerberos.sync import KerberosSync
from authentik.tasks.middleware import CurrentTask
from authentik.tasks.models import Task
LOGGER = get_logger()
CACHE_KEY_STATUS = "goauthentik.io/sources/kerberos/status/"
@@ -32,7 +33,7 @@ def kerberos_connectivity_check(pk: str):
description=_("Sync Kerberos source."),
)
def kerberos_sync(pk: str):
self = CurrentTask.get_task()
self: Task = CurrentTask.get_task()
source: KerberosSource = KerberosSource.objects.filter(enabled=True, pk=pk).first()
if not source:
return

View File

@@ -4,6 +4,7 @@ from uuid import uuid4
from django.core.cache import cache
from django.utils.translation import gettext_lazy as _
from django_dramatiq_postgres.middleware import CurrentTask
from dramatiq.actor import actor
from dramatiq.composition import group
from dramatiq.message import Message
@@ -20,7 +21,6 @@ from authentik.sources.ldap.sync.forward_delete_users import UserLDAPForwardDele
from authentik.sources.ldap.sync.groups import GroupLDAPSynchronizer
from authentik.sources.ldap.sync.membership import MembershipLDAPSynchronizer
from authentik.sources.ldap.sync.users import UserLDAPSynchronizer
from authentik.tasks.middleware import CurrentTask
from authentik.tasks.models import Task
LOGGER = get_logger()
@@ -53,7 +53,7 @@ def ldap_connectivity_check(pk: str | None = None):
)
def ldap_sync(source_pk: str):
"""Sync a single source"""
task = CurrentTask.get_task()
task: Task = CurrentTask.get_task()
source: LDAPSource = LDAPSource.objects.filter(pk=source_pk, enabled=True).first()
if not source:
return
@@ -127,7 +127,7 @@ def ldap_sync_paginator(
)
def ldap_sync_page(source_pk: str, sync_class: str, page_cache_key: str):
"""Synchronization of an LDAP Source"""
self = CurrentTask.get_task()
self: Task = CurrentTask.get_task()
source: LDAPSource = LDAPSource.objects.filter(pk=source_pk).first()
if not source:
# Because the source couldn't be found, we don't have a UID

View File

@@ -3,13 +3,14 @@
from json import dumps
from django.utils.translation import gettext_lazy as _
from django_dramatiq_postgres.middleware import CurrentTask
from dramatiq.actor import actor
from requests import RequestException
from structlog.stdlib import get_logger
from authentik.lib.utils.http import get_http_session
from authentik.sources.oauth.models import OAuthSource
from authentik.tasks.middleware import CurrentTask
from authentik.tasks.models import Task
LOGGER = get_logger()
@@ -20,7 +21,7 @@ LOGGER = get_logger()
)
)
def update_well_known_jwks():
self = CurrentTask.get_task()
self: Task = CurrentTask.get_task()
session = get_http_session()
for source in OAuthSource.objects.all().exclude(oidc_well_known_url=""):
try:

View File

@@ -1,6 +1,7 @@
"""Plex tasks"""
from django.utils.translation import gettext_lazy as _
from django_dramatiq_postgres.middleware import CurrentTask
from dramatiq.actor import actor
from requests import RequestException
@@ -8,13 +9,13 @@ from authentik.events.models import Event, EventAction
from authentik.lib.utils.errors import exception_to_string
from authentik.sources.plex.models import PlexSource
from authentik.sources.plex.plex import PlexAuth
from authentik.tasks.middleware import CurrentTask
from authentik.tasks.models import Task
@actor(description=_("Check the validity of a Plex source."))
def check_plex_token(source_pk: str):
"""Check the validity of a Plex source."""
self = CurrentTask.get_task()
self: Task = CurrentTask.get_task()
sources = PlexSource.objects.filter(pk=source_pk)
if not sources.exists():
return

View File

@@ -1,31 +0,0 @@
"""Telegram source property mappings API"""
from rest_framework.viewsets import ModelViewSet
from authentik.core.api.property_mappings import PropertyMappingFilterSet, PropertyMappingSerializer
from authentik.core.api.used_by import UsedByMixin
from authentik.sources.telegram.models import TelegramSourcePropertyMapping
class TelegramSourcePropertyMappingSerializer(PropertyMappingSerializer):
"""TelegramSourcePropertyMapping Serializer"""
class Meta(PropertyMappingSerializer.Meta):
model = TelegramSourcePropertyMapping
class TelegramSourcePropertyMappingFilter(PropertyMappingFilterSet):
"""Filter for TelegramSourcePropertyMapping"""
class Meta(PropertyMappingFilterSet.Meta):
model = TelegramSourcePropertyMapping
class TelegramSourcePropertyMappingViewSet(UsedByMixin, ModelViewSet):
"""TelegramSourcePropertyMapping Viewset"""
queryset = TelegramSourcePropertyMapping.objects.all()
serializer_class = TelegramSourcePropertyMappingSerializer
filterset_class = TelegramSourcePropertyMappingFilter
search_fields = ["name"]
ordering = ["name"]

View File

@@ -1,41 +0,0 @@
from rest_framework.viewsets import ModelViewSet
from authentik.core.api.sources import SourceSerializer
from authentik.core.api.used_by import UsedByMixin
from authentik.sources.telegram.models import TelegramSource
class TelegramSourceSerializer(SourceSerializer):
class Meta:
model = TelegramSource
fields = SourceSerializer.Meta.fields + [
"bot_username",
"bot_token",
"request_message_access",
"pre_authentication_flow",
]
extra_kwargs = {
"bot_token": {"write_only": True},
}
class TelegramSourceViewSet(UsedByMixin, ModelViewSet):
queryset = TelegramSource.objects.all()
serializer_class = TelegramSourceSerializer
lookup_field = "slug"
filterset_fields = [
"pbm_uuid",
"name",
"slug",
"enabled",
"authentication_flow",
"enrollment_flow",
"policy_engine_mode",
"user_matching_mode",
"group_matching_mode",
"bot_username",
"request_message_access",
]
search_fields = ["name", "slug"]
ordering = ["name"]

View File

@@ -1,33 +0,0 @@
from rest_framework.viewsets import ModelViewSet
from authentik.core.api.sources import (
GroupSourceConnectionSerializer,
GroupSourceConnectionViewSet,
UserSourceConnectionSerializer,
UserSourceConnectionViewSet,
)
from authentik.sources.telegram.models import (
GroupTelegramSourceConnection,
UserTelegramSourceConnection,
)
class UserTelegramSourceConnectionSerializer(UserSourceConnectionSerializer):
class Meta(UserSourceConnectionSerializer.Meta):
model = UserTelegramSourceConnection
fields = UserSourceConnectionSerializer.Meta.fields
class UserTelegramSourceConnectionViewSet(UserSourceConnectionViewSet, ModelViewSet):
queryset = UserTelegramSourceConnection.objects.all()
serializer_class = UserTelegramSourceConnectionSerializer
class GroupTelegramSourceConnectionSerializer(GroupSourceConnectionSerializer):
class Meta(GroupSourceConnectionSerializer.Meta):
model = GroupTelegramSourceConnection
class GroupTelegramSourceConnectionViewSet(GroupSourceConnectionViewSet, ModelViewSet):
queryset = GroupTelegramSourceConnection.objects.all()
serializer_class = GroupTelegramSourceConnectionSerializer

View File

@@ -1,9 +0,0 @@
from authentik.blueprints.apps import ManagedAppConfig
class TelegramConfig(ManagedAppConfig):
name = "authentik.sources.telegram"
label = "authentik_sources_telegram"
verbose_name = "authentik Sources.Telegram"
mountpoint = "source/telegram/"
default = True

View File

@@ -1,118 +0,0 @@
# Generated by Django 5.1.12 on 2025-09-24 07:14
import django.db.models.deletion
from django.db import migrations, models
class Migration(migrations.Migration):
initial = True
dependencies = [
("authentik_core", "0050_user_last_updated_and_more"),
("authentik_flows", "0028_flowtoken_revoke_on_execution"),
]
operations = [
migrations.CreateModel(
name="GroupTelegramSourceConnection",
fields=[
(
"groupsourceconnection_ptr",
models.OneToOneField(
auto_created=True,
on_delete=django.db.models.deletion.CASCADE,
parent_link=True,
primary_key=True,
serialize=False,
to="authentik_core.groupsourceconnection",
),
),
],
options={
"verbose_name": "Group Telegram Source Connection",
"verbose_name_plural": "Group Telegram Source Connections",
},
bases=("authentik_core.groupsourceconnection",),
),
migrations.CreateModel(
name="TelegramSourcePropertyMapping",
fields=[
(
"propertymapping_ptr",
models.OneToOneField(
auto_created=True,
on_delete=django.db.models.deletion.CASCADE,
parent_link=True,
primary_key=True,
serialize=False,
to="authentik_core.propertymapping",
),
),
],
options={
"verbose_name": "Telegram Source Property Mapping",
"verbose_name_plural": "Telegram Source Property Mappings",
},
bases=("authentik_core.propertymapping",),
),
migrations.CreateModel(
name="UserTelegramSourceConnection",
fields=[
(
"usersourceconnection_ptr",
models.OneToOneField(
auto_created=True,
on_delete=django.db.models.deletion.CASCADE,
parent_link=True,
primary_key=True,
serialize=False,
to="authentik_core.usersourceconnection",
),
),
],
options={
"verbose_name": "User Telegram Source Connection",
"verbose_name_plural": "User Telegram Source Connections",
},
bases=("authentik_core.usersourceconnection",),
),
migrations.CreateModel(
name="TelegramSource",
fields=[
(
"source_ptr",
models.OneToOneField(
auto_created=True,
on_delete=django.db.models.deletion.CASCADE,
parent_link=True,
primary_key=True,
serialize=False,
to="authentik_core.source",
),
),
("bot_username", models.TextField(help_text="Telegram bot username")),
("bot_token", models.TextField(help_text="Telegram bot token")),
(
"request_message_access",
models.BooleanField(
default=False, help_text="Request access to send messages from your bot."
),
),
(
"pre_authentication_flow",
models.ForeignKey(
help_text="Flow used before authentication.",
on_delete=django.db.models.deletion.CASCADE,
related_name="telegram_source_pre_authentication",
to="authentik_flows.flow",
),
),
],
options={
"verbose_name": "Telegram Source",
"verbose_name_plural": "Telegram Sources",
},
bases=("authentik_core.source",),
),
]

View File

@@ -1,156 +0,0 @@
"""Telegram source"""
from typing import Any
from django.db import models
from django.http import HttpRequest
from django.templatetags.static import static
from django.urls import reverse
from django.utils.translation import gettext_lazy as _
from rest_framework.serializers import BaseSerializer, Serializer
from authentik.core.models import (
GroupSourceConnection,
PropertyMapping,
Source,
UserSourceConnection,
)
from authentik.core.types import UILoginButton, UserSettingSerializer
from authentik.flows.challenge import RedirectChallenge
from authentik.flows.models import Flow
class TelegramSource(Source):
"""Log in with Telegram."""
bot_username = models.TextField(help_text=_("Telegram bot username"))
bot_token = models.TextField(help_text=_("Telegram bot token"))
request_message_access = models.BooleanField(
default=False, help_text=_("Request access to send messages from your bot.")
)
pre_authentication_flow = models.ForeignKey(
Flow,
on_delete=models.CASCADE,
help_text=_("Flow used before authentication."),
related_name="telegram_source_pre_authentication",
)
@property
def component(self) -> str:
return "ak-source-telegram-form"
@property
def icon_url(self) -> str | None:
icon = super().icon_url
if not icon:
icon = static("authentik/sources/telegram.svg")
return icon
@property
def serializer(self) -> type[BaseSerializer]:
from authentik.sources.telegram.api.source import TelegramSourceSerializer
return TelegramSourceSerializer
def ui_login_button(self, request: HttpRequest) -> UILoginButton:
return UILoginButton(
challenge=RedirectChallenge(
data={
"to": reverse(
"authentik_sources_telegram:start",
kwargs={"source_slug": self.slug},
),
}
),
name=self.name,
icon_url=self.icon_url,
)
def ui_user_settings(self) -> UserSettingSerializer | None:
return UserSettingSerializer(
data={
"title": self.name,
"component": "ak-user-settings-source-telegram",
"icon_url": self.icon_url,
}
)
@property
def property_mapping_type(self) -> "type[PropertyMapping]":
return TelegramSourcePropertyMapping
def get_base_user_properties(
self, info: dict[str, Any] | None = None, **kwargs
) -> dict[str, Any | dict[str, Any]]:
info = info or {}
name = info.get("first_name", "")
if "last_name" in info:
name += " " + info["last_name"]
return {
"username": info.get("username", None),
"email": None,
"name": name if name else None,
}
def get_base_group_properties(self, group_id: str, **kwargs):
return {
"name": group_id,
}
class Meta:
verbose_name = _("Telegram Source")
verbose_name_plural = _("Telegram Sources")
class TelegramSourcePropertyMapping(PropertyMapping):
"""Map Telegram properties to User or Group object attributes"""
@property
def component(self) -> str:
return "ak-property-mapping-source-telegram-form"
@property
def serializer(self) -> type[Serializer]:
from authentik.sources.telegram.api.property_mappings import (
TelegramSourcePropertyMappingSerializer,
)
return TelegramSourcePropertyMappingSerializer
class Meta:
verbose_name = _("Telegram Source Property Mapping")
verbose_name_plural = _("Telegram Source Property Mappings")
class UserTelegramSourceConnection(UserSourceConnection):
"""Connect user and Telegram source"""
@property
def serializer(self) -> type[Serializer]:
from authentik.sources.telegram.api.source_connection import (
UserTelegramSourceConnectionSerializer,
)
return UserTelegramSourceConnectionSerializer
class Meta:
verbose_name = _("User Telegram Source Connection")
verbose_name_plural = _("User Telegram Source Connections")
class GroupTelegramSourceConnection(GroupSourceConnection):
"""Group-source connection for Telegram"""
@property
def serializer(self) -> type[Serializer]:
from authentik.sources.telegram.api.source_connection import (
GroupTelegramSourceConnectionSerializer,
)
return GroupTelegramSourceConnectionSerializer
class Meta:
verbose_name = _("Group Telegram Source Connection")
verbose_name_plural = _("Group Telegram Source Connections")

View File

@@ -1,48 +0,0 @@
import hashlib
import hmac
from datetime import datetime, timedelta
from django.utils.translation import gettext_lazy as _
from rest_framework.fields import BooleanField, CharField, IntegerField, URLField
from rest_framework.serializers import ValidationError
from authentik.flows.challenge import Challenge, ChallengeResponse
from authentik.stages.identification.stage import LoginChallengeMixin
class TelegramLoginChallenge(LoginChallengeMixin, Challenge):
component = CharField(default="ak-source-telegram")
bot_username = CharField(help_text=_("Telegram bot username"))
request_message_access = BooleanField()
class TelegramChallengeResponse(ChallengeResponse):
component = CharField(default="ak-source-telegram")
id = IntegerField()
first_name = CharField(max_length=255, required=False)
last_name = CharField(max_length=255, required=False)
username = CharField(max_length=255, required=False)
photo_url = URLField(required=False)
auth_date = IntegerField(required=True)
hash = CharField(max_length=64, required=True)
def validate_auth_date(self, auth_date: int) -> int:
if datetime.fromtimestamp(auth_date) < datetime.now() - timedelta(minutes=5):
raise ValidationError(_("Authentication date is too old"))
return auth_date
def validate(self, attrs: dict) -> dict:
# Check the response as defined in https://core.telegram.org/widgets/login
attrs_to_check = attrs.copy()
attrs_to_check.pop("component")
attrs_to_check.pop("hash")
check_str = "\n".join([f"{key}={value}" for key, value in sorted(attrs_to_check.items())])
digest = hmac.new(
hashlib.sha256(self.stage.source.bot_token.encode("utf-8")).digest(),
check_str.encode("utf-8"),
"sha256",
).hexdigest()
if not hmac.compare_digest(digest, attrs["hash"]):
raise ValidationError(_("Invalid hash"))
return attrs

View File

@@ -1,185 +0,0 @@
"""Telegram source tests"""
import hashlib
import hmac
from datetime import datetime, timedelta
from unittest.mock import Mock
from django.test import TestCase
from django.urls import reverse
from rest_framework.exceptions import ValidationError
from authentik.core.tests.utils import create_test_flow
from authentik.flows.models import FlowDesignation, FlowStageBinding
from authentik.flows.tests import FlowTestCase
from authentik.sources.telegram.stage import TelegramChallengeResponse
from authentik.stages.identification.models import IdentificationStage, UserFields
class MockTelegramResponseMixin:
def _add_hash(self, response):
to_hash = "\n".join([f"{key}={value}" for key, value in sorted(response.items())])
response["hash"] = hmac.new(
hashlib.sha256(self.source.bot_token.encode("utf-8")).digest(),
to_hash.encode("utf-8"),
"sha256",
).hexdigest()
def _make_valid_response(self):
resp = {
"id": "123456789",
"first_name": "Test",
"last_name": "User",
"username": "testuser",
"auth_date": str(int(datetime.now().timestamp())),
}
self._add_hash(resp)
return resp
def _make_outdated_response(self):
resp = self._make_valid_response()
resp["auth_date"] = str(int((datetime.now() - timedelta(days=1)).timestamp()))
self._add_hash(resp)
return resp
class TestTelegramSource(MockTelegramResponseMixin, TestCase):
"""Telegram Source tests"""
def setUp(self):
from authentik.sources.telegram.models import TelegramSource
self.source = TelegramSource.objects.create(
name="test",
slug="test",
bot_username="test_bot",
bot_token="modern_token", # nosec
request_message_access=True,
pre_authentication_flow=create_test_flow(),
)
self.mock_stage = Mock()
self.mock_stage.source = self.source
def test_ui_login_button(self):
"""Test UI login button"""
ui_login_button = self.source.ui_login_button(None)
self.assertIsNotNone(ui_login_button)
self.assertEqual(ui_login_button.name, "test")
self.assertTrue(ui_login_button.challenge.is_valid(raise_exception=True))
def test_challenge_response(self):
"""Test correct Telegram response validation"""
cr = TelegramChallengeResponse(data=self._make_valid_response())
cr.stage = self.mock_stage
self.assertTrue(cr.is_valid(raise_exception=True))
def test_outdated_challenge_response(self):
"""Test outdated Telegram response validation"""
cr = TelegramChallengeResponse(data=self._make_outdated_response())
cr.stage = self.mock_stage
with self.assertRaises(ValidationError):
cr.is_valid(raise_exception=True)
def test_invalid_hash_challenge_response(self):
"""Test invalid hash in Telegram response validation"""
resp = self._make_valid_response()
resp["hash"] = "invalid_hash"
cr = TelegramChallengeResponse(data=resp)
cr.stage = self.mock_stage
with self.assertRaises(ValidationError):
cr.is_valid(raise_exception=True)
def test_user_base_properties(self):
"""Test user base properties"""
cr = TelegramChallengeResponse(data=self._make_valid_response())
cr.stage = self.mock_stage
cr.is_valid(raise_exception=True)
properties = self.source.get_base_user_properties(info=cr.validated_data)
self.assertEqual(
properties,
{
"username": "testuser",
"name": "Test User",
"email": None,
},
)
def test_group_base_properties(self):
"""Test group base properties"""
for group_id in ["group 1", "group 2"]:
properties = self.source.get_base_group_properties(group_id=group_id)
self.assertEqual(properties, {"name": group_id})
class TestTelegramViews(MockTelegramResponseMixin, FlowTestCase):
"""Test Telegram source views"""
def setUp(self):
super().setUp()
from authentik.sources.telegram.models import TelegramSource
self.pre_auth_flow = create_test_flow()
self.source = TelegramSource.objects.create(
name="test",
slug="test",
bot_username="test_bot",
bot_token="modern_token", # nosec
request_message_access=True,
enrollment_flow=create_test_flow(),
pre_authentication_flow=self.pre_auth_flow,
)
self.flow = create_test_flow(FlowDesignation.AUTHENTICATION)
self.stage = IdentificationStage.objects.create(
name="identification",
user_fields=[UserFields.E_MAIL],
pretend_user_exists=False,
)
self.stage.sources.set([self.source])
self.stage.save()
FlowStageBinding.objects.create(
target=self.flow,
stage=self.stage,
order=0,
)
def _make_initial_request(self):
return self.client.get(
reverse("authentik_api:flow-executor", kwargs={"flow_slug": self.flow.slug})
)
def _make_start_request(self):
return self.client.get(
reverse("authentik_sources_telegram:start", kwargs={"source_slug": self.source.slug}),
follow=True,
)
def test_start_view(self):
"""Test TelegramStartView"""
self.assertEqual(self._make_initial_request().status_code, 200)
response = self._make_start_request()
self.assertEqual(response.status_code, 200)
self.assertEqual(
response.redirect_chain[0][0],
reverse("authentik_core:if-flow", kwargs={"flow_slug": self.pre_auth_flow.slug}),
)
def test_challenge_view(self):
"""Test TelegramLoginView"""
self._make_initial_request()
self._make_start_request()
url = reverse("authentik_api:flow-executor", kwargs={"flow_slug": self.pre_auth_flow.slug})
get_response = self.client.get(url)
self.assertEqual(get_response.status_code, 200)
form_data = self._make_valid_response()
form_data["component"] = "ak-source-telegram"
response = self.client.post(url, form_data)
self.assertEqual(response.status_code, 200)
self.assertStageRedirects(
response,
reverse(
"authentik_core:if-flow", kwargs={"flow_slug": self.source.enrollment_flow.slug}
),
)

View File

@@ -1,23 +0,0 @@
"""Telegram source API views"""
from django.urls import path
from authentik.sources.telegram.api.property_mappings import TelegramSourcePropertyMappingViewSet
from authentik.sources.telegram.api.source import TelegramSourceViewSet
from authentik.sources.telegram.api.source_connection import (
GroupTelegramSourceConnectionViewSet,
UserTelegramSourceConnectionViewSet,
)
from authentik.sources.telegram.views import TelegramLoginView, TelegramStartView
urlpatterns = [
path("<slug:source_slug>/start/", TelegramStartView.as_view(), name="start"),
path("<slug:source_slug>/", TelegramLoginView.as_view(), name="login"),
]
api_urlpatterns = [
("propertymappings/source/telegram", TelegramSourcePropertyMappingViewSet),
("sources/user_connections/telegram", UserTelegramSourceConnectionViewSet),
("sources/group_connections/telegram", GroupTelegramSourceConnectionViewSet),
("sources/telegram", TelegramSourceViewSet),
]

View File

@@ -1,98 +0,0 @@
from django.http import Http404, HttpRequest, HttpResponse
from django.shortcuts import get_object_or_404
from django.views import View
from authentik.core.sources.flow_manager import SourceFlowManager
from authentik.flows.challenge import Challenge
from authentik.flows.exceptions import FlowNonApplicableException
from authentik.flows.models import in_memory_stage
from authentik.flows.planner import (
PLAN_CONTEXT_REDIRECT,
PLAN_CONTEXT_SOURCE,
PLAN_CONTEXT_SSO,
FlowPlanner,
)
from authentik.flows.stage import ChallengeStageView
from authentik.flows.views.executor import NEXT_ARG_NAME, SESSION_KEY_GET
from authentik.sources.telegram.models import (
GroupTelegramSourceConnection,
TelegramSource,
UserTelegramSourceConnection,
)
from authentik.sources.telegram.stage import TelegramChallengeResponse, TelegramLoginChallenge
class TelegramStartView(View):
def handle_login_flow(
self, source: TelegramSource, *stages_to_append, **kwargs
) -> HttpResponse:
"""Prepare Authentication Plan, redirect user FlowExecutor"""
# Ensure redirect is carried through when user was trying to
# authorize application
final_redirect = self.request.session.get(SESSION_KEY_GET, {}).get(
NEXT_ARG_NAME, "authentik_core:if-user"
)
kwargs.update(
{
PLAN_CONTEXT_SSO: True,
PLAN_CONTEXT_SOURCE: source,
PLAN_CONTEXT_REDIRECT: final_redirect,
}
)
# We run the Flow planner here so we can pass the Pending user in the context
planner = FlowPlanner(source.pre_authentication_flow)
planner.allow_empty_flows = True
try:
plan = planner.plan(self.request, kwargs)
except FlowNonApplicableException:
raise Http404 from None
for stage in stages_to_append:
plan.append_stage(stage)
return plan.to_redirect(self.request, source.pre_authentication_flow)
def get(self, request: HttpRequest, source_slug: str) -> HttpResponse:
source = get_object_or_404(TelegramSource, slug=source_slug, enabled=True)
telegram_login_stage = in_memory_stage(TelegramLoginView)
return self.handle_login_flow(source, telegram_login_stage)
class TelegramSourceFlowManager(SourceFlowManager):
"""Flow manager for Telegram source"""
user_connection_type = UserTelegramSourceConnection
group_connection_type = GroupTelegramSourceConnection
class TelegramLoginView(ChallengeStageView):
response_class = TelegramChallengeResponse
def dispatch(self, request, *args, **kwargs):
self.source = self.executor.plan.context[PLAN_CONTEXT_SOURCE]
return super().dispatch(request, *args, **kwargs)
def get_challenge(self, *args, **kwargs) -> Challenge:
return TelegramLoginChallenge(
data={
"bot_username": self.source.bot_username,
"request_message_access": self.source.request_message_access,
},
)
def challenge_valid(self, response: TelegramChallengeResponse) -> HttpResponse:
raw_info = response.validated_data.copy()
raw_info.pop("component")
raw_info.pop("hash")
raw_info.pop("auth_date")
source = self.source
sfm = TelegramSourceFlowManager(
source=source,
request=self.request,
identifier=raw_info["id"],
user_info={"info": raw_info},
policy_context={"telegram": raw_info},
)
return sfm.get_flow(
raw_info=raw_info,
)

File diff suppressed because one or more lines are too long

View File

@@ -7,6 +7,7 @@ from pathlib import Path
from django.core.cache import cache
from django.db.transaction import atomic
from django.utils.translation import gettext_lazy as _
from django_dramatiq_postgres.middleware import CurrentTask
from dramatiq.actor import actor
from fido2.mds3 import filter_revoked, parse_blob
@@ -14,7 +15,7 @@ from authentik.stages.authenticator_webauthn.models import (
UNKNOWN_DEVICE_TYPE_AAGUID,
WebAuthnDeviceType,
)
from authentik.tasks.middleware import CurrentTask
from authentik.tasks.models import Task
CACHE_KEY_MDS_NO = "goauthentik.io/stages/authenticator_webauthn/mds_no"
AAGUID_BLOB_PATH = Path(__file__).parent / "mds" / "aaguid.json"
@@ -32,7 +33,7 @@ def mds_ca() -> bytes:
@actor(description=_("Background task to import FIDO Alliance MDS blob and AAGUIDs into database."))
def webauthn_mds_import(force=False):
"""Background task to import FIDO Alliance MDS blob and AAGUIDs into database"""
self = CurrentTask.get_task()
self: Task = CurrentTask.get_task()
with open(MDS_BLOB_PATH, mode="rb") as _raw_blob:
blob = parse_blob(_raw_blob.read(), mds_ca())
to_create_update = [

View File

@@ -7,6 +7,7 @@ from django.core.mail import EmailMultiAlternatives
from django.core.mail.utils import DNS_NAME
from django.utils.text import slugify
from django.utils.translation import gettext_lazy as _
from django_dramatiq_postgres.middleware import CurrentTask
from dramatiq.actor import actor
from dramatiq.composition import group
from structlog.stdlib import get_logger
@@ -16,31 +17,27 @@ from authentik.lib.utils.reflection import class_to_path, path_to_class
from authentik.stages.authenticator_email.models import AuthenticatorEmailStage
from authentik.stages.email.models import EmailStage
from authentik.stages.email.utils import logo_data
from authentik.tasks.middleware import CurrentTask
from authentik.tasks.models import Task
LOGGER = get_logger()
def send_mails(
stage: EmailStage | AuthenticatorEmailStage | None, *messages: list[EmailMultiAlternatives]
stage: EmailStage | AuthenticatorEmailStage, *messages: list[EmailMultiAlternatives]
):
"""Wrapper to convert EmailMessage to dict and send it from worker
Args:
stage: Either an EmailStage or AuthenticatorEmailStage instance,
or nothing to use global settings
stage: Either an EmailStage or AuthenticatorEmailStage instance
messages: List of email messages to send
Returns:
Dramatiq group promise for the email sending tasks
"""
tasks = []
# Use the class path instead of the class itself for serialization
stage_class_path, stage_pk = None, None
if stage:
stage_class_path = class_to_path(stage.__class__)
stage_pk = str(stage.pk)
stage_class_path = class_to_path(stage.__class__)
for message in messages:
tasks.append(send_mail.message(message.__dict__, stage_class_path, stage_pk))
tasks.append(send_mail.message(message.__dict__, stage_class_path, str(stage.pk)))
return group(tasks).run()
@@ -59,7 +56,7 @@ def send_mail(
email_stage_pk: str | None = None,
):
"""Send Email for Email Stage. Retries are scheduled automatically."""
self = CurrentTask.get_task()
self: Task = CurrentTask.get_task()
message_id = make_msgid(domain=DNS_NAME)
self.set_uid(slugify(message_id.replace(".", "_").replace("@", "_")))
if not stage_class_path or not email_stage_pk:

View File

@@ -123,7 +123,7 @@ class UserLoginStageView(ChallengeStageView):
def is_known_device(self, user: User):
"""Returns `True` if the login happened on a "known" device, by the same user."""
client_ip = ClientIPMiddleware.get_client_ip(self.request)
if AuthenticatedSession.objects.filter(session__last_ip=client_ip, user=user).exists():
if AuthenticatedSession.objects.filter(session__last_ip=client_ip).exists():
return True
if COOKIE_NAME_KNOWN_DEVICE not in self.request.COOKIES:
return False

View File

@@ -1,22 +1,21 @@
import socket
from http.server import BaseHTTPRequestHandler
from time import sleep
from typing import Any, cast
from typing import Any
import pglock
from django.db import OperationalError, connections
from django.utils.timezone import now
from django_dramatiq_postgres.middleware import (
CurrentTask as BaseCurrentTask,
)
from django_dramatiq_postgres.middleware import HTTPServer
from django_dramatiq_postgres.middleware import (
MetricsMiddleware as BaseMetricsMiddleware,
)
from django_redis import get_redis_connection
from dramatiq.broker import Broker
from dramatiq.message import Message
from dramatiq.middleware import Middleware
from psycopg.errors import Error
from redis.exceptions import RedisError
from structlog.stdlib import get_logger
from authentik import authentik_full_version
@@ -29,13 +28,7 @@ from authentik.tenants.utils import get_current_tenant
LOGGER = get_logger()
HEALTHCHECK_LOGGER = get_logger("authentik.worker").bind()
DB_ERRORS = (OperationalError, Error)
class CurrentTask(BaseCurrentTask):
@classmethod
def get_task(cls) -> Task:
return cast(Task, super().get_task())
DB_ERRORS = (OperationalError, Error, RedisError)
class TenantMiddleware(Middleware):
@@ -186,6 +179,8 @@ class _healthcheck_handler(BaseHTTPRequestHandler):
# Force connection reload
db_conn.connect()
_ = db_conn.cursor()
redis_conn = get_redis_connection()
redis_conn.ping()
self.send_response(200)
except DB_ERRORS: # pragma: no cover
self.send_response(503)

View File

@@ -1,48 +0,0 @@
from channels.db import database_sync_to_async
from django.db import close_old_connections, connection
from django.http.request import split_domain_port
from django_tenants.utils import (
get_public_schema_name,
remove_www,
)
from authentik.tenants.models import Domain, Tenant
class TenantsAwareMiddleware:
"""Set the database schema for use with django-tenants"""
def __init__(self, inner):
self.inner = inner
def get_hostname_from_scope(self, scope: list[tuple[bytes, bytes]]) -> str | None:
headers = {k.replace(b"-", b"_").upper(): v for k, v in scope.get("headers", [])}
hostname, _ = split_domain_port(headers.get(b"HOST", b"").decode("utf-8"))
if not hostname:
return None
return remove_www(hostname)
async def get_default_tenant(self) -> Tenant:
return await database_sync_to_async(Tenant.objects.get)(
schema_name=get_public_schema_name()
)
async def get_tenant(self, hostname: str | None) -> Tenant:
if not hostname:
return await self.get_default_tenant()
try:
domain = await database_sync_to_async(Domain.objects.select_related("tenant").get)(
domain=hostname
)
except Domain.DoesNotExist:
return await self.get_default_tenant()
return domain.tenant
async def __call__(self, scope, receive, send):
close_old_connections()
hostname = self.get_hostname_from_scope(scope)
tenant = await self.get_tenant(hostname)
scope["tenant"] = tenant
connection.set_tenant(tenant)
return await self.inner(scope, receive, send)

View File

@@ -3016,166 +3016,6 @@
}
}
},
{
"type": "object",
"required": [
"model",
"identifiers"
],
"properties": {
"model": {
"const": "authentik_sources_telegram.grouptelegramsourceconnection"
},
"id": {
"type": "string"
},
"state": {
"type": "string",
"enum": [
"absent",
"created",
"must_created",
"present"
],
"default": "present"
},
"conditions": {
"type": "array",
"items": {
"type": "boolean"
}
},
"permissions": {
"$ref": "#/$defs/model_authentik_sources_telegram.grouptelegramsourceconnection_permissions"
},
"attrs": {
"$ref": "#/$defs/model_authentik_sources_telegram.grouptelegramsourceconnection"
},
"identifiers": {
"$ref": "#/$defs/model_authentik_sources_telegram.grouptelegramsourceconnection"
}
}
},
{
"type": "object",
"required": [
"model",
"identifiers"
],
"properties": {
"model": {
"const": "authentik_sources_telegram.telegramsource"
},
"id": {
"type": "string"
},
"state": {
"type": "string",
"enum": [
"absent",
"created",
"must_created",
"present"
],
"default": "present"
},
"conditions": {
"type": "array",
"items": {
"type": "boolean"
}
},
"permissions": {
"$ref": "#/$defs/model_authentik_sources_telegram.telegramsource_permissions"
},
"attrs": {
"$ref": "#/$defs/model_authentik_sources_telegram.telegramsource"
},
"identifiers": {
"$ref": "#/$defs/model_authentik_sources_telegram.telegramsource"
}
}
},
{
"type": "object",
"required": [
"model",
"identifiers"
],
"properties": {
"model": {
"const": "authentik_sources_telegram.telegramsourcepropertymapping"
},
"id": {
"type": "string"
},
"state": {
"type": "string",
"enum": [
"absent",
"created",
"must_created",
"present"
],
"default": "present"
},
"conditions": {
"type": "array",
"items": {
"type": "boolean"
}
},
"permissions": {
"$ref": "#/$defs/model_authentik_sources_telegram.telegramsourcepropertymapping_permissions"
},
"attrs": {
"$ref": "#/$defs/model_authentik_sources_telegram.telegramsourcepropertymapping"
},
"identifiers": {
"$ref": "#/$defs/model_authentik_sources_telegram.telegramsourcepropertymapping"
}
}
},
{
"type": "object",
"required": [
"model",
"identifiers"
],
"properties": {
"model": {
"const": "authentik_sources_telegram.usertelegramsourceconnection"
},
"id": {
"type": "string"
},
"state": {
"type": "string",
"enum": [
"absent",
"created",
"must_created",
"present"
],
"default": "present"
},
"conditions": {
"type": "array",
"items": {
"type": "boolean"
}
},
"permissions": {
"$ref": "#/$defs/model_authentik_sources_telegram.usertelegramsourceconnection_permissions"
},
"attrs": {
"$ref": "#/$defs/model_authentik_sources_telegram.usertelegramsourceconnection"
},
"identifiers": {
"$ref": "#/$defs/model_authentik_sources_telegram.usertelegramsourceconnection"
}
}
},
{
"type": "object",
"required": [
@@ -5431,22 +5271,6 @@
"authentik_sources_scim.view_scimsourcegroup",
"authentik_sources_scim.view_scimsourcepropertymapping",
"authentik_sources_scim.view_scimsourceuser",
"authentik_sources_telegram.add_grouptelegramsourceconnection",
"authentik_sources_telegram.add_telegramsource",
"authentik_sources_telegram.add_telegramsourcepropertymapping",
"authentik_sources_telegram.add_usertelegramsourceconnection",
"authentik_sources_telegram.change_grouptelegramsourceconnection",
"authentik_sources_telegram.change_telegramsource",
"authentik_sources_telegram.change_telegramsourcepropertymapping",
"authentik_sources_telegram.change_usertelegramsourceconnection",
"authentik_sources_telegram.delete_grouptelegramsourceconnection",
"authentik_sources_telegram.delete_telegramsource",
"authentik_sources_telegram.delete_telegramsourcepropertymapping",
"authentik_sources_telegram.delete_usertelegramsourceconnection",
"authentik_sources_telegram.view_grouptelegramsourceconnection",
"authentik_sources_telegram.view_telegramsource",
"authentik_sources_telegram.view_telegramsourcepropertymapping",
"authentik_sources_telegram.view_usertelegramsourceconnection",
"authentik_stages_authenticator_duo.add_authenticatorduostage",
"authentik_stages_authenticator_duo.add_duodevice",
"authentik_stages_authenticator_duo.change_authenticatorduostage",
@@ -7512,7 +7336,6 @@
"authentik.sources.plex",
"authentik.sources.saml",
"authentik.sources.scim",
"authentik.sources.telegram",
"authentik.stages.authenticator",
"authentik.stages.authenticator_duo",
"authentik.stages.authenticator_email",
@@ -7623,10 +7446,6 @@
"authentik_sources_saml.groupsamlsourceconnection",
"authentik_sources_scim.scimsource",
"authentik_sources_scim.scimsourcepropertymapping",
"authentik_sources_telegram.telegramsource",
"authentik_sources_telegram.telegramsourcepropertymapping",
"authentik_sources_telegram.usertelegramsourceconnection",
"authentik_sources_telegram.grouptelegramsourceconnection",
"authentik_stages_authenticator_duo.authenticatorduostage",
"authentik_stages_authenticator_duo.duodevice",
"authentik_stages_authenticator_email.authenticatoremailstage",
@@ -10157,22 +9976,6 @@
"authentik_sources_scim.view_scimsourcegroup",
"authentik_sources_scim.view_scimsourcepropertymapping",
"authentik_sources_scim.view_scimsourceuser",
"authentik_sources_telegram.add_grouptelegramsourceconnection",
"authentik_sources_telegram.add_telegramsource",
"authentik_sources_telegram.add_telegramsourcepropertymapping",
"authentik_sources_telegram.add_usertelegramsourceconnection",
"authentik_sources_telegram.change_grouptelegramsourceconnection",
"authentik_sources_telegram.change_telegramsource",
"authentik_sources_telegram.change_telegramsourcepropertymapping",
"authentik_sources_telegram.change_usertelegramsourceconnection",
"authentik_sources_telegram.delete_grouptelegramsourceconnection",
"authentik_sources_telegram.delete_telegramsource",
"authentik_sources_telegram.delete_telegramsourcepropertymapping",
"authentik_sources_telegram.delete_usertelegramsourceconnection",
"authentik_sources_telegram.view_grouptelegramsourceconnection",
"authentik_sources_telegram.view_telegramsource",
"authentik_sources_telegram.view_telegramsourcepropertymapping",
"authentik_sources_telegram.view_usertelegramsourceconnection",
"authentik_stages_authenticator_duo.add_authenticatorduostage",
"authentik_stages_authenticator_duo.add_duodevice",
"authentik_stages_authenticator_duo.change_authenticatorduostage",
@@ -12265,289 +12068,6 @@
}
}
},
"model_authentik_sources_telegram.grouptelegramsourceconnection": {
"type": "object",
"properties": {
"group": {
"type": "string",
"format": "uuid",
"title": "Group"
},
"source": {
"type": "integer",
"title": "Source"
},
"identifier": {
"type": "string",
"minLength": 1,
"title": "Identifier"
},
"icon": {
"type": "string",
"minLength": 1,
"title": "Icon"
}
},
"required": []
},
"model_authentik_sources_telegram.grouptelegramsourceconnection_permissions": {
"type": "array",
"items": {
"type": "object",
"required": [
"permission"
],
"properties": {
"permission": {
"type": "string",
"enum": [
"add_grouptelegramsourceconnection",
"change_grouptelegramsourceconnection",
"delete_grouptelegramsourceconnection",
"view_grouptelegramsourceconnection"
]
},
"user": {
"type": "integer"
},
"role": {
"type": "string"
}
}
}
},
"model_authentik_sources_telegram.telegramsource": {
"type": "object",
"properties": {
"name": {
"type": "string",
"minLength": 1,
"title": "Name",
"description": "Source's display Name."
},
"slug": {
"type": "string",
"maxLength": 50,
"minLength": 1,
"pattern": "^[-a-zA-Z0-9_]+$",
"title": "Slug",
"description": "Internal source name, used in URLs."
},
"enabled": {
"type": "boolean",
"title": "Enabled"
},
"authentication_flow": {
"type": "string",
"format": "uuid",
"title": "Authentication flow",
"description": "Flow to use when authenticating existing users."
},
"enrollment_flow": {
"type": "string",
"format": "uuid",
"title": "Enrollment flow",
"description": "Flow to use when enrolling new users."
},
"user_property_mappings": {
"type": "array",
"items": {
"type": "string",
"format": "uuid"
},
"title": "User property mappings"
},
"group_property_mappings": {
"type": "array",
"items": {
"type": "string",
"format": "uuid"
},
"title": "Group property mappings"
},
"policy_engine_mode": {
"type": "string",
"enum": [
"all",
"any"
],
"title": "Policy engine mode"
},
"user_matching_mode": {
"type": "string",
"enum": [
"identifier",
"email_link",
"email_deny",
"username_link",
"username_deny"
],
"title": "User matching mode",
"description": "How the source determines if an existing user should be authenticated or a new user enrolled."
},
"user_path_template": {
"type": "string",
"minLength": 1,
"title": "User path template"
},
"icon": {
"type": "string",
"minLength": 1,
"title": "Icon"
},
"bot_username": {
"type": "string",
"minLength": 1,
"title": "Bot username",
"description": "Telegram bot username"
},
"bot_token": {
"type": "string",
"minLength": 1,
"title": "Bot token",
"description": "Telegram bot token"
},
"request_message_access": {
"type": "boolean",
"title": "Request message access",
"description": "Request access to send messages from your bot."
},
"pre_authentication_flow": {
"type": "string",
"format": "uuid",
"title": "Pre authentication flow",
"description": "Flow used before authentication."
}
},
"required": []
},
"model_authentik_sources_telegram.telegramsource_permissions": {
"type": "array",
"items": {
"type": "object",
"required": [
"permission"
],
"properties": {
"permission": {
"type": "string",
"enum": [
"add_telegramsource",
"change_telegramsource",
"delete_telegramsource",
"view_telegramsource"
]
},
"user": {
"type": "integer"
},
"role": {
"type": "string"
}
}
}
},
"model_authentik_sources_telegram.telegramsourcepropertymapping": {
"type": "object",
"properties": {
"managed": {
"type": [
"string",
"null"
],
"minLength": 1,
"title": "Managed by authentik",
"description": "Objects that are managed by authentik. These objects are created and updated automatically. This flag only indicates that an object can be overwritten by migrations. You can still modify the objects via the API, but expect changes to be overwritten in a later update."
},
"name": {
"type": "string",
"minLength": 1,
"title": "Name"
},
"expression": {
"type": "string",
"minLength": 1,
"title": "Expression"
}
},
"required": []
},
"model_authentik_sources_telegram.telegramsourcepropertymapping_permissions": {
"type": "array",
"items": {
"type": "object",
"required": [
"permission"
],
"properties": {
"permission": {
"type": "string",
"enum": [
"add_telegramsourcepropertymapping",
"change_telegramsourcepropertymapping",
"delete_telegramsourcepropertymapping",
"view_telegramsourcepropertymapping"
]
},
"user": {
"type": "integer"
},
"role": {
"type": "string"
}
}
}
},
"model_authentik_sources_telegram.usertelegramsourceconnection": {
"type": "object",
"properties": {
"user": {
"type": "integer",
"title": "User"
},
"source": {
"type": "integer",
"title": "Source"
},
"identifier": {
"type": "string",
"minLength": 1,
"title": "Identifier"
},
"icon": {
"type": "string",
"minLength": 1,
"title": "Icon"
}
},
"required": []
},
"model_authentik_sources_telegram.usertelegramsourceconnection_permissions": {
"type": "array",
"items": {
"type": "object",
"required": [
"permission"
],
"properties": {
"permission": {
"type": "string",
"enum": [
"add_usertelegramsourceconnection",
"change_usertelegramsourceconnection",
"delete_usertelegramsourceconnection",
"view_usertelegramsourceconnection"
]
},
"user": {
"type": "integer"
},
"role": {
"type": "string"
}
}
}
},
"model_authentik_stages_authenticator_duo.authenticatorduostage": {
"type": "object",
"properties": {

6
go.mod
View File

@@ -11,7 +11,7 @@ require (
github.com/coreos/go-oidc/v3 v3.15.0
github.com/getsentry/sentry-go v0.35.3
github.com/go-http-utils/etag v0.0.0-20161124023236-513ea8f21eb1
github.com/go-ldap/ldap/v3 v3.4.12
github.com/go-ldap/ldap/v3 v3.4.11
github.com/go-openapi/runtime v0.29.0
github.com/golang-jwt/jwt/v5 v5.3.0
github.com/google/uuid v1.6.0
@@ -32,7 +32,7 @@ require (
github.com/spf13/cobra v1.10.1
github.com/stretchr/testify v1.11.1
github.com/wwt/guac v1.3.2
goauthentik.io/api/v3 v3.2025100.16
goauthentik.io/api/v3 v3.2025100.14
golang.org/x/exp v0.0.0-20230210204819-062eb4c674ab
golang.org/x/oauth2 v0.31.0
golang.org/x/sync v0.17.0
@@ -96,5 +96,3 @@ require (
google.golang.org/protobuf v1.36.8 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)
replace goauthentik.io/api/v3 => ./gen-go-api

12
go.sum
View File

@@ -4,8 +4,8 @@ beryju.io/radius-eap v0.1.0 h1:5M3HwkzH3nIEBcKDA2z5+sb4nCY3WdKL/SDDKTBvoqw=
beryju.io/radius-eap v0.1.0/go.mod h1:yYtO59iyoLNEepdyp1gZ0i1tGdjPbrR2M/v5yOz7Fkc=
github.com/Azure/go-ntlmssp v0.0.0-20221128193559-754e69321358 h1:mFRzDkZVAjdal+s7s0MwaRv9igoPqLRdzOLzw/8Xvq8=
github.com/Azure/go-ntlmssp v0.0.0-20221128193559-754e69321358/go.mod h1:chxPXzSsl7ZWRAuOIE23GDNzjWuZquvFlgA8xmpunjU=
github.com/alexbrainman/sspi v0.0.0-20250919150558-7d374ff0d59e h1:4dAU9FXIyQktpoUAgOJK3OTFc/xug0PCXYCqU0FgDKI=
github.com/alexbrainman/sspi v0.0.0-20250919150558-7d374ff0d59e/go.mod h1:cEWa1LVoE5KvSD9ONXsZrj0z6KqySlCCNKHlLzbqAt4=
github.com/alexbrainman/sspi v0.0.0-20231016080023-1a75b4708caa h1:LHTHcTQiSGT7VVbI0o4wBRNQIgn917usHWOd6VAffYI=
github.com/alexbrainman/sspi v0.0.0-20231016080023-1a75b4708caa/go.mod h1:cEWa1LVoE5KvSD9ONXsZrj0z6KqySlCCNKHlLzbqAt4=
github.com/asaskevich/govalidator v0.0.0-20230301143203-a9d515a09cc2 h1:DklsrG3dyBCFEj5IhUbnKptjxatkF07cF2ak3yi77so=
github.com/asaskevich/govalidator v0.0.0-20230301143203-a9d515a09cc2/go.mod h1:WaHUgvxTVq04UNunO+XhnAqY/wQc+bxr74GqbsZ/Jqw=
github.com/avast/retry-go/v4 v4.6.1 h1:VkOLRubHdisGrHnTu89g08aQEWEgRU7LVEop3GbIcMk=
@@ -42,8 +42,8 @@ github.com/go-http-utils/headers v0.0.0-20181008091004-fed159eddc2a h1:v6zMvHuY9
github.com/go-http-utils/headers v0.0.0-20181008091004-fed159eddc2a/go.mod h1:I79BieaU4fxrw4LMXby6q5OS9XnoR9UIKLOzDFjUmuw=
github.com/go-jose/go-jose/v4 v4.0.5 h1:M6T8+mKZl/+fNNuFHvGIzDz7BTLQPIounk/b9dw3AaE=
github.com/go-jose/go-jose/v4 v4.0.5/go.mod h1:s3P1lRrkT8igV8D9OjyL4WRyHvjB6a4JSllnOrmmBOA=
github.com/go-ldap/ldap/v3 v3.4.12 h1:1b81mv7MagXZ7+1r7cLTWmyuTqVqdwbtJSjC0DAp9s4=
github.com/go-ldap/ldap/v3 v3.4.12/go.mod h1:+SPAGcTtOfmGsCb3h1RFiq4xpp4N636G75OEace8lNo=
github.com/go-ldap/ldap/v3 v3.4.11 h1:4k0Yxweg+a3OyBLjdYn5OKglv18JNvfDykSoI8bW0gU=
github.com/go-ldap/ldap/v3 v3.4.11/go.mod h1:bY7t0FLK8OAVpp/vV6sSlpz3EQDGcQwc8pF0ujLgKvM=
github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A=
github.com/go-logr/logr v1.4.3 h1:CjnDlHq8ikf6E492q6eKboGOC0T8CDaOvkHCIg8idEI=
github.com/go-logr/logr v1.4.3/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY=
@@ -207,8 +207,8 @@ go.yaml.in/yaml/v2 v2.4.2 h1:DzmwEr2rDGHl7lsFgAHxmNz/1NlQ7xLIrlN2h5d1eGI=
go.yaml.in/yaml/v2 v2.4.2/go.mod h1:081UH+NErpNdqlCXm3TtEran0rJZGxAYx9hb/ELlsPU=
go.yaml.in/yaml/v3 v3.0.4 h1:tfq32ie2Jv2UxXFdLJdh3jXuOzWiL1fo0bu/FbuKpbc=
go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg=
goauthentik.io/api/v3 v3.2025100.16 h1:+tkfBSjpIt7eo267Y86LDmprrxo7bxiuLVKh/dZuGMg=
goauthentik.io/api/v3 v3.2025100.16/go.mod h1:82lqAz4jxzl6Cg0YDbhNtvvTG2rm6605ZhdJFnbbsl8=
goauthentik.io/api/v3 v3.2025100.14 h1:fBRPhJ+nMIzD3AHC8ofQcBs6ZKvOWD+r/tOCnbZupX0=
goauthentik.io/api/v3 v3.2025100.14/go.mod h1:82lqAz4jxzl6Cg0YDbhNtvvTG2rm6605ZhdJFnbbsl8=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
golang.org/x/crypto v0.13.0/go.mod h1:y6Z2r+Rw4iayiXXAIxJIDAJ1zMW4yaTpebo8fPOliYc=

View File

@@ -1,7 +1,7 @@
# syntax=docker/dockerfile:1
# Stage 1: Build
FROM --platform=${BUILDPLATFORM} docker.io/library/golang:1.25.1-bookworm AS builder
FROM --platform=${BUILDPLATFORM} docker.io/library/golang:1.25-bookworm AS builder
ARG TARGETOS
ARG TARGETARCH

View File

@@ -9,8 +9,8 @@
"version": "0.0.0",
"license": "MIT",
"devDependencies": {
"aws-cdk": "^2.1029.4",
"cross-env": "^10.1.0"
"aws-cdk": "^2.1029.3",
"cross-env": "^10.0.0"
},
"engines": {
"node": ">=20"
@@ -24,9 +24,9 @@
"license": "MIT"
},
"node_modules/aws-cdk": {
"version": "2.1029.4",
"resolved": "https://registry.npmjs.org/aws-cdk/-/aws-cdk-2.1029.4.tgz",
"integrity": "sha512-rJa8QLd8WHaoTEjPLqVwmNpDMmyJycVaxdr/Evr/1MDLq+WCovP46IqPaXfH0q/jY0gCsga9or907tEayK5xcg==",
"version": "2.1029.3",
"resolved": "https://registry.npmjs.org/aws-cdk/-/aws-cdk-2.1029.3.tgz",
"integrity": "sha512-otRJP5a4r07S+SLKs/WvJH+0auZHkaRMnv1vtD4fpp1figV8Vr9MKdB4QPNjfKdLGyK9f95OEHwVlIW9xpjPBg==",
"dev": true,
"license": "Apache-2.0",
"bin": {
@@ -40,9 +40,9 @@
}
},
"node_modules/cross-env": {
"version": "10.1.0",
"resolved": "https://registry.npmjs.org/cross-env/-/cross-env-10.1.0.tgz",
"integrity": "sha512-GsYosgnACZTADcmEyJctkJIoqAhHjttw7RsFrVoJNXbsWWqaq6Ym+7kZjq6mS45O0jij6vtiReppKQEtqWy6Dw==",
"version": "10.0.0",
"resolved": "https://registry.npmjs.org/cross-env/-/cross-env-10.0.0.tgz",
"integrity": "sha512-aU8qlEK/nHYtVuN4p7UQgAwVljzMg8hB4YK5ThRqD2l/ziSnryncPNn7bMLt5cFYsKVKBh8HqLqyCoTupEUu7Q==",
"dev": true,
"license": "MIT",
"dependencies": {

View File

@@ -10,7 +10,7 @@
"node": ">=20"
},
"devDependencies": {
"aws-cdk": "^2.1029.4",
"cross-env": "^10.1.0"
"aws-cdk": "^2.1029.3",
"cross-env": "^10.0.0"
}
}

View File

@@ -1,5 +1,7 @@
# flake8: noqa
from redis import Redis
from authentik.lib.config import CONFIG
from lifecycle.migrate import BaseMigration
SQL_STATEMENT = """BEGIN TRANSACTION;
@@ -104,3 +106,17 @@ class Migration(BaseMigration):
def run(self):
with self.con.transaction():
self.cur.execute(SQL_STATEMENT)
# We also need to clean the cache to make sure no pickeled objects still exist
for db in [
CONFIG.get("redis.message_queue_db"),
CONFIG.get("redis.cache_db"),
CONFIG.get("redis.ws_db"),
]:
redis = Redis(
host=CONFIG.get("redis.host"),
port=6379,
db=db,
username=CONFIG.get("redis.username"),
password=CONFIG.get("redis.password"),
)
redis.flushall()

View File

@@ -1,13 +1,14 @@
#!/usr/bin/env python
"""This file needs to be run from the root of the project to correctly
import authentik. This is done by the dockerfile."""
from sys import exit as sysexit
from time import sleep
from psycopg import OperationalError, connect
from redis import Redis
from redis.exceptions import RedisError
from authentik.lib.config import CONFIG
from authentik.lib.config import CONFIG, redis_url
CHECK_THRESHOLD = 30
@@ -39,6 +40,24 @@ def check_postgres():
CONFIG.log("info", "PostgreSQL connection successful")
def check_redis():
url = CONFIG.get("cache.url") or redis_url(CONFIG.get("redis.db"))
attempt = 0
while True:
if attempt >= CHECK_THRESHOLD:
sysexit(1)
try:
redis = Redis.from_url(url)
redis.ping()
break
except RedisError as exc:
sleep(1)
CONFIG.log("info", f"Redis Connection failed, retrying... ({exc})")
finally:
attempt += 1
CONFIG.log("info", "Redis Connection successful")
def wait_for_db():
CONFIG.log("info", "Starting authentik bootstrap")
# Sanity check, ensure SECRET_KEY is set before we even check for database connectivity
@@ -50,6 +69,7 @@ def wait_for_db():
CONFIG.log("info", "----------------------------------------------------------------------")
sysexit(1)
check_postgres()
check_redis()
CONFIG.log("info", "Finished authentik bootstrap")

Binary file not shown.

View File

@@ -8,7 +8,7 @@ msgid ""
msgstr ""
"Project-Id-Version: PACKAGE VERSION\n"
"Report-Msgid-Bugs-To: \n"
"POT-Creation-Date: 2025-10-02 00:10+0000\n"
"POT-Creation-Date: 2025-09-26 00:10+0000\n"
"PO-Revision-Date: YEAR-MO-DA HO:MI+ZONE\n"
"Last-Translator: FULL NAME <EMAIL@ADDRESS>\n"
"Language-Team: LANGUAGE <LL@li.org>\n"
@@ -152,18 +152,6 @@ msgstr ""
msgid "No empty segments in user path allowed."
msgstr ""
#: authentik/core/api/users.py
msgid "A user/group with these details already exists"
msgstr ""
#: authentik/core/api/users.py
msgid "Unable to create user"
msgstr ""
#: authentik/core/api/users.py
msgid "Unknown error occurred"
msgstr ""
#: authentik/core/api/users.py
msgid "This field is required."
msgstr ""
@@ -1678,13 +1666,6 @@ msgid ""
"minutes=2;seconds=3)."
msgstr ""
#: authentik/providers/oauth2/models.py
msgid ""
"When refreshing a token, if the refresh token is valid for less than this "
"duration, it will be renewed. When set to seconds=0, token will always be "
"renewed. (Format: hours=1;minutes=2;seconds=3)."
msgstr ""
#: authentik/providers/oauth2/models.py
msgid ""
"Configure what data should be used as unique User Identifier. For most "
@@ -2817,12 +2798,6 @@ msgstr ""
msgid "Check the validity of a Plex source."
msgstr ""
#: authentik/sources/saml/api/source.py
msgid ""
"With a Verification Certificate selected, at least one of 'Verify Assertion "
"Signature' or 'Verify Response Signature' must be selected."
msgstr ""
#: authentik/sources/saml/models.py
msgid "Redirect Binding"
msgstr ""
@@ -2835,7 +2810,7 @@ msgstr ""
msgid "POST Binding with auto-confirmation"
msgstr ""
#: authentik/sources/saml/models.py authentik/sources/telegram/models.py
#: authentik/sources/saml/models.py
msgid "Flow used before authentication."
msgstr ""
@@ -2942,58 +2917,6 @@ msgstr ""
msgid "SCIM Source Property Mappings"
msgstr ""
#: authentik/sources/telegram/models.py authentik/sources/telegram/stage.py
msgid "Telegram bot username"
msgstr ""
#: authentik/sources/telegram/models.py
msgid "Telegram bot token"
msgstr ""
#: authentik/sources/telegram/models.py
msgid "Request access to send messages from your bot."
msgstr ""
#: authentik/sources/telegram/models.py
msgid "Telegram Source"
msgstr ""
#: authentik/sources/telegram/models.py
msgid "Telegram Sources"
msgstr ""
#: authentik/sources/telegram/models.py
msgid "Telegram Source Property Mapping"
msgstr ""
#: authentik/sources/telegram/models.py
msgid "Telegram Source Property Mappings"
msgstr ""
#: authentik/sources/telegram/models.py
msgid "User Telegram Source Connection"
msgstr ""
#: authentik/sources/telegram/models.py
msgid "User Telegram Source Connections"
msgstr ""
#: authentik/sources/telegram/models.py
msgid "Group Telegram Source Connection"
msgstr ""
#: authentik/sources/telegram/models.py
msgid "Group Telegram Source Connections"
msgstr ""
#: authentik/sources/telegram/stage.py
msgid "Authentication date is too old"
msgstr ""
#: authentik/sources/telegram/stage.py
msgid "Invalid hash"
msgstr ""
#: authentik/stages/authenticator_duo/models.py
msgid "Duo Authenticator Setup Stage"
msgstr ""

View File

@@ -11,13 +11,10 @@ class DjangoDramatiqPostgres(AppConfig):
name = "django_dramatiq_postgres"
verbose_name = "Django Dramatiq postgres"
def ready(self) -> None:
try:
old_broker = dramatiq.get_broker()
except ModuleNotFoundError:
old_broker = None
def ready(self):
old_broker = dramatiq.get_broker()
if old_broker is not None and len(old_broker.actors) != 0:
if len(old_broker.actors) != 0:
raise ImproperlyConfigured(
"Actors were previously registered. "
"Make sure your actors are not imported too early."
@@ -41,7 +38,7 @@ class DjangoDramatiqPostgres(AppConfig):
*Conf().result_backend_args,
**Conf().result_backend_kwargs,
)
broker.add_middleware(middleware) # type: ignore[no-untyped-call]
broker.add_middleware(middleware)
dramatiq.set_broker(broker)

View File

@@ -1,10 +1,9 @@
import functools
import logging
import time
from collections.abc import Callable, Iterable
from datetime import timedelta
from collections.abc import Iterable
from queue import Empty, Queue
from typing import Any, ParamSpec, TypeVar, cast
from typing import Any
import tenacity
from django.core.exceptions import ImproperlyConfigured
@@ -38,36 +37,31 @@ from django_dramatiq_postgres.models import CHANNEL_PREFIX, ChannelIdentifier, T
logger = get_logger(__name__)
P = ParamSpec("P")
R = TypeVar("R")
def channel_name(queue_name: str, identifier: ChannelIdentifier) -> str:
return f"{CHANNEL_PREFIX}.{queue_name}.{identifier.value}"
def raise_connection_error(func: Callable[P, R]) -> Callable[P, R]:
def raise_connection_error(func):
@functools.wraps(func)
def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
def wrapper(*args, **kwargs):
try:
return func(*args, **kwargs)
except OperationalError as exc:
raise ConnectionError(str(exc)) from exc # type: ignore[no-untyped-call]
raise ConnectionError(str(exc)) from exc
return wrapper
class PostgresBroker(Broker):
queues: set[str] # type: ignore[assignment]
def __init__(
self,
*args: Any,
*args,
middleware: list[Middleware] | None = None,
db_alias: str = DEFAULT_DB_ALIAS,
**kwargs: Any,
) -> None:
super().__init__(*args, middleware=[], **kwargs) # type: ignore[no-untyped-call,misc]
**kwargs,
):
super().__init__(*args, middleware=[], **kwargs)
self.logger = get_logger(__name__, type(self))
self.queues = set()
@@ -81,7 +75,7 @@ class PostgresBroker(Broker):
@property
def connection(self) -> DatabaseWrapper:
return cast(DatabaseWrapper, connections[self.db_alias])
return connections[self.db_alias]
@property
def consumer_class(self) -> "type[_PostgresConsumer]":
@@ -89,12 +83,11 @@ class PostgresBroker(Broker):
@cached_property
def model(self) -> type[TaskBase]:
model: type[TaskBase] = import_string(Conf().task_model)
return model
return import_string(Conf().task_model)
@property
def query_set(self) -> QuerySet[TaskBase]:
return self.model._default_manager.using(self.db_alias).defer("message", "result")
def query_set(self) -> QuerySet:
return self.model.objects.using(self.db_alias).defer("message", "result")
def consume(self, queue_name: str, prefetch: int = 1, timeout: int = 30000) -> Consumer:
self.declare_queue(queue_name)
@@ -106,18 +99,18 @@ class PostgresBroker(Broker):
timeout=timeout,
)
def declare_queue(self, queue_name: str) -> None:
def declare_queue(self, queue_name: str):
if queue_name not in self.queues:
self.emit_before("declare_queue", queue_name) # type: ignore[no-untyped-call]
self.emit_before("declare_queue", queue_name)
self.queues.add(queue_name)
# Nothing to do, all queues are in the same table
self.emit_after("declare_queue", queue_name) # type: ignore[no-untyped-call]
self.emit_after("declare_queue", queue_name)
delayed_name = dq_name(queue_name) # type: ignore[no-untyped-call]
delayed_name = dq_name(queue_name)
self.delay_queues.add(delayed_name)
self.emit_after("declare_delay_queue", delayed_name) # type: ignore[no-untyped-call]
self.emit_after("declare_delay_queue", delayed_name)
def model_defaults(self, message: Message[Any]) -> dict[str, Any]:
def model_defaults(self, message: Message) -> dict[str, Any]:
return {
"queue_name": message.queue_name,
"actor_name": message.actor_name,
@@ -137,16 +130,14 @@ class PostgresBroker(Broker):
reraise=True,
wait=tenacity.wait_random_exponential(multiplier=1, max=5),
stop=tenacity.stop_after_attempt(3),
before_sleep=tenacity.before_sleep_log(
cast(logging.Logger, logger), logging.INFO, exc_info=True
),
before_sleep=tenacity.before_sleep_log(logger, logging.INFO, exc_info=True),
)
def enqueue(self, message: Message[Any], *, delay: int | None = None) -> Message[Any]:
def enqueue(self, message: Message, *, delay: int | None = None) -> Message:
canonical_queue_name = message.queue_name
queue_name = canonical_queue_name
if delay:
queue_name = dq_name(queue_name) # type: ignore[no-untyped-call]
message_eta = current_millis() + delay # type: ignore[no-untyped-call]
queue_name = dq_name(queue_name)
message_eta = current_millis() + delay
message = message.copy(
queue_name=queue_name,
options={
@@ -160,7 +151,7 @@ class PostgresBroker(Broker):
)
message.options["model_defaults"] = self.model_defaults(message)
self.emit_before("enqueue", message, delay) # type: ignore[no-untyped-call]
self.emit_before("enqueue", message, delay)
with transaction.atomic(using=self.db_alias):
query = {
@@ -182,18 +173,18 @@ class PostgresBroker(Broker):
message.options["task"] = task
message.options["task_created"] = created
self.emit_after("enqueue", message, delay) # type: ignore[no-untyped-call]
self.emit_after("enqueue", message, delay)
return message
def get_declared_queues(self) -> set[str]:
return self.queues.copy()
def flush(self, queue_name: str) -> None:
def flush(self, queue_name: str):
self.query_set.filter(
queue_name__in=(queue_name, dq_name(queue_name), xq_name(queue_name)) # type: ignore[no-untyped-call]
queue_name__in=(queue_name, dq_name(queue_name), xq_name(queue_name))
).delete()
def flush_all(self) -> None:
def flush_all(self):
for queue_name in self.queues:
self.flush(queue_name)
@@ -203,11 +194,11 @@ class PostgresBroker(Broker):
interval: int = 100,
*,
timeout: int | None = None,
) -> None:
):
deadline = timeout and time.monotonic() + timeout / 1000
while True:
if deadline and time.monotonic() >= deadline:
raise QueueJoinTimeout(queue_name) # type: ignore[no-untyped-call]
raise QueueJoinTimeout(queue_name)
if self.query_set.filter(
queue_name=queue_name,
@@ -221,14 +212,14 @@ class PostgresBroker(Broker):
class _PostgresConsumer(Consumer):
def __init__(
self,
*args: Any,
*args,
broker: PostgresBroker,
db_alias: str,
queue_name: str,
prefetch: int,
timeout: int,
**kwargs: Any,
) -> None:
**kwargs,
):
self.logger = get_logger(__name__, type(self))
self.notifies: list[Notify] = []
@@ -236,8 +227,8 @@ class _PostgresConsumer(Consumer):
self.db_alias = db_alias
self.queue_name = queue_name
self.timeout = timeout // 1000
self.unlock_queue: Queue[str] = Queue()
self.in_processing: set[str] = set()
self.unlock_queue = Queue()
self.in_processing = set()
self.prefetch = prefetch
self.misses = 0
self._listen_connection: DatabaseWrapper | None = None
@@ -246,34 +237,32 @@ class _PostgresConsumer(Consumer):
# Override because dramatiq doesn't allow us setting this manually
self.timeout = Conf().worker["consumer_listen_timeout"]
self.lock_purge_interval = timedelta(seconds=Conf().lock_purge_interval)
self.lock_purge_interval = timezone.timedelta(seconds=Conf().lock_purge_interval)
self.lock_purge_last_run = timezone.now()
self.task_purge_interval = timedelta(seconds=Conf().task_purge_interval)
self.task_purge_interval = timezone.timedelta(seconds=Conf().task_purge_interval)
self.task_purge_last_run = timezone.now() - self.task_purge_interval
self.scheduler = None
if Conf().schedule_model:
self.scheduler = import_string(Conf().scheduler_class)()
self.scheduler.broker = self.broker
self.scheduler_interval = timedelta(seconds=Conf().scheduler_interval)
self.scheduler_interval = timezone.timedelta(seconds=Conf().scheduler_interval)
self.scheduler_last_run = timezone.now() - self.scheduler_interval
@property
def connection(self) -> DatabaseWrapper:
return cast(DatabaseWrapper, connections[self.db_alias])
return connections[self.db_alias]
@property
def query_set(self) -> QuerySet[TaskBase]:
def query_set(self) -> QuerySet:
return self.broker.query_set
@property
def listen_connection(self) -> DatabaseWrapper:
if self._listen_connection is not None and self._listen_connection.is_usable():
return self._listen_connection
self._listen_connection = cast(
DatabaseWrapper, connections.create_connection(self.db_alias)
)
self._listen_connection = connections.create_connection(self.db_alias)
# Required for notifications
# See https://www.psycopg.org/psycopg3/docs/advanced/async.html#asynchronous-notifications
# Should be set to True by Django by default
@@ -283,7 +272,7 @@ class _PostgresConsumer(Consumer):
return self._listen_connection
@raise_connection_error
def ack(self, message: Message[Any]) -> None:
def ack(self, message: Message):
task = message.options.pop("task", None)
self.query_set.filter(
message_id=message.message_id,
@@ -298,7 +287,7 @@ class _PostgresConsumer(Consumer):
self.in_processing.remove(message.message_id)
@raise_connection_error
def nack(self, message: Message[Any]) -> None:
def nack(self, message: Message):
task = message.options.pop("task", None)
self.query_set.filter(
message_id=message.message_id,
@@ -314,7 +303,7 @@ class _PostgresConsumer(Consumer):
self.in_processing.remove(message.message_id)
@raise_connection_error
def requeue(self, messages: Iterable[Message[Any]]) -> None:
def requeue(self, messages: Iterable[Message]):
self.query_set.filter(
message_id__in=[message.message_id for message in messages],
).update(
@@ -337,9 +326,7 @@ class _PostgresConsumer(Consumer):
)
.values_list("message_id", flat=True)
)
return [
Notify(pid=0, channel=self.postgres_channel, payload=str(item)) for item in notifies
]
return [Notify(pid=0, channel=self.postgres_channel, payload=item) for item in notifies]
def _poll_for_notify(self) -> list[Notify]:
with self.listen_connection.cursor() as cursor:
@@ -352,12 +339,11 @@ class _PostgresConsumer(Consumer):
return notifies
def _get_message_lock_id(self, message_id: str) -> int:
lock_id = _cast_lock_id(
return _cast_lock_id(
f"{channel_name(self.queue_name, ChannelIdentifier.LOCK)}.{message_id}"
) # type: ignore[no-untyped-call]
return cast(int, lock_id)
)
def _consume_one(self, message: Message[Any]) -> bool:
def _consume_one(self, message: Message) -> bool:
if message.message_id in self.in_processing:
self.logger.debug("Message already consumed by self", message_id=message.message_id)
return False
@@ -402,7 +388,7 @@ class _PostgresConsumer(Consumer):
# If we have too many messages already processing, wait and don't consume a message
# straight away, other workers will be faster.
# After waiting consume a message regardless.
self.misses, backoff_ms = compute_backoff(self.misses, max_backoff=1000) # type: ignore[no-untyped-call]
self.misses, backoff_ms = compute_backoff(self.misses, max_backoff=1000)
self.logger.debug(
"Too many messages in processing, Sleeping",
processing=processing,
@@ -426,11 +412,11 @@ class _PostgresConsumer(Consumer):
)
if task is None:
continue
message = Message.decode(cast(bytes, task.message))
message = Message.decode(task.message)
message.options["task"] = task
if self._consume_one(message):
self.in_processing.add(message.message_id)
return MessageProxy(message) # type: ignore[no-untyped-call]
return MessageProxy(message)
else:
self.logger.debug(
"Message already consumed. Skipping.", message_id=message.message_id
@@ -443,7 +429,7 @@ class _PostgresConsumer(Consumer):
self.misses = 0
return None
def _purge_locks(self) -> None:
def _purge_locks(self):
if timezone.now() - self.lock_purge_last_run < self.lock_purge_interval:
return
while True:
@@ -459,19 +445,19 @@ class _PostgresConsumer(Consumer):
self.unlock_queue.task_done()
self.lock_purge_last_run = timezone.now()
def _auto_purge(self) -> None:
def _auto_purge(self):
if timezone.now() - self.task_purge_last_run < self.task_purge_interval:
return
self.logger.debug("Running garbage collector")
count = self.query_set.filter(
state__in=(TaskState.DONE, TaskState.REJECTED),
mtime__lte=timezone.now() - timedelta(seconds=Conf().task_expiration),
mtime__lte=timezone.now() - timezone.timedelta(seconds=Conf().task_purge_interval),
result_expiry__lte=timezone.now(),
).delete()
self.logger.info("Purged messages in all queues", count=count)
self.task_purge_last_run = timezone.now()
def _scheduler(self) -> None:
def _scheduler(self):
if not self.scheduler:
return
if timezone.now() - self.scheduler_last_run < self.scheduler_interval:
@@ -480,7 +466,7 @@ class _PostgresConsumer(Consumer):
self.schedule_last_run = timezone.now()
@raise_connection_error
def close(self) -> None:
def close(self):
try:
self._purge_locks()
finally:

View File

@@ -1,11 +1,11 @@
from typing import Any, cast
from typing import Any
from django.conf import settings
from django.core.exceptions import ImproperlyConfigured
class Conf:
def __init__(self) -> None:
def __init__(self):
try:
_ = settings.DRAMATIQ
except AttributeError as exc:
@@ -19,75 +19,68 @@ class Conf:
@property
def encoder_class(self) -> str:
return cast(str, self.conf.get("encoder_class", "dramatiq.encoder.PickleEncoder"))
return self.conf.get("encoder_class", "dramatiq.encoder.PickleEncoder")
@property
def broker_class(self) -> str:
return cast(
str, self.conf.get("broker_class", "django_dramatiq_postgres.broker.PostgresBroker")
)
return self.conf.get("broker_class", "django_dramatiq_postgres.broker.PostgresBroker")
@property
def broker_args(self) -> tuple[Any]:
return cast(tuple[Any], self.conf.get("broker_args", tuple()))
return self.conf.get("broker_args", ())
@property
def broker_kwargs(self) -> dict[str, Any]:
return cast(dict[str, Any], self.conf.get("broker_kwargs", {}))
return self.conf.get("broker_kwargs", {})
@property
def middlewares(self) -> tuple[tuple[str, dict[str, Any]]]:
return cast(
tuple[tuple[str, dict[str, Any]]],
self.conf.get(
"middlewares",
(
("django_dramatiq_postgres.middleware.DbConnectionMiddleware", {}),
("dramatiq.middleware.age_limit.AgeLimit", {}),
("dramatiq.middleware.time_limit.TimeLimit", {}),
("dramatiq.middleware.shutdown.ShutdownNotifications", {}),
("dramatiq.middleware.callbacks.Callbacks", {}),
("dramatiq.middleware.pipelines.Pipelines", {}),
("dramatiq.middleware.retries.Retries", {}),
),
return self.conf.get(
"middlewares",
(
("django_dramatiq_postgres.middleware.DbConnectionMiddleware", {}),
("dramatiq.middleware.age_limit.AgeLimit", {}),
("dramatiq.middleware.time_limit.TimeLimit", {}),
("dramatiq.middleware.shutdown.ShutdownNotifications", {}),
("dramatiq.middleware.callbacks.Callbacks", {}),
("dramatiq.middleware.pipelines.Pipelines", {}),
("dramatiq.middleware.retries.Retries", {}),
),
)
@property
def channel_prefix(self) -> str:
return cast(str, self.conf.get("channel_prefix", "dramatiq"))
return self.conf.get("channel_prefix", "dramatiq")
@property
def task_model(self) -> str:
return cast(str, self.conf["task_model"])
return self.conf["task_model"]
@property
def lock_purge_interval(self) -> int:
return cast(int, self.conf.get("lock_purge_interval", 60))
return self.conf.get("lock_purge_interval", 60)
@property
def task_purge_interval(self) -> int:
# 24 hours
return cast(int, self.conf.get("task_purge_interval", 24 * 60 * 60))
return self.conf.get("task_purge_interval", 24 * 60 * 60)
@property
def task_expiration(self) -> int:
# 30 days
return cast(int, self.conf.get("task_expiration", 60 * 60 * 24 * 30))
return self.conf.get("task_expiration", 60 * 60 * 24 * 30)
@property
def result_backend(self) -> str:
return cast(
str, self.conf.get("result_backend", "django_dramatiq_postgres.results.PostgresBackend")
)
return self.conf.get("result_backend", "django_dramatiq_postgres.results.PostgresBackend")
@property
def result_backend_args(self) -> tuple[Any]:
return cast(tuple[Any], self.conf.get("result_backend_args", ()))
return self.conf.get("result_backend_args", ())
@property
def result_backend_kwargs(self) -> dict[str, Any]:
return cast(dict[str, Any], self.conf.get("result_backend_kwargs", {}))
return self.conf.get("result_backend_kwargs", {})
@property
def autodiscovery(self) -> dict[str, Any]:
@@ -120,18 +113,16 @@ class Conf:
@property
def scheduler_class(self) -> str:
return cast(
str, self.conf.get("scheduler_class", "django_dramatiq_postgres.scheduler.Scheduler")
)
return self.conf.get("scheduler_class", "django_dramatiq_postgres.scheduler.Scheduler")
@property
def schedule_model(self) -> str | None:
return cast(str | None, self.conf.get("schedule_model"))
return self.conf.get("schedule_model")
@property
def scheduler_interval(self) -> int:
return cast(int, self.conf.get("scheduler_interval", 60))
return self.conf.get("scheduler_interval", 60)
@property
def test(self) -> bool:
return cast(bool, self.conf.get("test", False))
return self.conf.get("test", False)

View File

@@ -1,11 +1,12 @@
import os
from importlib import import_module
from signal import pause
from django.utils.module_loading import import_module
from django_dramatiq_postgres.conf import Conf
def worker_metrics() -> None:
def worker_metrics():
import_module(Conf().autodiscovery["setup_module"])
from django_dramatiq_postgres.middleware import MetricsMiddleware

View File

@@ -1,12 +1,11 @@
import sys
from argparse import Namespace
from typing import Any
from django.apps.registry import apps
from django.core.management.base import BaseCommand, CommandParser
from django.core.management.base import BaseCommand
from django.db import connections
from django.utils.module_loading import import_string, module_has_submodule
from dramatiq.cli import main
from dramatiq.__main__ import main
from django_dramatiq_postgres.conf import Conf
@@ -14,7 +13,7 @@ from django_dramatiq_postgres.conf import Conf
class Command(BaseCommand):
"""Run worker"""
def add_arguments(self, parser: CommandParser) -> None:
def add_arguments(self, parser):
parser.add_argument(
"--pid-file",
action="store",
@@ -32,11 +31,11 @@ class Command(BaseCommand):
def handle(
self,
pid_file: str,
watch: bool,
verbosity: int,
**options: Any,
) -> None:
pid_file,
watch,
verbosity,
**options,
):
worker = Conf().worker
setup, modules = self._discover_tasks_modules()
args = Namespace(
@@ -71,7 +70,7 @@ class Command(BaseCommand):
args.verbose = verbosity - 1
connections.close_all()
sys.exit(main(args)) # type: ignore[no-untyped-call]
sys.exit(main(args))
def _discover_tasks_modules(self) -> tuple[str, list[str]]:
# Does not support a tasks directory

View File

@@ -1,13 +1,15 @@
import contextvars
import os
import socket
from collections.abc import Callable
from http.server import BaseHTTPRequestHandler
from http.server import HTTPServer as BaseHTTPServer
from ipaddress import IPv6Address, ip_address
from typing import Any, cast
from typing import Any
from django.db import DatabaseError, close_old_connections, connections
from django.db import (
close_old_connections,
connections,
)
from dramatiq.actor import Actor
from dramatiq.broker import Broker
from dramatiq.common import current_millis
@@ -20,11 +22,10 @@ from django_dramatiq_postgres.models import TaskBase
class HTTPServer(BaseHTTPServer):
def server_bind(self) -> None:
def server_bind(self):
self.socket.close()
host, port = self.server_address[:2]
host = cast(str, host)
if host == "0.0.0.0" and socket.has_dualstack_ipv6(): # nosec
host = "::" # nosec
@@ -51,7 +52,7 @@ class HTTPServer(BaseHTTPServer):
class DbConnectionMiddleware(Middleware):
def _close_old_connections(self, *args: Any, **kwargs: Any) -> None:
def _close_old_connections(self, *args, **kwargs):
if Conf().test:
return
close_old_connections()
@@ -59,7 +60,7 @@ class DbConnectionMiddleware(Middleware):
before_process_message = _close_old_connections
after_process_message = _close_old_connections
def _close_connections(self, *args: Any, **kwargs: Any) -> None:
def _close_connections(self, *args, **kwargs):
connections.close_all()
before_consumer_thread_shutdown = _close_connections
@@ -68,7 +69,7 @@ class DbConnectionMiddleware(Middleware):
class FullyQualifiedActorName(Middleware):
def before_declare_actor(self, broker: Broker, actor: Actor[Any, Any]) -> None:
def before_declare_actor(self, broker: Broker, actor: Actor):
actor.actor_name = f"{actor.fn.__module__}.{actor.fn.__name__}"
@@ -79,7 +80,7 @@ class CurrentTaskNotFound(Exception):
class CurrentTask(Middleware):
def __init__(self) -> None:
def __init__(self):
self.logger = get_logger(__name__, type(self))
# This is a list of tasks, so that in tests, when a task calls another task, this acts as a pile
@@ -95,7 +96,7 @@ class CurrentTask(Middleware):
raise CurrentTaskNotFound()
return task[-1]
def before_process_message(self, broker: Broker, message: Message[Any]) -> None:
def before_process_message(self, broker: Broker, message: Message):
tasks = self._TASKS.get()
if tasks is None:
tasks = []
@@ -105,11 +106,11 @@ class CurrentTask(Middleware):
def after_process_message(
self,
broker: Broker,
message: Message[Any],
message: Message,
*,
result: Any | None = None,
exception: Exception | None = None,
) -> None:
):
tasks: list[TaskBase] | None = self._TASKS.get()
if tasks is None or len(tasks) == 0:
return
@@ -128,19 +129,13 @@ class CurrentTask(Middleware):
fields_to_update = [
f.name
for f in task._meta.get_fields()
if f.name not in fields_to_exclude
and f.concrete
and not f.auto_created
and not f.many_to_many
if f.name not in fields_to_exclude and not f.auto_created and f.column
]
if fields_to_update:
try:
task.save(update_fields=fields_to_update)
except DatabaseError:
pass
task.save(update_fields=fields_to_update)
self._TASKS.set(tasks[:-1])
def after_skip_message(self, broker: Broker, message: Message[Any]) -> None:
def after_skip_message(self, broker: Broker, message: Message):
self.after_process_message(broker, message)
@@ -149,21 +144,21 @@ class MetricsMiddleware(Middleware):
self,
prefix: str,
labels: list[str] | None = None,
) -> None:
):
super().__init__()
self.prefix = prefix
self.labels: list[str] = labels if labels is not None else ["queue_name", "actor_name"]
self.delayed_messages: set[str] = set()
self.message_start_times: dict[str, int] = {}
self.delayed_messages = set()
self.message_start_times = {}
@property
def forks(self) -> list[Callable[[], None]]:
def forks(self):
from django_dramatiq_postgres.forks import worker_metrics
return [worker_metrics]
def before_worker_boot(self, broker: Broker, worker: Any) -> None:
def before_worker_boot(self, broker: Broker, worker):
if Conf().test:
return
@@ -229,47 +224,47 @@ class MetricsMiddleware(Middleware):
),
)
def after_worker_shutdown(self, broker: Broker, worker: Any) -> None:
def after_worker_shutdown(self, broker: Broker, worker):
from prometheus_client import multiprocess
# TODO: worker_id
multiprocess.mark_process_dead(os.getpid()) # type: ignore[no-untyped-call]
multiprocess.mark_process_dead(os.getpid())
def _make_labels(self, message: Message[Any]) -> list[str]:
def _make_labels(self, message: Message) -> list[str]:
return [message.queue_name, message.actor_name]
def after_nack(self, broker: Broker, message: Message[Any]) -> None:
def after_nack(self, broker: Broker, message: Message):
self.total_rejected_messages.labels(*self._make_labels(message)).inc()
def after_enqueue(self, broker: Broker, message: Message[Any], delay: int) -> None:
def after_enqueue(self, broker: Broker, message: Message, delay: int):
if "retries" in message.options:
self.total_retried_messages.labels(*self._make_labels(message)).inc()
def before_delay_message(self, broker: Broker, message: Message[Any]) -> None:
def before_delay_message(self, broker: Broker, message: Message):
self.delayed_messages.add(message.message_id)
self.in_progress_delayed_messages.labels(*self._make_labels(message)).inc()
def before_process_message(self, broker: Broker, message: Message[Any]) -> None:
def before_process_message(self, broker: Broker, message: Message):
labels = self._make_labels(message)
if message.message_id in self.delayed_messages:
self.delayed_messages.remove(message.message_id)
self.in_progress_delayed_messages.labels(*labels).dec()
self.in_progress_messages.labels(*labels).inc()
self.message_start_times[message.message_id] = current_millis() # type: ignore[no-untyped-call]
self.message_start_times[message.message_id] = current_millis()
def after_process_message(
self,
broker: Broker,
message: Message[Any],
message: Message,
*,
result: Any | None = None,
exception: Exception | None = None,
) -> None:
):
labels = self._make_labels(message)
message_start_time = self.message_start_times.pop(message.message_id, current_millis()) # type: ignore[no-untyped-call]
message_duration = current_millis() - message_start_time # type: ignore[no-untyped-call]
message_start_time = self.message_start_times.pop(message.message_id, current_millis())
message_duration = current_millis() - message_start_time
self.messages_durations.labels(*labels).observe(message_duration)
self.in_progress_messages.labels(*labels).dec()
@@ -280,7 +275,7 @@ class MetricsMiddleware(Middleware):
after_skip_message = after_process_message
@staticmethod
def run(addr: str, port: int) -> None:
def run(addr: str, port: int):
try:
server = HTTPServer((addr, port), _MetricsHandler)
server.serve_forever()
@@ -291,7 +286,7 @@ class MetricsMiddleware(Middleware):
class _MetricsHandler(BaseHTTPRequestHandler):
def do_GET(self) -> None:
def do_GET(self):
from prometheus_client import (
CONTENT_TYPE_LATEST,
CollectorRegistry,
@@ -300,13 +295,13 @@ class _MetricsHandler(BaseHTTPRequestHandler):
)
registry = CollectorRegistry()
multiprocess.MultiProcessCollector(registry) # type: ignore[no-untyped-call]
multiprocess.MultiProcessCollector(registry)
output = generate_latest(registry)
self.send_response(200)
self.send_header("Content-Type", CONTENT_TYPE_LATEST)
self.end_headers()
self.wfile.write(output)
def log_message(self, format: str, *args: Any) -> None:
def log_message(self, format, *args):
logger = get_logger(__name__, type(self))
logger.debug(format, *args)

View File

@@ -1,7 +1,5 @@
import pickle # nosec
from datetime import datetime, timedelta
from enum import StrEnum, auto
from typing import Any
from uuid import uuid4
import pgtrigger
@@ -10,7 +8,7 @@ from django.contrib.contenttypes.fields import GenericForeignKey
from django.contrib.contenttypes.models import ContentType
from django.core.exceptions import ValidationError
from django.db import models
from django.utils.timezone import now
from django.utils.timezone import datetime, now, timedelta
from django.utils.translation import gettext_lazy as _
from dramatiq.actor import Actor
from dramatiq.broker import Broker, get_broker
@@ -73,11 +71,11 @@ class TaskBase(models.Model):
),
)
def __str__(self) -> str:
def __str__(self):
return str(self.message_id)
def validate_crontab(value: str) -> None:
def validate_crontab(value):
try:
Cron(value)
except ValueError as exc:
@@ -121,14 +119,14 @@ class ScheduleBase(models.Model):
),
)
def __init__(self, *args: Any, **kwargs: Any) -> None:
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._original_crontab = self.crontab
def __str__(self) -> str:
def __str__(self):
return f"Schedule {self.actor_name} ({self.id})"
def save(self, *args: Any, **kwargs: Any) -> None:
def save(self, *args, **kwargs):
if self.crontab != self._original_crontab:
self.next_run = self.compute_next_run(now())
@@ -137,16 +135,16 @@ class ScheduleBase(models.Model):
self._original_crontab = self.crontab
@classmethod
def dispatch_by_actor(cls, actor: Actor[Any, Any]) -> None:
def dispatch_by_actor(cls, actor: Actor):
"""Dispatch a schedule by looking up its actor.
Only available for schedules without custom arguments."""
schedule = cls._default_manager.filter(actor_name=actor.actor_name, paused=False).first()
schedule = cls.objects.filter(actor_name=actor.actor_name, paused=False).first()
if schedule:
schedule.send()
def send(self, broker: Broker | None = None) -> Message[Any]:
def send(self, broker: Broker | None = None) -> Message:
broker = broker or get_broker()
actor: Actor[Any, Any] = broker.get_actor(self.actor_name) # type: ignore[no-untyped-call]
actor: Actor = broker.get_actor(self.actor_name)
return actor.send_with_options(
args=pickle.loads(self.args), # nosec
kwargs=pickle.loads(self.kwargs), # nosec
@@ -155,7 +153,7 @@ class ScheduleBase(models.Model):
)
def compute_next_run(self, next_run: datetime | None = None) -> datetime:
next_run = self.next_run if not next_run else next_run
next_run: datetime = self.next_run if not next_run else next_run
while True:
next_run = Cron(self.crontab).schedule(next_run).next()
if next_run > now():

View File

@@ -1,6 +1,3 @@
from datetime import timedelta
from typing import Any, cast
from django.db import DEFAULT_DB_ALIAS
from django.db.models import QuerySet
from django.utils import timezone
@@ -14,27 +11,26 @@ from django_dramatiq_postgres.models import TaskBase
class PostgresBackend(ResultBackend):
def __init__(self, *args: Any, db_alias: str = DEFAULT_DB_ALIAS, **kwargs: Any) -> None:
def __init__(self, *args, db_alias: str = DEFAULT_DB_ALIAS, **kwargs):
super().__init__(*args, **kwargs)
self.db_alias = db_alias
@cached_property
def model(self) -> type[TaskBase]:
model: type[TaskBase] = import_string(Conf().task_model)
return model
return import_string(Conf().task_model)
@property
def query_set(self) -> QuerySet[TaskBase]:
return self.model._default_manager.using(self.db_alias).defer("message")
def query_set(self) -> QuerySet:
return self.model.objects.using(self.db_alias).defer("message")
def build_message_key(self, message: Message[Result]) -> str:
def build_message_key(self, message: Message) -> str:
return str(message.message_id)
def _get(self, message_key: str) -> MResult:
message = self.query_set.filter(message_id=message_key).first()
if message is None:
return Missing
data = cast(bytes | None, message.result)
data = message.result
if data is None:
return Missing
return self.encoder.decode(data)
@@ -43,5 +39,5 @@ class PostgresBackend(ResultBackend):
self.query_set.filter(message_id=message_key).update(
mtime=timezone.now(),
result=self.encoder.encode(result),
result_expiry=timezone.now() + timedelta(milliseconds=ttl),
result_expiry=timezone.now() + timezone.timedelta(milliseconds=ttl),
)

View File

@@ -1,5 +1,3 @@
from typing import Any, cast
import pglock
from django.db import router, transaction
from django.db.models import QuerySet
@@ -16,21 +14,19 @@ from django_dramatiq_postgres.models import ScheduleBase
class Scheduler:
broker: Broker
def __init__(self, *args: Any, **kwargs: Any) -> None:
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.logger = get_logger(__name__, type(self))
@cached_property
def model(self) -> type[ScheduleBase]:
schedule_model = cast(str, Conf().schedule_model)
model: type[ScheduleBase] = import_string(schedule_model)
return model
return import_string(Conf().schedule_model)
@property
def query_set(self) -> QuerySet[ScheduleBase]:
return self.model._default_manager.filter(paused=False)
def query_set(self) -> QuerySet:
return self.model.objects.filter(paused=False)
def process_schedule(self, schedule: ScheduleBase) -> None:
def process_schedule(self, schedule: ScheduleBase):
schedule.next_run = schedule.compute_next_run()
schedule.send(self.broker)
schedule.save()
@@ -52,11 +48,10 @@ class Scheduler:
count += 1
return count
def run(self) -> int:
def run(self):
with self._lock() as lock_acquired:
if not lock_acquired:
self.logger.debug("Could not acquire lock, skipping scheduling")
return -1
return
count = self._run()
self.logger.info("Sent scheduled tasks", count=count)
return count

View File

@@ -27,7 +27,6 @@ classifiers = [
"Programming Language :: Python :: 3.13",
"Programming Language :: Python",
"Topic :: Software Development :: Libraries :: Python Modules",
"Typing :: Typed",
]
dependencies = [
@@ -36,7 +35,7 @@ dependencies = [
"django-pgtrigger >=4,<5",
"dramatiq[watch] >=1.17,<1.18",
"tenacity >=9,<10",
"structlog >=25,<26",
"structlog >=25,<26"
]
[project.urls]

View File

@@ -1,12 +0,0 @@
# django-postgres-cache
### Use in migrations
Migrations that use the cache with this installed need to depend on the migration to create the cache entry table:
```python
dependencies = [
# ...other requirements
("django_postgres_cache", "0001_initial"),
]
```

View File

@@ -1,6 +0,0 @@
from django.apps import AppConfig
class DjangoPostgresCache(AppConfig):
name = "django_postgres_cache"
verbose_name = "Django Postgres cache"

View File

@@ -1,51 +0,0 @@
from typing import Any
from django.core.cache.backends.db import DatabaseCache as BaseDatabaseCache
from django.db.utils import ProgrammingError
from django.utils.module_loading import import_string
from django.utils.timezone import now
from django_postgres_cache.models import CacheEntry
class DatabaseCache(BaseDatabaseCache):
def __init__(self, table: str, params: dict[str, Any]) -> None:
super().__init__(table, params)
self.reverse_key_func = import_string(params["REVERSE_KEY_FUNCTION"])
self._table = CacheEntry._meta.db_table
self.cache_model_class = CacheEntry
def _cull(self, *args: Any, **kwargs: Any) -> None:
"""Stubbed out cull method as we cull in a background task"""
pass
def get(self, key: str, default: Any | None = None, version: int | None = None) -> Any:
try:
return super().get(key, default=default, version=version)
except ProgrammingError:
return default
def keys(self, keys_pattern: str, version: int | None = None) -> list[str]:
try:
return self._keys(keys_pattern, version=version)
except ProgrammingError:
return []
def _keys(self, keys_pattern: str, version: int | None = None) -> list[str]:
keys_pattern = self.make_key(keys_pattern.replace("*", ".*"), version=version)
return [
self.reverse_key_func(key)
for key in CacheEntry.objects.filter(cache_key__regex=keys_pattern).values_list(
"cache_key", flat=True
)
]
def ttl(self, key: str, version: int | None = None) -> int | None:
"""Get TTL left for a given key and version"""
key = self.make_and_validate_key(key, version=version)
entry = CacheEntry.objects.filter(cache_key=key).first()
if not entry:
return None
return int((entry.expires - now()).total_seconds())

View File

@@ -1,24 +0,0 @@
# Generated by Django 5.1.12 on 2025-09-06 16:16
from django.db import migrations, models
class Migration(migrations.Migration):
initial = True
dependencies = []
operations = [
migrations.CreateModel(
name="CacheEntry",
fields=[
("cache_key", models.TextField(primary_key=True, serialize=False)),
("value", models.TextField()),
("expires", models.DateTimeField(db_index=True)),
],
options={
"default_permissions": [],
},
),
]

View File

@@ -1,14 +0,0 @@
from django.db import models
class CacheEntry(models.Model):
cache_key = models.TextField(primary_key=True)
value = models.TextField()
expires = models.DateTimeField(db_index=True)
class Meta:
default_permissions = []
def __str__(self) -> str:
return f"Cache entry '{self.cache_key}'"

Some files were not shown because too many files have changed in this diff Show More