mirror of
https://github.com/goauthentik/authentik
synced 2026-05-06 07:02:51 +02:00
Compare commits
1 Commits
flows/conc
...
a11y-times
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
fbd9460720 |
1
.github/actions/setup/docker-compose.yml
vendored
1
.github/actions/setup/docker-compose.yml
vendored
@@ -3,7 +3,6 @@ services:
|
||||
image: docker.io/library/postgres:${PSQL_TAG:-16}
|
||||
volumes:
|
||||
- db-data:/var/lib/postgresql/data
|
||||
command: "-c log_statement=all"
|
||||
environment:
|
||||
POSTGRES_USER: authentik
|
||||
POSTGRES_PASSWORD: "EK-5jnKfjrGRm<77"
|
||||
|
||||
28
.github/actions/test-results/action.yml
vendored
28
.github/actions/test-results/action.yml
vendored
@@ -1,28 +0,0 @@
|
||||
name: "Process test results"
|
||||
description: Convert test results to JUnit, add them to GitHub Actions and codecov
|
||||
|
||||
inputs:
|
||||
flags:
|
||||
description: Codecov flags
|
||||
|
||||
runs:
|
||||
using: "composite"
|
||||
steps:
|
||||
- uses: codecov/codecov-action@v5
|
||||
with:
|
||||
flags: ${{ inputs.flags }}
|
||||
use_oidc: true
|
||||
- uses: codecov/test-results-action@v1
|
||||
with:
|
||||
flags: ${{ inputs.flags }}
|
||||
file: unittest.xml
|
||||
use_oidc: true
|
||||
- name: PostgreSQL Logs
|
||||
shell: bash
|
||||
run: |
|
||||
if [[ $ACTIONS_RUNNER_DEBUG == 'true' || $ACTIONS_STEP_DEBUG == 'true' ]]; then
|
||||
docker stop setup-postgresql-1
|
||||
echo "::group::PostgreSQL Logs"
|
||||
docker logs setup-postgresql-1
|
||||
echo "::endgroup::"
|
||||
fi
|
||||
@@ -74,7 +74,7 @@ jobs:
|
||||
mkdir -p ./gen-go-api
|
||||
- name: Setup node
|
||||
if: ${{ !inputs.release }}
|
||||
uses: actions/setup-node@v5
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version-file: web/package.json
|
||||
cache: "npm"
|
||||
|
||||
37
.github/workflows/ci-main.yml
vendored
37
.github/workflows/ci-main.yml
vendored
@@ -113,10 +113,6 @@ jobs:
|
||||
CI_TOTAL_RUNS: "5"
|
||||
run: |
|
||||
uv run make ci-test
|
||||
- uses: ./.github/actions/test-results
|
||||
if: ${{ always() }}
|
||||
with:
|
||||
flags: unit-migrate
|
||||
test-unittest:
|
||||
name: test-unittest - PostgreSQL ${{ matrix.psql }} - Run ${{ matrix.run_id }}/5
|
||||
runs-on: ubuntu-latest
|
||||
@@ -143,10 +139,17 @@ jobs:
|
||||
CI_TOTAL_RUNS: "5"
|
||||
run: |
|
||||
uv run make ci-test
|
||||
- uses: ./.github/actions/test-results
|
||||
if: ${{ always() }}
|
||||
- if: ${{ always() }}
|
||||
uses: codecov/codecov-action@v5
|
||||
with:
|
||||
flags: unit
|
||||
use_oidc: true
|
||||
- if: ${{ !cancelled() }}
|
||||
uses: codecov/test-results-action@v1
|
||||
with:
|
||||
flags: unit
|
||||
file: unittest.xml
|
||||
use_oidc: true
|
||||
test-integration:
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 30
|
||||
@@ -160,10 +163,17 @@ jobs:
|
||||
run: |
|
||||
uv run coverage run manage.py test tests/integration
|
||||
uv run coverage xml
|
||||
- uses: ./.github/actions/test-results
|
||||
if: ${{ always() }}
|
||||
- if: ${{ always() }}
|
||||
uses: codecov/codecov-action@v5
|
||||
with:
|
||||
flags: integration
|
||||
use_oidc: true
|
||||
- if: ${{ !cancelled() }}
|
||||
uses: codecov/test-results-action@v1
|
||||
with:
|
||||
flags: integration
|
||||
file: unittest.xml
|
||||
use_oidc: true
|
||||
test-e2e:
|
||||
name: test-e2e (${{ matrix.job.name }})
|
||||
runs-on: ubuntu-latest
|
||||
@@ -212,10 +222,17 @@ jobs:
|
||||
run: |
|
||||
uv run coverage run manage.py test ${{ matrix.job.glob }}
|
||||
uv run coverage xml
|
||||
- uses: ./.github/actions/test-results
|
||||
if: ${{ always() }}
|
||||
- if: ${{ always() }}
|
||||
uses: codecov/codecov-action@v5
|
||||
with:
|
||||
flags: e2e
|
||||
use_oidc: true
|
||||
- if: ${{ !cancelled() }}
|
||||
uses: codecov/test-results-action@v1
|
||||
with:
|
||||
flags: e2e
|
||||
file: unittest.xml
|
||||
use_oidc: true
|
||||
ci-core-mark:
|
||||
if: always()
|
||||
needs:
|
||||
|
||||
4
.github/workflows/translation-advice.yml
vendored
4
.github/workflows/translation-advice.yml
vendored
@@ -20,14 +20,14 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Find Comment
|
||||
uses: peter-evans/find-comment@v4
|
||||
uses: peter-evans/find-comment@v3
|
||||
id: fc
|
||||
with:
|
||||
issue-number: ${{ github.event.pull_request.number }}
|
||||
comment-author: "github-actions[bot]"
|
||||
body-includes: authentik translations instructions
|
||||
- name: Create or update comment
|
||||
uses: peter-evans/create-or-update-comment@v5
|
||||
uses: peter-evans/create-or-update-comment@v4
|
||||
with:
|
||||
comment-id: ${{ steps.fc.outputs.comment-id }}
|
||||
issue-number: ${{ github.event.pull_request.number }}
|
||||
|
||||
@@ -24,7 +24,6 @@ Makefile @goauthentik/infrastructure
|
||||
.editorconfig @goauthentik/infrastructure
|
||||
CODEOWNERS @goauthentik/infrastructure
|
||||
# Backend packages
|
||||
packages/django-postgres-cache @goauthentik/backend
|
||||
packages/django-dramatiq-postgres @goauthentik/backend
|
||||
# Web packages
|
||||
packages/docusaurus-config @goauthentik/frontend
|
||||
|
||||
@@ -26,7 +26,7 @@ RUN npm run build && \
|
||||
npm run build:sfe
|
||||
|
||||
# Stage 2: Build go proxy
|
||||
FROM --platform=${BUILDPLATFORM} docker.io/library/golang:1.25.1-bookworm AS go-builder
|
||||
FROM --platform=${BUILDPLATFORM} docker.io/library/golang:1.25-bookworm AS go-builder
|
||||
|
||||
ARG TARGETOS
|
||||
ARG TARGETARCH
|
||||
@@ -119,11 +119,7 @@ RUN --mount=type=cache,id=apt-$TARGETARCH$TARGETVARIANT,sharing=locked,target=/v
|
||||
libltdl-dev && \
|
||||
curl https://sh.rustup.rs -sSf | sh -s -- -y
|
||||
|
||||
ENV UV_NO_BINARY_PACKAGE="cryptography lxml python-kadmin-rs xmlsec" \
|
||||
# https://github.com/rust-lang/rustup/issues/2949
|
||||
# Fixes issues where the rust version in the build cache is older than latest
|
||||
# and rustup tries to update it, which fails
|
||||
RUSTUP_PERMIT_COPY_RENAME="true"
|
||||
ENV UV_NO_BINARY_PACKAGE="cryptography lxml python-kadmin-rs xmlsec"
|
||||
|
||||
RUN --mount=type=bind,target=pyproject.toml,src=pyproject.toml \
|
||||
--mount=type=bind,target=uv.lock,src=uv.lock \
|
||||
|
||||
@@ -1,9 +0,0 @@
|
||||
from django.dispatch import receiver
|
||||
|
||||
from authentik.admin.tasks import _set_prom_info
|
||||
from authentik.root.signals import post_startup
|
||||
|
||||
|
||||
@receiver(post_startup)
|
||||
def post_startup_admin_metrics(sender, **_):
|
||||
_set_prom_info()
|
||||
@@ -2,6 +2,7 @@
|
||||
|
||||
from django.core.cache import cache
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
from django_dramatiq_postgres.middleware import CurrentTask
|
||||
from dramatiq import actor
|
||||
from packaging.version import parse
|
||||
from requests import RequestException
|
||||
@@ -12,7 +13,7 @@ from authentik.admin.apps import PROM_INFO
|
||||
from authentik.events.models import Event, EventAction
|
||||
from authentik.lib.config import CONFIG
|
||||
from authentik.lib.utils.http import get_http_session
|
||||
from authentik.tasks.middleware import CurrentTask
|
||||
from authentik.tasks.models import Task
|
||||
|
||||
LOGGER = get_logger()
|
||||
VERSION_NULL = "0.0.0"
|
||||
@@ -34,7 +35,7 @@ def _set_prom_info():
|
||||
|
||||
@actor(description=_("Update latest version info."))
|
||||
def update_latest_version():
|
||||
self = CurrentTask.get_task()
|
||||
self: Task = CurrentTask.get_task()
|
||||
if CONFIG.get_bool("disable_update_check"):
|
||||
cache.set(VERSION_CACHE_KEY, VERSION_NULL, VERSION_CACHE_TIMEOUT)
|
||||
self.info("Version check disabled.")
|
||||
@@ -71,3 +72,6 @@ def update_latest_version():
|
||||
except (RequestException, IndexError) as exc:
|
||||
cache.set(VERSION_CACHE_KEY, VERSION_NULL, VERSION_CACHE_TIMEOUT)
|
||||
raise exc
|
||||
|
||||
|
||||
_set_prom_info()
|
||||
|
||||
@@ -12,7 +12,7 @@ from django.db import DatabaseError, InternalError, ProgrammingError
|
||||
from django.utils.text import slugify
|
||||
from django.utils.timezone import now
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
from django_dramatiq_postgres.middleware import CurrentTaskNotFound
|
||||
from django_dramatiq_postgres.middleware import CurrentTask, CurrentTaskNotFound
|
||||
from dramatiq.actor import actor
|
||||
from dramatiq.middleware import Middleware
|
||||
from structlog.stdlib import get_logger
|
||||
@@ -39,7 +39,6 @@ from authentik.events.logs import capture_logs
|
||||
from authentik.events.utils import sanitize_dict
|
||||
from authentik.lib.config import CONFIG
|
||||
from authentik.tasks.apps import PRIORITY_HIGH
|
||||
from authentik.tasks.middleware import CurrentTask
|
||||
from authentik.tasks.models import Task
|
||||
from authentik.tasks.schedules.models import Schedule
|
||||
from authentik.tenants.models import Tenant
|
||||
@@ -156,7 +155,7 @@ def blueprints_find() -> list[BlueprintFile]:
|
||||
throws=(DatabaseError, ProgrammingError, InternalError),
|
||||
)
|
||||
def blueprints_discovery(path: str | None = None):
|
||||
self = CurrentTask.get_task()
|
||||
self: Task = CurrentTask.get_task()
|
||||
count = 0
|
||||
for blueprint in blueprints_find():
|
||||
if path and blueprint.path != path:
|
||||
@@ -196,7 +195,7 @@ def check_blueprint_v1_file(blueprint: BlueprintFile):
|
||||
@actor(description=_("Apply single blueprint."))
|
||||
def apply_blueprint(instance_pk: UUID):
|
||||
try:
|
||||
self = CurrentTask.get_task()
|
||||
self: Task = CurrentTask.get_task()
|
||||
except CurrentTaskNotFound:
|
||||
self = Task()
|
||||
self.set_uid(str(instance_pk))
|
||||
|
||||
@@ -334,21 +334,6 @@ class UserPasswordSetSerializer(PassiveSerializer):
|
||||
password = CharField(required=True)
|
||||
|
||||
|
||||
class UserServiceAccountSerializer(PassiveSerializer):
|
||||
"""Payload to create a service account"""
|
||||
|
||||
name = CharField(
|
||||
required=True,
|
||||
validators=[UniqueValidator(queryset=User.objects.all().order_by("username"))],
|
||||
)
|
||||
create_group = BooleanField(default=False)
|
||||
expiring = BooleanField(default=True)
|
||||
expires = DateTimeField(
|
||||
required=False,
|
||||
help_text="If not provided, valid for 360 days",
|
||||
)
|
||||
|
||||
|
||||
class UsersFilter(FilterSet):
|
||||
"""Filter for users"""
|
||||
|
||||
@@ -509,7 +494,18 @@ class UserViewSet(UsedByMixin, ModelViewSet):
|
||||
|
||||
@permission_required(None, ["authentik_core.add_user", "authentik_core.add_token"])
|
||||
@extend_schema(
|
||||
request=UserServiceAccountSerializer,
|
||||
request=inline_serializer(
|
||||
"UserServiceAccountSerializer",
|
||||
{
|
||||
"name": CharField(required=True),
|
||||
"create_group": BooleanField(default=False),
|
||||
"expiring": BooleanField(default=True),
|
||||
"expires": DateTimeField(
|
||||
required=False,
|
||||
help_text="If not provided, valid for 360 days",
|
||||
),
|
||||
},
|
||||
),
|
||||
responses={
|
||||
200: inline_serializer(
|
||||
"UserServiceAccountResponse",
|
||||
@@ -531,12 +527,11 @@ class UserViewSet(UsedByMixin, ModelViewSet):
|
||||
)
|
||||
def service_account(self, request: Request) -> Response:
|
||||
"""Create a new user account that is marked as a service account"""
|
||||
data = UserServiceAccountSerializer(data=request.data)
|
||||
data.is_valid(raise_exception=True)
|
||||
expires = data.validated_data.get("expires", now() + timedelta(days=360))
|
||||
username = request.data.get("name")
|
||||
create_group = request.data.get("create_group", False)
|
||||
expiring = request.data.get("expiring", True)
|
||||
expires = request.data.get("expires", now() + timedelta(days=360))
|
||||
|
||||
username = data.validated_data["name"]
|
||||
expiring = data.validated_data["expiring"]
|
||||
with atomic():
|
||||
try:
|
||||
user: User = User.objects.create(
|
||||
@@ -554,10 +549,10 @@ class UserViewSet(UsedByMixin, ModelViewSet):
|
||||
"user_uid": user.uid,
|
||||
"user_pk": user.pk,
|
||||
}
|
||||
if data.validated_data["create_group"] and self.request.user.has_perm(
|
||||
"authentik_core.add_group"
|
||||
):
|
||||
group = Group.objects.create(name=username)
|
||||
if create_group and self.request.user.has_perm("authentik_core.add_group"):
|
||||
group = Group.objects.create(
|
||||
name=username,
|
||||
)
|
||||
group.users.add(user)
|
||||
response["group_pk"] = str(group.pk)
|
||||
token = Token.objects.create(
|
||||
@@ -570,29 +565,7 @@ class UserViewSet(UsedByMixin, ModelViewSet):
|
||||
response["token"] = token.key
|
||||
return Response(response)
|
||||
except IntegrityError as exc:
|
||||
error_msg = str(exc).lower()
|
||||
|
||||
if "unique" in error_msg:
|
||||
return Response(
|
||||
data={
|
||||
"non_field_errors": [
|
||||
_("A user/group with these details already exists")
|
||||
]
|
||||
},
|
||||
status=400,
|
||||
)
|
||||
else:
|
||||
LOGGER.warning("Service account creation failed", exc=exc)
|
||||
return Response(
|
||||
data={"non_field_errors": [_("Unable to create user")]},
|
||||
status=400,
|
||||
)
|
||||
except (ValueError, TypeError) as exc:
|
||||
LOGGER.error("Unexpected error during service account creation", exc=exc)
|
||||
return Response(
|
||||
data={"non_field_errors": [_("Unknown error occurred")]},
|
||||
status=500,
|
||||
)
|
||||
return Response(data={"non_field_errors": [str(exc)]}, status=400)
|
||||
|
||||
@extend_schema(responses={200: SessionUserSerializer(many=False)})
|
||||
@action(
|
||||
|
||||
@@ -13,6 +13,14 @@ import authentik.core.models
|
||||
import authentik.lib.models
|
||||
|
||||
|
||||
def migrate_sessions(apps: Apps, schema_editor: BaseDatabaseSchemaEditor):
|
||||
from django.contrib.sessions.backends.cache import KEY_PREFIX
|
||||
from django.core.cache import cache
|
||||
|
||||
session_keys = cache.keys(KEY_PREFIX + "*")
|
||||
cache.delete_many(session_keys)
|
||||
|
||||
|
||||
def fix_duplicates(apps: Apps, schema_editor: BaseDatabaseSchemaEditor):
|
||||
db_alias = schema_editor.connection.alias
|
||||
Token = apps.get_model("authentik_core", "token")
|
||||
@@ -143,6 +151,9 @@ class Migration(migrations.Migration):
|
||||
"abstract": False,
|
||||
},
|
||||
),
|
||||
migrations.RunPython(
|
||||
code=migrate_sessions,
|
||||
),
|
||||
migrations.AlterField(
|
||||
model_name="application",
|
||||
name="meta_launch_url",
|
||||
|
||||
@@ -7,10 +7,15 @@ from django.contrib.auth import BACKEND_SESSION_KEY, HASH_SESSION_KEY, SESSION_K
|
||||
from django.db import migrations, models
|
||||
import django.db.models.deletion
|
||||
from django.conf import settings
|
||||
from django.contrib.sessions.backends.cache import KEY_PREFIX
|
||||
from django.utils.timezone import now, timedelta
|
||||
from authentik.lib.migrations import progress_bar
|
||||
from authentik.root.middleware import ClientIPMiddleware
|
||||
|
||||
|
||||
SESSION_CACHE_ALIAS = "default"
|
||||
|
||||
|
||||
class PickleSerializer:
|
||||
"""
|
||||
Simple wrapper around pickle to be used in signing.dumps()/loads() and
|
||||
@@ -78,6 +83,27 @@ def _migrate_session(
|
||||
)
|
||||
|
||||
|
||||
def migrate_redis_sessions(apps, schema_editor):
|
||||
from django.core.cache import caches
|
||||
|
||||
db_alias = schema_editor.connection.alias
|
||||
cache = caches[SESSION_CACHE_ALIAS]
|
||||
|
||||
# Not a redis cache, skipping
|
||||
if not hasattr(cache, "keys"):
|
||||
return
|
||||
|
||||
print("\nMigrating Redis sessions to database, this might take a couple of minutes...")
|
||||
for key, session_data in progress_bar(cache.get_many(cache.keys(f"{KEY_PREFIX}*")).items()):
|
||||
_migrate_session(
|
||||
apps=apps,
|
||||
db_alias=db_alias,
|
||||
session_key=key.removeprefix(KEY_PREFIX),
|
||||
session_data=session_data,
|
||||
expires=now() + timedelta(seconds=cache.ttl(key)),
|
||||
)
|
||||
|
||||
|
||||
def migrate_database_sessions(apps, schema_editor):
|
||||
DjangoSession = apps.get_model("sessions", "Session")
|
||||
db_alias = schema_editor.connection.alias
|
||||
@@ -205,6 +231,10 @@ class Migration(migrations.Migration):
|
||||
"verbose_name_plural": "Authenticated Sessions",
|
||||
},
|
||||
),
|
||||
migrations.RunPython(
|
||||
code=migrate_redis_sessions,
|
||||
reverse_code=migrations.RunPython.noop,
|
||||
),
|
||||
migrations.RunPython(
|
||||
code=migrate_database_sessions,
|
||||
reverse_code=migrations.RunPython.noop,
|
||||
|
||||
@@ -406,8 +406,6 @@ class User(SerializerModel, GuardianUserMixin, AttributesMixin, AbstractUser):
|
||||
|
||||
def locale(self, request: HttpRequest | None = None) -> str:
|
||||
"""Get the locale the user has configured"""
|
||||
if request and hasattr(request, "LANGUAGE_CODE"):
|
||||
return request.LANGUAGE_CODE
|
||||
try:
|
||||
return self.attributes.get("settings", {}).get("locale", "")
|
||||
|
||||
|
||||
@@ -4,7 +4,7 @@ from datetime import datetime, timedelta
|
||||
|
||||
from django.utils.timezone import now
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
from django_postgres_cache.tasks import clear_expired_cache
|
||||
from django_dramatiq_postgres.middleware import CurrentTask
|
||||
from dramatiq.actor import actor
|
||||
from structlog.stdlib import get_logger
|
||||
|
||||
@@ -15,14 +15,14 @@ from authentik.core.models import (
|
||||
User,
|
||||
)
|
||||
from authentik.lib.utils.db import chunked_queryset
|
||||
from authentik.tasks.middleware import CurrentTask
|
||||
from authentik.tasks.models import Task
|
||||
|
||||
LOGGER = get_logger()
|
||||
|
||||
|
||||
@actor(description=_("Remove expired objects."))
|
||||
def clean_expired_models():
|
||||
self = CurrentTask.get_task()
|
||||
self: Task = CurrentTask.get_task()
|
||||
for cls in ExpiringModel.__subclasses__():
|
||||
cls: ExpiringModel
|
||||
objects = (
|
||||
@@ -33,12 +33,11 @@ def clean_expired_models():
|
||||
obj.expire_action()
|
||||
LOGGER.debug("Expired models", model=cls, amount=amount)
|
||||
self.info(f"Expired {amount} {cls._meta.verbose_name_plural}")
|
||||
clear_expired_cache()
|
||||
|
||||
|
||||
@actor(description=_("Remove temporary users created by SAML Sources."))
|
||||
def clean_temporary_users():
|
||||
self = CurrentTask.get_task()
|
||||
self: Task = CurrentTask.get_task()
|
||||
_now = datetime.now()
|
||||
deleted_users = 0
|
||||
for user in User.objects.filter(**{f"attributes__{USER_ATTRIBUTE_GENERATED}": True}):
|
||||
|
||||
@@ -469,274 +469,3 @@ class TestUsersAPI(APITestCase):
|
||||
body = loads(response.content)
|
||||
self.assertEqual(len(body["results"]), 2)
|
||||
self.assertEqual(body["results"][0]["pk"], user.pk)
|
||||
|
||||
def test_service_account_validation_empty_username(self):
|
||||
"""Test service account creation with empty/blank username validation"""
|
||||
self.client.force_login(self.admin)
|
||||
|
||||
# Test with empty string
|
||||
response = self.client.post(
|
||||
reverse("authentik_api:user-service-account"),
|
||||
data={
|
||||
"name": "",
|
||||
"create_group": True,
|
||||
},
|
||||
)
|
||||
self.assertEqual(response.status_code, 400)
|
||||
self.assertJSONEqual(
|
||||
response.content,
|
||||
{"name": ["This field may not be blank."]},
|
||||
)
|
||||
|
||||
# Test with only whitespace
|
||||
response = self.client.post(
|
||||
reverse("authentik_api:user-service-account"),
|
||||
data={
|
||||
"name": " ",
|
||||
"create_group": True,
|
||||
},
|
||||
)
|
||||
self.assertEqual(response.status_code, 400)
|
||||
self.assertJSONEqual(
|
||||
response.content,
|
||||
{"name": ["This field may not be blank."]},
|
||||
)
|
||||
|
||||
# Test with tab and newline characters
|
||||
response = self.client.post(
|
||||
reverse("authentik_api:user-service-account"),
|
||||
data={
|
||||
"name": "\t\n",
|
||||
"create_group": True,
|
||||
},
|
||||
)
|
||||
self.assertEqual(response.status_code, 400)
|
||||
self.assertJSONEqual(
|
||||
response.content,
|
||||
{"name": ["This field may not be blank."]},
|
||||
)
|
||||
|
||||
def test_service_account_validation_valid_username(self):
|
||||
"""Test service account creation with valid username"""
|
||||
self.client.force_login(self.admin)
|
||||
|
||||
# Test with valid username
|
||||
response = self.client.post(
|
||||
reverse("authentik_api:user-service-account"),
|
||||
data={
|
||||
"name": "valid-service-account",
|
||||
"create_group": True,
|
||||
},
|
||||
)
|
||||
self.assertEqual(response.status_code, 200)
|
||||
|
||||
# Verify response structure
|
||||
body = loads(response.content)
|
||||
self.assertIn("username", body)
|
||||
self.assertIn("user_uid", body)
|
||||
self.assertIn("user_pk", body)
|
||||
self.assertIn("group_pk", body) # Should exist since create_group=True
|
||||
self.assertIn("token", body)
|
||||
|
||||
# Verify field types
|
||||
self.assertEqual(body["username"], "valid-service-account")
|
||||
self.assertIsInstance(body["user_pk"], int)
|
||||
self.assertIsInstance(body["user_uid"], str)
|
||||
self.assertIsInstance(body["token"], str)
|
||||
self.assertIsInstance(body["group_pk"], str)
|
||||
|
||||
def test_service_account_validation_without_group(self):
|
||||
"""Test service account creation without creating a group"""
|
||||
self.client.force_login(self.admin)
|
||||
|
||||
response = self.client.post(
|
||||
reverse("authentik_api:user-service-account"),
|
||||
data={
|
||||
"name": "no-group-service-account",
|
||||
"create_group": False,
|
||||
},
|
||||
)
|
||||
self.assertEqual(response.status_code, 200)
|
||||
|
||||
body = loads(response.content)
|
||||
self.assertIn("username", body)
|
||||
self.assertIn("user_uid", body)
|
||||
self.assertIn("user_pk", body)
|
||||
self.assertIn("token", body)
|
||||
# Should NOT have group_pk when create_group=False
|
||||
self.assertNotIn("group_pk", body)
|
||||
|
||||
def test_service_account_validation_duplicate_username(self):
|
||||
"""Test service account creation with duplicate username"""
|
||||
self.client.force_login(self.admin)
|
||||
|
||||
# Create first service account
|
||||
response = self.client.post(
|
||||
reverse("authentik_api:user-service-account"),
|
||||
data={
|
||||
"name": "duplicate-test",
|
||||
"create_group": True,
|
||||
},
|
||||
)
|
||||
self.assertEqual(response.status_code, 200)
|
||||
|
||||
# Attempt to create second with same username
|
||||
response = self.client.post(
|
||||
reverse("authentik_api:user-service-account"),
|
||||
data={
|
||||
"name": "duplicate-test",
|
||||
"create_group": True,
|
||||
},
|
||||
)
|
||||
self.assertEqual(response.status_code, 400)
|
||||
self.assertJSONEqual(
|
||||
response.content,
|
||||
{"name": ["This field must be unique."]},
|
||||
)
|
||||
|
||||
def test_service_account_validation_invalid_create_group(self):
|
||||
"""Test service account creation with invalid create_group field"""
|
||||
self.client.force_login(self.admin)
|
||||
|
||||
# Test with string instead of boolean
|
||||
response = self.client.post(
|
||||
reverse("authentik_api:user-service-account"),
|
||||
data={
|
||||
"name": "test-sa",
|
||||
"create_group": "invalid",
|
||||
},
|
||||
)
|
||||
self.assertEqual(response.status_code, 400)
|
||||
self.assertJSONEqual(
|
||||
response.content,
|
||||
{"create_group": ["Must be a valid boolean."]},
|
||||
)
|
||||
|
||||
# Test with number instead of boolean
|
||||
response = self.client.post(
|
||||
reverse("authentik_api:user-service-account"),
|
||||
data={
|
||||
"name": "test-sa",
|
||||
"create_group": 123,
|
||||
},
|
||||
)
|
||||
self.assertEqual(response.status_code, 400)
|
||||
self.assertJSONEqual(
|
||||
response.content,
|
||||
{"create_group": ["Must be a valid boolean."]},
|
||||
)
|
||||
|
||||
def test_service_account_validation_invalid_expiring(self):
|
||||
"""Test service account creation with invalid expiring field"""
|
||||
self.client.force_login(self.admin)
|
||||
|
||||
# Test with string instead of boolean
|
||||
response = self.client.post(
|
||||
reverse("authentik_api:user-service-account"),
|
||||
data={
|
||||
"name": "test-sa",
|
||||
"expiring": "invalid",
|
||||
},
|
||||
)
|
||||
self.assertEqual(response.status_code, 400)
|
||||
self.assertJSONEqual(
|
||||
response.content,
|
||||
{"expiring": ["Must be a valid boolean."]},
|
||||
)
|
||||
|
||||
def test_service_account_validation_invalid_expires(self):
|
||||
"""Test service account creation with invalid expires field"""
|
||||
self.client.force_login(self.admin)
|
||||
|
||||
# Test with invalid datetime string
|
||||
response = self.client.post(
|
||||
reverse("authentik_api:user-service-account"),
|
||||
data={
|
||||
"name": "test-sa",
|
||||
"expires": "invalid-datetime",
|
||||
},
|
||||
)
|
||||
self.assertEqual(response.status_code, 400)
|
||||
self.assertJSONEqual(
|
||||
response.content,
|
||||
{
|
||||
"expires": [
|
||||
"Datetime has wrong format. Use one of these formats instead: "
|
||||
"YYYY-MM-DDThh:mm[:ss[.uuuuuu]][+HH:MM|-HH:MM|Z]."
|
||||
]
|
||||
},
|
||||
)
|
||||
|
||||
# Test with invalid format
|
||||
response = self.client.post(
|
||||
reverse("authentik_api:user-service-account"),
|
||||
data={
|
||||
"name": "test-sa",
|
||||
"expires": "2024-13-45", # Invalid month/day
|
||||
},
|
||||
)
|
||||
self.assertEqual(response.status_code, 400)
|
||||
self.assertJSONEqual(
|
||||
response.content,
|
||||
{
|
||||
"expires": [
|
||||
"Datetime has wrong format. Use one of these formats instead: "
|
||||
"YYYY-MM-DDThh:mm[:ss[.uuuuuu]][+HH:MM|-HH:MM|Z]."
|
||||
]
|
||||
},
|
||||
)
|
||||
|
||||
def test_service_account_validation_multiple_errors(self):
|
||||
"""Test service account creation with multiple validation errors"""
|
||||
self.client.force_login(self.admin)
|
||||
|
||||
response = self.client.post(
|
||||
reverse("authentik_api:user-service-account"),
|
||||
data={
|
||||
"name": "", # Empty username
|
||||
"create_group": "invalid", # Invalid boolean
|
||||
"expiring": 123, # Invalid boolean
|
||||
"expires": "not-a-date", # Invalid datetime
|
||||
},
|
||||
)
|
||||
self.assertEqual(response.status_code, 400)
|
||||
self.assertJSONEqual(
|
||||
response.content,
|
||||
{
|
||||
"name": ["This field may not be blank."],
|
||||
"create_group": ["Must be a valid boolean."],
|
||||
"expiring": ["Must be a valid boolean."],
|
||||
"expires": [
|
||||
"Datetime has wrong format. Use one of these formats instead: "
|
||||
"YYYY-MM-DDThh:mm[:ss[.uuuuuu]][+HH:MM|-HH:MM|Z]."
|
||||
],
|
||||
},
|
||||
)
|
||||
|
||||
def test_service_account_validation_user_friendly_duplicate_error(self):
|
||||
"""Test that duplicate username returns user-friendly error, not database error"""
|
||||
self.client.force_login(self.admin)
|
||||
|
||||
# Create first service account
|
||||
response = self.client.post(
|
||||
reverse("authentik_api:user-service-account"),
|
||||
data={
|
||||
"name": "duplicate-username-test",
|
||||
"create_group": True,
|
||||
},
|
||||
)
|
||||
self.assertEqual(response.status_code, 200)
|
||||
|
||||
# Attempt to create second with same username
|
||||
response = self.client.post(
|
||||
reverse("authentik_api:user-service-account"),
|
||||
data={
|
||||
"name": "duplicate-username-test",
|
||||
"create_group": True,
|
||||
},
|
||||
)
|
||||
self.assertEqual(response.status_code, 400)
|
||||
self.assertJSONEqual(
|
||||
response.content,
|
||||
{"name": ["This field must be unique."]},
|
||||
)
|
||||
|
||||
@@ -30,7 +30,6 @@ from authentik.flows.views.interface import FlowInterfaceView
|
||||
from authentik.root.asgi_middleware import AuthMiddlewareStack
|
||||
from authentik.root.messages.consumer import MessageConsumer
|
||||
from authentik.root.middleware import ChannelsLoggingMiddleware
|
||||
from authentik.tenants.channels import TenantsAwareMiddleware
|
||||
|
||||
urlpatterns = [
|
||||
path(
|
||||
@@ -98,9 +97,7 @@ api_urlpatterns = [
|
||||
websocket_urlpatterns = [
|
||||
path(
|
||||
"ws/client/",
|
||||
ChannelsLoggingMiddleware(
|
||||
TenantsAwareMiddleware(AuthMiddlewareStack(MessageConsumer.as_asgi()))
|
||||
),
|
||||
ChannelsLoggingMiddleware(AuthMiddlewareStack(MessageConsumer.as_asgi())),
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
@@ -7,12 +7,13 @@ from cryptography.hazmat.backends import default_backend
|
||||
from cryptography.hazmat.primitives.serialization import load_pem_private_key
|
||||
from cryptography.x509.base import load_pem_x509_certificate
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
from django_dramatiq_postgres.middleware import CurrentTask
|
||||
from dramatiq.actor import actor
|
||||
from structlog.stdlib import get_logger
|
||||
|
||||
from authentik.crypto.models import CertificateKeyPair
|
||||
from authentik.lib.config import CONFIG
|
||||
from authentik.tasks.middleware import CurrentTask
|
||||
from authentik.tasks.models import Task
|
||||
|
||||
LOGGER = get_logger()
|
||||
|
||||
@@ -37,7 +38,7 @@ def ensure_certificate_valid(body: str):
|
||||
|
||||
@actor(description=_("Discover, import and update certificates from the filesystem."))
|
||||
def certificate_discovery():
|
||||
self = CurrentTask.get_task()
|
||||
self: Task = CurrentTask.get_task()
|
||||
certs = {}
|
||||
private_keys = {}
|
||||
discovered = 0
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from django.db.models.aggregates import Count
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
from django_dramatiq_postgres.middleware import CurrentTask
|
||||
from dramatiq.actor import actor
|
||||
from structlog import get_logger
|
||||
|
||||
@@ -7,7 +8,7 @@ from authentik.enterprise.policies.unique_password.models import (
|
||||
UniquePasswordPolicy,
|
||||
UserPasswordHistory,
|
||||
)
|
||||
from authentik.tasks.middleware import CurrentTask
|
||||
from authentik.tasks.models import Task
|
||||
|
||||
LOGGER = get_logger()
|
||||
|
||||
@@ -18,7 +19,7 @@ LOGGER = get_logger()
|
||||
)
|
||||
)
|
||||
def check_and_purge_password_history():
|
||||
self = CurrentTask.get_task()
|
||||
self: Task = CurrentTask.get_task()
|
||||
|
||||
if not UniquePasswordPolicy.objects.exists():
|
||||
UserPasswordHistory.objects.all().delete()
|
||||
@@ -38,7 +39,7 @@ def trim_password_histories():
|
||||
UniquePasswordPolicy policies.
|
||||
"""
|
||||
|
||||
self = CurrentTask.get_task()
|
||||
self: Task = CurrentTask.get_task()
|
||||
|
||||
# No policy, we'll let the cleanup above do its thing
|
||||
if not UniquePasswordPolicy.objects.exists():
|
||||
|
||||
@@ -4,6 +4,7 @@ from uuid import UUID
|
||||
from django.http import HttpRequest
|
||||
from django.utils.timezone import now
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
from django_dramatiq_postgres.middleware import CurrentTask
|
||||
from dramatiq.actor import actor
|
||||
from requests.exceptions import RequestException
|
||||
from structlog.stdlib import get_logger
|
||||
@@ -19,7 +20,7 @@ from authentik.enterprise.providers.ssf.models import (
|
||||
from authentik.lib.utils.http import get_http_session
|
||||
from authentik.lib.utils.time import timedelta_from_string
|
||||
from authentik.policies.engine import PolicyEngine
|
||||
from authentik.tasks.middleware import CurrentTask
|
||||
from authentik.tasks.models import Task
|
||||
|
||||
session = get_http_session()
|
||||
LOGGER = get_logger()
|
||||
@@ -73,7 +74,7 @@ def _check_app_access(stream: Stream, event_data: dict) -> bool:
|
||||
|
||||
@actor(description=_("Send an SSF event."))
|
||||
def send_ssf_event(stream_uuid: UUID, event_data: dict[str, Any]):
|
||||
self = CurrentTask.get_task()
|
||||
self: Task = CurrentTask.get_task()
|
||||
|
||||
stream = Stream.objects.filter(pk=stream_uuid).first()
|
||||
if not stream:
|
||||
|
||||
@@ -4,6 +4,7 @@ from uuid import UUID
|
||||
|
||||
from django.db.models.query_utils import Q
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
from django_dramatiq_postgres.middleware import CurrentTask
|
||||
from dramatiq.actor import actor
|
||||
from guardian.shortcuts import get_anonymous_user
|
||||
from structlog.stdlib import get_logger
|
||||
@@ -18,7 +19,7 @@ from authentik.events.models import (
|
||||
from authentik.lib.utils.db import chunked_queryset
|
||||
from authentik.policies.engine import PolicyEngine
|
||||
from authentik.policies.models import PolicyBinding, PolicyEngineMode
|
||||
from authentik.tasks.middleware import CurrentTask
|
||||
from authentik.tasks.models import Task
|
||||
|
||||
LOGGER = get_logger()
|
||||
|
||||
@@ -37,7 +38,7 @@ def event_trigger_dispatch(event_uuid: UUID):
|
||||
)
|
||||
def event_trigger_handler(event_uuid: UUID, trigger_name: str):
|
||||
"""Check if policies attached to NotificationRule match event"""
|
||||
self = CurrentTask.get_task()
|
||||
self: Task = CurrentTask.get_task()
|
||||
|
||||
event: Event = Event.objects.filter(event_uuid=event_uuid).first()
|
||||
if not event:
|
||||
@@ -130,7 +131,7 @@ def gdpr_cleanup(user_pk: int):
|
||||
@actor(description=_("Cleanup seen notifications and notifications whose event expired."))
|
||||
def notification_cleanup():
|
||||
"""Cleanup seen notifications and notifications whose event expired."""
|
||||
self = CurrentTask.get_task()
|
||||
self: Task = CurrentTask.get_task()
|
||||
notifications = Notification.objects.filter(Q(event=None) | Q(seen=True))
|
||||
amount = notifications.count()
|
||||
notifications.delete()
|
||||
|
||||
@@ -54,7 +54,6 @@ class Challenge(PassiveSerializer):
|
||||
|
||||
flow_info = ContextualFlowInfo(required=False)
|
||||
component = CharField(default="")
|
||||
xid = CharField(required=False)
|
||||
|
||||
response_errors = DictField(
|
||||
child=ErrorDetailSerializer(many=True), allow_empty=True, required=False
|
||||
|
||||
@@ -143,12 +143,10 @@ class FlowPlan:
|
||||
request: HttpRequest,
|
||||
flow: Flow,
|
||||
allowed_silent_types: list["StageView"] | None = None,
|
||||
**get_params,
|
||||
) -> HttpResponse:
|
||||
"""Redirect to the flow executor for this flow plan"""
|
||||
from authentik.flows.views.executor import (
|
||||
SESSION_KEY_PLAN,
|
||||
FlowContainer,
|
||||
FlowExecutorView,
|
||||
)
|
||||
|
||||
@@ -159,7 +157,6 @@ class FlowPlan:
|
||||
# No unskippable stages found, so we can directly return the response of the last stage
|
||||
final_stage: type[StageView] = self.bindings[-1].stage.view
|
||||
temp_exec = FlowExecutorView(flow=flow, request=request, plan=self)
|
||||
temp_exec.container = FlowContainer(request)
|
||||
temp_exec.current_stage = self.bindings[-1].stage
|
||||
temp_exec.current_stage_view = final_stage
|
||||
temp_exec.setup(request, flow.slug)
|
||||
@@ -177,9 +174,6 @@ class FlowPlan:
|
||||
):
|
||||
get_qs["inspector"] = "available"
|
||||
|
||||
for key, value in get_params:
|
||||
get_qs[key] = value
|
||||
|
||||
return redirect_with_qs(
|
||||
"authentik_core:if-flow",
|
||||
get_qs,
|
||||
|
||||
@@ -192,7 +192,6 @@ class ChallengeStageView(StageView):
|
||||
)
|
||||
flow_info.is_valid()
|
||||
challenge.initial_data["flow_info"] = flow_info.data
|
||||
challenge.initial_data["xid"] = self.executor.container.exec_id
|
||||
if isinstance(challenge, WithUserInfoChallenge):
|
||||
# If there's a pending user, update the `username` field
|
||||
# this field is only used by password managers.
|
||||
|
||||
@@ -29,7 +29,7 @@ window.authentik.flow = {
|
||||
{% block body %}
|
||||
<ak-skip-to-content></ak-skip-to-content>
|
||||
<ak-message-container></ak-message-container>
|
||||
<ak-flow-executor flowSlug="{{ flow.slug }}" xid="{{ xid }}">
|
||||
<ak-flow-executor flowSlug="{{ flow.slug }}">
|
||||
<ak-loading></ak-loading>
|
||||
</ak-flow-executor>
|
||||
{% endblock %}
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
"""authentik multi-stage authentication engine"""
|
||||
|
||||
from copy import deepcopy
|
||||
from uuid import uuid4
|
||||
|
||||
from django.conf import settings
|
||||
from django.contrib.auth.mixins import LoginRequiredMixin
|
||||
@@ -64,7 +63,6 @@ from authentik.policies.engine import PolicyEngine
|
||||
LOGGER = get_logger()
|
||||
# Argument used to redirect user after login
|
||||
NEXT_ARG_NAME = "next"
|
||||
SESSION_KEY_PLAN_CONTAINER = "authentik/flows/plan_container/%s"
|
||||
SESSION_KEY_PLAN = "authentik/flows/plan"
|
||||
SESSION_KEY_APPLICATION_PRE = "authentik/flows/application_pre"
|
||||
SESSION_KEY_GET = "authentik/flows/get"
|
||||
@@ -72,7 +70,6 @@ SESSION_KEY_POST = "authentik/flows/post"
|
||||
SESSION_KEY_HISTORY = "authentik/flows/history"
|
||||
QS_KEY_TOKEN = "flow_token" # nosec
|
||||
QS_QUERY = "query"
|
||||
QS_EXEC_ID = "xid"
|
||||
|
||||
|
||||
def challenge_types():
|
||||
@@ -99,88 +96,6 @@ class InvalidStageError(SentryIgnoredException):
|
||||
"""Error raised when a challenge from a stage is not valid"""
|
||||
|
||||
|
||||
class FlowContainer:
|
||||
"""Allow for multiple concurrent flow executions in the same session"""
|
||||
|
||||
def __init__(self, request: HttpRequest, exec_id: str | None = None) -> None:
|
||||
self.request = request
|
||||
self.exec_id = exec_id
|
||||
|
||||
@staticmethod
|
||||
def new(request: HttpRequest):
|
||||
exec_id = str(uuid4())
|
||||
request.session[SESSION_KEY_PLAN_CONTAINER % exec_id] = {}
|
||||
return FlowContainer(request, exec_id)
|
||||
|
||||
def exists(self) -> bool:
|
||||
"""Check if flow exists in container/session"""
|
||||
return SESSION_KEY_PLAN in self.session
|
||||
|
||||
def save(self):
|
||||
self.request.session.modified = True
|
||||
|
||||
@property
|
||||
def session(self):
|
||||
# Backwards compatibility: store session plan/etc directly in session
|
||||
if not self.exec_id:
|
||||
return self.request.session
|
||||
self.request.session.setdefault(SESSION_KEY_PLAN_CONTAINER % self.exec_id, {})
|
||||
return self.request.session.get(SESSION_KEY_PLAN_CONTAINER % self.exec_id, {})
|
||||
|
||||
@property
|
||||
def plan(self) -> FlowPlan:
|
||||
return self.session.get(SESSION_KEY_PLAN)
|
||||
|
||||
def to_redirect(
|
||||
self,
|
||||
request: HttpRequest,
|
||||
flow: Flow,
|
||||
allowed_silent_types: list[StageView] | None = None,
|
||||
**get_params,
|
||||
) -> HttpResponse:
|
||||
get_params[QS_EXEC_ID] = self.exec_id
|
||||
return self.plan.to_redirect(
|
||||
request, flow, allowed_silent_types=allowed_silent_types, **get_params
|
||||
)
|
||||
|
||||
@plan.setter
|
||||
def plan(self, value: FlowPlan):
|
||||
self.session[SESSION_KEY_PLAN] = value
|
||||
self.request.session.modified = True
|
||||
self.save()
|
||||
|
||||
@property
|
||||
def application_pre(self):
|
||||
return self.session.get(SESSION_KEY_APPLICATION_PRE)
|
||||
|
||||
@property
|
||||
def get(self) -> QueryDict:
|
||||
return self.session.get(SESSION_KEY_GET)
|
||||
|
||||
@get.setter
|
||||
def get(self, value: QueryDict):
|
||||
self.session[SESSION_KEY_GET] = value
|
||||
self.save()
|
||||
|
||||
@property
|
||||
def post(self) -> QueryDict:
|
||||
return self.session.get(SESSION_KEY_POST)
|
||||
|
||||
@post.setter
|
||||
def post(self, value: QueryDict):
|
||||
self.session[SESSION_KEY_POST] = value
|
||||
self.save()
|
||||
|
||||
@property
|
||||
def history(self) -> list[FlowPlan]:
|
||||
return self.session.get(SESSION_KEY_HISTORY)
|
||||
|
||||
@history.setter
|
||||
def history(self, value: list[FlowPlan]):
|
||||
self.session[SESSION_KEY_HISTORY] = value
|
||||
self.save()
|
||||
|
||||
|
||||
@method_decorator(xframe_options_sameorigin, name="dispatch")
|
||||
class FlowExecutorView(APIView):
|
||||
"""Flow executor, passing requests to Stage Views"""
|
||||
@@ -188,9 +103,8 @@ class FlowExecutorView(APIView):
|
||||
permission_classes = [AllowAny]
|
||||
|
||||
flow: Flow = None
|
||||
plan: FlowPlan | None = None
|
||||
container: FlowContainer
|
||||
|
||||
plan: FlowPlan | None = None
|
||||
current_binding: FlowStageBinding | None = None
|
||||
current_stage: Stage
|
||||
current_stage_view: View
|
||||
@@ -246,12 +160,10 @@ class FlowExecutorView(APIView):
|
||||
if QS_KEY_TOKEN in get_params:
|
||||
plan = self._check_flow_token(get_params[QS_KEY_TOKEN])
|
||||
if plan:
|
||||
container = FlowContainer.new(request)
|
||||
container.plan = plan
|
||||
self.request.session[SESSION_KEY_PLAN] = plan
|
||||
# Early check if there's an active Plan for the current session
|
||||
self.container = FlowContainer(request, request.GET.get(QS_EXEC_ID))
|
||||
if self.container.exists():
|
||||
self.plan: FlowPlan = self.container.plan
|
||||
if SESSION_KEY_PLAN in self.request.session:
|
||||
self.plan: FlowPlan = self.request.session[SESSION_KEY_PLAN]
|
||||
if self.plan.flow_pk != self.flow.pk.hex:
|
||||
self._logger.warning(
|
||||
"f(exec): Found existing plan for other flow, deleting plan",
|
||||
@@ -264,14 +176,13 @@ class FlowExecutorView(APIView):
|
||||
self._logger.debug("f(exec): Continuing existing plan")
|
||||
|
||||
# Initial flow request, check if we have an upstream query string passed in
|
||||
self.container.get = get_params
|
||||
request.session[SESSION_KEY_GET] = get_params
|
||||
# Don't check session again as we've either already loaded the plan or we need to plan
|
||||
if not self.plan:
|
||||
self.container.history = []
|
||||
request.session[SESSION_KEY_HISTORY] = []
|
||||
self._logger.debug("f(exec): No active Plan found, initiating planner")
|
||||
try:
|
||||
self.plan = self._initiate_plan()
|
||||
self.container.plan = self.plan
|
||||
except FlowNonApplicableException as exc:
|
||||
self._logger.warning("f(exec): Flow not applicable to current user", exc=exc)
|
||||
return self.handle_invalid_flow(exc)
|
||||
@@ -344,19 +255,12 @@ class FlowExecutorView(APIView):
|
||||
request=OpenApiTypes.NONE,
|
||||
parameters=[
|
||||
OpenApiParameter(
|
||||
name=QS_QUERY,
|
||||
name="query",
|
||||
location=OpenApiParameter.QUERY,
|
||||
required=True,
|
||||
description="Querystring as received",
|
||||
type=OpenApiTypes.STR,
|
||||
),
|
||||
OpenApiParameter(
|
||||
name=QS_EXEC_ID,
|
||||
location=OpenApiParameter.QUERY,
|
||||
required=False,
|
||||
description="Flow execution ID",
|
||||
type=OpenApiTypes.STR,
|
||||
),
|
||||
)
|
||||
],
|
||||
operation_id="flows_executor_get",
|
||||
)
|
||||
@@ -383,8 +287,8 @@ class FlowExecutorView(APIView):
|
||||
span.set_data("authentik Stage", self.current_stage_view)
|
||||
span.set_data("authentik Flow", self.flow.slug)
|
||||
stage_response = self.current_stage_view.dispatch(request)
|
||||
return to_stage_response(request, stage_response, self.container.exec_id)
|
||||
except Exception as exc:
|
||||
return to_stage_response(request, stage_response)
|
||||
except Exception as exc: # noqa
|
||||
return self.handle_exception(exc)
|
||||
|
||||
@extend_schema(
|
||||
@@ -402,19 +306,12 @@ class FlowExecutorView(APIView):
|
||||
),
|
||||
parameters=[
|
||||
OpenApiParameter(
|
||||
name=QS_QUERY,
|
||||
name="query",
|
||||
location=OpenApiParameter.QUERY,
|
||||
required=True,
|
||||
description="Querystring as received",
|
||||
type=OpenApiTypes.STR,
|
||||
),
|
||||
OpenApiParameter(
|
||||
name=QS_EXEC_ID,
|
||||
location=OpenApiParameter.QUERY,
|
||||
required=True,
|
||||
description="Flow execution ID",
|
||||
type=OpenApiTypes.STR,
|
||||
),
|
||||
)
|
||||
],
|
||||
operation_id="flows_executor_solve",
|
||||
)
|
||||
@@ -441,15 +338,14 @@ class FlowExecutorView(APIView):
|
||||
span.set_data("authentik Stage", self.current_stage_view)
|
||||
span.set_data("authentik Flow", self.flow.slug)
|
||||
stage_response = self.current_stage_view.dispatch(request)
|
||||
return to_stage_response(request, stage_response, self.container.exec_id)
|
||||
return to_stage_response(request, stage_response)
|
||||
except Exception as exc: # noqa
|
||||
return self.handle_exception(exc)
|
||||
|
||||
def _initiate_plan(self) -> FlowPlan:
|
||||
planner = FlowPlanner(self.flow)
|
||||
plan = planner.plan(self.request)
|
||||
container = FlowContainer.new(self.request)
|
||||
container.plan = plan
|
||||
self.request.session[SESSION_KEY_PLAN] = plan
|
||||
try:
|
||||
# Call the has_stages getter to check that
|
||||
# there are no issues with the class we might've gotten
|
||||
@@ -473,7 +369,7 @@ class FlowExecutorView(APIView):
|
||||
except FlowNonApplicableException as exc:
|
||||
self._logger.warning("f(exec): Flow restart not applicable to current user", exc=exc)
|
||||
return self.handle_invalid_flow(exc)
|
||||
self.container.plan = plan
|
||||
self.request.session[SESSION_KEY_PLAN] = plan
|
||||
kwargs = self.kwargs
|
||||
kwargs.update({"flow_slug": self.flow.slug})
|
||||
return redirect_with_qs("authentik_api:flow-executor", self.request.GET, **kwargs)
|
||||
@@ -495,13 +391,9 @@ class FlowExecutorView(APIView):
|
||||
)
|
||||
self.cancel()
|
||||
if next_param and not is_url_absolute(next_param):
|
||||
return to_stage_response(
|
||||
self.request, redirect_with_qs(next_param), self.container.exec_id
|
||||
)
|
||||
return to_stage_response(self.request, redirect_with_qs(next_param))
|
||||
return to_stage_response(
|
||||
self.request,
|
||||
self.stage_invalid(error_message=_("Invalid next URL")),
|
||||
self.container.exec_id,
|
||||
self.request, self.stage_invalid(error_message=_("Invalid next URL"))
|
||||
)
|
||||
|
||||
def stage_ok(self) -> HttpResponse:
|
||||
@@ -515,7 +407,7 @@ class FlowExecutorView(APIView):
|
||||
self.current_stage_view.cleanup()
|
||||
self.request.session.get(SESSION_KEY_HISTORY, []).append(deepcopy(self.plan))
|
||||
self.plan.pop()
|
||||
self.container.plan = self.plan
|
||||
self.request.session[SESSION_KEY_PLAN] = self.plan
|
||||
if self.plan.bindings:
|
||||
self._logger.debug(
|
||||
"f(exec): Continuing with next stage",
|
||||
@@ -558,7 +450,6 @@ class FlowExecutorView(APIView):
|
||||
|
||||
def cancel(self):
|
||||
"""Cancel current flow execution"""
|
||||
# TODO: Clean up container
|
||||
keys_to_delete = [
|
||||
SESSION_KEY_APPLICATION_PRE,
|
||||
SESSION_KEY_PLAN,
|
||||
@@ -581,8 +472,8 @@ class CancelView(View):
|
||||
|
||||
def get(self, request: HttpRequest) -> HttpResponse:
|
||||
"""View which canels the currently active plan"""
|
||||
if FlowContainer(request, request.GET.get(QS_EXEC_ID)).exists():
|
||||
del request.session[SESSION_KEY_PLAN_CONTAINER % request.GET.get(QS_EXEC_ID)]
|
||||
if SESSION_KEY_PLAN in request.session:
|
||||
del request.session[SESSION_KEY_PLAN]
|
||||
LOGGER.debug("Canceled current plan")
|
||||
return redirect("authentik_flows:default-invalidation")
|
||||
|
||||
@@ -630,12 +521,19 @@ class ToDefaultFlow(View):
|
||||
|
||||
def dispatch(self, request: HttpRequest) -> HttpResponse:
|
||||
flow = self.get_flow()
|
||||
get_qs = request.GET.copy()
|
||||
get_qs[QS_EXEC_ID] = str(uuid4())
|
||||
return redirect_with_qs("authentik_core:if-flow", get_qs, flow_slug=flow.slug)
|
||||
# If user already has a pending plan, clear it so we don't have to later.
|
||||
if SESSION_KEY_PLAN in self.request.session:
|
||||
plan: FlowPlan = self.request.session[SESSION_KEY_PLAN]
|
||||
if plan.flow_pk != flow.pk.hex:
|
||||
LOGGER.warning(
|
||||
"f(def): Found existing plan for other flow, deleting plan",
|
||||
flow_slug=flow.slug,
|
||||
)
|
||||
del self.request.session[SESSION_KEY_PLAN]
|
||||
return redirect_with_qs("authentik_core:if-flow", request.GET, flow_slug=flow.slug)
|
||||
|
||||
|
||||
def to_stage_response(request: HttpRequest, source: HttpResponse, xid: str) -> HttpResponse:
|
||||
def to_stage_response(request: HttpRequest, source: HttpResponse) -> HttpResponse:
|
||||
"""Convert normal HttpResponse into JSON Response"""
|
||||
if (
|
||||
isinstance(source, HttpResponseRedirect)
|
||||
@@ -654,7 +552,6 @@ def to_stage_response(request: HttpRequest, source: HttpResponse, xid: str) -> H
|
||||
RedirectChallenge(
|
||||
{
|
||||
"to": str(redirect_url),
|
||||
"xid": xid,
|
||||
}
|
||||
)
|
||||
)
|
||||
@@ -663,7 +560,6 @@ def to_stage_response(request: HttpRequest, source: HttpResponse, xid: str) -> H
|
||||
ShellChallenge(
|
||||
{
|
||||
"body": source.render().content.decode("utf-8"),
|
||||
"xid": xid,
|
||||
}
|
||||
)
|
||||
)
|
||||
@@ -673,7 +569,6 @@ def to_stage_response(request: HttpRequest, source: HttpResponse, xid: str) -> H
|
||||
ShellChallenge(
|
||||
{
|
||||
"body": source.content.decode("utf-8"),
|
||||
"xid": xid,
|
||||
}
|
||||
)
|
||||
)
|
||||
@@ -705,6 +600,4 @@ class ConfigureFlowInitView(LoginRequiredMixin, View):
|
||||
except FlowNonApplicableException:
|
||||
LOGGER.warning("Flow not applicable to user")
|
||||
raise Http404 from None
|
||||
container = FlowContainer.new(request)
|
||||
container.plan = plan
|
||||
return container.to_redirect(request, stage.configure_flow)
|
||||
return plan.to_redirect(request, stage.configure_flow)
|
||||
|
||||
@@ -7,7 +7,6 @@ from ua_parser.user_agent_parser import Parse
|
||||
|
||||
from authentik.core.views.interface import InterfaceView
|
||||
from authentik.flows.models import Flow
|
||||
from authentik.flows.views.executor import QS_EXEC_ID
|
||||
|
||||
|
||||
class FlowInterfaceView(InterfaceView):
|
||||
@@ -18,7 +17,6 @@ class FlowInterfaceView(InterfaceView):
|
||||
kwargs["flow"] = flow
|
||||
kwargs["flow_background_url"] = flow.background_url(self.request)
|
||||
kwargs["inspector"] = "inspector" in self.request.GET
|
||||
kwargs["xid"] = self.request.GET.get(QS_EXEC_ID)
|
||||
return super().get_context_data(**kwargs)
|
||||
|
||||
def compat_needs_sfe(self) -> bool:
|
||||
|
||||
@@ -3,10 +3,9 @@
|
||||
import re
|
||||
import socket
|
||||
from ipaddress import ip_address, ip_network
|
||||
from smtplib import SMTPException
|
||||
from textwrap import indent
|
||||
from types import CodeType
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from typing import Any
|
||||
|
||||
from cachetools import TLRUCache, cached
|
||||
from django.core.exceptions import FieldError
|
||||
@@ -30,10 +29,6 @@ from authentik.policies.types import PolicyRequest, PolicyResult
|
||||
from authentik.providers.oauth2.id_token import IDToken
|
||||
from authentik.providers.oauth2.models import AccessToken, OAuth2Provider
|
||||
from authentik.stages.authenticator import devices_for_user
|
||||
from authentik.stages.email.utils import TemplateEmailMessage
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from authentik.stages.email.models import EmailStage
|
||||
|
||||
LOGGER = get_logger()
|
||||
|
||||
@@ -62,12 +57,11 @@ class BaseEvaluator:
|
||||
self._globals = {
|
||||
"ak_call_policy": self.expr_func_call_policy,
|
||||
"ak_create_event": self.expr_event_create,
|
||||
"ak_create_jwt": self.expr_create_jwt,
|
||||
"ak_is_group_member": BaseEvaluator.expr_is_group_member,
|
||||
"ak_logger": get_logger(self._filename).bind(),
|
||||
"ak_send_email": self.expr_send_email,
|
||||
"ak_user_by": BaseEvaluator.expr_user_by,
|
||||
"ak_user_has_authenticator": BaseEvaluator.expr_func_user_has_authenticator,
|
||||
"ak_create_jwt": self.expr_create_jwt,
|
||||
"ip_address": ip_address,
|
||||
"ip_network": ip_network,
|
||||
"list_flatten": BaseEvaluator.expr_flatten,
|
||||
@@ -222,81 +216,6 @@ class BaseEvaluator:
|
||||
access_token.save()
|
||||
return access_token.token
|
||||
|
||||
def expr_send_email(
|
||||
self,
|
||||
address: str | list[str],
|
||||
subject: str,
|
||||
body: str | None = None,
|
||||
stage: "EmailStage | None" = None,
|
||||
template: str | None = None,
|
||||
context: dict | None = None,
|
||||
) -> bool:
|
||||
"""Send an email using authentik's email system
|
||||
|
||||
Args:
|
||||
address: Email address(es) to send to. Can be:
|
||||
- Single email: "user@example.com"
|
||||
- List of emails: ["user1@example.com", "user2@example.com"]
|
||||
subject: Email subject
|
||||
body: Email body (plain text/HTML). Mutually exclusive with template.
|
||||
stage: EmailStage instance to use for settings. If None, uses global settings.
|
||||
template: Template name to render. Mutually exclusive with body.
|
||||
context: Additional context variables for template rendering.
|
||||
|
||||
Returns:
|
||||
bool: True if email was queued successfully, False otherwise
|
||||
"""
|
||||
# Deferred imports to avoid circular import issues
|
||||
from authentik.stages.email.tasks import send_mails
|
||||
|
||||
if body and template:
|
||||
raise ValueError("body and template parameters are mutually exclusive")
|
||||
|
||||
if not body and not template:
|
||||
raise ValueError("Either body or template parameter must be provided")
|
||||
|
||||
# Normalize address parameter to list of (name, email) tuples
|
||||
if isinstance(address, str):
|
||||
# Single email address
|
||||
to_addresses = [("", address)]
|
||||
elif isinstance(address, list):
|
||||
if not address:
|
||||
raise ValueError("Address list cannot be empty")
|
||||
# List of email strings
|
||||
to_addresses = [("", email) for email in address]
|
||||
else:
|
||||
raise ValueError("Address must be a string or list of strings")
|
||||
|
||||
try:
|
||||
if template is not None:
|
||||
# Use all available context from the evaluator for template rendering
|
||||
template_context = self._context.copy()
|
||||
# Add any custom context passed to the function
|
||||
if context:
|
||||
template_context.update(context)
|
||||
|
||||
# Use template rendering
|
||||
message = TemplateEmailMessage(
|
||||
subject=subject,
|
||||
to=to_addresses,
|
||||
template_name=template,
|
||||
template_context=template_context,
|
||||
)
|
||||
else:
|
||||
# Use plain body
|
||||
message = TemplateEmailMessage(
|
||||
subject=subject,
|
||||
to=to_addresses,
|
||||
body=body,
|
||||
)
|
||||
|
||||
send_mails(stage, message)
|
||||
return True
|
||||
|
||||
except (SMTPException, ConnectionError, ValidationError, ValueError) as exc:
|
||||
LOGGER.warning("Failed to send email", exc=exc, addresses=to_addresses, subject=subject)
|
||||
return False
|
||||
|
||||
def wrap_expression(self, expression: str) -> str:
|
||||
"""Wrap expression in a function, call it, and save the result as `result`"""
|
||||
handler_signature = ",".join(sanitize_arg(x) for x in self._context.keys())
|
||||
|
||||
@@ -112,7 +112,6 @@ def get_logger_config():
|
||||
"hpack": "WARNING",
|
||||
"httpx": "WARNING",
|
||||
"azure": "WARNING",
|
||||
"channels_postgres": "WARNING",
|
||||
}
|
||||
for handler_name, level in handler_level_map.items():
|
||||
base_config["loggers"][handler_name] = {
|
||||
|
||||
@@ -3,15 +3,19 @@
|
||||
from asyncio.exceptions import CancelledError
|
||||
from typing import Any
|
||||
|
||||
from channels_redis.core import ChannelFull
|
||||
from django.conf import settings
|
||||
from django.core.exceptions import ImproperlyConfigured, SuspiciousOperation, ValidationError
|
||||
from django.db import DatabaseError, InternalError, OperationalError, ProgrammingError
|
||||
from django.http.response import Http404
|
||||
from django_redis.exceptions import ConnectionInterrupted
|
||||
from docker.errors import DockerException
|
||||
from dramatiq.errors import Retry
|
||||
from h11 import LocalProtocolError
|
||||
from ldap3.core.exceptions import LDAPException
|
||||
from psycopg.errors import Error
|
||||
from redis.exceptions import ConnectionError as RedisConnectionError
|
||||
from redis.exceptions import RedisError, ResponseError
|
||||
from rest_framework.exceptions import APIException
|
||||
from sentry_sdk import HttpTransport, get_current_scope
|
||||
from sentry_sdk import init as sentry_sdk_init
|
||||
@@ -19,6 +23,7 @@ from sentry_sdk.api import set_tag
|
||||
from sentry_sdk.integrations.argv import ArgvIntegration
|
||||
from sentry_sdk.integrations.django import DjangoIntegration
|
||||
from sentry_sdk.integrations.dramatiq import DramatiqIntegration
|
||||
from sentry_sdk.integrations.redis import RedisIntegration
|
||||
from sentry_sdk.integrations.socket import SocketIntegration
|
||||
from sentry_sdk.integrations.stdlib import StdlibIntegration
|
||||
from sentry_sdk.integrations.threading import ThreadingIntegration
|
||||
@@ -54,7 +59,13 @@ ignored_classes = (
|
||||
ProgrammingError,
|
||||
SuspiciousOperation,
|
||||
ValidationError,
|
||||
# Redis errors
|
||||
RedisConnectionError,
|
||||
ConnectionInterrupted,
|
||||
RedisError,
|
||||
ResponseError,
|
||||
# websocket errors
|
||||
ChannelFull,
|
||||
WebSocketException,
|
||||
LocalProtocolError,
|
||||
# rest_framework error
|
||||
@@ -101,6 +112,7 @@ def sentry_init(**sentry_init_kwargs):
|
||||
ArgvIntegration(),
|
||||
DjangoIntegration(transaction_style="function_name", cache_spans=True),
|
||||
DramatiqIntegration(),
|
||||
RedisIntegration(),
|
||||
SocketIntegration(),
|
||||
StdlibIntegration(),
|
||||
ThreadingIntegration(propagate_hub=True),
|
||||
@@ -147,7 +159,9 @@ def before_send(event: dict, hint: dict) -> dict | None:
|
||||
if event["logger"] in [
|
||||
"asyncio",
|
||||
"multiprocessing",
|
||||
"django_redis",
|
||||
"django.security.DisallowedHost",
|
||||
"django_redis.cache",
|
||||
"paramiko.transport",
|
||||
]:
|
||||
return None
|
||||
|
||||
@@ -2,6 +2,7 @@ from django.core.paginator import Paginator
|
||||
from django.db.models import Model, QuerySet
|
||||
from django.db.models.query import Q
|
||||
from django.utils.text import slugify
|
||||
from django_dramatiq_postgres.middleware import CurrentTask
|
||||
from dramatiq.actor import Actor
|
||||
from dramatiq.composition import group
|
||||
from dramatiq.errors import Retry
|
||||
@@ -21,7 +22,6 @@ from authentik.lib.sync.outgoing.exceptions import (
|
||||
from authentik.lib.sync.outgoing.models import OutgoingSyncProvider
|
||||
from authentik.lib.utils.errors import exception_to_dict
|
||||
from authentik.lib.utils.reflection import class_to_path, path_to_class
|
||||
from authentik.tasks.middleware import CurrentTask
|
||||
from authentik.tasks.models import Task
|
||||
|
||||
|
||||
@@ -61,7 +61,7 @@ class SyncTasks:
|
||||
provider_pk: int,
|
||||
sync_objects: Actor[[str, int, int, bool], None],
|
||||
):
|
||||
task = CurrentTask.get_task()
|
||||
task: Task = CurrentTask.get_task()
|
||||
self.logger = get_logger().bind(
|
||||
provider_type=class_to_path(self._provider_model),
|
||||
provider_pk=provider_pk,
|
||||
@@ -118,7 +118,7 @@ class SyncTasks:
|
||||
override_dry_run=False,
|
||||
**filter,
|
||||
):
|
||||
task = CurrentTask.get_task()
|
||||
task: Task = CurrentTask.get_task()
|
||||
_object_type: type[Model] = path_to_class(object_type)
|
||||
self.logger = get_logger().bind(
|
||||
provider_type=class_to_path(self._provider_model),
|
||||
@@ -173,7 +173,7 @@ class SyncTasks:
|
||||
except TransientSyncException as exc:
|
||||
self.logger.warning("failed to sync object", exc=exc, user=obj)
|
||||
task.warning(
|
||||
f"Failed to sync {str(obj)} due to transient error: {str(exc)}",
|
||||
f"Failed to sync {str(obj)} due to " f"transient error: {str(exc)}",
|
||||
obj=sanitize_item(obj),
|
||||
exception=exception_to_dict(exc),
|
||||
)
|
||||
@@ -207,7 +207,7 @@ class SyncTasks:
|
||||
provider_pk: int,
|
||||
raw_op: str,
|
||||
):
|
||||
task = CurrentTask.get_task()
|
||||
task: Task = CurrentTask.get_task()
|
||||
self.logger = get_logger().bind(
|
||||
provider_type=class_to_path(self._provider_model),
|
||||
)
|
||||
@@ -281,7 +281,7 @@ class SyncTasks:
|
||||
action: str,
|
||||
pk_set: list[int],
|
||||
):
|
||||
task = CurrentTask.get_task()
|
||||
task: Task = CurrentTask.get_task()
|
||||
self.logger = get_logger().bind(
|
||||
provider_type=class_to_path(self._provider_model),
|
||||
)
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
"""Test Evaluator base functions"""
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
from django.test import RequestFactory, TestCase
|
||||
from django.urls import reverse
|
||||
from jwt import decode
|
||||
@@ -79,163 +77,3 @@ class TestEvaluator(TestCase):
|
||||
jwt, provider.client_secret, algorithms=["HS256"], audience=provider.client_id
|
||||
)
|
||||
self.assertEqual(decoded["preferred_username"], user.username)
|
||||
|
||||
@patch("authentik.stages.email.tasks.send_mails")
|
||||
def test_expr_send_email_with_body(self, mock_send_mails):
|
||||
"""Test ak_send_email with body parameter"""
|
||||
user = create_test_user()
|
||||
evaluator = BaseEvaluator(generate_id())
|
||||
evaluator._context = {"user": user}
|
||||
|
||||
# Test sending email with body
|
||||
result = evaluator.evaluate(
|
||||
"return ak_send_email('test@example.com', 'Test Subject', body='Test Body')"
|
||||
)
|
||||
|
||||
self.assertTrue(result)
|
||||
mock_send_mails.assert_called_once()
|
||||
|
||||
# Verify the call arguments - send_mails is called with (stage, message)
|
||||
args, kwargs = mock_send_mails.call_args
|
||||
stage, message = args
|
||||
|
||||
# Check that global settings are used (stage is None)
|
||||
self.assertIsNone(stage)
|
||||
|
||||
# Check message properties
|
||||
self.assertEqual(message.subject, "Test Subject")
|
||||
self.assertEqual(message.to, ["test@example.com"])
|
||||
self.assertEqual(message.body, "Test Body")
|
||||
|
||||
@patch("authentik.stages.email.tasks.send_mails")
|
||||
def test_expr_send_email_with_template(self, mock_send_mails):
|
||||
"""Test ak_send_email with template parameter"""
|
||||
user = create_test_user()
|
||||
evaluator = BaseEvaluator(generate_id())
|
||||
evaluator._context = {"user": user}
|
||||
|
||||
# Test sending email with template
|
||||
result = evaluator.evaluate(
|
||||
"return ak_send_email('test@example.com', 'Test Subject', "
|
||||
"template='email/password_reset.html')"
|
||||
)
|
||||
|
||||
self.assertTrue(result)
|
||||
mock_send_mails.assert_called_once()
|
||||
|
||||
def test_expr_send_email_validation_errors(self):
|
||||
"""Test ak_send_email validation errors"""
|
||||
evaluator = BaseEvaluator(generate_id())
|
||||
|
||||
# Test error when both body and template are provided
|
||||
with self.assertRaises(ValueError) as cm:
|
||||
evaluator.evaluate(
|
||||
"return ak_send_email('test@example.com', 'Test', "
|
||||
"body='Body', template='template.html')"
|
||||
)
|
||||
self.assertIn("mutually exclusive", str(cm.exception))
|
||||
|
||||
# Test error when neither body nor template are provided
|
||||
with self.assertRaises(ValueError) as cm:
|
||||
evaluator.evaluate("return ak_send_email('test@example.com', 'Test')")
|
||||
self.assertIn("Either body or template parameter must be provided", str(cm.exception))
|
||||
|
||||
@patch("authentik.stages.email.tasks.send_mails")
|
||||
def test_expr_send_email_with_custom_stage(self, mock_send_mails):
|
||||
"""Test ak_send_email with custom EmailStage"""
|
||||
from authentik.stages.email.models import EmailStage
|
||||
|
||||
user = create_test_user()
|
||||
custom_stage = EmailStage(
|
||||
name="custom-stage", use_global_settings=False, from_address="custom@example.com"
|
||||
)
|
||||
|
||||
evaluator = BaseEvaluator(generate_id())
|
||||
evaluator._context = {"user": user, "custom_stage": custom_stage}
|
||||
|
||||
# Test sending email with custom stage
|
||||
result = evaluator.evaluate(
|
||||
"return ak_send_email('test@example.com', 'Test Subject', "
|
||||
"body='Test Body', stage=custom_stage)"
|
||||
)
|
||||
|
||||
self.assertTrue(result)
|
||||
mock_send_mails.assert_called_once()
|
||||
|
||||
# Verify the custom stage was used
|
||||
args, kwargs = mock_send_mails.call_args
|
||||
stage, message = args
|
||||
|
||||
self.assertEqual(stage, custom_stage)
|
||||
self.assertFalse(stage.use_global_settings)
|
||||
|
||||
@patch("authentik.stages.email.tasks.send_mails")
|
||||
def test_expr_send_email_with_context(self, mock_send_mails):
|
||||
"""Test ak_send_email with custom context parameter"""
|
||||
user = create_test_user()
|
||||
evaluator = BaseEvaluator(generate_id())
|
||||
evaluator._context = {"user": user, "request_id": "123"}
|
||||
|
||||
# Test sending email with template and custom context
|
||||
result = evaluator.evaluate(
|
||||
"return ak_send_email('test@example.com', 'Test Subject', "
|
||||
"template='email/password_reset.html', "
|
||||
"context={'url': 'http://localhost', 'expires': '2026-01-01'})"
|
||||
)
|
||||
|
||||
self.assertTrue(result)
|
||||
mock_send_mails.assert_called_once()
|
||||
|
||||
# Verify the call arguments - send_mails is called with (stage, message)
|
||||
args, kwargs = mock_send_mails.call_args
|
||||
stage, message = args
|
||||
|
||||
# Check that global settings are used (stage is None)
|
||||
self.assertIsNone(stage)
|
||||
|
||||
self.assertEqual(message.subject, "Test Subject")
|
||||
self.assertEqual(message.to, ["test@example.com"])
|
||||
self.assertIn("2026-01-01", message.body)
|
||||
self.assertIn("http://localhost", message.body)
|
||||
|
||||
@patch("authentik.stages.email.tasks.send_mails")
|
||||
def test_expr_send_email_multiple_addresses(self, mock_send_mails):
|
||||
"""Test ak_send_email with multiple email addresses"""
|
||||
user = create_test_user()
|
||||
evaluator = BaseEvaluator(generate_id())
|
||||
evaluator._context = {"user": user}
|
||||
|
||||
# Test sending email to multiple addresses
|
||||
result = evaluator.evaluate(
|
||||
"return ak_send_email(['user1@example.com', 'user2@example.com'], "
|
||||
"'Test Subject', body='Test Body')"
|
||||
)
|
||||
|
||||
self.assertTrue(result)
|
||||
mock_send_mails.assert_called_once()
|
||||
|
||||
# Verify the call arguments - send_mails is called with (stage, message)
|
||||
args, kwargs = mock_send_mails.call_args
|
||||
stage, message = args
|
||||
|
||||
# Check that global settings are used (stage is None)
|
||||
self.assertIsNone(stage)
|
||||
|
||||
# Check message properties - should have multiple recipients
|
||||
self.assertEqual(message.subject, "Test Subject")
|
||||
self.assertEqual(message.to, ["user1@example.com", "user2@example.com"])
|
||||
self.assertEqual(message.body, "Test Body")
|
||||
|
||||
def test_expr_send_email_multiple_addresses_validation(self):
|
||||
"""Test ak_send_email validation with multiple addresses"""
|
||||
evaluator = BaseEvaluator(generate_id())
|
||||
|
||||
# Test error when empty list is provided
|
||||
with self.assertRaises(ValueError) as cm:
|
||||
evaluator.evaluate("return ak_send_email([], 'Test', body='Body')")
|
||||
self.assertIn("Address list cannot be empty", str(cm.exception))
|
||||
|
||||
# Test error when invalid type is provided
|
||||
with self.assertRaises(ValueError) as cm:
|
||||
evaluator.evaluate("return ak_send_email(123, 'Test', body='Body')")
|
||||
self.assertIn("Address must be a string or list of strings", str(cm.exception))
|
||||
|
||||
@@ -3,9 +3,7 @@
|
||||
from dataclasses import asdict, dataclass, field
|
||||
from datetime import datetime
|
||||
from enum import IntEnum
|
||||
from hashlib import sha256
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from asgiref.sync import async_to_sync
|
||||
from channels.exceptions import DenyConnection
|
||||
@@ -20,15 +18,8 @@ from structlog.stdlib import BoundLogger, get_logger
|
||||
from authentik.outposts.apps import GAUGE_OUTPOSTS_CONNECTED, GAUGE_OUTPOSTS_LAST_UPDATE
|
||||
from authentik.outposts.models import OUTPOST_HELLO_INTERVAL, Outpost, OutpostState
|
||||
|
||||
|
||||
def build_outpost_group(outpost_pk: str | UUID) -> str:
|
||||
return sha256(f"{connection.schema_name}/group_outpost_{str(outpost_pk)}".encode()).hexdigest()
|
||||
|
||||
|
||||
def build_outpost_group_instance(outpost_pk: str | UUID, instance: str) -> str:
|
||||
return sha256(
|
||||
f"{connection.schema_name}/group_outpost_{str(outpost_pk)}_{instance}".encode()
|
||||
).hexdigest()
|
||||
OUTPOST_GROUP = "group_outpost_%(outpost_pk)s"
|
||||
OUTPOST_GROUP_INSTANCE = "group_outpost_%(outpost_pk)s_%(instance)s"
|
||||
|
||||
|
||||
class WebsocketMessageInstruction(IntEnum):
|
||||
@@ -73,24 +64,26 @@ class OutpostConsumer(JsonWebsocketConsumer):
|
||||
def connect(self):
|
||||
uuid = self.scope["url_route"]["kwargs"]["pk"]
|
||||
user = self.scope["user"]
|
||||
self.outpost: Outpost | None = (
|
||||
outpost = (
|
||||
get_objects_for_user(user, "authentik_outposts.view_outpost").filter(pk=uuid).first()
|
||||
)
|
||||
if self.outpost is None:
|
||||
if not outpost:
|
||||
raise DenyConnection()
|
||||
self.logger = self.logger.bind(outpost=self.outpost)
|
||||
self.logger = self.logger.bind(outpost=outpost)
|
||||
try:
|
||||
self.accept()
|
||||
except RuntimeError as exc:
|
||||
self.logger.warning("runtime error during accept", exc=exc)
|
||||
raise DenyConnection() from None
|
||||
self.outpost = outpost
|
||||
query = QueryDict(self.scope["query_string"].decode())
|
||||
self.instance_uid = query.get("instance_uuid", self.channel_name)
|
||||
async_to_sync(self.channel_layer.group_add)(
|
||||
build_outpost_group(self.outpost.pk), self.channel_name
|
||||
OUTPOST_GROUP % {"outpost_pk": str(self.outpost.pk)}, self.channel_name
|
||||
)
|
||||
async_to_sync(self.channel_layer.group_add)(
|
||||
build_outpost_group_instance(self.outpost.pk, self.instance_uid),
|
||||
OUTPOST_GROUP_INSTANCE
|
||||
% {"outpost_pk": str(self.outpost.pk), "instance": self.instance_uid},
|
||||
self.channel_name,
|
||||
)
|
||||
GAUGE_OUTPOSTS_CONNECTED.labels(
|
||||
@@ -103,11 +96,12 @@ class OutpostConsumer(JsonWebsocketConsumer):
|
||||
def disconnect(self, code):
|
||||
if self.outpost:
|
||||
async_to_sync(self.channel_layer.group_discard)(
|
||||
build_outpost_group(self.outpost.pk), self.channel_name
|
||||
OUTPOST_GROUP % {"outpost_pk": str(self.outpost.pk)}, self.channel_name
|
||||
)
|
||||
if self.instance_uid:
|
||||
async_to_sync(self.channel_layer.group_discard)(
|
||||
build_outpost_group_instance(self.outpost.pk, self.instance_uid),
|
||||
OUTPOST_GROUP_INSTANCE
|
||||
% {"outpost_pk": str(self.outpost.pk), "instance": self.instance_uid},
|
||||
self.channel_name,
|
||||
)
|
||||
if self.outpost and self.instance_uid:
|
||||
|
||||
@@ -12,6 +12,7 @@ from channels.layers import get_channel_layer
|
||||
from django.core.cache import cache
|
||||
from django.utils.text import slugify
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
from django_dramatiq_postgres.middleware import CurrentTask
|
||||
from docker.constants import DEFAULT_UNIX_SOCKET
|
||||
from dramatiq.actor import actor
|
||||
from kubernetes.config.incluster_config import SERVICE_TOKEN_FILENAME
|
||||
@@ -20,7 +21,7 @@ from structlog.stdlib import get_logger
|
||||
from yaml import safe_load
|
||||
|
||||
from authentik.lib.config import CONFIG
|
||||
from authentik.outposts.consumer import build_outpost_group
|
||||
from authentik.outposts.consumer import OUTPOST_GROUP
|
||||
from authentik.outposts.controllers.base import BaseController, ControllerException
|
||||
from authentik.outposts.controllers.docker import DockerClient
|
||||
from authentik.outposts.controllers.kubernetes import KubernetesClient
|
||||
@@ -40,7 +41,7 @@ from authentik.providers.rac.controllers.docker import RACDockerController
|
||||
from authentik.providers.rac.controllers.kubernetes import RACKubernetesController
|
||||
from authentik.providers.radius.controllers.docker import RadiusDockerController
|
||||
from authentik.providers.radius.controllers.kubernetes import RadiusKubernetesController
|
||||
from authentik.tasks.middleware import CurrentTask
|
||||
from authentik.tasks.models import Task
|
||||
|
||||
LOGGER = get_logger()
|
||||
CACHE_KEY_OUTPOST_DOWN = "goauthentik.io/outposts/teardown/%s"
|
||||
@@ -107,7 +108,7 @@ def outpost_service_connection_monitor(connection_pk: Any):
|
||||
@actor(description=_("Create/update/monitor/delete the deployment of an Outpost."))
|
||||
def outpost_controller(outpost_pk: str, action: str = "up", from_cache: bool = False):
|
||||
"""Create/update/monitor/delete the deployment of an Outpost"""
|
||||
self = CurrentTask.get_task()
|
||||
self: Task = CurrentTask.get_task()
|
||||
self.set_uid(outpost_pk)
|
||||
logs = []
|
||||
if from_cache:
|
||||
@@ -141,7 +142,7 @@ def outpost_token_ensurer():
|
||||
"""
|
||||
Periodically ensure that all Outposts have valid Service Accounts and Tokens
|
||||
"""
|
||||
self = CurrentTask.get_task()
|
||||
self: Task = CurrentTask.get_task()
|
||||
all_outposts = Outpost.objects.all()
|
||||
for outpost in all_outposts:
|
||||
_ = outpost.token
|
||||
@@ -160,7 +161,7 @@ def outpost_send_update(pk: Any):
|
||||
_ = outpost.token
|
||||
outpost.build_user_permissions(outpost.user)
|
||||
layer = get_channel_layer()
|
||||
group = build_outpost_group(outpost.pk)
|
||||
group = OUTPOST_GROUP % {"outpost_pk": str(outpost.pk)}
|
||||
LOGGER.debug("sending update", channel=group, outpost=outpost)
|
||||
async_to_sync(layer.group_send)(group, {"type": "event.update"})
|
||||
|
||||
@@ -168,7 +169,7 @@ def outpost_send_update(pk: Any):
|
||||
@actor(description=_("Checks the local environment and create Service connections."))
|
||||
def outpost_connection_discovery():
|
||||
"""Checks the local environment and create Service connections."""
|
||||
self = CurrentTask.get_task()
|
||||
self: Task = CurrentTask.get_task()
|
||||
if not CONFIG.get_bool("outposts.discover"):
|
||||
self.info("Outpost integration discovery is disabled")
|
||||
return
|
||||
@@ -212,7 +213,7 @@ def outpost_session_end(session_id: str):
|
||||
hashed_session_id = hash_session_key(session_id)
|
||||
for outpost in Outpost.objects.all():
|
||||
LOGGER.info("Sending session end signal to outpost", outpost=outpost)
|
||||
group = build_outpost_group(outpost.pk)
|
||||
group = OUTPOST_GROUP % {"outpost_pk": str(outpost.pk)}
|
||||
async_to_sync(layer.group_send)(
|
||||
group,
|
||||
{
|
||||
|
||||
@@ -45,7 +45,7 @@ class TestOutpostWS(TransactionTestCase):
|
||||
communicator = WebsocketCommunicator(
|
||||
URLRouter(websocket.websocket_urlpatterns),
|
||||
f"/ws/outpost/{self.outpost.pk}/",
|
||||
[(b"authorization", f"Bearer {self.token}".encode())],
|
||||
{b"authorization": f"Bearer {self.token}".encode()},
|
||||
)
|
||||
connected, _ = await communicator.connect()
|
||||
self.assertTrue(connected)
|
||||
@@ -56,7 +56,7 @@ class TestOutpostWS(TransactionTestCase):
|
||||
communicator = WebsocketCommunicator(
|
||||
URLRouter(websocket.websocket_urlpatterns),
|
||||
f"/ws/outpost/{self.outpost.pk}/",
|
||||
[(b"authorization", f"Bearer {self.token}".encode())],
|
||||
{b"authorization": f"Bearer {self.token}".encode()},
|
||||
)
|
||||
connected, _ = await communicator.connect()
|
||||
self.assertTrue(connected)
|
||||
@@ -83,7 +83,7 @@ class TestOutpostWS(TransactionTestCase):
|
||||
communicator = WebsocketCommunicator(
|
||||
URLRouter(websocket.websocket_urlpatterns),
|
||||
f"/ws/outpost/{self.outpost.pk}/",
|
||||
[(b"authorization", f"Bearer {self.token}".encode())],
|
||||
{b"authorization": f"Bearer {self.token}".encode()},
|
||||
)
|
||||
connected, _ = await communicator.connect()
|
||||
self.assertTrue(connected)
|
||||
|
||||
@@ -11,14 +11,11 @@ from authentik.outposts.api.service_connections import (
|
||||
from authentik.outposts.channels import TokenOutpostMiddleware
|
||||
from authentik.outposts.consumer import OutpostConsumer
|
||||
from authentik.root.middleware import ChannelsLoggingMiddleware
|
||||
from authentik.tenants.channels import TenantsAwareMiddleware
|
||||
|
||||
websocket_urlpatterns = [
|
||||
path(
|
||||
"ws/outpost/<uuid:pk>/",
|
||||
ChannelsLoggingMiddleware(
|
||||
TenantsAwareMiddleware(TokenOutpostMiddleware(OutpostConsumer.as_asgi()))
|
||||
),
|
||||
ChannelsLoggingMiddleware(TokenOutpostMiddleware(OutpostConsumer.as_asgi())),
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
@@ -23,8 +23,6 @@ SCOPE_OPENID_PROFILE = "profile"
|
||||
SCOPE_OPENID_EMAIL = "email"
|
||||
SCOPE_OFFLINE_ACCESS = "offline_access"
|
||||
|
||||
UI_LOCALES = "ui_locales"
|
||||
|
||||
# https://www.iana.org/assignments/oauth-parameters/auth-parameters.xhtml#pkce-code-challenge-method
|
||||
PKCE_METHOD_PLAIN = "plain"
|
||||
PKCE_METHOD_S256 = "S256"
|
||||
|
||||
@@ -1,13 +1,14 @@
|
||||
"""OAuth2 Provider Tasks"""
|
||||
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
from django_dramatiq_postgres.middleware import CurrentTask
|
||||
from dramatiq.actor import actor
|
||||
from structlog.stdlib import get_logger
|
||||
|
||||
from authentik.lib.utils.http import get_http_session
|
||||
from authentik.providers.oauth2.models import OAuth2Provider
|
||||
from authentik.providers.oauth2.utils import create_logout_token
|
||||
from authentik.tasks.middleware import CurrentTask
|
||||
from authentik.tasks.models import Task
|
||||
|
||||
LOGGER = get_logger()
|
||||
|
||||
@@ -30,7 +31,7 @@ def send_backchannel_logout_request(
|
||||
Returns:
|
||||
bool: True if the request was sent successfully, False otherwise
|
||||
"""
|
||||
self = CurrentTask.get_task()
|
||||
self: Task = CurrentTask.get_task()
|
||||
LOGGER.debug("Sending back-channel logout request", provider_pk=provider_pk, sub=sub)
|
||||
|
||||
provider = OAuth2Provider.objects.filter(pk=provider_pk).first()
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
"""Test authorize view"""
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
from urllib.parse import parse_qs, urlparse
|
||||
|
||||
from django.test import RequestFactory
|
||||
from django.urls import reverse
|
||||
@@ -671,55 +670,3 @@ class TestAuthorize(OAuthTestCase):
|
||||
)
|
||||
parsed = OAuthAuthorizationParams.from_request(request)
|
||||
self.assertNotIn(SCOPE_OFFLINE_ACCESS, parsed.scope)
|
||||
|
||||
def test_ui_locales(self):
|
||||
"""Test OIDC ui_locales authorization"""
|
||||
flow = create_test_flow()
|
||||
provider = OAuth2Provider.objects.create(
|
||||
name=generate_id(),
|
||||
client_id="test",
|
||||
authorization_flow=flow,
|
||||
redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "foo://localhost")],
|
||||
access_code_validity="seconds=100",
|
||||
)
|
||||
Application.objects.create(name="app", slug="app", provider=provider)
|
||||
state = generate_id()
|
||||
self.client.logout()
|
||||
response = self.client.get(
|
||||
reverse("authentik_providers_oauth2:authorize"),
|
||||
data={
|
||||
"response_type": "code",
|
||||
"client_id": "test",
|
||||
"state": state,
|
||||
"redirect_uri": "foo://localhost",
|
||||
"ui_locales": "invalid fr",
|
||||
},
|
||||
)
|
||||
parsed = parse_qs(urlparse(response.url).query)
|
||||
self.assertEqual(parsed["locale"], ["fr"])
|
||||
|
||||
def test_ui_locales_invalid(self):
|
||||
"""Test OIDC ui_locales authorization"""
|
||||
flow = create_test_flow()
|
||||
provider = OAuth2Provider.objects.create(
|
||||
name=generate_id(),
|
||||
client_id="test",
|
||||
authorization_flow=flow,
|
||||
redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "foo://localhost")],
|
||||
access_code_validity="seconds=100",
|
||||
)
|
||||
Application.objects.create(name="app", slug="app", provider=provider)
|
||||
state = generate_id()
|
||||
self.client.logout()
|
||||
response = self.client.get(
|
||||
reverse("authentik_providers_oauth2:authorize"),
|
||||
data={
|
||||
"response_type": "code",
|
||||
"client_id": "test",
|
||||
"state": state,
|
||||
"redirect_uri": "foo://localhost",
|
||||
"ui_locales": "invalid",
|
||||
},
|
||||
)
|
||||
parsed = parse_qs(urlparse(response.url).query)
|
||||
self.assertNotIn("locale", parsed)
|
||||
|
||||
@@ -5,14 +5,13 @@ from datetime import timedelta
|
||||
from json import dumps
|
||||
from re import error as RegexError
|
||||
from re import fullmatch
|
||||
from urllib.parse import parse_qs, quote, urlencode, urlparse, urlsplit, urlunparse, urlunsplit
|
||||
from urllib.parse import parse_qs, urlencode, urlparse, urlsplit, urlunsplit
|
||||
from uuid import uuid4
|
||||
|
||||
from django.conf import settings
|
||||
from django.http import HttpRequest, HttpResponse, HttpResponseRedirect
|
||||
from django.http import HttpRequest, HttpResponse
|
||||
from django.http.response import Http404, HttpResponseBadRequest
|
||||
from django.shortcuts import get_object_or_404
|
||||
from django.utils import timezone, translation
|
||||
from django.utils import timezone
|
||||
from django.utils.translation import gettext as _
|
||||
from structlog.stdlib import get_logger
|
||||
|
||||
@@ -42,7 +41,6 @@ from authentik.providers.oauth2.constants import (
|
||||
SCOPE_OFFLINE_ACCESS,
|
||||
SCOPE_OPENID,
|
||||
TOKEN_TYPE,
|
||||
UI_LOCALES,
|
||||
)
|
||||
from authentik.providers.oauth2.errors import (
|
||||
AuthorizeError,
|
||||
@@ -389,45 +387,6 @@ class AuthorizationFlowInitView(BufferedPolicyAccessView):
|
||||
request.context["oauth_response_type"] = self.params.response_type
|
||||
return request
|
||||
|
||||
def dispatch_with_language(self, request: HttpRequest, *args, **kwargs) -> HttpResponse:
|
||||
"""Activate language from OIDC specific ui_locales parameter, picking the earliest one
|
||||
available"""
|
||||
selected_language = None
|
||||
if UI_LOCALES in self.request.GET:
|
||||
languages = str(self.request.GET[UI_LOCALES]).split(" ")
|
||||
for language in languages:
|
||||
if translation.check_for_language(language):
|
||||
selected_language = translation.get_supported_language_variant(language)
|
||||
LOGGER.debug(
|
||||
"Activating language from oidc ui_locales", locale=selected_language
|
||||
)
|
||||
break
|
||||
translation.activate(selected_language)
|
||||
response = super().dispatch(request, *args, **kwargs)
|
||||
if selected_language:
|
||||
response.set_cookie(
|
||||
settings.LANGUAGE_COOKIE_NAME,
|
||||
selected_language,
|
||||
max_age=settings.LANGUAGE_COOKIE_AGE,
|
||||
path=settings.LANGUAGE_COOKIE_PATH,
|
||||
domain=settings.LANGUAGE_COOKIE_DOMAIN,
|
||||
secure=settings.LANGUAGE_COOKIE_SECURE,
|
||||
httponly=settings.LANGUAGE_COOKIE_HTTPONLY,
|
||||
samesite=settings.LANGUAGE_COOKIE_SAMESITE,
|
||||
)
|
||||
if isinstance(response, HttpResponseRedirect):
|
||||
parsed_url = urlparse(response.url)
|
||||
args = parse_qs(parsed_url.query)
|
||||
args["locale"] = selected_language
|
||||
response["Location"] = urlunparse(
|
||||
parsed_url._replace(query=urlencode(args, quote_via=quote, doseq=True))
|
||||
)
|
||||
return response
|
||||
|
||||
def dispatch(self, request: HttpRequest, *args, **kwargs):
|
||||
# Activate language before parsing params (error messages should be localised)
|
||||
return self.dispatch_with_language(request, *args, **kwargs)
|
||||
|
||||
def get(self, request: HttpRequest, *args, **kwargs) -> HttpResponse:
|
||||
"""Start FlowPLanner, return to flow executor shell"""
|
||||
# Require a login event to be set, otherwise make the user re-login
|
||||
|
||||
@@ -5,7 +5,7 @@ from channels.layers import get_channel_layer
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
from dramatiq.actor import actor
|
||||
|
||||
from authentik.outposts.consumer import build_outpost_group
|
||||
from authentik.outposts.consumer import OUTPOST_GROUP
|
||||
from authentik.outposts.models import Outpost, OutpostType
|
||||
from authentik.providers.oauth2.id_token import hash_session_key
|
||||
|
||||
@@ -15,7 +15,7 @@ def proxy_on_logout(session_id: str):
|
||||
layer = get_channel_layer()
|
||||
hashed_session_id = hash_session_key(session_id)
|
||||
for outpost in Outpost.objects.filter(type=OutpostType.PROXY):
|
||||
group = build_outpost_group(outpost.pk)
|
||||
group = OUTPOST_GROUP % {"outpost_pk": str(outpost.pk)}
|
||||
async_to_sync(layer.group_send)(
|
||||
group,
|
||||
{
|
||||
|
||||
@@ -1,46 +1,28 @@
|
||||
"""RAC Client consumer"""
|
||||
|
||||
from hashlib import sha256
|
||||
|
||||
from asgiref.sync import async_to_sync
|
||||
from channels.db import database_sync_to_async
|
||||
from channels.exceptions import ChannelFull, DenyConnection
|
||||
from channels.generic.websocket import AsyncWebsocketConsumer
|
||||
from django.db import connection
|
||||
from django.http.request import QueryDict
|
||||
from structlog.stdlib import BoundLogger, get_logger
|
||||
|
||||
from authentik.outposts.consumer import build_outpost_group_instance
|
||||
from authentik.outposts.consumer import OUTPOST_GROUP_INSTANCE
|
||||
from authentik.outposts.models import Outpost, OutpostState, OutpostType
|
||||
from authentik.providers.rac.models import ConnectionToken, RACProvider
|
||||
|
||||
|
||||
def build_rac_client_group() -> str:
|
||||
"""
|
||||
Global broadcast group, which messages are sent to when the outpost connects back
|
||||
to authentik for a specific connection
|
||||
The `RACClientConsumer` consumer adds itself to this group on connection,
|
||||
and removes itself once it has been assigned a specific outpost channel
|
||||
"""
|
||||
return sha256(f"{connection.schema_name}/group_rac_client".encode()).hexdigest()
|
||||
|
||||
|
||||
def build_rac_client_group_session(session_key: str) -> str:
|
||||
"""
|
||||
A group for all connections in a given authentik session ID
|
||||
A disconnect message is sent to this group when the session expires/is deleted
|
||||
"""
|
||||
return sha256(f"{connection.schema_name}/group_rac_client_{session_key}".encode()).hexdigest()
|
||||
|
||||
|
||||
def build_rac_client_group_token(token: str) -> str:
|
||||
"""
|
||||
A group for all connections with a specific token, which in almost all cases
|
||||
is just one connection, however this is used to disconnect the connection
|
||||
when the token is deleted
|
||||
"""
|
||||
return sha256(f"{connection.schema_name}/group_rac_token_{token}".encode()).hexdigest()
|
||||
|
||||
# Global broadcast group, which messages are sent to when the outpost connects back
|
||||
# to authentik for a specific connection
|
||||
# The `RACClientConsumer` consumer adds itself to this group on connection,
|
||||
# and removes itself once it has been assigned a specific outpost channel
|
||||
RAC_CLIENT_GROUP = "group_rac_client"
|
||||
# A group for all connections in a given authentik session ID
|
||||
# A disconnect message is sent to this group when the session expires/is deleted
|
||||
RAC_CLIENT_GROUP_SESSION = "group_rac_client_%(session)s"
|
||||
# A group for all connections with a specific token, which in almost all cases
|
||||
# is just one connection, however this is used to disconnect the connection
|
||||
# when the token is deleted
|
||||
RAC_CLIENT_GROUP_TOKEN = "group_rac_token_%(token)s" # nosec
|
||||
|
||||
# Step 1: Client connects to this websocket endpoint
|
||||
# Step 2: We prepare all the connection args for Guac
|
||||
@@ -63,23 +45,22 @@ class RACClientConsumer(AsyncWebsocketConsumer):
|
||||
async def connect(self):
|
||||
self.logger = get_logger()
|
||||
await self.accept("guacamole")
|
||||
await self.channel_layer.group_add(build_rac_client_group(), self.channel_name)
|
||||
await self.channel_layer.group_add(RAC_CLIENT_GROUP, self.channel_name)
|
||||
await self.channel_layer.group_add(
|
||||
build_rac_client_group_session(self.scope["session"].session_key),
|
||||
RAC_CLIENT_GROUP_SESSION % {"session": self.scope["session"].session_key},
|
||||
self.channel_name,
|
||||
)
|
||||
await self.init_outpost_connection()
|
||||
|
||||
async def disconnect(self, code):
|
||||
self.logger.debug("Disconnecting")
|
||||
if self.dest_channel_id:
|
||||
# Tell the outpost we're disconnecting
|
||||
await self.channel_layer.send(
|
||||
self.dest_channel_id,
|
||||
{
|
||||
"type": "event.disconnect",
|
||||
},
|
||||
)
|
||||
# Tell the outpost we're disconnecting
|
||||
await self.channel_layer.send(
|
||||
self.dest_channel_id,
|
||||
{
|
||||
"type": "event.disconnect",
|
||||
},
|
||||
)
|
||||
|
||||
@database_sync_to_async
|
||||
def init_outpost_connection(self):
|
||||
@@ -128,8 +109,10 @@ class RACClientConsumer(AsyncWebsocketConsumer):
|
||||
if len(states) < 1:
|
||||
continue
|
||||
self.logger.debug("Sending out connection broadcast")
|
||||
group = build_outpost_group_instance(outpost.pk, states[0].uid)
|
||||
async_to_sync(self.channel_layer.group_send)(group, msg)
|
||||
async_to_sync(self.channel_layer.group_send)(
|
||||
OUTPOST_GROUP_INSTANCE % {"outpost_pk": str(outpost.pk), "instance": states[0].uid},
|
||||
msg,
|
||||
)
|
||||
if self.provider and self.provider.delete_token_on_disconnect:
|
||||
self.logger.info("Deleting connection token to prevent reconnect", token=self.token)
|
||||
self.token.delete()
|
||||
@@ -174,7 +157,7 @@ class RACClientConsumer(AsyncWebsocketConsumer):
|
||||
self.dest_channel_id = outpost_channel
|
||||
# Since we have a specific outpost channel now, we can remove
|
||||
# ourselves from the global broadcast group
|
||||
await self.channel_layer.group_discard(build_rac_client_group(), self.channel_name)
|
||||
await self.channel_layer.group_discard(RAC_CLIENT_GROUP, self.channel_name)
|
||||
|
||||
async def event_send(self, event: dict):
|
||||
"""Handler called by outpost websocket that sends data to this specific
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
from channels.exceptions import ChannelFull
|
||||
from channels.generic.websocket import AsyncWebsocketConsumer
|
||||
|
||||
from authentik.providers.rac.consumer_client import build_rac_client_group
|
||||
from authentik.providers.rac.consumer_client import RAC_CLIENT_GROUP
|
||||
|
||||
|
||||
class RACOutpostConsumer(AsyncWebsocketConsumer):
|
||||
@@ -15,7 +15,7 @@ class RACOutpostConsumer(AsyncWebsocketConsumer):
|
||||
self.dest_channel_id = self.scope["url_route"]["kwargs"]["channel"]
|
||||
await self.accept()
|
||||
await self.channel_layer.group_send(
|
||||
build_rac_client_group(),
|
||||
RAC_CLIENT_GROUP,
|
||||
{
|
||||
"type": "event.outpost.connected",
|
||||
"outpost_channel": self.channel_name,
|
||||
|
||||
@@ -9,8 +9,8 @@ from django.dispatch import receiver
|
||||
from authentik.core.models import AuthenticatedSession
|
||||
from authentik.providers.rac.api.endpoints import user_endpoint_cache_key
|
||||
from authentik.providers.rac.consumer_client import (
|
||||
build_rac_client_group_session,
|
||||
build_rac_client_group_token,
|
||||
RAC_CLIENT_GROUP_SESSION,
|
||||
RAC_CLIENT_GROUP_TOKEN,
|
||||
)
|
||||
from authentik.providers.rac.models import ConnectionToken, Endpoint
|
||||
|
||||
@@ -19,7 +19,10 @@ from authentik.providers.rac.models import ConnectionToken, Endpoint
|
||||
def user_session_deleted(sender, instance: AuthenticatedSession, **_):
|
||||
layer = get_channel_layer()
|
||||
async_to_sync(layer.group_send)(
|
||||
build_rac_client_group_session(instance.session.session_key),
|
||||
RAC_CLIENT_GROUP_SESSION
|
||||
% {
|
||||
"session": instance.session.session_key,
|
||||
},
|
||||
{"type": "event.disconnect", "reason": "session_logout"},
|
||||
)
|
||||
|
||||
@@ -29,7 +32,10 @@ def pre_delete_connection_token_disconnect(sender, instance: ConnectionToken, **
|
||||
"""Disconnect session when connection token is deleted"""
|
||||
layer = get_channel_layer()
|
||||
async_to_sync(layer.group_send)(
|
||||
build_rac_client_group_token(instance.token),
|
||||
RAC_CLIENT_GROUP_TOKEN
|
||||
% {
|
||||
"token": instance.token,
|
||||
},
|
||||
{"type": "event.disconnect", "reason": "token_delete"},
|
||||
)
|
||||
|
||||
|
||||
@@ -12,7 +12,6 @@ from authentik.providers.rac.consumer_outpost import RACOutpostConsumer
|
||||
from authentik.providers.rac.views import RACInterface, RACStartView
|
||||
from authentik.root.asgi_middleware import AuthMiddlewareStack
|
||||
from authentik.root.middleware import ChannelsLoggingMiddleware
|
||||
from authentik.tenants.channels import TenantsAwareMiddleware
|
||||
|
||||
urlpatterns = [
|
||||
path(
|
||||
@@ -30,15 +29,11 @@ urlpatterns = [
|
||||
websocket_urlpatterns = [
|
||||
path(
|
||||
"ws/rac/<str:token>/",
|
||||
ChannelsLoggingMiddleware(
|
||||
TenantsAwareMiddleware(AuthMiddlewareStack(RACClientConsumer.as_asgi()))
|
||||
),
|
||||
ChannelsLoggingMiddleware(AuthMiddlewareStack(RACClientConsumer.as_asgi())),
|
||||
),
|
||||
path(
|
||||
"ws/outpost_rac/<str:channel>/",
|
||||
ChannelsLoggingMiddleware(
|
||||
TenantsAwareMiddleware(TokenOutpostMiddleware(RACOutpostConsumer.as_asgi()))
|
||||
),
|
||||
ChannelsLoggingMiddleware(TokenOutpostMiddleware(RACOutpostConsumer.as_asgi())),
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
@@ -1,35 +0,0 @@
|
||||
from typing import Any
|
||||
|
||||
from channels_postgres.core import PostgresChannelLayer as BasePostgresChannelLayer
|
||||
from channels_postgres.db import DatabaseLayer as BaseDatabaseLayer
|
||||
from django.conf import settings
|
||||
from psycopg_pool import AsyncConnectionPool
|
||||
|
||||
from authentik.root.db.base import DatabaseWrapper
|
||||
|
||||
|
||||
class DatabaseLayer(BaseDatabaseLayer):
|
||||
async def get_db_pool(self, db_params: dict[str, Any]) -> AsyncConnectionPool:
|
||||
db_wrapper = DatabaseWrapper(settings.CHANNEL_LAYERS["default"]["CONFIG"])
|
||||
db_params = db_wrapper.get_connection_params()
|
||||
db_params.pop("cursor_factory")
|
||||
db_params.pop("context")
|
||||
return await super().get_db_pool(db_params)
|
||||
|
||||
|
||||
class PostgresChannelLayer(BasePostgresChannelLayer):
|
||||
def __init__(self, *args, **kwargs) -> None:
|
||||
super().__init__(*args, **kwargs)
|
||||
self.django_db = DatabaseLayer(self.django_db.psycopg_options, self.db_params)
|
||||
|
||||
@property
|
||||
def db_params(self):
|
||||
db_wrapper = DatabaseWrapper(settings.CHANNEL_LAYERS["default"]["CONFIG"])
|
||||
db_params = db_wrapper.get_connection_params()
|
||||
db_params.pop("cursor_factory")
|
||||
db_params.pop("context")
|
||||
return db_params
|
||||
|
||||
@db_params.setter
|
||||
def db_params(self, value):
|
||||
pass
|
||||
@@ -292,7 +292,7 @@ class ChannelsLoggingMiddleware:
|
||||
except DenyConnection:
|
||||
return await send({"type": "websocket.close"})
|
||||
except Exception as exc:
|
||||
if settings.DEBUG or settings.TEST:
|
||||
if settings.DEBUG:
|
||||
raise exc
|
||||
LOGGER.warning("Exception in ASGI application", exc=exc)
|
||||
return await send({"type": "websocket.close"})
|
||||
|
||||
@@ -11,6 +11,8 @@ from django.dispatch import Signal
|
||||
from django.http import HttpRequest, HttpResponse
|
||||
from django.views import View
|
||||
from django_prometheus.exports import ExportToDjangoView
|
||||
from django_redis import get_redis_connection
|
||||
from redis.exceptions import RedisError
|
||||
|
||||
monitoring_set = Signal()
|
||||
|
||||
@@ -42,17 +44,19 @@ class LiveView(View):
|
||||
|
||||
|
||||
class ReadyView(View):
|
||||
"""View for readiness probe, always returns Http 200, unless sql is down"""
|
||||
|
||||
def check_db(self):
|
||||
for db_conn in connections.all():
|
||||
# Force connection reload
|
||||
db_conn.connect()
|
||||
_ = db_conn.cursor()
|
||||
"""View for readiness probe, always returns Http 200, unless sql or redis is down"""
|
||||
|
||||
def dispatch(self, request: HttpRequest) -> HttpResponse:
|
||||
try:
|
||||
self.check_db()
|
||||
for db_conn in connections.all():
|
||||
# Force connection reload
|
||||
db_conn.connect()
|
||||
_ = db_conn.cursor()
|
||||
except OperationalError: # pragma: no cover
|
||||
return HttpResponse(status=503)
|
||||
try:
|
||||
redis_conn = get_redis_connection()
|
||||
redis_conn.ping()
|
||||
except RedisError: # pragma: no cover
|
||||
return HttpResponse(status=503)
|
||||
return HttpResponse(status=200)
|
||||
|
||||
@@ -10,7 +10,7 @@ from sentry_sdk import set_tag
|
||||
from xmlsec import enable_debug_trace
|
||||
|
||||
from authentik import authentik_version
|
||||
from authentik.lib.config import CONFIG, django_db_config
|
||||
from authentik.lib.config import CONFIG, django_db_config, redis_url
|
||||
from authentik.lib.logging import get_logger_config, structlog_configure
|
||||
from authentik.lib.sentry import sentry_init
|
||||
from authentik.lib.utils.reflection import get_env
|
||||
@@ -64,7 +64,6 @@ SHARED_APPS = [
|
||||
"pgactivity",
|
||||
"pglock",
|
||||
"channels",
|
||||
"channels_postgres",
|
||||
"django_dramatiq_postgres",
|
||||
"authentik.tasks",
|
||||
]
|
||||
@@ -73,7 +72,6 @@ TENANT_APPS = [
|
||||
"django.contrib.contenttypes",
|
||||
"django.contrib.sessions",
|
||||
"pgtrigger",
|
||||
"django_postgres_cache",
|
||||
"authentik.admin",
|
||||
"authentik.api",
|
||||
"authentik.core",
|
||||
@@ -105,7 +103,6 @@ TENANT_APPS = [
|
||||
"authentik.sources.plex",
|
||||
"authentik.sources.saml",
|
||||
"authentik.sources.scim",
|
||||
"authentik.sources.telegram",
|
||||
"authentik.stages.authenticator",
|
||||
"authentik.stages.authenticator_duo",
|
||||
"authentik.stages.authenticator_email",
|
||||
@@ -228,11 +225,20 @@ REST_FRAMEWORK = {
|
||||
|
||||
CACHES = {
|
||||
"default": {
|
||||
"BACKEND": "django_postgres_cache.backend.DatabaseCache",
|
||||
"BACKEND": "django_redis.cache.RedisCache",
|
||||
"LOCATION": CONFIG.get("cache.url") or redis_url(CONFIG.get("redis.db")),
|
||||
"TIMEOUT": CONFIG.get_int("cache.timeout", 300),
|
||||
"OPTIONS": {
|
||||
"CLIENT_CLASS": "django_redis.client.DefaultClient",
|
||||
},
|
||||
"KEY_PREFIX": "authentik_cache",
|
||||
"KEY_FUNCTION": "django_tenants.cache.make_key",
|
||||
"REVERSE_KEY_FUNCTION": "django_tenants.cache.reverse_key",
|
||||
}
|
||||
}
|
||||
DJANGO_REDIS_SCAN_ITERSIZE = 1000
|
||||
DJANGO_REDIS_IGNORE_EXCEPTIONS = True
|
||||
DJANGO_REDIS_LOG_IGNORED_EXCEPTIONS = True
|
||||
SESSION_ENGINE = "authentik.core.sessions"
|
||||
# Configured via custom SessionMiddleware
|
||||
# SESSION_COOKIE_SAMESITE = "None"
|
||||
@@ -290,6 +296,16 @@ TEMPLATES = [
|
||||
|
||||
ASGI_APPLICATION = "authentik.root.asgi.application"
|
||||
|
||||
CHANNEL_LAYERS = {
|
||||
"default": {
|
||||
"BACKEND": "channels_redis.pubsub.RedisPubSubChannelLayer",
|
||||
"CONFIG": {
|
||||
"hosts": [CONFIG.get("channel.url") or redis_url(CONFIG.get("redis.db"))],
|
||||
"prefix": "authentik_channels_",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
# Database
|
||||
# https://docs.djangoproject.com/en/2.1/ref/settings/#databases
|
||||
@@ -302,16 +318,6 @@ DATABASE_ROUTERS = (
|
||||
"django_tenants.routers.TenantSyncRouter",
|
||||
)
|
||||
|
||||
CHANNEL_LAYERS = {
|
||||
"default": {
|
||||
"BACKEND": "authentik.root.channels.PostgresChannelLayer",
|
||||
"CONFIG": {
|
||||
**DATABASES["default"],
|
||||
"TIME_ZONE": None,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
# Email
|
||||
# These values should never actually be used, emails are only sent from email stages, which
|
||||
# loads the config directly from CONFIG
|
||||
@@ -393,6 +399,8 @@ DRAMATIQ = {
|
||||
).total_seconds(),
|
||||
"middlewares": (
|
||||
("django_dramatiq_postgres.middleware.FullyQualifiedActorName", {}),
|
||||
# TODO: fixme
|
||||
# ("dramatiq.middleware.prometheus.Prometheus", {}),
|
||||
("django_dramatiq_postgres.middleware.DbConnectionMiddleware", {}),
|
||||
("dramatiq.middleware.age_limit.AgeLimit", {}),
|
||||
(
|
||||
@@ -415,7 +423,7 @@ DRAMATIQ = {
|
||||
},
|
||||
),
|
||||
("dramatiq.results.middleware.Results", {"store_results": True}),
|
||||
("authentik.tasks.middleware.CurrentTask", {}),
|
||||
("django_dramatiq_postgres.middleware.CurrentTask", {}),
|
||||
("authentik.tasks.middleware.TenantMiddleware", {}),
|
||||
("authentik.tasks.middleware.RelObjMiddleware", {}),
|
||||
("authentik.tasks.middleware.MessagesMiddleware", {}),
|
||||
|
||||
@@ -62,11 +62,6 @@ class PytestTestRunner(DiscoverRunner): # pragma: no cover
|
||||
"""Configure test environment settings"""
|
||||
settings.TEST = True
|
||||
settings.DRAMATIQ["test"] = True
|
||||
settings.CHANNEL_LAYERS["default"]["CONFIG"] = {
|
||||
**settings.DATABASES["default"],
|
||||
**settings.DATABASES["default"]["TEST"],
|
||||
"TIME_ZONE": None,
|
||||
}
|
||||
|
||||
# Test-specific configuration
|
||||
test_config = {
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
|
||||
from django.core.cache import cache
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
from django_dramatiq_postgres.middleware import CurrentTask
|
||||
from dramatiq.actor import actor
|
||||
from structlog.stdlib import get_logger
|
||||
|
||||
@@ -9,7 +10,7 @@ from authentik.lib.config import CONFIG
|
||||
from authentik.lib.sync.outgoing.exceptions import StopSync
|
||||
from authentik.sources.kerberos.models import KerberosSource
|
||||
from authentik.sources.kerberos.sync import KerberosSync
|
||||
from authentik.tasks.middleware import CurrentTask
|
||||
from authentik.tasks.models import Task
|
||||
|
||||
LOGGER = get_logger()
|
||||
CACHE_KEY_STATUS = "goauthentik.io/sources/kerberos/status/"
|
||||
@@ -32,7 +33,7 @@ def kerberos_connectivity_check(pk: str):
|
||||
description=_("Sync Kerberos source."),
|
||||
)
|
||||
def kerberos_sync(pk: str):
|
||||
self = CurrentTask.get_task()
|
||||
self: Task = CurrentTask.get_task()
|
||||
source: KerberosSource = KerberosSource.objects.filter(enabled=True, pk=pk).first()
|
||||
if not source:
|
||||
return
|
||||
|
||||
@@ -4,6 +4,7 @@ from uuid import uuid4
|
||||
|
||||
from django.core.cache import cache
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
from django_dramatiq_postgres.middleware import CurrentTask
|
||||
from dramatiq.actor import actor
|
||||
from dramatiq.composition import group
|
||||
from dramatiq.message import Message
|
||||
@@ -20,7 +21,6 @@ from authentik.sources.ldap.sync.forward_delete_users import UserLDAPForwardDele
|
||||
from authentik.sources.ldap.sync.groups import GroupLDAPSynchronizer
|
||||
from authentik.sources.ldap.sync.membership import MembershipLDAPSynchronizer
|
||||
from authentik.sources.ldap.sync.users import UserLDAPSynchronizer
|
||||
from authentik.tasks.middleware import CurrentTask
|
||||
from authentik.tasks.models import Task
|
||||
|
||||
LOGGER = get_logger()
|
||||
@@ -53,7 +53,7 @@ def ldap_connectivity_check(pk: str | None = None):
|
||||
)
|
||||
def ldap_sync(source_pk: str):
|
||||
"""Sync a single source"""
|
||||
task = CurrentTask.get_task()
|
||||
task: Task = CurrentTask.get_task()
|
||||
source: LDAPSource = LDAPSource.objects.filter(pk=source_pk, enabled=True).first()
|
||||
if not source:
|
||||
return
|
||||
@@ -127,7 +127,7 @@ def ldap_sync_paginator(
|
||||
)
|
||||
def ldap_sync_page(source_pk: str, sync_class: str, page_cache_key: str):
|
||||
"""Synchronization of an LDAP Source"""
|
||||
self = CurrentTask.get_task()
|
||||
self: Task = CurrentTask.get_task()
|
||||
source: LDAPSource = LDAPSource.objects.filter(pk=source_pk).first()
|
||||
if not source:
|
||||
# Because the source couldn't be found, we don't have a UID
|
||||
|
||||
@@ -3,13 +3,14 @@
|
||||
from json import dumps
|
||||
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
from django_dramatiq_postgres.middleware import CurrentTask
|
||||
from dramatiq.actor import actor
|
||||
from requests import RequestException
|
||||
from structlog.stdlib import get_logger
|
||||
|
||||
from authentik.lib.utils.http import get_http_session
|
||||
from authentik.sources.oauth.models import OAuthSource
|
||||
from authentik.tasks.middleware import CurrentTask
|
||||
from authentik.tasks.models import Task
|
||||
|
||||
LOGGER = get_logger()
|
||||
|
||||
@@ -20,7 +21,7 @@ LOGGER = get_logger()
|
||||
)
|
||||
)
|
||||
def update_well_known_jwks():
|
||||
self = CurrentTask.get_task()
|
||||
self: Task = CurrentTask.get_task()
|
||||
session = get_http_session()
|
||||
for source in OAuthSource.objects.all().exclude(oidc_well_known_url=""):
|
||||
try:
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
"""Plex tasks"""
|
||||
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
from django_dramatiq_postgres.middleware import CurrentTask
|
||||
from dramatiq.actor import actor
|
||||
from requests import RequestException
|
||||
|
||||
@@ -8,13 +9,13 @@ from authentik.events.models import Event, EventAction
|
||||
from authentik.lib.utils.errors import exception_to_string
|
||||
from authentik.sources.plex.models import PlexSource
|
||||
from authentik.sources.plex.plex import PlexAuth
|
||||
from authentik.tasks.middleware import CurrentTask
|
||||
from authentik.tasks.models import Task
|
||||
|
||||
|
||||
@actor(description=_("Check the validity of a Plex source."))
|
||||
def check_plex_token(source_pk: str):
|
||||
"""Check the validity of a Plex source."""
|
||||
self = CurrentTask.get_task()
|
||||
self: Task = CurrentTask.get_task()
|
||||
sources = PlexSource.objects.filter(pk=source_pk)
|
||||
if not sources.exists():
|
||||
return
|
||||
|
||||
@@ -1,31 +0,0 @@
|
||||
"""Telegram source property mappings API"""
|
||||
|
||||
from rest_framework.viewsets import ModelViewSet
|
||||
|
||||
from authentik.core.api.property_mappings import PropertyMappingFilterSet, PropertyMappingSerializer
|
||||
from authentik.core.api.used_by import UsedByMixin
|
||||
from authentik.sources.telegram.models import TelegramSourcePropertyMapping
|
||||
|
||||
|
||||
class TelegramSourcePropertyMappingSerializer(PropertyMappingSerializer):
|
||||
"""TelegramSourcePropertyMapping Serializer"""
|
||||
|
||||
class Meta(PropertyMappingSerializer.Meta):
|
||||
model = TelegramSourcePropertyMapping
|
||||
|
||||
|
||||
class TelegramSourcePropertyMappingFilter(PropertyMappingFilterSet):
|
||||
"""Filter for TelegramSourcePropertyMapping"""
|
||||
|
||||
class Meta(PropertyMappingFilterSet.Meta):
|
||||
model = TelegramSourcePropertyMapping
|
||||
|
||||
|
||||
class TelegramSourcePropertyMappingViewSet(UsedByMixin, ModelViewSet):
|
||||
"""TelegramSourcePropertyMapping Viewset"""
|
||||
|
||||
queryset = TelegramSourcePropertyMapping.objects.all()
|
||||
serializer_class = TelegramSourcePropertyMappingSerializer
|
||||
filterset_class = TelegramSourcePropertyMappingFilter
|
||||
search_fields = ["name"]
|
||||
ordering = ["name"]
|
||||
@@ -1,41 +0,0 @@
|
||||
from rest_framework.viewsets import ModelViewSet
|
||||
|
||||
from authentik.core.api.sources import SourceSerializer
|
||||
from authentik.core.api.used_by import UsedByMixin
|
||||
from authentik.sources.telegram.models import TelegramSource
|
||||
|
||||
|
||||
class TelegramSourceSerializer(SourceSerializer):
|
||||
class Meta:
|
||||
model = TelegramSource
|
||||
fields = SourceSerializer.Meta.fields + [
|
||||
"bot_username",
|
||||
"bot_token",
|
||||
"request_message_access",
|
||||
"pre_authentication_flow",
|
||||
]
|
||||
extra_kwargs = {
|
||||
"bot_token": {"write_only": True},
|
||||
}
|
||||
|
||||
|
||||
class TelegramSourceViewSet(UsedByMixin, ModelViewSet):
|
||||
queryset = TelegramSource.objects.all()
|
||||
serializer_class = TelegramSourceSerializer
|
||||
lookup_field = "slug"
|
||||
|
||||
filterset_fields = [
|
||||
"pbm_uuid",
|
||||
"name",
|
||||
"slug",
|
||||
"enabled",
|
||||
"authentication_flow",
|
||||
"enrollment_flow",
|
||||
"policy_engine_mode",
|
||||
"user_matching_mode",
|
||||
"group_matching_mode",
|
||||
"bot_username",
|
||||
"request_message_access",
|
||||
]
|
||||
search_fields = ["name", "slug"]
|
||||
ordering = ["name"]
|
||||
@@ -1,33 +0,0 @@
|
||||
from rest_framework.viewsets import ModelViewSet
|
||||
|
||||
from authentik.core.api.sources import (
|
||||
GroupSourceConnectionSerializer,
|
||||
GroupSourceConnectionViewSet,
|
||||
UserSourceConnectionSerializer,
|
||||
UserSourceConnectionViewSet,
|
||||
)
|
||||
from authentik.sources.telegram.models import (
|
||||
GroupTelegramSourceConnection,
|
||||
UserTelegramSourceConnection,
|
||||
)
|
||||
|
||||
|
||||
class UserTelegramSourceConnectionSerializer(UserSourceConnectionSerializer):
|
||||
class Meta(UserSourceConnectionSerializer.Meta):
|
||||
model = UserTelegramSourceConnection
|
||||
fields = UserSourceConnectionSerializer.Meta.fields
|
||||
|
||||
|
||||
class UserTelegramSourceConnectionViewSet(UserSourceConnectionViewSet, ModelViewSet):
|
||||
queryset = UserTelegramSourceConnection.objects.all()
|
||||
serializer_class = UserTelegramSourceConnectionSerializer
|
||||
|
||||
|
||||
class GroupTelegramSourceConnectionSerializer(GroupSourceConnectionSerializer):
|
||||
class Meta(GroupSourceConnectionSerializer.Meta):
|
||||
model = GroupTelegramSourceConnection
|
||||
|
||||
|
||||
class GroupTelegramSourceConnectionViewSet(GroupSourceConnectionViewSet, ModelViewSet):
|
||||
queryset = GroupTelegramSourceConnection.objects.all()
|
||||
serializer_class = GroupTelegramSourceConnectionSerializer
|
||||
@@ -1,9 +0,0 @@
|
||||
from authentik.blueprints.apps import ManagedAppConfig
|
||||
|
||||
|
||||
class TelegramConfig(ManagedAppConfig):
|
||||
name = "authentik.sources.telegram"
|
||||
label = "authentik_sources_telegram"
|
||||
verbose_name = "authentik Sources.Telegram"
|
||||
mountpoint = "source/telegram/"
|
||||
default = True
|
||||
@@ -1,118 +0,0 @@
|
||||
# Generated by Django 5.1.12 on 2025-09-24 07:14
|
||||
|
||||
import django.db.models.deletion
|
||||
from django.db import migrations, models
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
|
||||
initial = True
|
||||
|
||||
dependencies = [
|
||||
("authentik_core", "0050_user_last_updated_and_more"),
|
||||
("authentik_flows", "0028_flowtoken_revoke_on_execution"),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.CreateModel(
|
||||
name="GroupTelegramSourceConnection",
|
||||
fields=[
|
||||
(
|
||||
"groupsourceconnection_ptr",
|
||||
models.OneToOneField(
|
||||
auto_created=True,
|
||||
on_delete=django.db.models.deletion.CASCADE,
|
||||
parent_link=True,
|
||||
primary_key=True,
|
||||
serialize=False,
|
||||
to="authentik_core.groupsourceconnection",
|
||||
),
|
||||
),
|
||||
],
|
||||
options={
|
||||
"verbose_name": "Group Telegram Source Connection",
|
||||
"verbose_name_plural": "Group Telegram Source Connections",
|
||||
},
|
||||
bases=("authentik_core.groupsourceconnection",),
|
||||
),
|
||||
migrations.CreateModel(
|
||||
name="TelegramSourcePropertyMapping",
|
||||
fields=[
|
||||
(
|
||||
"propertymapping_ptr",
|
||||
models.OneToOneField(
|
||||
auto_created=True,
|
||||
on_delete=django.db.models.deletion.CASCADE,
|
||||
parent_link=True,
|
||||
primary_key=True,
|
||||
serialize=False,
|
||||
to="authentik_core.propertymapping",
|
||||
),
|
||||
),
|
||||
],
|
||||
options={
|
||||
"verbose_name": "Telegram Source Property Mapping",
|
||||
"verbose_name_plural": "Telegram Source Property Mappings",
|
||||
},
|
||||
bases=("authentik_core.propertymapping",),
|
||||
),
|
||||
migrations.CreateModel(
|
||||
name="UserTelegramSourceConnection",
|
||||
fields=[
|
||||
(
|
||||
"usersourceconnection_ptr",
|
||||
models.OneToOneField(
|
||||
auto_created=True,
|
||||
on_delete=django.db.models.deletion.CASCADE,
|
||||
parent_link=True,
|
||||
primary_key=True,
|
||||
serialize=False,
|
||||
to="authentik_core.usersourceconnection",
|
||||
),
|
||||
),
|
||||
],
|
||||
options={
|
||||
"verbose_name": "User Telegram Source Connection",
|
||||
"verbose_name_plural": "User Telegram Source Connections",
|
||||
},
|
||||
bases=("authentik_core.usersourceconnection",),
|
||||
),
|
||||
migrations.CreateModel(
|
||||
name="TelegramSource",
|
||||
fields=[
|
||||
(
|
||||
"source_ptr",
|
||||
models.OneToOneField(
|
||||
auto_created=True,
|
||||
on_delete=django.db.models.deletion.CASCADE,
|
||||
parent_link=True,
|
||||
primary_key=True,
|
||||
serialize=False,
|
||||
to="authentik_core.source",
|
||||
),
|
||||
),
|
||||
("bot_username", models.TextField(help_text="Telegram bot username")),
|
||||
("bot_token", models.TextField(help_text="Telegram bot token")),
|
||||
(
|
||||
"request_message_access",
|
||||
models.BooleanField(
|
||||
default=False, help_text="Request access to send messages from your bot."
|
||||
),
|
||||
),
|
||||
(
|
||||
"pre_authentication_flow",
|
||||
models.ForeignKey(
|
||||
help_text="Flow used before authentication.",
|
||||
on_delete=django.db.models.deletion.CASCADE,
|
||||
related_name="telegram_source_pre_authentication",
|
||||
to="authentik_flows.flow",
|
||||
),
|
||||
),
|
||||
],
|
||||
options={
|
||||
"verbose_name": "Telegram Source",
|
||||
"verbose_name_plural": "Telegram Sources",
|
||||
},
|
||||
bases=("authentik_core.source",),
|
||||
),
|
||||
]
|
||||
@@ -1,156 +0,0 @@
|
||||
"""Telegram source"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from django.db import models
|
||||
from django.http import HttpRequest
|
||||
from django.templatetags.static import static
|
||||
from django.urls import reverse
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
from rest_framework.serializers import BaseSerializer, Serializer
|
||||
|
||||
from authentik.core.models import (
|
||||
GroupSourceConnection,
|
||||
PropertyMapping,
|
||||
Source,
|
||||
UserSourceConnection,
|
||||
)
|
||||
from authentik.core.types import UILoginButton, UserSettingSerializer
|
||||
from authentik.flows.challenge import RedirectChallenge
|
||||
from authentik.flows.models import Flow
|
||||
|
||||
|
||||
class TelegramSource(Source):
|
||||
"""Log in with Telegram."""
|
||||
|
||||
bot_username = models.TextField(help_text=_("Telegram bot username"))
|
||||
bot_token = models.TextField(help_text=_("Telegram bot token"))
|
||||
|
||||
request_message_access = models.BooleanField(
|
||||
default=False, help_text=_("Request access to send messages from your bot.")
|
||||
)
|
||||
|
||||
pre_authentication_flow = models.ForeignKey(
|
||||
Flow,
|
||||
on_delete=models.CASCADE,
|
||||
help_text=_("Flow used before authentication."),
|
||||
related_name="telegram_source_pre_authentication",
|
||||
)
|
||||
|
||||
@property
|
||||
def component(self) -> str:
|
||||
return "ak-source-telegram-form"
|
||||
|
||||
@property
|
||||
def icon_url(self) -> str | None:
|
||||
icon = super().icon_url
|
||||
if not icon:
|
||||
icon = static("authentik/sources/telegram.svg")
|
||||
return icon
|
||||
|
||||
@property
|
||||
def serializer(self) -> type[BaseSerializer]:
|
||||
from authentik.sources.telegram.api.source import TelegramSourceSerializer
|
||||
|
||||
return TelegramSourceSerializer
|
||||
|
||||
def ui_login_button(self, request: HttpRequest) -> UILoginButton:
|
||||
return UILoginButton(
|
||||
challenge=RedirectChallenge(
|
||||
data={
|
||||
"to": reverse(
|
||||
"authentik_sources_telegram:start",
|
||||
kwargs={"source_slug": self.slug},
|
||||
),
|
||||
}
|
||||
),
|
||||
name=self.name,
|
||||
icon_url=self.icon_url,
|
||||
)
|
||||
|
||||
def ui_user_settings(self) -> UserSettingSerializer | None:
|
||||
return UserSettingSerializer(
|
||||
data={
|
||||
"title": self.name,
|
||||
"component": "ak-user-settings-source-telegram",
|
||||
"icon_url": self.icon_url,
|
||||
}
|
||||
)
|
||||
|
||||
@property
|
||||
def property_mapping_type(self) -> "type[PropertyMapping]":
|
||||
return TelegramSourcePropertyMapping
|
||||
|
||||
def get_base_user_properties(
|
||||
self, info: dict[str, Any] | None = None, **kwargs
|
||||
) -> dict[str, Any | dict[str, Any]]:
|
||||
info = info or {}
|
||||
name = info.get("first_name", "")
|
||||
if "last_name" in info:
|
||||
name += " " + info["last_name"]
|
||||
return {
|
||||
"username": info.get("username", None),
|
||||
"email": None,
|
||||
"name": name if name else None,
|
||||
}
|
||||
|
||||
def get_base_group_properties(self, group_id: str, **kwargs):
|
||||
return {
|
||||
"name": group_id,
|
||||
}
|
||||
|
||||
class Meta:
|
||||
verbose_name = _("Telegram Source")
|
||||
verbose_name_plural = _("Telegram Sources")
|
||||
|
||||
|
||||
class TelegramSourcePropertyMapping(PropertyMapping):
|
||||
"""Map Telegram properties to User or Group object attributes"""
|
||||
|
||||
@property
|
||||
def component(self) -> str:
|
||||
return "ak-property-mapping-source-telegram-form"
|
||||
|
||||
@property
|
||||
def serializer(self) -> type[Serializer]:
|
||||
from authentik.sources.telegram.api.property_mappings import (
|
||||
TelegramSourcePropertyMappingSerializer,
|
||||
)
|
||||
|
||||
return TelegramSourcePropertyMappingSerializer
|
||||
|
||||
class Meta:
|
||||
verbose_name = _("Telegram Source Property Mapping")
|
||||
verbose_name_plural = _("Telegram Source Property Mappings")
|
||||
|
||||
|
||||
class UserTelegramSourceConnection(UserSourceConnection):
|
||||
"""Connect user and Telegram source"""
|
||||
|
||||
@property
|
||||
def serializer(self) -> type[Serializer]:
|
||||
from authentik.sources.telegram.api.source_connection import (
|
||||
UserTelegramSourceConnectionSerializer,
|
||||
)
|
||||
|
||||
return UserTelegramSourceConnectionSerializer
|
||||
|
||||
class Meta:
|
||||
verbose_name = _("User Telegram Source Connection")
|
||||
verbose_name_plural = _("User Telegram Source Connections")
|
||||
|
||||
|
||||
class GroupTelegramSourceConnection(GroupSourceConnection):
|
||||
"""Group-source connection for Telegram"""
|
||||
|
||||
@property
|
||||
def serializer(self) -> type[Serializer]:
|
||||
from authentik.sources.telegram.api.source_connection import (
|
||||
GroupTelegramSourceConnectionSerializer,
|
||||
)
|
||||
|
||||
return GroupTelegramSourceConnectionSerializer
|
||||
|
||||
class Meta:
|
||||
verbose_name = _("Group Telegram Source Connection")
|
||||
verbose_name_plural = _("Group Telegram Source Connections")
|
||||
@@ -1,48 +0,0 @@
|
||||
import hashlib
|
||||
import hmac
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
from rest_framework.fields import BooleanField, CharField, IntegerField, URLField
|
||||
from rest_framework.serializers import ValidationError
|
||||
|
||||
from authentik.flows.challenge import Challenge, ChallengeResponse
|
||||
from authentik.stages.identification.stage import LoginChallengeMixin
|
||||
|
||||
|
||||
class TelegramLoginChallenge(LoginChallengeMixin, Challenge):
|
||||
component = CharField(default="ak-source-telegram")
|
||||
bot_username = CharField(help_text=_("Telegram bot username"))
|
||||
request_message_access = BooleanField()
|
||||
|
||||
|
||||
class TelegramChallengeResponse(ChallengeResponse):
|
||||
component = CharField(default="ak-source-telegram")
|
||||
|
||||
id = IntegerField()
|
||||
first_name = CharField(max_length=255, required=False)
|
||||
last_name = CharField(max_length=255, required=False)
|
||||
username = CharField(max_length=255, required=False)
|
||||
photo_url = URLField(required=False)
|
||||
auth_date = IntegerField(required=True)
|
||||
hash = CharField(max_length=64, required=True)
|
||||
|
||||
def validate_auth_date(self, auth_date: int) -> int:
|
||||
if datetime.fromtimestamp(auth_date) < datetime.now() - timedelta(minutes=5):
|
||||
raise ValidationError(_("Authentication date is too old"))
|
||||
return auth_date
|
||||
|
||||
def validate(self, attrs: dict) -> dict:
|
||||
# Check the response as defined in https://core.telegram.org/widgets/login
|
||||
attrs_to_check = attrs.copy()
|
||||
attrs_to_check.pop("component")
|
||||
attrs_to_check.pop("hash")
|
||||
check_str = "\n".join([f"{key}={value}" for key, value in sorted(attrs_to_check.items())])
|
||||
digest = hmac.new(
|
||||
hashlib.sha256(self.stage.source.bot_token.encode("utf-8")).digest(),
|
||||
check_str.encode("utf-8"),
|
||||
"sha256",
|
||||
).hexdigest()
|
||||
if not hmac.compare_digest(digest, attrs["hash"]):
|
||||
raise ValidationError(_("Invalid hash"))
|
||||
return attrs
|
||||
@@ -1,185 +0,0 @@
|
||||
"""Telegram source tests"""
|
||||
|
||||
import hashlib
|
||||
import hmac
|
||||
from datetime import datetime, timedelta
|
||||
from unittest.mock import Mock
|
||||
|
||||
from django.test import TestCase
|
||||
from django.urls import reverse
|
||||
from rest_framework.exceptions import ValidationError
|
||||
|
||||
from authentik.core.tests.utils import create_test_flow
|
||||
from authentik.flows.models import FlowDesignation, FlowStageBinding
|
||||
from authentik.flows.tests import FlowTestCase
|
||||
from authentik.sources.telegram.stage import TelegramChallengeResponse
|
||||
from authentik.stages.identification.models import IdentificationStage, UserFields
|
||||
|
||||
|
||||
class MockTelegramResponseMixin:
|
||||
def _add_hash(self, response):
|
||||
to_hash = "\n".join([f"{key}={value}" for key, value in sorted(response.items())])
|
||||
response["hash"] = hmac.new(
|
||||
hashlib.sha256(self.source.bot_token.encode("utf-8")).digest(),
|
||||
to_hash.encode("utf-8"),
|
||||
"sha256",
|
||||
).hexdigest()
|
||||
|
||||
def _make_valid_response(self):
|
||||
resp = {
|
||||
"id": "123456789",
|
||||
"first_name": "Test",
|
||||
"last_name": "User",
|
||||
"username": "testuser",
|
||||
"auth_date": str(int(datetime.now().timestamp())),
|
||||
}
|
||||
self._add_hash(resp)
|
||||
return resp
|
||||
|
||||
def _make_outdated_response(self):
|
||||
resp = self._make_valid_response()
|
||||
resp["auth_date"] = str(int((datetime.now() - timedelta(days=1)).timestamp()))
|
||||
self._add_hash(resp)
|
||||
return resp
|
||||
|
||||
|
||||
class TestTelegramSource(MockTelegramResponseMixin, TestCase):
|
||||
"""Telegram Source tests"""
|
||||
|
||||
def setUp(self):
|
||||
from authentik.sources.telegram.models import TelegramSource
|
||||
|
||||
self.source = TelegramSource.objects.create(
|
||||
name="test",
|
||||
slug="test",
|
||||
bot_username="test_bot",
|
||||
bot_token="modern_token", # nosec
|
||||
request_message_access=True,
|
||||
pre_authentication_flow=create_test_flow(),
|
||||
)
|
||||
self.mock_stage = Mock()
|
||||
self.mock_stage.source = self.source
|
||||
|
||||
def test_ui_login_button(self):
|
||||
"""Test UI login button"""
|
||||
ui_login_button = self.source.ui_login_button(None)
|
||||
self.assertIsNotNone(ui_login_button)
|
||||
self.assertEqual(ui_login_button.name, "test")
|
||||
self.assertTrue(ui_login_button.challenge.is_valid(raise_exception=True))
|
||||
|
||||
def test_challenge_response(self):
|
||||
"""Test correct Telegram response validation"""
|
||||
cr = TelegramChallengeResponse(data=self._make_valid_response())
|
||||
cr.stage = self.mock_stage
|
||||
self.assertTrue(cr.is_valid(raise_exception=True))
|
||||
|
||||
def test_outdated_challenge_response(self):
|
||||
"""Test outdated Telegram response validation"""
|
||||
cr = TelegramChallengeResponse(data=self._make_outdated_response())
|
||||
cr.stage = self.mock_stage
|
||||
with self.assertRaises(ValidationError):
|
||||
cr.is_valid(raise_exception=True)
|
||||
|
||||
def test_invalid_hash_challenge_response(self):
|
||||
"""Test invalid hash in Telegram response validation"""
|
||||
resp = self._make_valid_response()
|
||||
resp["hash"] = "invalid_hash"
|
||||
cr = TelegramChallengeResponse(data=resp)
|
||||
cr.stage = self.mock_stage
|
||||
with self.assertRaises(ValidationError):
|
||||
cr.is_valid(raise_exception=True)
|
||||
|
||||
def test_user_base_properties(self):
|
||||
"""Test user base properties"""
|
||||
cr = TelegramChallengeResponse(data=self._make_valid_response())
|
||||
cr.stage = self.mock_stage
|
||||
cr.is_valid(raise_exception=True)
|
||||
properties = self.source.get_base_user_properties(info=cr.validated_data)
|
||||
self.assertEqual(
|
||||
properties,
|
||||
{
|
||||
"username": "testuser",
|
||||
"name": "Test User",
|
||||
"email": None,
|
||||
},
|
||||
)
|
||||
|
||||
def test_group_base_properties(self):
|
||||
"""Test group base properties"""
|
||||
for group_id in ["group 1", "group 2"]:
|
||||
properties = self.source.get_base_group_properties(group_id=group_id)
|
||||
self.assertEqual(properties, {"name": group_id})
|
||||
|
||||
|
||||
class TestTelegramViews(MockTelegramResponseMixin, FlowTestCase):
|
||||
"""Test Telegram source views"""
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
from authentik.sources.telegram.models import TelegramSource
|
||||
|
||||
self.pre_auth_flow = create_test_flow()
|
||||
|
||||
self.source = TelegramSource.objects.create(
|
||||
name="test",
|
||||
slug="test",
|
||||
bot_username="test_bot",
|
||||
bot_token="modern_token", # nosec
|
||||
request_message_access=True,
|
||||
enrollment_flow=create_test_flow(),
|
||||
pre_authentication_flow=self.pre_auth_flow,
|
||||
)
|
||||
|
||||
self.flow = create_test_flow(FlowDesignation.AUTHENTICATION)
|
||||
self.stage = IdentificationStage.objects.create(
|
||||
name="identification",
|
||||
user_fields=[UserFields.E_MAIL],
|
||||
pretend_user_exists=False,
|
||||
)
|
||||
self.stage.sources.set([self.source])
|
||||
self.stage.save()
|
||||
FlowStageBinding.objects.create(
|
||||
target=self.flow,
|
||||
stage=self.stage,
|
||||
order=0,
|
||||
)
|
||||
|
||||
def _make_initial_request(self):
|
||||
return self.client.get(
|
||||
reverse("authentik_api:flow-executor", kwargs={"flow_slug": self.flow.slug})
|
||||
)
|
||||
|
||||
def _make_start_request(self):
|
||||
return self.client.get(
|
||||
reverse("authentik_sources_telegram:start", kwargs={"source_slug": self.source.slug}),
|
||||
follow=True,
|
||||
)
|
||||
|
||||
def test_start_view(self):
|
||||
"""Test TelegramStartView"""
|
||||
self.assertEqual(self._make_initial_request().status_code, 200)
|
||||
|
||||
response = self._make_start_request()
|
||||
self.assertEqual(response.status_code, 200)
|
||||
self.assertEqual(
|
||||
response.redirect_chain[0][0],
|
||||
reverse("authentik_core:if-flow", kwargs={"flow_slug": self.pre_auth_flow.slug}),
|
||||
)
|
||||
|
||||
def test_challenge_view(self):
|
||||
"""Test TelegramLoginView"""
|
||||
self._make_initial_request()
|
||||
self._make_start_request()
|
||||
url = reverse("authentik_api:flow-executor", kwargs={"flow_slug": self.pre_auth_flow.slug})
|
||||
get_response = self.client.get(url)
|
||||
self.assertEqual(get_response.status_code, 200)
|
||||
form_data = self._make_valid_response()
|
||||
form_data["component"] = "ak-source-telegram"
|
||||
response = self.client.post(url, form_data)
|
||||
self.assertEqual(response.status_code, 200)
|
||||
self.assertStageRedirects(
|
||||
response,
|
||||
reverse(
|
||||
"authentik_core:if-flow", kwargs={"flow_slug": self.source.enrollment_flow.slug}
|
||||
),
|
||||
)
|
||||
@@ -1,23 +0,0 @@
|
||||
"""Telegram source API views"""
|
||||
|
||||
from django.urls import path
|
||||
|
||||
from authentik.sources.telegram.api.property_mappings import TelegramSourcePropertyMappingViewSet
|
||||
from authentik.sources.telegram.api.source import TelegramSourceViewSet
|
||||
from authentik.sources.telegram.api.source_connection import (
|
||||
GroupTelegramSourceConnectionViewSet,
|
||||
UserTelegramSourceConnectionViewSet,
|
||||
)
|
||||
from authentik.sources.telegram.views import TelegramLoginView, TelegramStartView
|
||||
|
||||
urlpatterns = [
|
||||
path("<slug:source_slug>/start/", TelegramStartView.as_view(), name="start"),
|
||||
path("<slug:source_slug>/", TelegramLoginView.as_view(), name="login"),
|
||||
]
|
||||
|
||||
api_urlpatterns = [
|
||||
("propertymappings/source/telegram", TelegramSourcePropertyMappingViewSet),
|
||||
("sources/user_connections/telegram", UserTelegramSourceConnectionViewSet),
|
||||
("sources/group_connections/telegram", GroupTelegramSourceConnectionViewSet),
|
||||
("sources/telegram", TelegramSourceViewSet),
|
||||
]
|
||||
@@ -1,98 +0,0 @@
|
||||
from django.http import Http404, HttpRequest, HttpResponse
|
||||
from django.shortcuts import get_object_or_404
|
||||
from django.views import View
|
||||
|
||||
from authentik.core.sources.flow_manager import SourceFlowManager
|
||||
from authentik.flows.challenge import Challenge
|
||||
from authentik.flows.exceptions import FlowNonApplicableException
|
||||
from authentik.flows.models import in_memory_stage
|
||||
from authentik.flows.planner import (
|
||||
PLAN_CONTEXT_REDIRECT,
|
||||
PLAN_CONTEXT_SOURCE,
|
||||
PLAN_CONTEXT_SSO,
|
||||
FlowPlanner,
|
||||
)
|
||||
from authentik.flows.stage import ChallengeStageView
|
||||
from authentik.flows.views.executor import NEXT_ARG_NAME, SESSION_KEY_GET
|
||||
from authentik.sources.telegram.models import (
|
||||
GroupTelegramSourceConnection,
|
||||
TelegramSource,
|
||||
UserTelegramSourceConnection,
|
||||
)
|
||||
from authentik.sources.telegram.stage import TelegramChallengeResponse, TelegramLoginChallenge
|
||||
|
||||
|
||||
class TelegramStartView(View):
|
||||
def handle_login_flow(
|
||||
self, source: TelegramSource, *stages_to_append, **kwargs
|
||||
) -> HttpResponse:
|
||||
"""Prepare Authentication Plan, redirect user FlowExecutor"""
|
||||
# Ensure redirect is carried through when user was trying to
|
||||
# authorize application
|
||||
final_redirect = self.request.session.get(SESSION_KEY_GET, {}).get(
|
||||
NEXT_ARG_NAME, "authentik_core:if-user"
|
||||
)
|
||||
kwargs.update(
|
||||
{
|
||||
PLAN_CONTEXT_SSO: True,
|
||||
PLAN_CONTEXT_SOURCE: source,
|
||||
PLAN_CONTEXT_REDIRECT: final_redirect,
|
||||
}
|
||||
)
|
||||
# We run the Flow planner here so we can pass the Pending user in the context
|
||||
planner = FlowPlanner(source.pre_authentication_flow)
|
||||
planner.allow_empty_flows = True
|
||||
try:
|
||||
plan = planner.plan(self.request, kwargs)
|
||||
except FlowNonApplicableException:
|
||||
raise Http404 from None
|
||||
for stage in stages_to_append:
|
||||
plan.append_stage(stage)
|
||||
return plan.to_redirect(self.request, source.pre_authentication_flow)
|
||||
|
||||
def get(self, request: HttpRequest, source_slug: str) -> HttpResponse:
|
||||
source = get_object_or_404(TelegramSource, slug=source_slug, enabled=True)
|
||||
telegram_login_stage = in_memory_stage(TelegramLoginView)
|
||||
|
||||
return self.handle_login_flow(source, telegram_login_stage)
|
||||
|
||||
|
||||
class TelegramSourceFlowManager(SourceFlowManager):
|
||||
"""Flow manager for Telegram source"""
|
||||
|
||||
user_connection_type = UserTelegramSourceConnection
|
||||
group_connection_type = GroupTelegramSourceConnection
|
||||
|
||||
|
||||
class TelegramLoginView(ChallengeStageView):
|
||||
|
||||
response_class = TelegramChallengeResponse
|
||||
|
||||
def dispatch(self, request, *args, **kwargs):
|
||||
self.source = self.executor.plan.context[PLAN_CONTEXT_SOURCE]
|
||||
return super().dispatch(request, *args, **kwargs)
|
||||
|
||||
def get_challenge(self, *args, **kwargs) -> Challenge:
|
||||
return TelegramLoginChallenge(
|
||||
data={
|
||||
"bot_username": self.source.bot_username,
|
||||
"request_message_access": self.source.request_message_access,
|
||||
},
|
||||
)
|
||||
|
||||
def challenge_valid(self, response: TelegramChallengeResponse) -> HttpResponse:
|
||||
raw_info = response.validated_data.copy()
|
||||
raw_info.pop("component")
|
||||
raw_info.pop("hash")
|
||||
raw_info.pop("auth_date")
|
||||
source = self.source
|
||||
sfm = TelegramSourceFlowManager(
|
||||
source=source,
|
||||
request=self.request,
|
||||
identifier=raw_info["id"],
|
||||
user_info={"info": raw_info},
|
||||
policy_context={"telegram": raw_info},
|
||||
)
|
||||
return sfm.get_flow(
|
||||
raw_info=raw_info,
|
||||
)
|
||||
File diff suppressed because one or more lines are too long
@@ -7,6 +7,7 @@ from pathlib import Path
|
||||
from django.core.cache import cache
|
||||
from django.db.transaction import atomic
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
from django_dramatiq_postgres.middleware import CurrentTask
|
||||
from dramatiq.actor import actor
|
||||
from fido2.mds3 import filter_revoked, parse_blob
|
||||
|
||||
@@ -14,7 +15,7 @@ from authentik.stages.authenticator_webauthn.models import (
|
||||
UNKNOWN_DEVICE_TYPE_AAGUID,
|
||||
WebAuthnDeviceType,
|
||||
)
|
||||
from authentik.tasks.middleware import CurrentTask
|
||||
from authentik.tasks.models import Task
|
||||
|
||||
CACHE_KEY_MDS_NO = "goauthentik.io/stages/authenticator_webauthn/mds_no"
|
||||
AAGUID_BLOB_PATH = Path(__file__).parent / "mds" / "aaguid.json"
|
||||
@@ -32,7 +33,7 @@ def mds_ca() -> bytes:
|
||||
@actor(description=_("Background task to import FIDO Alliance MDS blob and AAGUIDs into database."))
|
||||
def webauthn_mds_import(force=False):
|
||||
"""Background task to import FIDO Alliance MDS blob and AAGUIDs into database"""
|
||||
self = CurrentTask.get_task()
|
||||
self: Task = CurrentTask.get_task()
|
||||
with open(MDS_BLOB_PATH, mode="rb") as _raw_blob:
|
||||
blob = parse_blob(_raw_blob.read(), mds_ca())
|
||||
to_create_update = [
|
||||
|
||||
@@ -7,6 +7,7 @@ from django.core.mail import EmailMultiAlternatives
|
||||
from django.core.mail.utils import DNS_NAME
|
||||
from django.utils.text import slugify
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
from django_dramatiq_postgres.middleware import CurrentTask
|
||||
from dramatiq.actor import actor
|
||||
from dramatiq.composition import group
|
||||
from structlog.stdlib import get_logger
|
||||
@@ -16,31 +17,27 @@ from authentik.lib.utils.reflection import class_to_path, path_to_class
|
||||
from authentik.stages.authenticator_email.models import AuthenticatorEmailStage
|
||||
from authentik.stages.email.models import EmailStage
|
||||
from authentik.stages.email.utils import logo_data
|
||||
from authentik.tasks.middleware import CurrentTask
|
||||
from authentik.tasks.models import Task
|
||||
|
||||
LOGGER = get_logger()
|
||||
|
||||
|
||||
def send_mails(
|
||||
stage: EmailStage | AuthenticatorEmailStage | None, *messages: list[EmailMultiAlternatives]
|
||||
stage: EmailStage | AuthenticatorEmailStage, *messages: list[EmailMultiAlternatives]
|
||||
):
|
||||
"""Wrapper to convert EmailMessage to dict and send it from worker
|
||||
|
||||
Args:
|
||||
stage: Either an EmailStage or AuthenticatorEmailStage instance,
|
||||
or nothing to use global settings
|
||||
stage: Either an EmailStage or AuthenticatorEmailStage instance
|
||||
messages: List of email messages to send
|
||||
Returns:
|
||||
Dramatiq group promise for the email sending tasks
|
||||
"""
|
||||
tasks = []
|
||||
# Use the class path instead of the class itself for serialization
|
||||
stage_class_path, stage_pk = None, None
|
||||
if stage:
|
||||
stage_class_path = class_to_path(stage.__class__)
|
||||
stage_pk = str(stage.pk)
|
||||
stage_class_path = class_to_path(stage.__class__)
|
||||
for message in messages:
|
||||
tasks.append(send_mail.message(message.__dict__, stage_class_path, stage_pk))
|
||||
tasks.append(send_mail.message(message.__dict__, stage_class_path, str(stage.pk)))
|
||||
return group(tasks).run()
|
||||
|
||||
|
||||
@@ -59,7 +56,7 @@ def send_mail(
|
||||
email_stage_pk: str | None = None,
|
||||
):
|
||||
"""Send Email for Email Stage. Retries are scheduled automatically."""
|
||||
self = CurrentTask.get_task()
|
||||
self: Task = CurrentTask.get_task()
|
||||
message_id = make_msgid(domain=DNS_NAME)
|
||||
self.set_uid(slugify(message_id.replace(".", "_").replace("@", "_")))
|
||||
if not stage_class_path or not email_stage_pk:
|
||||
|
||||
@@ -123,7 +123,7 @@ class UserLoginStageView(ChallengeStageView):
|
||||
def is_known_device(self, user: User):
|
||||
"""Returns `True` if the login happened on a "known" device, by the same user."""
|
||||
client_ip = ClientIPMiddleware.get_client_ip(self.request)
|
||||
if AuthenticatedSession.objects.filter(session__last_ip=client_ip, user=user).exists():
|
||||
if AuthenticatedSession.objects.filter(session__last_ip=client_ip).exists():
|
||||
return True
|
||||
if COOKIE_NAME_KNOWN_DEVICE not in self.request.COOKIES:
|
||||
return False
|
||||
|
||||
@@ -1,22 +1,21 @@
|
||||
import socket
|
||||
from http.server import BaseHTTPRequestHandler
|
||||
from time import sleep
|
||||
from typing import Any, cast
|
||||
from typing import Any
|
||||
|
||||
import pglock
|
||||
from django.db import OperationalError, connections
|
||||
from django.utils.timezone import now
|
||||
from django_dramatiq_postgres.middleware import (
|
||||
CurrentTask as BaseCurrentTask,
|
||||
)
|
||||
from django_dramatiq_postgres.middleware import HTTPServer
|
||||
from django_dramatiq_postgres.middleware import (
|
||||
MetricsMiddleware as BaseMetricsMiddleware,
|
||||
)
|
||||
from django_redis import get_redis_connection
|
||||
from dramatiq.broker import Broker
|
||||
from dramatiq.message import Message
|
||||
from dramatiq.middleware import Middleware
|
||||
from psycopg.errors import Error
|
||||
from redis.exceptions import RedisError
|
||||
from structlog.stdlib import get_logger
|
||||
|
||||
from authentik import authentik_full_version
|
||||
@@ -29,13 +28,7 @@ from authentik.tenants.utils import get_current_tenant
|
||||
|
||||
LOGGER = get_logger()
|
||||
HEALTHCHECK_LOGGER = get_logger("authentik.worker").bind()
|
||||
DB_ERRORS = (OperationalError, Error)
|
||||
|
||||
|
||||
class CurrentTask(BaseCurrentTask):
|
||||
@classmethod
|
||||
def get_task(cls) -> Task:
|
||||
return cast(Task, super().get_task())
|
||||
DB_ERRORS = (OperationalError, Error, RedisError)
|
||||
|
||||
|
||||
class TenantMiddleware(Middleware):
|
||||
@@ -186,6 +179,8 @@ class _healthcheck_handler(BaseHTTPRequestHandler):
|
||||
# Force connection reload
|
||||
db_conn.connect()
|
||||
_ = db_conn.cursor()
|
||||
redis_conn = get_redis_connection()
|
||||
redis_conn.ping()
|
||||
self.send_response(200)
|
||||
except DB_ERRORS: # pragma: no cover
|
||||
self.send_response(503)
|
||||
|
||||
@@ -1,48 +0,0 @@
|
||||
from channels.db import database_sync_to_async
|
||||
from django.db import close_old_connections, connection
|
||||
from django.http.request import split_domain_port
|
||||
from django_tenants.utils import (
|
||||
get_public_schema_name,
|
||||
remove_www,
|
||||
)
|
||||
|
||||
from authentik.tenants.models import Domain, Tenant
|
||||
|
||||
|
||||
class TenantsAwareMiddleware:
|
||||
"""Set the database schema for use with django-tenants"""
|
||||
|
||||
def __init__(self, inner):
|
||||
self.inner = inner
|
||||
|
||||
def get_hostname_from_scope(self, scope: list[tuple[bytes, bytes]]) -> str | None:
|
||||
headers = {k.replace(b"-", b"_").upper(): v for k, v in scope.get("headers", [])}
|
||||
hostname, _ = split_domain_port(headers.get(b"HOST", b"").decode("utf-8"))
|
||||
if not hostname:
|
||||
return None
|
||||
return remove_www(hostname)
|
||||
|
||||
async def get_default_tenant(self) -> Tenant:
|
||||
return await database_sync_to_async(Tenant.objects.get)(
|
||||
schema_name=get_public_schema_name()
|
||||
)
|
||||
|
||||
async def get_tenant(self, hostname: str | None) -> Tenant:
|
||||
if not hostname:
|
||||
return await self.get_default_tenant()
|
||||
|
||||
try:
|
||||
domain = await database_sync_to_async(Domain.objects.select_related("tenant").get)(
|
||||
domain=hostname
|
||||
)
|
||||
except Domain.DoesNotExist:
|
||||
return await self.get_default_tenant()
|
||||
return domain.tenant
|
||||
|
||||
async def __call__(self, scope, receive, send):
|
||||
close_old_connections()
|
||||
hostname = self.get_hostname_from_scope(scope)
|
||||
tenant = await self.get_tenant(hostname)
|
||||
scope["tenant"] = tenant
|
||||
connection.set_tenant(tenant)
|
||||
return await self.inner(scope, receive, send)
|
||||
@@ -3016,166 +3016,6 @@
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "object",
|
||||
"required": [
|
||||
"model",
|
||||
"identifiers"
|
||||
],
|
||||
"properties": {
|
||||
"model": {
|
||||
"const": "authentik_sources_telegram.grouptelegramsourceconnection"
|
||||
},
|
||||
"id": {
|
||||
"type": "string"
|
||||
},
|
||||
"state": {
|
||||
"type": "string",
|
||||
"enum": [
|
||||
"absent",
|
||||
"created",
|
||||
"must_created",
|
||||
"present"
|
||||
],
|
||||
"default": "present"
|
||||
},
|
||||
"conditions": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "boolean"
|
||||
}
|
||||
},
|
||||
"permissions": {
|
||||
"$ref": "#/$defs/model_authentik_sources_telegram.grouptelegramsourceconnection_permissions"
|
||||
},
|
||||
"attrs": {
|
||||
"$ref": "#/$defs/model_authentik_sources_telegram.grouptelegramsourceconnection"
|
||||
},
|
||||
"identifiers": {
|
||||
"$ref": "#/$defs/model_authentik_sources_telegram.grouptelegramsourceconnection"
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "object",
|
||||
"required": [
|
||||
"model",
|
||||
"identifiers"
|
||||
],
|
||||
"properties": {
|
||||
"model": {
|
||||
"const": "authentik_sources_telegram.telegramsource"
|
||||
},
|
||||
"id": {
|
||||
"type": "string"
|
||||
},
|
||||
"state": {
|
||||
"type": "string",
|
||||
"enum": [
|
||||
"absent",
|
||||
"created",
|
||||
"must_created",
|
||||
"present"
|
||||
],
|
||||
"default": "present"
|
||||
},
|
||||
"conditions": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "boolean"
|
||||
}
|
||||
},
|
||||
"permissions": {
|
||||
"$ref": "#/$defs/model_authentik_sources_telegram.telegramsource_permissions"
|
||||
},
|
||||
"attrs": {
|
||||
"$ref": "#/$defs/model_authentik_sources_telegram.telegramsource"
|
||||
},
|
||||
"identifiers": {
|
||||
"$ref": "#/$defs/model_authentik_sources_telegram.telegramsource"
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "object",
|
||||
"required": [
|
||||
"model",
|
||||
"identifiers"
|
||||
],
|
||||
"properties": {
|
||||
"model": {
|
||||
"const": "authentik_sources_telegram.telegramsourcepropertymapping"
|
||||
},
|
||||
"id": {
|
||||
"type": "string"
|
||||
},
|
||||
"state": {
|
||||
"type": "string",
|
||||
"enum": [
|
||||
"absent",
|
||||
"created",
|
||||
"must_created",
|
||||
"present"
|
||||
],
|
||||
"default": "present"
|
||||
},
|
||||
"conditions": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "boolean"
|
||||
}
|
||||
},
|
||||
"permissions": {
|
||||
"$ref": "#/$defs/model_authentik_sources_telegram.telegramsourcepropertymapping_permissions"
|
||||
},
|
||||
"attrs": {
|
||||
"$ref": "#/$defs/model_authentik_sources_telegram.telegramsourcepropertymapping"
|
||||
},
|
||||
"identifiers": {
|
||||
"$ref": "#/$defs/model_authentik_sources_telegram.telegramsourcepropertymapping"
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "object",
|
||||
"required": [
|
||||
"model",
|
||||
"identifiers"
|
||||
],
|
||||
"properties": {
|
||||
"model": {
|
||||
"const": "authentik_sources_telegram.usertelegramsourceconnection"
|
||||
},
|
||||
"id": {
|
||||
"type": "string"
|
||||
},
|
||||
"state": {
|
||||
"type": "string",
|
||||
"enum": [
|
||||
"absent",
|
||||
"created",
|
||||
"must_created",
|
||||
"present"
|
||||
],
|
||||
"default": "present"
|
||||
},
|
||||
"conditions": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "boolean"
|
||||
}
|
||||
},
|
||||
"permissions": {
|
||||
"$ref": "#/$defs/model_authentik_sources_telegram.usertelegramsourceconnection_permissions"
|
||||
},
|
||||
"attrs": {
|
||||
"$ref": "#/$defs/model_authentik_sources_telegram.usertelegramsourceconnection"
|
||||
},
|
||||
"identifiers": {
|
||||
"$ref": "#/$defs/model_authentik_sources_telegram.usertelegramsourceconnection"
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "object",
|
||||
"required": [
|
||||
@@ -5431,22 +5271,6 @@
|
||||
"authentik_sources_scim.view_scimsourcegroup",
|
||||
"authentik_sources_scim.view_scimsourcepropertymapping",
|
||||
"authentik_sources_scim.view_scimsourceuser",
|
||||
"authentik_sources_telegram.add_grouptelegramsourceconnection",
|
||||
"authentik_sources_telegram.add_telegramsource",
|
||||
"authentik_sources_telegram.add_telegramsourcepropertymapping",
|
||||
"authentik_sources_telegram.add_usertelegramsourceconnection",
|
||||
"authentik_sources_telegram.change_grouptelegramsourceconnection",
|
||||
"authentik_sources_telegram.change_telegramsource",
|
||||
"authentik_sources_telegram.change_telegramsourcepropertymapping",
|
||||
"authentik_sources_telegram.change_usertelegramsourceconnection",
|
||||
"authentik_sources_telegram.delete_grouptelegramsourceconnection",
|
||||
"authentik_sources_telegram.delete_telegramsource",
|
||||
"authentik_sources_telegram.delete_telegramsourcepropertymapping",
|
||||
"authentik_sources_telegram.delete_usertelegramsourceconnection",
|
||||
"authentik_sources_telegram.view_grouptelegramsourceconnection",
|
||||
"authentik_sources_telegram.view_telegramsource",
|
||||
"authentik_sources_telegram.view_telegramsourcepropertymapping",
|
||||
"authentik_sources_telegram.view_usertelegramsourceconnection",
|
||||
"authentik_stages_authenticator_duo.add_authenticatorduostage",
|
||||
"authentik_stages_authenticator_duo.add_duodevice",
|
||||
"authentik_stages_authenticator_duo.change_authenticatorduostage",
|
||||
@@ -7512,7 +7336,6 @@
|
||||
"authentik.sources.plex",
|
||||
"authentik.sources.saml",
|
||||
"authentik.sources.scim",
|
||||
"authentik.sources.telegram",
|
||||
"authentik.stages.authenticator",
|
||||
"authentik.stages.authenticator_duo",
|
||||
"authentik.stages.authenticator_email",
|
||||
@@ -7623,10 +7446,6 @@
|
||||
"authentik_sources_saml.groupsamlsourceconnection",
|
||||
"authentik_sources_scim.scimsource",
|
||||
"authentik_sources_scim.scimsourcepropertymapping",
|
||||
"authentik_sources_telegram.telegramsource",
|
||||
"authentik_sources_telegram.telegramsourcepropertymapping",
|
||||
"authentik_sources_telegram.usertelegramsourceconnection",
|
||||
"authentik_sources_telegram.grouptelegramsourceconnection",
|
||||
"authentik_stages_authenticator_duo.authenticatorduostage",
|
||||
"authentik_stages_authenticator_duo.duodevice",
|
||||
"authentik_stages_authenticator_email.authenticatoremailstage",
|
||||
@@ -10157,22 +9976,6 @@
|
||||
"authentik_sources_scim.view_scimsourcegroup",
|
||||
"authentik_sources_scim.view_scimsourcepropertymapping",
|
||||
"authentik_sources_scim.view_scimsourceuser",
|
||||
"authentik_sources_telegram.add_grouptelegramsourceconnection",
|
||||
"authentik_sources_telegram.add_telegramsource",
|
||||
"authentik_sources_telegram.add_telegramsourcepropertymapping",
|
||||
"authentik_sources_telegram.add_usertelegramsourceconnection",
|
||||
"authentik_sources_telegram.change_grouptelegramsourceconnection",
|
||||
"authentik_sources_telegram.change_telegramsource",
|
||||
"authentik_sources_telegram.change_telegramsourcepropertymapping",
|
||||
"authentik_sources_telegram.change_usertelegramsourceconnection",
|
||||
"authentik_sources_telegram.delete_grouptelegramsourceconnection",
|
||||
"authentik_sources_telegram.delete_telegramsource",
|
||||
"authentik_sources_telegram.delete_telegramsourcepropertymapping",
|
||||
"authentik_sources_telegram.delete_usertelegramsourceconnection",
|
||||
"authentik_sources_telegram.view_grouptelegramsourceconnection",
|
||||
"authentik_sources_telegram.view_telegramsource",
|
||||
"authentik_sources_telegram.view_telegramsourcepropertymapping",
|
||||
"authentik_sources_telegram.view_usertelegramsourceconnection",
|
||||
"authentik_stages_authenticator_duo.add_authenticatorduostage",
|
||||
"authentik_stages_authenticator_duo.add_duodevice",
|
||||
"authentik_stages_authenticator_duo.change_authenticatorduostage",
|
||||
@@ -12265,289 +12068,6 @@
|
||||
}
|
||||
}
|
||||
},
|
||||
"model_authentik_sources_telegram.grouptelegramsourceconnection": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"group": {
|
||||
"type": "string",
|
||||
"format": "uuid",
|
||||
"title": "Group"
|
||||
},
|
||||
"source": {
|
||||
"type": "integer",
|
||||
"title": "Source"
|
||||
},
|
||||
"identifier": {
|
||||
"type": "string",
|
||||
"minLength": 1,
|
||||
"title": "Identifier"
|
||||
},
|
||||
"icon": {
|
||||
"type": "string",
|
||||
"minLength": 1,
|
||||
"title": "Icon"
|
||||
}
|
||||
},
|
||||
"required": []
|
||||
},
|
||||
"model_authentik_sources_telegram.grouptelegramsourceconnection_permissions": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"required": [
|
||||
"permission"
|
||||
],
|
||||
"properties": {
|
||||
"permission": {
|
||||
"type": "string",
|
||||
"enum": [
|
||||
"add_grouptelegramsourceconnection",
|
||||
"change_grouptelegramsourceconnection",
|
||||
"delete_grouptelegramsourceconnection",
|
||||
"view_grouptelegramsourceconnection"
|
||||
]
|
||||
},
|
||||
"user": {
|
||||
"type": "integer"
|
||||
},
|
||||
"role": {
|
||||
"type": "string"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"model_authentik_sources_telegram.telegramsource": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {
|
||||
"type": "string",
|
||||
"minLength": 1,
|
||||
"title": "Name",
|
||||
"description": "Source's display Name."
|
||||
},
|
||||
"slug": {
|
||||
"type": "string",
|
||||
"maxLength": 50,
|
||||
"minLength": 1,
|
||||
"pattern": "^[-a-zA-Z0-9_]+$",
|
||||
"title": "Slug",
|
||||
"description": "Internal source name, used in URLs."
|
||||
},
|
||||
"enabled": {
|
||||
"type": "boolean",
|
||||
"title": "Enabled"
|
||||
},
|
||||
"authentication_flow": {
|
||||
"type": "string",
|
||||
"format": "uuid",
|
||||
"title": "Authentication flow",
|
||||
"description": "Flow to use when authenticating existing users."
|
||||
},
|
||||
"enrollment_flow": {
|
||||
"type": "string",
|
||||
"format": "uuid",
|
||||
"title": "Enrollment flow",
|
||||
"description": "Flow to use when enrolling new users."
|
||||
},
|
||||
"user_property_mappings": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "string",
|
||||
"format": "uuid"
|
||||
},
|
||||
"title": "User property mappings"
|
||||
},
|
||||
"group_property_mappings": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "string",
|
||||
"format": "uuid"
|
||||
},
|
||||
"title": "Group property mappings"
|
||||
},
|
||||
"policy_engine_mode": {
|
||||
"type": "string",
|
||||
"enum": [
|
||||
"all",
|
||||
"any"
|
||||
],
|
||||
"title": "Policy engine mode"
|
||||
},
|
||||
"user_matching_mode": {
|
||||
"type": "string",
|
||||
"enum": [
|
||||
"identifier",
|
||||
"email_link",
|
||||
"email_deny",
|
||||
"username_link",
|
||||
"username_deny"
|
||||
],
|
||||
"title": "User matching mode",
|
||||
"description": "How the source determines if an existing user should be authenticated or a new user enrolled."
|
||||
},
|
||||
"user_path_template": {
|
||||
"type": "string",
|
||||
"minLength": 1,
|
||||
"title": "User path template"
|
||||
},
|
||||
"icon": {
|
||||
"type": "string",
|
||||
"minLength": 1,
|
||||
"title": "Icon"
|
||||
},
|
||||
"bot_username": {
|
||||
"type": "string",
|
||||
"minLength": 1,
|
||||
"title": "Bot username",
|
||||
"description": "Telegram bot username"
|
||||
},
|
||||
"bot_token": {
|
||||
"type": "string",
|
||||
"minLength": 1,
|
||||
"title": "Bot token",
|
||||
"description": "Telegram bot token"
|
||||
},
|
||||
"request_message_access": {
|
||||
"type": "boolean",
|
||||
"title": "Request message access",
|
||||
"description": "Request access to send messages from your bot."
|
||||
},
|
||||
"pre_authentication_flow": {
|
||||
"type": "string",
|
||||
"format": "uuid",
|
||||
"title": "Pre authentication flow",
|
||||
"description": "Flow used before authentication."
|
||||
}
|
||||
},
|
||||
"required": []
|
||||
},
|
||||
"model_authentik_sources_telegram.telegramsource_permissions": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"required": [
|
||||
"permission"
|
||||
],
|
||||
"properties": {
|
||||
"permission": {
|
||||
"type": "string",
|
||||
"enum": [
|
||||
"add_telegramsource",
|
||||
"change_telegramsource",
|
||||
"delete_telegramsource",
|
||||
"view_telegramsource"
|
||||
]
|
||||
},
|
||||
"user": {
|
||||
"type": "integer"
|
||||
},
|
||||
"role": {
|
||||
"type": "string"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"model_authentik_sources_telegram.telegramsourcepropertymapping": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"managed": {
|
||||
"type": [
|
||||
"string",
|
||||
"null"
|
||||
],
|
||||
"minLength": 1,
|
||||
"title": "Managed by authentik",
|
||||
"description": "Objects that are managed by authentik. These objects are created and updated automatically. This flag only indicates that an object can be overwritten by migrations. You can still modify the objects via the API, but expect changes to be overwritten in a later update."
|
||||
},
|
||||
"name": {
|
||||
"type": "string",
|
||||
"minLength": 1,
|
||||
"title": "Name"
|
||||
},
|
||||
"expression": {
|
||||
"type": "string",
|
||||
"minLength": 1,
|
||||
"title": "Expression"
|
||||
}
|
||||
},
|
||||
"required": []
|
||||
},
|
||||
"model_authentik_sources_telegram.telegramsourcepropertymapping_permissions": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"required": [
|
||||
"permission"
|
||||
],
|
||||
"properties": {
|
||||
"permission": {
|
||||
"type": "string",
|
||||
"enum": [
|
||||
"add_telegramsourcepropertymapping",
|
||||
"change_telegramsourcepropertymapping",
|
||||
"delete_telegramsourcepropertymapping",
|
||||
"view_telegramsourcepropertymapping"
|
||||
]
|
||||
},
|
||||
"user": {
|
||||
"type": "integer"
|
||||
},
|
||||
"role": {
|
||||
"type": "string"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"model_authentik_sources_telegram.usertelegramsourceconnection": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"user": {
|
||||
"type": "integer",
|
||||
"title": "User"
|
||||
},
|
||||
"source": {
|
||||
"type": "integer",
|
||||
"title": "Source"
|
||||
},
|
||||
"identifier": {
|
||||
"type": "string",
|
||||
"minLength": 1,
|
||||
"title": "Identifier"
|
||||
},
|
||||
"icon": {
|
||||
"type": "string",
|
||||
"minLength": 1,
|
||||
"title": "Icon"
|
||||
}
|
||||
},
|
||||
"required": []
|
||||
},
|
||||
"model_authentik_sources_telegram.usertelegramsourceconnection_permissions": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"required": [
|
||||
"permission"
|
||||
],
|
||||
"properties": {
|
||||
"permission": {
|
||||
"type": "string",
|
||||
"enum": [
|
||||
"add_usertelegramsourceconnection",
|
||||
"change_usertelegramsourceconnection",
|
||||
"delete_usertelegramsourceconnection",
|
||||
"view_usertelegramsourceconnection"
|
||||
]
|
||||
},
|
||||
"user": {
|
||||
"type": "integer"
|
||||
},
|
||||
"role": {
|
||||
"type": "string"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"model_authentik_stages_authenticator_duo.authenticatorduostage": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
|
||||
6
go.mod
6
go.mod
@@ -11,7 +11,7 @@ require (
|
||||
github.com/coreos/go-oidc/v3 v3.15.0
|
||||
github.com/getsentry/sentry-go v0.35.3
|
||||
github.com/go-http-utils/etag v0.0.0-20161124023236-513ea8f21eb1
|
||||
github.com/go-ldap/ldap/v3 v3.4.12
|
||||
github.com/go-ldap/ldap/v3 v3.4.11
|
||||
github.com/go-openapi/runtime v0.29.0
|
||||
github.com/golang-jwt/jwt/v5 v5.3.0
|
||||
github.com/google/uuid v1.6.0
|
||||
@@ -32,7 +32,7 @@ require (
|
||||
github.com/spf13/cobra v1.10.1
|
||||
github.com/stretchr/testify v1.11.1
|
||||
github.com/wwt/guac v1.3.2
|
||||
goauthentik.io/api/v3 v3.2025100.16
|
||||
goauthentik.io/api/v3 v3.2025100.14
|
||||
golang.org/x/exp v0.0.0-20230210204819-062eb4c674ab
|
||||
golang.org/x/oauth2 v0.31.0
|
||||
golang.org/x/sync v0.17.0
|
||||
@@ -96,5 +96,3 @@ require (
|
||||
google.golang.org/protobuf v1.36.8 // indirect
|
||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||
)
|
||||
|
||||
replace goauthentik.io/api/v3 => ./gen-go-api
|
||||
|
||||
12
go.sum
12
go.sum
@@ -4,8 +4,8 @@ beryju.io/radius-eap v0.1.0 h1:5M3HwkzH3nIEBcKDA2z5+sb4nCY3WdKL/SDDKTBvoqw=
|
||||
beryju.io/radius-eap v0.1.0/go.mod h1:yYtO59iyoLNEepdyp1gZ0i1tGdjPbrR2M/v5yOz7Fkc=
|
||||
github.com/Azure/go-ntlmssp v0.0.0-20221128193559-754e69321358 h1:mFRzDkZVAjdal+s7s0MwaRv9igoPqLRdzOLzw/8Xvq8=
|
||||
github.com/Azure/go-ntlmssp v0.0.0-20221128193559-754e69321358/go.mod h1:chxPXzSsl7ZWRAuOIE23GDNzjWuZquvFlgA8xmpunjU=
|
||||
github.com/alexbrainman/sspi v0.0.0-20250919150558-7d374ff0d59e h1:4dAU9FXIyQktpoUAgOJK3OTFc/xug0PCXYCqU0FgDKI=
|
||||
github.com/alexbrainman/sspi v0.0.0-20250919150558-7d374ff0d59e/go.mod h1:cEWa1LVoE5KvSD9ONXsZrj0z6KqySlCCNKHlLzbqAt4=
|
||||
github.com/alexbrainman/sspi v0.0.0-20231016080023-1a75b4708caa h1:LHTHcTQiSGT7VVbI0o4wBRNQIgn917usHWOd6VAffYI=
|
||||
github.com/alexbrainman/sspi v0.0.0-20231016080023-1a75b4708caa/go.mod h1:cEWa1LVoE5KvSD9ONXsZrj0z6KqySlCCNKHlLzbqAt4=
|
||||
github.com/asaskevich/govalidator v0.0.0-20230301143203-a9d515a09cc2 h1:DklsrG3dyBCFEj5IhUbnKptjxatkF07cF2ak3yi77so=
|
||||
github.com/asaskevich/govalidator v0.0.0-20230301143203-a9d515a09cc2/go.mod h1:WaHUgvxTVq04UNunO+XhnAqY/wQc+bxr74GqbsZ/Jqw=
|
||||
github.com/avast/retry-go/v4 v4.6.1 h1:VkOLRubHdisGrHnTu89g08aQEWEgRU7LVEop3GbIcMk=
|
||||
@@ -42,8 +42,8 @@ github.com/go-http-utils/headers v0.0.0-20181008091004-fed159eddc2a h1:v6zMvHuY9
|
||||
github.com/go-http-utils/headers v0.0.0-20181008091004-fed159eddc2a/go.mod h1:I79BieaU4fxrw4LMXby6q5OS9XnoR9UIKLOzDFjUmuw=
|
||||
github.com/go-jose/go-jose/v4 v4.0.5 h1:M6T8+mKZl/+fNNuFHvGIzDz7BTLQPIounk/b9dw3AaE=
|
||||
github.com/go-jose/go-jose/v4 v4.0.5/go.mod h1:s3P1lRrkT8igV8D9OjyL4WRyHvjB6a4JSllnOrmmBOA=
|
||||
github.com/go-ldap/ldap/v3 v3.4.12 h1:1b81mv7MagXZ7+1r7cLTWmyuTqVqdwbtJSjC0DAp9s4=
|
||||
github.com/go-ldap/ldap/v3 v3.4.12/go.mod h1:+SPAGcTtOfmGsCb3h1RFiq4xpp4N636G75OEace8lNo=
|
||||
github.com/go-ldap/ldap/v3 v3.4.11 h1:4k0Yxweg+a3OyBLjdYn5OKglv18JNvfDykSoI8bW0gU=
|
||||
github.com/go-ldap/ldap/v3 v3.4.11/go.mod h1:bY7t0FLK8OAVpp/vV6sSlpz3EQDGcQwc8pF0ujLgKvM=
|
||||
github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A=
|
||||
github.com/go-logr/logr v1.4.3 h1:CjnDlHq8ikf6E492q6eKboGOC0T8CDaOvkHCIg8idEI=
|
||||
github.com/go-logr/logr v1.4.3/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY=
|
||||
@@ -207,8 +207,8 @@ go.yaml.in/yaml/v2 v2.4.2 h1:DzmwEr2rDGHl7lsFgAHxmNz/1NlQ7xLIrlN2h5d1eGI=
|
||||
go.yaml.in/yaml/v2 v2.4.2/go.mod h1:081UH+NErpNdqlCXm3TtEran0rJZGxAYx9hb/ELlsPU=
|
||||
go.yaml.in/yaml/v3 v3.0.4 h1:tfq32ie2Jv2UxXFdLJdh3jXuOzWiL1fo0bu/FbuKpbc=
|
||||
go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg=
|
||||
goauthentik.io/api/v3 v3.2025100.16 h1:+tkfBSjpIt7eo267Y86LDmprrxo7bxiuLVKh/dZuGMg=
|
||||
goauthentik.io/api/v3 v3.2025100.16/go.mod h1:82lqAz4jxzl6Cg0YDbhNtvvTG2rm6605ZhdJFnbbsl8=
|
||||
goauthentik.io/api/v3 v3.2025100.14 h1:fBRPhJ+nMIzD3AHC8ofQcBs6ZKvOWD+r/tOCnbZupX0=
|
||||
goauthentik.io/api/v3 v3.2025100.14/go.mod h1:82lqAz4jxzl6Cg0YDbhNtvvTG2rm6605ZhdJFnbbsl8=
|
||||
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
|
||||
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
|
||||
golang.org/x/crypto v0.13.0/go.mod h1:y6Z2r+Rw4iayiXXAIxJIDAJ1zMW4yaTpebo8fPOliYc=
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
# syntax=docker/dockerfile:1
|
||||
|
||||
# Stage 1: Build
|
||||
FROM --platform=${BUILDPLATFORM} docker.io/library/golang:1.25.1-bookworm AS builder
|
||||
FROM --platform=${BUILDPLATFORM} docker.io/library/golang:1.25-bookworm AS builder
|
||||
|
||||
ARG TARGETOS
|
||||
ARG TARGETARCH
|
||||
|
||||
16
lifecycle/aws/package-lock.json
generated
16
lifecycle/aws/package-lock.json
generated
@@ -9,8 +9,8 @@
|
||||
"version": "0.0.0",
|
||||
"license": "MIT",
|
||||
"devDependencies": {
|
||||
"aws-cdk": "^2.1029.4",
|
||||
"cross-env": "^10.1.0"
|
||||
"aws-cdk": "^2.1029.3",
|
||||
"cross-env": "^10.0.0"
|
||||
},
|
||||
"engines": {
|
||||
"node": ">=20"
|
||||
@@ -24,9 +24,9 @@
|
||||
"license": "MIT"
|
||||
},
|
||||
"node_modules/aws-cdk": {
|
||||
"version": "2.1029.4",
|
||||
"resolved": "https://registry.npmjs.org/aws-cdk/-/aws-cdk-2.1029.4.tgz",
|
||||
"integrity": "sha512-rJa8QLd8WHaoTEjPLqVwmNpDMmyJycVaxdr/Evr/1MDLq+WCovP46IqPaXfH0q/jY0gCsga9or907tEayK5xcg==",
|
||||
"version": "2.1029.3",
|
||||
"resolved": "https://registry.npmjs.org/aws-cdk/-/aws-cdk-2.1029.3.tgz",
|
||||
"integrity": "sha512-otRJP5a4r07S+SLKs/WvJH+0auZHkaRMnv1vtD4fpp1figV8Vr9MKdB4QPNjfKdLGyK9f95OEHwVlIW9xpjPBg==",
|
||||
"dev": true,
|
||||
"license": "Apache-2.0",
|
||||
"bin": {
|
||||
@@ -40,9 +40,9 @@
|
||||
}
|
||||
},
|
||||
"node_modules/cross-env": {
|
||||
"version": "10.1.0",
|
||||
"resolved": "https://registry.npmjs.org/cross-env/-/cross-env-10.1.0.tgz",
|
||||
"integrity": "sha512-GsYosgnACZTADcmEyJctkJIoqAhHjttw7RsFrVoJNXbsWWqaq6Ym+7kZjq6mS45O0jij6vtiReppKQEtqWy6Dw==",
|
||||
"version": "10.0.0",
|
||||
"resolved": "https://registry.npmjs.org/cross-env/-/cross-env-10.0.0.tgz",
|
||||
"integrity": "sha512-aU8qlEK/nHYtVuN4p7UQgAwVljzMg8hB4YK5ThRqD2l/ziSnryncPNn7bMLt5cFYsKVKBh8HqLqyCoTupEUu7Q==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
|
||||
@@ -10,7 +10,7 @@
|
||||
"node": ">=20"
|
||||
},
|
||||
"devDependencies": {
|
||||
"aws-cdk": "^2.1029.4",
|
||||
"cross-env": "^10.1.0"
|
||||
"aws-cdk": "^2.1029.3",
|
||||
"cross-env": "^10.0.0"
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
# flake8: noqa
|
||||
from redis import Redis
|
||||
|
||||
from authentik.lib.config import CONFIG
|
||||
from lifecycle.migrate import BaseMigration
|
||||
|
||||
SQL_STATEMENT = """BEGIN TRANSACTION;
|
||||
@@ -104,3 +106,17 @@ class Migration(BaseMigration):
|
||||
def run(self):
|
||||
with self.con.transaction():
|
||||
self.cur.execute(SQL_STATEMENT)
|
||||
# We also need to clean the cache to make sure no pickeled objects still exist
|
||||
for db in [
|
||||
CONFIG.get("redis.message_queue_db"),
|
||||
CONFIG.get("redis.cache_db"),
|
||||
CONFIG.get("redis.ws_db"),
|
||||
]:
|
||||
redis = Redis(
|
||||
host=CONFIG.get("redis.host"),
|
||||
port=6379,
|
||||
db=db,
|
||||
username=CONFIG.get("redis.username"),
|
||||
password=CONFIG.get("redis.password"),
|
||||
)
|
||||
redis.flushall()
|
||||
|
||||
@@ -1,13 +1,14 @@
|
||||
#!/usr/bin/env python
|
||||
"""This file needs to be run from the root of the project to correctly
|
||||
import authentik. This is done by the dockerfile."""
|
||||
|
||||
from sys import exit as sysexit
|
||||
from time import sleep
|
||||
|
||||
from psycopg import OperationalError, connect
|
||||
from redis import Redis
|
||||
from redis.exceptions import RedisError
|
||||
|
||||
from authentik.lib.config import CONFIG
|
||||
from authentik.lib.config import CONFIG, redis_url
|
||||
|
||||
CHECK_THRESHOLD = 30
|
||||
|
||||
@@ -39,6 +40,24 @@ def check_postgres():
|
||||
CONFIG.log("info", "PostgreSQL connection successful")
|
||||
|
||||
|
||||
def check_redis():
|
||||
url = CONFIG.get("cache.url") or redis_url(CONFIG.get("redis.db"))
|
||||
attempt = 0
|
||||
while True:
|
||||
if attempt >= CHECK_THRESHOLD:
|
||||
sysexit(1)
|
||||
try:
|
||||
redis = Redis.from_url(url)
|
||||
redis.ping()
|
||||
break
|
||||
except RedisError as exc:
|
||||
sleep(1)
|
||||
CONFIG.log("info", f"Redis Connection failed, retrying... ({exc})")
|
||||
finally:
|
||||
attempt += 1
|
||||
CONFIG.log("info", "Redis Connection successful")
|
||||
|
||||
|
||||
def wait_for_db():
|
||||
CONFIG.log("info", "Starting authentik bootstrap")
|
||||
# Sanity check, ensure SECRET_KEY is set before we even check for database connectivity
|
||||
@@ -50,6 +69,7 @@ def wait_for_db():
|
||||
CONFIG.log("info", "----------------------------------------------------------------------")
|
||||
sysexit(1)
|
||||
check_postgres()
|
||||
check_redis()
|
||||
CONFIG.log("info", "Finished authentik bootstrap")
|
||||
|
||||
|
||||
|
||||
Binary file not shown.
@@ -8,7 +8,7 @@ msgid ""
|
||||
msgstr ""
|
||||
"Project-Id-Version: PACKAGE VERSION\n"
|
||||
"Report-Msgid-Bugs-To: \n"
|
||||
"POT-Creation-Date: 2025-10-02 00:10+0000\n"
|
||||
"POT-Creation-Date: 2025-09-26 00:10+0000\n"
|
||||
"PO-Revision-Date: YEAR-MO-DA HO:MI+ZONE\n"
|
||||
"Last-Translator: FULL NAME <EMAIL@ADDRESS>\n"
|
||||
"Language-Team: LANGUAGE <LL@li.org>\n"
|
||||
@@ -152,18 +152,6 @@ msgstr ""
|
||||
msgid "No empty segments in user path allowed."
|
||||
msgstr ""
|
||||
|
||||
#: authentik/core/api/users.py
|
||||
msgid "A user/group with these details already exists"
|
||||
msgstr ""
|
||||
|
||||
#: authentik/core/api/users.py
|
||||
msgid "Unable to create user"
|
||||
msgstr ""
|
||||
|
||||
#: authentik/core/api/users.py
|
||||
msgid "Unknown error occurred"
|
||||
msgstr ""
|
||||
|
||||
#: authentik/core/api/users.py
|
||||
msgid "This field is required."
|
||||
msgstr ""
|
||||
@@ -1678,13 +1666,6 @@ msgid ""
|
||||
"minutes=2;seconds=3)."
|
||||
msgstr ""
|
||||
|
||||
#: authentik/providers/oauth2/models.py
|
||||
msgid ""
|
||||
"When refreshing a token, if the refresh token is valid for less than this "
|
||||
"duration, it will be renewed. When set to seconds=0, token will always be "
|
||||
"renewed. (Format: hours=1;minutes=2;seconds=3)."
|
||||
msgstr ""
|
||||
|
||||
#: authentik/providers/oauth2/models.py
|
||||
msgid ""
|
||||
"Configure what data should be used as unique User Identifier. For most "
|
||||
@@ -2817,12 +2798,6 @@ msgstr ""
|
||||
msgid "Check the validity of a Plex source."
|
||||
msgstr ""
|
||||
|
||||
#: authentik/sources/saml/api/source.py
|
||||
msgid ""
|
||||
"With a Verification Certificate selected, at least one of 'Verify Assertion "
|
||||
"Signature' or 'Verify Response Signature' must be selected."
|
||||
msgstr ""
|
||||
|
||||
#: authentik/sources/saml/models.py
|
||||
msgid "Redirect Binding"
|
||||
msgstr ""
|
||||
@@ -2835,7 +2810,7 @@ msgstr ""
|
||||
msgid "POST Binding with auto-confirmation"
|
||||
msgstr ""
|
||||
|
||||
#: authentik/sources/saml/models.py authentik/sources/telegram/models.py
|
||||
#: authentik/sources/saml/models.py
|
||||
msgid "Flow used before authentication."
|
||||
msgstr ""
|
||||
|
||||
@@ -2942,58 +2917,6 @@ msgstr ""
|
||||
msgid "SCIM Source Property Mappings"
|
||||
msgstr ""
|
||||
|
||||
#: authentik/sources/telegram/models.py authentik/sources/telegram/stage.py
|
||||
msgid "Telegram bot username"
|
||||
msgstr ""
|
||||
|
||||
#: authentik/sources/telegram/models.py
|
||||
msgid "Telegram bot token"
|
||||
msgstr ""
|
||||
|
||||
#: authentik/sources/telegram/models.py
|
||||
msgid "Request access to send messages from your bot."
|
||||
msgstr ""
|
||||
|
||||
#: authentik/sources/telegram/models.py
|
||||
msgid "Telegram Source"
|
||||
msgstr ""
|
||||
|
||||
#: authentik/sources/telegram/models.py
|
||||
msgid "Telegram Sources"
|
||||
msgstr ""
|
||||
|
||||
#: authentik/sources/telegram/models.py
|
||||
msgid "Telegram Source Property Mapping"
|
||||
msgstr ""
|
||||
|
||||
#: authentik/sources/telegram/models.py
|
||||
msgid "Telegram Source Property Mappings"
|
||||
msgstr ""
|
||||
|
||||
#: authentik/sources/telegram/models.py
|
||||
msgid "User Telegram Source Connection"
|
||||
msgstr ""
|
||||
|
||||
#: authentik/sources/telegram/models.py
|
||||
msgid "User Telegram Source Connections"
|
||||
msgstr ""
|
||||
|
||||
#: authentik/sources/telegram/models.py
|
||||
msgid "Group Telegram Source Connection"
|
||||
msgstr ""
|
||||
|
||||
#: authentik/sources/telegram/models.py
|
||||
msgid "Group Telegram Source Connections"
|
||||
msgstr ""
|
||||
|
||||
#: authentik/sources/telegram/stage.py
|
||||
msgid "Authentication date is too old"
|
||||
msgstr ""
|
||||
|
||||
#: authentik/sources/telegram/stage.py
|
||||
msgid "Invalid hash"
|
||||
msgstr ""
|
||||
|
||||
#: authentik/stages/authenticator_duo/models.py
|
||||
msgid "Duo Authenticator Setup Stage"
|
||||
msgstr ""
|
||||
|
||||
@@ -11,13 +11,10 @@ class DjangoDramatiqPostgres(AppConfig):
|
||||
name = "django_dramatiq_postgres"
|
||||
verbose_name = "Django Dramatiq postgres"
|
||||
|
||||
def ready(self) -> None:
|
||||
try:
|
||||
old_broker = dramatiq.get_broker()
|
||||
except ModuleNotFoundError:
|
||||
old_broker = None
|
||||
def ready(self):
|
||||
old_broker = dramatiq.get_broker()
|
||||
|
||||
if old_broker is not None and len(old_broker.actors) != 0:
|
||||
if len(old_broker.actors) != 0:
|
||||
raise ImproperlyConfigured(
|
||||
"Actors were previously registered. "
|
||||
"Make sure your actors are not imported too early."
|
||||
@@ -41,7 +38,7 @@ class DjangoDramatiqPostgres(AppConfig):
|
||||
*Conf().result_backend_args,
|
||||
**Conf().result_backend_kwargs,
|
||||
)
|
||||
broker.add_middleware(middleware) # type: ignore[no-untyped-call]
|
||||
broker.add_middleware(middleware)
|
||||
|
||||
dramatiq.set_broker(broker)
|
||||
|
||||
|
||||
@@ -1,10 +1,9 @@
|
||||
import functools
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import Callable, Iterable
|
||||
from datetime import timedelta
|
||||
from collections.abc import Iterable
|
||||
from queue import Empty, Queue
|
||||
from typing import Any, ParamSpec, TypeVar, cast
|
||||
from typing import Any
|
||||
|
||||
import tenacity
|
||||
from django.core.exceptions import ImproperlyConfigured
|
||||
@@ -38,36 +37,31 @@ from django_dramatiq_postgres.models import CHANNEL_PREFIX, ChannelIdentifier, T
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
|
||||
|
||||
def channel_name(queue_name: str, identifier: ChannelIdentifier) -> str:
|
||||
return f"{CHANNEL_PREFIX}.{queue_name}.{identifier.value}"
|
||||
|
||||
|
||||
def raise_connection_error(func: Callable[P, R]) -> Callable[P, R]:
|
||||
def raise_connection_error(func):
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||
def wrapper(*args, **kwargs):
|
||||
try:
|
||||
return func(*args, **kwargs)
|
||||
except OperationalError as exc:
|
||||
raise ConnectionError(str(exc)) from exc # type: ignore[no-untyped-call]
|
||||
raise ConnectionError(str(exc)) from exc
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
class PostgresBroker(Broker):
|
||||
queues: set[str] # type: ignore[assignment]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*args: Any,
|
||||
*args,
|
||||
middleware: list[Middleware] | None = None,
|
||||
db_alias: str = DEFAULT_DB_ALIAS,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().__init__(*args, middleware=[], **kwargs) # type: ignore[no-untyped-call,misc]
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(*args, middleware=[], **kwargs)
|
||||
self.logger = get_logger(__name__, type(self))
|
||||
|
||||
self.queues = set()
|
||||
@@ -81,7 +75,7 @@ class PostgresBroker(Broker):
|
||||
|
||||
@property
|
||||
def connection(self) -> DatabaseWrapper:
|
||||
return cast(DatabaseWrapper, connections[self.db_alias])
|
||||
return connections[self.db_alias]
|
||||
|
||||
@property
|
||||
def consumer_class(self) -> "type[_PostgresConsumer]":
|
||||
@@ -89,12 +83,11 @@ class PostgresBroker(Broker):
|
||||
|
||||
@cached_property
|
||||
def model(self) -> type[TaskBase]:
|
||||
model: type[TaskBase] = import_string(Conf().task_model)
|
||||
return model
|
||||
return import_string(Conf().task_model)
|
||||
|
||||
@property
|
||||
def query_set(self) -> QuerySet[TaskBase]:
|
||||
return self.model._default_manager.using(self.db_alias).defer("message", "result")
|
||||
def query_set(self) -> QuerySet:
|
||||
return self.model.objects.using(self.db_alias).defer("message", "result")
|
||||
|
||||
def consume(self, queue_name: str, prefetch: int = 1, timeout: int = 30000) -> Consumer:
|
||||
self.declare_queue(queue_name)
|
||||
@@ -106,18 +99,18 @@ class PostgresBroker(Broker):
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
def declare_queue(self, queue_name: str) -> None:
|
||||
def declare_queue(self, queue_name: str):
|
||||
if queue_name not in self.queues:
|
||||
self.emit_before("declare_queue", queue_name) # type: ignore[no-untyped-call]
|
||||
self.emit_before("declare_queue", queue_name)
|
||||
self.queues.add(queue_name)
|
||||
# Nothing to do, all queues are in the same table
|
||||
self.emit_after("declare_queue", queue_name) # type: ignore[no-untyped-call]
|
||||
self.emit_after("declare_queue", queue_name)
|
||||
|
||||
delayed_name = dq_name(queue_name) # type: ignore[no-untyped-call]
|
||||
delayed_name = dq_name(queue_name)
|
||||
self.delay_queues.add(delayed_name)
|
||||
self.emit_after("declare_delay_queue", delayed_name) # type: ignore[no-untyped-call]
|
||||
self.emit_after("declare_delay_queue", delayed_name)
|
||||
|
||||
def model_defaults(self, message: Message[Any]) -> dict[str, Any]:
|
||||
def model_defaults(self, message: Message) -> dict[str, Any]:
|
||||
return {
|
||||
"queue_name": message.queue_name,
|
||||
"actor_name": message.actor_name,
|
||||
@@ -137,16 +130,14 @@ class PostgresBroker(Broker):
|
||||
reraise=True,
|
||||
wait=tenacity.wait_random_exponential(multiplier=1, max=5),
|
||||
stop=tenacity.stop_after_attempt(3),
|
||||
before_sleep=tenacity.before_sleep_log(
|
||||
cast(logging.Logger, logger), logging.INFO, exc_info=True
|
||||
),
|
||||
before_sleep=tenacity.before_sleep_log(logger, logging.INFO, exc_info=True),
|
||||
)
|
||||
def enqueue(self, message: Message[Any], *, delay: int | None = None) -> Message[Any]:
|
||||
def enqueue(self, message: Message, *, delay: int | None = None) -> Message:
|
||||
canonical_queue_name = message.queue_name
|
||||
queue_name = canonical_queue_name
|
||||
if delay:
|
||||
queue_name = dq_name(queue_name) # type: ignore[no-untyped-call]
|
||||
message_eta = current_millis() + delay # type: ignore[no-untyped-call]
|
||||
queue_name = dq_name(queue_name)
|
||||
message_eta = current_millis() + delay
|
||||
message = message.copy(
|
||||
queue_name=queue_name,
|
||||
options={
|
||||
@@ -160,7 +151,7 @@ class PostgresBroker(Broker):
|
||||
)
|
||||
|
||||
message.options["model_defaults"] = self.model_defaults(message)
|
||||
self.emit_before("enqueue", message, delay) # type: ignore[no-untyped-call]
|
||||
self.emit_before("enqueue", message, delay)
|
||||
|
||||
with transaction.atomic(using=self.db_alias):
|
||||
query = {
|
||||
@@ -182,18 +173,18 @@ class PostgresBroker(Broker):
|
||||
message.options["task"] = task
|
||||
message.options["task_created"] = created
|
||||
|
||||
self.emit_after("enqueue", message, delay) # type: ignore[no-untyped-call]
|
||||
self.emit_after("enqueue", message, delay)
|
||||
return message
|
||||
|
||||
def get_declared_queues(self) -> set[str]:
|
||||
return self.queues.copy()
|
||||
|
||||
def flush(self, queue_name: str) -> None:
|
||||
def flush(self, queue_name: str):
|
||||
self.query_set.filter(
|
||||
queue_name__in=(queue_name, dq_name(queue_name), xq_name(queue_name)) # type: ignore[no-untyped-call]
|
||||
queue_name__in=(queue_name, dq_name(queue_name), xq_name(queue_name))
|
||||
).delete()
|
||||
|
||||
def flush_all(self) -> None:
|
||||
def flush_all(self):
|
||||
for queue_name in self.queues:
|
||||
self.flush(queue_name)
|
||||
|
||||
@@ -203,11 +194,11 @@ class PostgresBroker(Broker):
|
||||
interval: int = 100,
|
||||
*,
|
||||
timeout: int | None = None,
|
||||
) -> None:
|
||||
):
|
||||
deadline = timeout and time.monotonic() + timeout / 1000
|
||||
while True:
|
||||
if deadline and time.monotonic() >= deadline:
|
||||
raise QueueJoinTimeout(queue_name) # type: ignore[no-untyped-call]
|
||||
raise QueueJoinTimeout(queue_name)
|
||||
|
||||
if self.query_set.filter(
|
||||
queue_name=queue_name,
|
||||
@@ -221,14 +212,14 @@ class PostgresBroker(Broker):
|
||||
class _PostgresConsumer(Consumer):
|
||||
def __init__(
|
||||
self,
|
||||
*args: Any,
|
||||
*args,
|
||||
broker: PostgresBroker,
|
||||
db_alias: str,
|
||||
queue_name: str,
|
||||
prefetch: int,
|
||||
timeout: int,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
**kwargs,
|
||||
):
|
||||
self.logger = get_logger(__name__, type(self))
|
||||
|
||||
self.notifies: list[Notify] = []
|
||||
@@ -236,8 +227,8 @@ class _PostgresConsumer(Consumer):
|
||||
self.db_alias = db_alias
|
||||
self.queue_name = queue_name
|
||||
self.timeout = timeout // 1000
|
||||
self.unlock_queue: Queue[str] = Queue()
|
||||
self.in_processing: set[str] = set()
|
||||
self.unlock_queue = Queue()
|
||||
self.in_processing = set()
|
||||
self.prefetch = prefetch
|
||||
self.misses = 0
|
||||
self._listen_connection: DatabaseWrapper | None = None
|
||||
@@ -246,34 +237,32 @@ class _PostgresConsumer(Consumer):
|
||||
# Override because dramatiq doesn't allow us setting this manually
|
||||
self.timeout = Conf().worker["consumer_listen_timeout"]
|
||||
|
||||
self.lock_purge_interval = timedelta(seconds=Conf().lock_purge_interval)
|
||||
self.lock_purge_interval = timezone.timedelta(seconds=Conf().lock_purge_interval)
|
||||
self.lock_purge_last_run = timezone.now()
|
||||
|
||||
self.task_purge_interval = timedelta(seconds=Conf().task_purge_interval)
|
||||
self.task_purge_interval = timezone.timedelta(seconds=Conf().task_purge_interval)
|
||||
self.task_purge_last_run = timezone.now() - self.task_purge_interval
|
||||
|
||||
self.scheduler = None
|
||||
if Conf().schedule_model:
|
||||
self.scheduler = import_string(Conf().scheduler_class)()
|
||||
self.scheduler.broker = self.broker
|
||||
self.scheduler_interval = timedelta(seconds=Conf().scheduler_interval)
|
||||
self.scheduler_interval = timezone.timedelta(seconds=Conf().scheduler_interval)
|
||||
self.scheduler_last_run = timezone.now() - self.scheduler_interval
|
||||
|
||||
@property
|
||||
def connection(self) -> DatabaseWrapper:
|
||||
return cast(DatabaseWrapper, connections[self.db_alias])
|
||||
return connections[self.db_alias]
|
||||
|
||||
@property
|
||||
def query_set(self) -> QuerySet[TaskBase]:
|
||||
def query_set(self) -> QuerySet:
|
||||
return self.broker.query_set
|
||||
|
||||
@property
|
||||
def listen_connection(self) -> DatabaseWrapper:
|
||||
if self._listen_connection is not None and self._listen_connection.is_usable():
|
||||
return self._listen_connection
|
||||
self._listen_connection = cast(
|
||||
DatabaseWrapper, connections.create_connection(self.db_alias)
|
||||
)
|
||||
self._listen_connection = connections.create_connection(self.db_alias)
|
||||
# Required for notifications
|
||||
# See https://www.psycopg.org/psycopg3/docs/advanced/async.html#asynchronous-notifications
|
||||
# Should be set to True by Django by default
|
||||
@@ -283,7 +272,7 @@ class _PostgresConsumer(Consumer):
|
||||
return self._listen_connection
|
||||
|
||||
@raise_connection_error
|
||||
def ack(self, message: Message[Any]) -> None:
|
||||
def ack(self, message: Message):
|
||||
task = message.options.pop("task", None)
|
||||
self.query_set.filter(
|
||||
message_id=message.message_id,
|
||||
@@ -298,7 +287,7 @@ class _PostgresConsumer(Consumer):
|
||||
self.in_processing.remove(message.message_id)
|
||||
|
||||
@raise_connection_error
|
||||
def nack(self, message: Message[Any]) -> None:
|
||||
def nack(self, message: Message):
|
||||
task = message.options.pop("task", None)
|
||||
self.query_set.filter(
|
||||
message_id=message.message_id,
|
||||
@@ -314,7 +303,7 @@ class _PostgresConsumer(Consumer):
|
||||
self.in_processing.remove(message.message_id)
|
||||
|
||||
@raise_connection_error
|
||||
def requeue(self, messages: Iterable[Message[Any]]) -> None:
|
||||
def requeue(self, messages: Iterable[Message]):
|
||||
self.query_set.filter(
|
||||
message_id__in=[message.message_id for message in messages],
|
||||
).update(
|
||||
@@ -337,9 +326,7 @@ class _PostgresConsumer(Consumer):
|
||||
)
|
||||
.values_list("message_id", flat=True)
|
||||
)
|
||||
return [
|
||||
Notify(pid=0, channel=self.postgres_channel, payload=str(item)) for item in notifies
|
||||
]
|
||||
return [Notify(pid=0, channel=self.postgres_channel, payload=item) for item in notifies]
|
||||
|
||||
def _poll_for_notify(self) -> list[Notify]:
|
||||
with self.listen_connection.cursor() as cursor:
|
||||
@@ -352,12 +339,11 @@ class _PostgresConsumer(Consumer):
|
||||
return notifies
|
||||
|
||||
def _get_message_lock_id(self, message_id: str) -> int:
|
||||
lock_id = _cast_lock_id(
|
||||
return _cast_lock_id(
|
||||
f"{channel_name(self.queue_name, ChannelIdentifier.LOCK)}.{message_id}"
|
||||
) # type: ignore[no-untyped-call]
|
||||
return cast(int, lock_id)
|
||||
)
|
||||
|
||||
def _consume_one(self, message: Message[Any]) -> bool:
|
||||
def _consume_one(self, message: Message) -> bool:
|
||||
if message.message_id in self.in_processing:
|
||||
self.logger.debug("Message already consumed by self", message_id=message.message_id)
|
||||
return False
|
||||
@@ -402,7 +388,7 @@ class _PostgresConsumer(Consumer):
|
||||
# If we have too many messages already processing, wait and don't consume a message
|
||||
# straight away, other workers will be faster.
|
||||
# After waiting consume a message regardless.
|
||||
self.misses, backoff_ms = compute_backoff(self.misses, max_backoff=1000) # type: ignore[no-untyped-call]
|
||||
self.misses, backoff_ms = compute_backoff(self.misses, max_backoff=1000)
|
||||
self.logger.debug(
|
||||
"Too many messages in processing, Sleeping",
|
||||
processing=processing,
|
||||
@@ -426,11 +412,11 @@ class _PostgresConsumer(Consumer):
|
||||
)
|
||||
if task is None:
|
||||
continue
|
||||
message = Message.decode(cast(bytes, task.message))
|
||||
message = Message.decode(task.message)
|
||||
message.options["task"] = task
|
||||
if self._consume_one(message):
|
||||
self.in_processing.add(message.message_id)
|
||||
return MessageProxy(message) # type: ignore[no-untyped-call]
|
||||
return MessageProxy(message)
|
||||
else:
|
||||
self.logger.debug(
|
||||
"Message already consumed. Skipping.", message_id=message.message_id
|
||||
@@ -443,7 +429,7 @@ class _PostgresConsumer(Consumer):
|
||||
self.misses = 0
|
||||
return None
|
||||
|
||||
def _purge_locks(self) -> None:
|
||||
def _purge_locks(self):
|
||||
if timezone.now() - self.lock_purge_last_run < self.lock_purge_interval:
|
||||
return
|
||||
while True:
|
||||
@@ -459,19 +445,19 @@ class _PostgresConsumer(Consumer):
|
||||
self.unlock_queue.task_done()
|
||||
self.lock_purge_last_run = timezone.now()
|
||||
|
||||
def _auto_purge(self) -> None:
|
||||
def _auto_purge(self):
|
||||
if timezone.now() - self.task_purge_last_run < self.task_purge_interval:
|
||||
return
|
||||
self.logger.debug("Running garbage collector")
|
||||
count = self.query_set.filter(
|
||||
state__in=(TaskState.DONE, TaskState.REJECTED),
|
||||
mtime__lte=timezone.now() - timedelta(seconds=Conf().task_expiration),
|
||||
mtime__lte=timezone.now() - timezone.timedelta(seconds=Conf().task_purge_interval),
|
||||
result_expiry__lte=timezone.now(),
|
||||
).delete()
|
||||
self.logger.info("Purged messages in all queues", count=count)
|
||||
self.task_purge_last_run = timezone.now()
|
||||
|
||||
def _scheduler(self) -> None:
|
||||
def _scheduler(self):
|
||||
if not self.scheduler:
|
||||
return
|
||||
if timezone.now() - self.scheduler_last_run < self.scheduler_interval:
|
||||
@@ -480,7 +466,7 @@ class _PostgresConsumer(Consumer):
|
||||
self.schedule_last_run = timezone.now()
|
||||
|
||||
@raise_connection_error
|
||||
def close(self) -> None:
|
||||
def close(self):
|
||||
try:
|
||||
self._purge_locks()
|
||||
finally:
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
from typing import Any, cast
|
||||
from typing import Any
|
||||
|
||||
from django.conf import settings
|
||||
from django.core.exceptions import ImproperlyConfigured
|
||||
|
||||
|
||||
class Conf:
|
||||
def __init__(self) -> None:
|
||||
def __init__(self):
|
||||
try:
|
||||
_ = settings.DRAMATIQ
|
||||
except AttributeError as exc:
|
||||
@@ -19,75 +19,68 @@ class Conf:
|
||||
|
||||
@property
|
||||
def encoder_class(self) -> str:
|
||||
return cast(str, self.conf.get("encoder_class", "dramatiq.encoder.PickleEncoder"))
|
||||
return self.conf.get("encoder_class", "dramatiq.encoder.PickleEncoder")
|
||||
|
||||
@property
|
||||
def broker_class(self) -> str:
|
||||
return cast(
|
||||
str, self.conf.get("broker_class", "django_dramatiq_postgres.broker.PostgresBroker")
|
||||
)
|
||||
return self.conf.get("broker_class", "django_dramatiq_postgres.broker.PostgresBroker")
|
||||
|
||||
@property
|
||||
def broker_args(self) -> tuple[Any]:
|
||||
return cast(tuple[Any], self.conf.get("broker_args", tuple()))
|
||||
return self.conf.get("broker_args", ())
|
||||
|
||||
@property
|
||||
def broker_kwargs(self) -> dict[str, Any]:
|
||||
return cast(dict[str, Any], self.conf.get("broker_kwargs", {}))
|
||||
return self.conf.get("broker_kwargs", {})
|
||||
|
||||
@property
|
||||
def middlewares(self) -> tuple[tuple[str, dict[str, Any]]]:
|
||||
return cast(
|
||||
tuple[tuple[str, dict[str, Any]]],
|
||||
self.conf.get(
|
||||
"middlewares",
|
||||
(
|
||||
("django_dramatiq_postgres.middleware.DbConnectionMiddleware", {}),
|
||||
("dramatiq.middleware.age_limit.AgeLimit", {}),
|
||||
("dramatiq.middleware.time_limit.TimeLimit", {}),
|
||||
("dramatiq.middleware.shutdown.ShutdownNotifications", {}),
|
||||
("dramatiq.middleware.callbacks.Callbacks", {}),
|
||||
("dramatiq.middleware.pipelines.Pipelines", {}),
|
||||
("dramatiq.middleware.retries.Retries", {}),
|
||||
),
|
||||
return self.conf.get(
|
||||
"middlewares",
|
||||
(
|
||||
("django_dramatiq_postgres.middleware.DbConnectionMiddleware", {}),
|
||||
("dramatiq.middleware.age_limit.AgeLimit", {}),
|
||||
("dramatiq.middleware.time_limit.TimeLimit", {}),
|
||||
("dramatiq.middleware.shutdown.ShutdownNotifications", {}),
|
||||
("dramatiq.middleware.callbacks.Callbacks", {}),
|
||||
("dramatiq.middleware.pipelines.Pipelines", {}),
|
||||
("dramatiq.middleware.retries.Retries", {}),
|
||||
),
|
||||
)
|
||||
|
||||
@property
|
||||
def channel_prefix(self) -> str:
|
||||
return cast(str, self.conf.get("channel_prefix", "dramatiq"))
|
||||
return self.conf.get("channel_prefix", "dramatiq")
|
||||
|
||||
@property
|
||||
def task_model(self) -> str:
|
||||
return cast(str, self.conf["task_model"])
|
||||
return self.conf["task_model"]
|
||||
|
||||
@property
|
||||
def lock_purge_interval(self) -> int:
|
||||
return cast(int, self.conf.get("lock_purge_interval", 60))
|
||||
return self.conf.get("lock_purge_interval", 60)
|
||||
|
||||
@property
|
||||
def task_purge_interval(self) -> int:
|
||||
# 24 hours
|
||||
return cast(int, self.conf.get("task_purge_interval", 24 * 60 * 60))
|
||||
return self.conf.get("task_purge_interval", 24 * 60 * 60)
|
||||
|
||||
@property
|
||||
def task_expiration(self) -> int:
|
||||
# 30 days
|
||||
return cast(int, self.conf.get("task_expiration", 60 * 60 * 24 * 30))
|
||||
return self.conf.get("task_expiration", 60 * 60 * 24 * 30)
|
||||
|
||||
@property
|
||||
def result_backend(self) -> str:
|
||||
return cast(
|
||||
str, self.conf.get("result_backend", "django_dramatiq_postgres.results.PostgresBackend")
|
||||
)
|
||||
return self.conf.get("result_backend", "django_dramatiq_postgres.results.PostgresBackend")
|
||||
|
||||
@property
|
||||
def result_backend_args(self) -> tuple[Any]:
|
||||
return cast(tuple[Any], self.conf.get("result_backend_args", ()))
|
||||
return self.conf.get("result_backend_args", ())
|
||||
|
||||
@property
|
||||
def result_backend_kwargs(self) -> dict[str, Any]:
|
||||
return cast(dict[str, Any], self.conf.get("result_backend_kwargs", {}))
|
||||
return self.conf.get("result_backend_kwargs", {})
|
||||
|
||||
@property
|
||||
def autodiscovery(self) -> dict[str, Any]:
|
||||
@@ -120,18 +113,16 @@ class Conf:
|
||||
|
||||
@property
|
||||
def scheduler_class(self) -> str:
|
||||
return cast(
|
||||
str, self.conf.get("scheduler_class", "django_dramatiq_postgres.scheduler.Scheduler")
|
||||
)
|
||||
return self.conf.get("scheduler_class", "django_dramatiq_postgres.scheduler.Scheduler")
|
||||
|
||||
@property
|
||||
def schedule_model(self) -> str | None:
|
||||
return cast(str | None, self.conf.get("schedule_model"))
|
||||
return self.conf.get("schedule_model")
|
||||
|
||||
@property
|
||||
def scheduler_interval(self) -> int:
|
||||
return cast(int, self.conf.get("scheduler_interval", 60))
|
||||
return self.conf.get("scheduler_interval", 60)
|
||||
|
||||
@property
|
||||
def test(self) -> bool:
|
||||
return cast(bool, self.conf.get("test", False))
|
||||
return self.conf.get("test", False)
|
||||
|
||||
@@ -1,11 +1,12 @@
|
||||
import os
|
||||
from importlib import import_module
|
||||
from signal import pause
|
||||
|
||||
from django.utils.module_loading import import_module
|
||||
|
||||
from django_dramatiq_postgres.conf import Conf
|
||||
|
||||
|
||||
def worker_metrics() -> None:
|
||||
def worker_metrics():
|
||||
import_module(Conf().autodiscovery["setup_module"])
|
||||
|
||||
from django_dramatiq_postgres.middleware import MetricsMiddleware
|
||||
|
||||
@@ -1,12 +1,11 @@
|
||||
import sys
|
||||
from argparse import Namespace
|
||||
from typing import Any
|
||||
|
||||
from django.apps.registry import apps
|
||||
from django.core.management.base import BaseCommand, CommandParser
|
||||
from django.core.management.base import BaseCommand
|
||||
from django.db import connections
|
||||
from django.utils.module_loading import import_string, module_has_submodule
|
||||
from dramatiq.cli import main
|
||||
from dramatiq.__main__ import main
|
||||
|
||||
from django_dramatiq_postgres.conf import Conf
|
||||
|
||||
@@ -14,7 +13,7 @@ from django_dramatiq_postgres.conf import Conf
|
||||
class Command(BaseCommand):
|
||||
"""Run worker"""
|
||||
|
||||
def add_arguments(self, parser: CommandParser) -> None:
|
||||
def add_arguments(self, parser):
|
||||
parser.add_argument(
|
||||
"--pid-file",
|
||||
action="store",
|
||||
@@ -32,11 +31,11 @@ class Command(BaseCommand):
|
||||
|
||||
def handle(
|
||||
self,
|
||||
pid_file: str,
|
||||
watch: bool,
|
||||
verbosity: int,
|
||||
**options: Any,
|
||||
) -> None:
|
||||
pid_file,
|
||||
watch,
|
||||
verbosity,
|
||||
**options,
|
||||
):
|
||||
worker = Conf().worker
|
||||
setup, modules = self._discover_tasks_modules()
|
||||
args = Namespace(
|
||||
@@ -71,7 +70,7 @@ class Command(BaseCommand):
|
||||
args.verbose = verbosity - 1
|
||||
|
||||
connections.close_all()
|
||||
sys.exit(main(args)) # type: ignore[no-untyped-call]
|
||||
sys.exit(main(args))
|
||||
|
||||
def _discover_tasks_modules(self) -> tuple[str, list[str]]:
|
||||
# Does not support a tasks directory
|
||||
|
||||
@@ -1,13 +1,15 @@
|
||||
import contextvars
|
||||
import os
|
||||
import socket
|
||||
from collections.abc import Callable
|
||||
from http.server import BaseHTTPRequestHandler
|
||||
from http.server import HTTPServer as BaseHTTPServer
|
||||
from ipaddress import IPv6Address, ip_address
|
||||
from typing import Any, cast
|
||||
from typing import Any
|
||||
|
||||
from django.db import DatabaseError, close_old_connections, connections
|
||||
from django.db import (
|
||||
close_old_connections,
|
||||
connections,
|
||||
)
|
||||
from dramatiq.actor import Actor
|
||||
from dramatiq.broker import Broker
|
||||
from dramatiq.common import current_millis
|
||||
@@ -20,11 +22,10 @@ from django_dramatiq_postgres.models import TaskBase
|
||||
|
||||
|
||||
class HTTPServer(BaseHTTPServer):
|
||||
def server_bind(self) -> None:
|
||||
def server_bind(self):
|
||||
self.socket.close()
|
||||
|
||||
host, port = self.server_address[:2]
|
||||
host = cast(str, host)
|
||||
if host == "0.0.0.0" and socket.has_dualstack_ipv6(): # nosec
|
||||
host = "::" # nosec
|
||||
|
||||
@@ -51,7 +52,7 @@ class HTTPServer(BaseHTTPServer):
|
||||
|
||||
|
||||
class DbConnectionMiddleware(Middleware):
|
||||
def _close_old_connections(self, *args: Any, **kwargs: Any) -> None:
|
||||
def _close_old_connections(self, *args, **kwargs):
|
||||
if Conf().test:
|
||||
return
|
||||
close_old_connections()
|
||||
@@ -59,7 +60,7 @@ class DbConnectionMiddleware(Middleware):
|
||||
before_process_message = _close_old_connections
|
||||
after_process_message = _close_old_connections
|
||||
|
||||
def _close_connections(self, *args: Any, **kwargs: Any) -> None:
|
||||
def _close_connections(self, *args, **kwargs):
|
||||
connections.close_all()
|
||||
|
||||
before_consumer_thread_shutdown = _close_connections
|
||||
@@ -68,7 +69,7 @@ class DbConnectionMiddleware(Middleware):
|
||||
|
||||
|
||||
class FullyQualifiedActorName(Middleware):
|
||||
def before_declare_actor(self, broker: Broker, actor: Actor[Any, Any]) -> None:
|
||||
def before_declare_actor(self, broker: Broker, actor: Actor):
|
||||
actor.actor_name = f"{actor.fn.__module__}.{actor.fn.__name__}"
|
||||
|
||||
|
||||
@@ -79,7 +80,7 @@ class CurrentTaskNotFound(Exception):
|
||||
|
||||
|
||||
class CurrentTask(Middleware):
|
||||
def __init__(self) -> None:
|
||||
def __init__(self):
|
||||
self.logger = get_logger(__name__, type(self))
|
||||
|
||||
# This is a list of tasks, so that in tests, when a task calls another task, this acts as a pile
|
||||
@@ -95,7 +96,7 @@ class CurrentTask(Middleware):
|
||||
raise CurrentTaskNotFound()
|
||||
return task[-1]
|
||||
|
||||
def before_process_message(self, broker: Broker, message: Message[Any]) -> None:
|
||||
def before_process_message(self, broker: Broker, message: Message):
|
||||
tasks = self._TASKS.get()
|
||||
if tasks is None:
|
||||
tasks = []
|
||||
@@ -105,11 +106,11 @@ class CurrentTask(Middleware):
|
||||
def after_process_message(
|
||||
self,
|
||||
broker: Broker,
|
||||
message: Message[Any],
|
||||
message: Message,
|
||||
*,
|
||||
result: Any | None = None,
|
||||
exception: Exception | None = None,
|
||||
) -> None:
|
||||
):
|
||||
tasks: list[TaskBase] | None = self._TASKS.get()
|
||||
if tasks is None or len(tasks) == 0:
|
||||
return
|
||||
@@ -128,19 +129,13 @@ class CurrentTask(Middleware):
|
||||
fields_to_update = [
|
||||
f.name
|
||||
for f in task._meta.get_fields()
|
||||
if f.name not in fields_to_exclude
|
||||
and f.concrete
|
||||
and not f.auto_created
|
||||
and not f.many_to_many
|
||||
if f.name not in fields_to_exclude and not f.auto_created and f.column
|
||||
]
|
||||
if fields_to_update:
|
||||
try:
|
||||
task.save(update_fields=fields_to_update)
|
||||
except DatabaseError:
|
||||
pass
|
||||
task.save(update_fields=fields_to_update)
|
||||
self._TASKS.set(tasks[:-1])
|
||||
|
||||
def after_skip_message(self, broker: Broker, message: Message[Any]) -> None:
|
||||
def after_skip_message(self, broker: Broker, message: Message):
|
||||
self.after_process_message(broker, message)
|
||||
|
||||
|
||||
@@ -149,21 +144,21 @@ class MetricsMiddleware(Middleware):
|
||||
self,
|
||||
prefix: str,
|
||||
labels: list[str] | None = None,
|
||||
) -> None:
|
||||
):
|
||||
super().__init__()
|
||||
self.prefix = prefix
|
||||
self.labels: list[str] = labels if labels is not None else ["queue_name", "actor_name"]
|
||||
|
||||
self.delayed_messages: set[str] = set()
|
||||
self.message_start_times: dict[str, int] = {}
|
||||
self.delayed_messages = set()
|
||||
self.message_start_times = {}
|
||||
|
||||
@property
|
||||
def forks(self) -> list[Callable[[], None]]:
|
||||
def forks(self):
|
||||
from django_dramatiq_postgres.forks import worker_metrics
|
||||
|
||||
return [worker_metrics]
|
||||
|
||||
def before_worker_boot(self, broker: Broker, worker: Any) -> None:
|
||||
def before_worker_boot(self, broker: Broker, worker):
|
||||
if Conf().test:
|
||||
return
|
||||
|
||||
@@ -229,47 +224,47 @@ class MetricsMiddleware(Middleware):
|
||||
),
|
||||
)
|
||||
|
||||
def after_worker_shutdown(self, broker: Broker, worker: Any) -> None:
|
||||
def after_worker_shutdown(self, broker: Broker, worker):
|
||||
from prometheus_client import multiprocess
|
||||
|
||||
# TODO: worker_id
|
||||
multiprocess.mark_process_dead(os.getpid()) # type: ignore[no-untyped-call]
|
||||
multiprocess.mark_process_dead(os.getpid())
|
||||
|
||||
def _make_labels(self, message: Message[Any]) -> list[str]:
|
||||
def _make_labels(self, message: Message) -> list[str]:
|
||||
return [message.queue_name, message.actor_name]
|
||||
|
||||
def after_nack(self, broker: Broker, message: Message[Any]) -> None:
|
||||
def after_nack(self, broker: Broker, message: Message):
|
||||
self.total_rejected_messages.labels(*self._make_labels(message)).inc()
|
||||
|
||||
def after_enqueue(self, broker: Broker, message: Message[Any], delay: int) -> None:
|
||||
def after_enqueue(self, broker: Broker, message: Message, delay: int):
|
||||
if "retries" in message.options:
|
||||
self.total_retried_messages.labels(*self._make_labels(message)).inc()
|
||||
|
||||
def before_delay_message(self, broker: Broker, message: Message[Any]) -> None:
|
||||
def before_delay_message(self, broker: Broker, message: Message):
|
||||
self.delayed_messages.add(message.message_id)
|
||||
self.in_progress_delayed_messages.labels(*self._make_labels(message)).inc()
|
||||
|
||||
def before_process_message(self, broker: Broker, message: Message[Any]) -> None:
|
||||
def before_process_message(self, broker: Broker, message: Message):
|
||||
labels = self._make_labels(message)
|
||||
if message.message_id in self.delayed_messages:
|
||||
self.delayed_messages.remove(message.message_id)
|
||||
self.in_progress_delayed_messages.labels(*labels).dec()
|
||||
|
||||
self.in_progress_messages.labels(*labels).inc()
|
||||
self.message_start_times[message.message_id] = current_millis() # type: ignore[no-untyped-call]
|
||||
self.message_start_times[message.message_id] = current_millis()
|
||||
|
||||
def after_process_message(
|
||||
self,
|
||||
broker: Broker,
|
||||
message: Message[Any],
|
||||
message: Message,
|
||||
*,
|
||||
result: Any | None = None,
|
||||
exception: Exception | None = None,
|
||||
) -> None:
|
||||
):
|
||||
labels = self._make_labels(message)
|
||||
|
||||
message_start_time = self.message_start_times.pop(message.message_id, current_millis()) # type: ignore[no-untyped-call]
|
||||
message_duration = current_millis() - message_start_time # type: ignore[no-untyped-call]
|
||||
message_start_time = self.message_start_times.pop(message.message_id, current_millis())
|
||||
message_duration = current_millis() - message_start_time
|
||||
self.messages_durations.labels(*labels).observe(message_duration)
|
||||
|
||||
self.in_progress_messages.labels(*labels).dec()
|
||||
@@ -280,7 +275,7 @@ class MetricsMiddleware(Middleware):
|
||||
after_skip_message = after_process_message
|
||||
|
||||
@staticmethod
|
||||
def run(addr: str, port: int) -> None:
|
||||
def run(addr: str, port: int):
|
||||
try:
|
||||
server = HTTPServer((addr, port), _MetricsHandler)
|
||||
server.serve_forever()
|
||||
@@ -291,7 +286,7 @@ class MetricsMiddleware(Middleware):
|
||||
|
||||
|
||||
class _MetricsHandler(BaseHTTPRequestHandler):
|
||||
def do_GET(self) -> None:
|
||||
def do_GET(self):
|
||||
from prometheus_client import (
|
||||
CONTENT_TYPE_LATEST,
|
||||
CollectorRegistry,
|
||||
@@ -300,13 +295,13 @@ class _MetricsHandler(BaseHTTPRequestHandler):
|
||||
)
|
||||
|
||||
registry = CollectorRegistry()
|
||||
multiprocess.MultiProcessCollector(registry) # type: ignore[no-untyped-call]
|
||||
multiprocess.MultiProcessCollector(registry)
|
||||
output = generate_latest(registry)
|
||||
self.send_response(200)
|
||||
self.send_header("Content-Type", CONTENT_TYPE_LATEST)
|
||||
self.end_headers()
|
||||
self.wfile.write(output)
|
||||
|
||||
def log_message(self, format: str, *args: Any) -> None:
|
||||
def log_message(self, format, *args):
|
||||
logger = get_logger(__name__, type(self))
|
||||
logger.debug(format, *args)
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
import pickle # nosec
|
||||
from datetime import datetime, timedelta
|
||||
from enum import StrEnum, auto
|
||||
from typing import Any
|
||||
from uuid import uuid4
|
||||
|
||||
import pgtrigger
|
||||
@@ -10,7 +8,7 @@ from django.contrib.contenttypes.fields import GenericForeignKey
|
||||
from django.contrib.contenttypes.models import ContentType
|
||||
from django.core.exceptions import ValidationError
|
||||
from django.db import models
|
||||
from django.utils.timezone import now
|
||||
from django.utils.timezone import datetime, now, timedelta
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
from dramatiq.actor import Actor
|
||||
from dramatiq.broker import Broker, get_broker
|
||||
@@ -73,11 +71,11 @@ class TaskBase(models.Model):
|
||||
),
|
||||
)
|
||||
|
||||
def __str__(self) -> str:
|
||||
def __str__(self):
|
||||
return str(self.message_id)
|
||||
|
||||
|
||||
def validate_crontab(value: str) -> None:
|
||||
def validate_crontab(value):
|
||||
try:
|
||||
Cron(value)
|
||||
except ValueError as exc:
|
||||
@@ -121,14 +119,14 @@ class ScheduleBase(models.Model):
|
||||
),
|
||||
)
|
||||
|
||||
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self._original_crontab = self.crontab
|
||||
|
||||
def __str__(self) -> str:
|
||||
def __str__(self):
|
||||
return f"Schedule {self.actor_name} ({self.id})"
|
||||
|
||||
def save(self, *args: Any, **kwargs: Any) -> None:
|
||||
def save(self, *args, **kwargs):
|
||||
if self.crontab != self._original_crontab:
|
||||
self.next_run = self.compute_next_run(now())
|
||||
|
||||
@@ -137,16 +135,16 @@ class ScheduleBase(models.Model):
|
||||
self._original_crontab = self.crontab
|
||||
|
||||
@classmethod
|
||||
def dispatch_by_actor(cls, actor: Actor[Any, Any]) -> None:
|
||||
def dispatch_by_actor(cls, actor: Actor):
|
||||
"""Dispatch a schedule by looking up its actor.
|
||||
Only available for schedules without custom arguments."""
|
||||
schedule = cls._default_manager.filter(actor_name=actor.actor_name, paused=False).first()
|
||||
schedule = cls.objects.filter(actor_name=actor.actor_name, paused=False).first()
|
||||
if schedule:
|
||||
schedule.send()
|
||||
|
||||
def send(self, broker: Broker | None = None) -> Message[Any]:
|
||||
def send(self, broker: Broker | None = None) -> Message:
|
||||
broker = broker or get_broker()
|
||||
actor: Actor[Any, Any] = broker.get_actor(self.actor_name) # type: ignore[no-untyped-call]
|
||||
actor: Actor = broker.get_actor(self.actor_name)
|
||||
return actor.send_with_options(
|
||||
args=pickle.loads(self.args), # nosec
|
||||
kwargs=pickle.loads(self.kwargs), # nosec
|
||||
@@ -155,7 +153,7 @@ class ScheduleBase(models.Model):
|
||||
)
|
||||
|
||||
def compute_next_run(self, next_run: datetime | None = None) -> datetime:
|
||||
next_run = self.next_run if not next_run else next_run
|
||||
next_run: datetime = self.next_run if not next_run else next_run
|
||||
while True:
|
||||
next_run = Cron(self.crontab).schedule(next_run).next()
|
||||
if next_run > now():
|
||||
|
||||
@@ -1,6 +1,3 @@
|
||||
from datetime import timedelta
|
||||
from typing import Any, cast
|
||||
|
||||
from django.db import DEFAULT_DB_ALIAS
|
||||
from django.db.models import QuerySet
|
||||
from django.utils import timezone
|
||||
@@ -14,27 +11,26 @@ from django_dramatiq_postgres.models import TaskBase
|
||||
|
||||
|
||||
class PostgresBackend(ResultBackend):
|
||||
def __init__(self, *args: Any, db_alias: str = DEFAULT_DB_ALIAS, **kwargs: Any) -> None:
|
||||
def __init__(self, *args, db_alias: str = DEFAULT_DB_ALIAS, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.db_alias = db_alias
|
||||
|
||||
@cached_property
|
||||
def model(self) -> type[TaskBase]:
|
||||
model: type[TaskBase] = import_string(Conf().task_model)
|
||||
return model
|
||||
return import_string(Conf().task_model)
|
||||
|
||||
@property
|
||||
def query_set(self) -> QuerySet[TaskBase]:
|
||||
return self.model._default_manager.using(self.db_alias).defer("message")
|
||||
def query_set(self) -> QuerySet:
|
||||
return self.model.objects.using(self.db_alias).defer("message")
|
||||
|
||||
def build_message_key(self, message: Message[Result]) -> str:
|
||||
def build_message_key(self, message: Message) -> str:
|
||||
return str(message.message_id)
|
||||
|
||||
def _get(self, message_key: str) -> MResult:
|
||||
message = self.query_set.filter(message_id=message_key).first()
|
||||
if message is None:
|
||||
return Missing
|
||||
data = cast(bytes | None, message.result)
|
||||
data = message.result
|
||||
if data is None:
|
||||
return Missing
|
||||
return self.encoder.decode(data)
|
||||
@@ -43,5 +39,5 @@ class PostgresBackend(ResultBackend):
|
||||
self.query_set.filter(message_id=message_key).update(
|
||||
mtime=timezone.now(),
|
||||
result=self.encoder.encode(result),
|
||||
result_expiry=timezone.now() + timedelta(milliseconds=ttl),
|
||||
result_expiry=timezone.now() + timezone.timedelta(milliseconds=ttl),
|
||||
)
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
from typing import Any, cast
|
||||
|
||||
import pglock
|
||||
from django.db import router, transaction
|
||||
from django.db.models import QuerySet
|
||||
@@ -16,21 +14,19 @@ from django_dramatiq_postgres.models import ScheduleBase
|
||||
class Scheduler:
|
||||
broker: Broker
|
||||
|
||||
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.logger = get_logger(__name__, type(self))
|
||||
|
||||
@cached_property
|
||||
def model(self) -> type[ScheduleBase]:
|
||||
schedule_model = cast(str, Conf().schedule_model)
|
||||
model: type[ScheduleBase] = import_string(schedule_model)
|
||||
return model
|
||||
return import_string(Conf().schedule_model)
|
||||
|
||||
@property
|
||||
def query_set(self) -> QuerySet[ScheduleBase]:
|
||||
return self.model._default_manager.filter(paused=False)
|
||||
def query_set(self) -> QuerySet:
|
||||
return self.model.objects.filter(paused=False)
|
||||
|
||||
def process_schedule(self, schedule: ScheduleBase) -> None:
|
||||
def process_schedule(self, schedule: ScheduleBase):
|
||||
schedule.next_run = schedule.compute_next_run()
|
||||
schedule.send(self.broker)
|
||||
schedule.save()
|
||||
@@ -52,11 +48,10 @@ class Scheduler:
|
||||
count += 1
|
||||
return count
|
||||
|
||||
def run(self) -> int:
|
||||
def run(self):
|
||||
with self._lock() as lock_acquired:
|
||||
if not lock_acquired:
|
||||
self.logger.debug("Could not acquire lock, skipping scheduling")
|
||||
return -1
|
||||
return
|
||||
count = self._run()
|
||||
self.logger.info("Sent scheduled tasks", count=count)
|
||||
return count
|
||||
|
||||
@@ -27,7 +27,6 @@ classifiers = [
|
||||
"Programming Language :: Python :: 3.13",
|
||||
"Programming Language :: Python",
|
||||
"Topic :: Software Development :: Libraries :: Python Modules",
|
||||
"Typing :: Typed",
|
||||
]
|
||||
|
||||
dependencies = [
|
||||
@@ -36,7 +35,7 @@ dependencies = [
|
||||
"django-pgtrigger >=4,<5",
|
||||
"dramatiq[watch] >=1.17,<1.18",
|
||||
"tenacity >=9,<10",
|
||||
"structlog >=25,<26",
|
||||
"structlog >=25,<26"
|
||||
]
|
||||
|
||||
[project.urls]
|
||||
|
||||
@@ -1,12 +0,0 @@
|
||||
# django-postgres-cache
|
||||
|
||||
### Use in migrations
|
||||
|
||||
Migrations that use the cache with this installed need to depend on the migration to create the cache entry table:
|
||||
|
||||
```python
|
||||
dependencies = [
|
||||
# ...other requirements
|
||||
("django_postgres_cache", "0001_initial"),
|
||||
]
|
||||
```
|
||||
@@ -1,6 +0,0 @@
|
||||
from django.apps import AppConfig
|
||||
|
||||
|
||||
class DjangoPostgresCache(AppConfig):
|
||||
name = "django_postgres_cache"
|
||||
verbose_name = "Django Postgres cache"
|
||||
@@ -1,51 +0,0 @@
|
||||
from typing import Any
|
||||
|
||||
from django.core.cache.backends.db import DatabaseCache as BaseDatabaseCache
|
||||
from django.db.utils import ProgrammingError
|
||||
from django.utils.module_loading import import_string
|
||||
from django.utils.timezone import now
|
||||
|
||||
from django_postgres_cache.models import CacheEntry
|
||||
|
||||
|
||||
class DatabaseCache(BaseDatabaseCache):
|
||||
|
||||
def __init__(self, table: str, params: dict[str, Any]) -> None:
|
||||
super().__init__(table, params)
|
||||
self.reverse_key_func = import_string(params["REVERSE_KEY_FUNCTION"])
|
||||
self._table = CacheEntry._meta.db_table
|
||||
self.cache_model_class = CacheEntry
|
||||
|
||||
def _cull(self, *args: Any, **kwargs: Any) -> None:
|
||||
"""Stubbed out cull method as we cull in a background task"""
|
||||
pass
|
||||
|
||||
def get(self, key: str, default: Any | None = None, version: int | None = None) -> Any:
|
||||
try:
|
||||
return super().get(key, default=default, version=version)
|
||||
except ProgrammingError:
|
||||
return default
|
||||
|
||||
def keys(self, keys_pattern: str, version: int | None = None) -> list[str]:
|
||||
try:
|
||||
return self._keys(keys_pattern, version=version)
|
||||
except ProgrammingError:
|
||||
return []
|
||||
|
||||
def _keys(self, keys_pattern: str, version: int | None = None) -> list[str]:
|
||||
keys_pattern = self.make_key(keys_pattern.replace("*", ".*"), version=version)
|
||||
|
||||
return [
|
||||
self.reverse_key_func(key)
|
||||
for key in CacheEntry.objects.filter(cache_key__regex=keys_pattern).values_list(
|
||||
"cache_key", flat=True
|
||||
)
|
||||
]
|
||||
|
||||
def ttl(self, key: str, version: int | None = None) -> int | None:
|
||||
"""Get TTL left for a given key and version"""
|
||||
key = self.make_and_validate_key(key, version=version)
|
||||
entry = CacheEntry.objects.filter(cache_key=key).first()
|
||||
if not entry:
|
||||
return None
|
||||
return int((entry.expires - now()).total_seconds())
|
||||
@@ -1,24 +0,0 @@
|
||||
# Generated by Django 5.1.12 on 2025-09-06 16:16
|
||||
|
||||
from django.db import migrations, models
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
|
||||
initial = True
|
||||
|
||||
dependencies = []
|
||||
|
||||
operations = [
|
||||
migrations.CreateModel(
|
||||
name="CacheEntry",
|
||||
fields=[
|
||||
("cache_key", models.TextField(primary_key=True, serialize=False)),
|
||||
("value", models.TextField()),
|
||||
("expires", models.DateTimeField(db_index=True)),
|
||||
],
|
||||
options={
|
||||
"default_permissions": [],
|
||||
},
|
||||
),
|
||||
]
|
||||
@@ -1,14 +0,0 @@
|
||||
from django.db import models
|
||||
|
||||
|
||||
class CacheEntry(models.Model):
|
||||
|
||||
cache_key = models.TextField(primary_key=True)
|
||||
value = models.TextField()
|
||||
expires = models.DateTimeField(db_index=True)
|
||||
|
||||
class Meta:
|
||||
default_permissions = []
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"Cache entry '{self.cache_key}'"
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user