From 29f20a48298c0e4df948678e2117fb38c79453a1 Mon Sep 17 00:00:00 2001 From: Marc 'risson' Schmitt Date: Mon, 28 Jul 2025 17:00:09 +0200 Subject: [PATCH] *: replace Celery with Dramatiq (#13492) --- .gitignore | 5 - Dockerfile | 2 + Makefile | 7 +- authentik/admin/api/version.py | 2 +- authentik/admin/api/workers.py | 57 -- authentik/admin/apps.py | 15 + authentik/admin/settings.py | 15 - authentik/admin/signals.py | 35 - authentik/admin/tasks.py | 18 +- authentik/admin/tests/test_api.py | 7 - authentik/admin/tests/test_tasks.py | 11 +- authentik/admin/urls.py | 2 - authentik/blueprints/api.py | 6 +- authentik/blueprints/apps.py | 59 +- authentik/blueprints/models.py | 8 + authentik/blueprints/settings.py | 18 - authentik/blueprints/tasks.py | 2 + authentik/blueprints/tests/test_v1_tasks.py | 6 +- authentik/blueprints/v1/importer.py | 4 +- .../blueprints/v1/meta/apply_blueprint.py | 2 +- authentik/blueprints/v1/tasks.py | 94 +- authentik/brands/apps.py | 1 + authentik/core/apps.py | 26 +- .../management/commands/bootstrap_tasks.py | 21 - authentik/core/management/commands/worker.py | 47 - authentik/core/tasks.py | 28 +- authentik/core/tests/test_tasks.py | 4 +- authentik/crypto/apps.py | 13 + authentik/crypto/settings.py | 13 - authentik/crypto/tasks.py | 17 +- authentik/crypto/tests.py | 2 +- authentik/enterprise/apps.py | 13 + .../policies/unique_password/apps.py | 20 + .../policies/unique_password/settings.py | 20 - .../policies/unique_password/tasks.py | 34 +- .../unique_password/tests/test_tasks.py | 10 +- .../google_workspace/api/providers.py | 2 +- .../providers/google_workspace/models.py | 7 + .../providers/google_workspace/settings.py | 13 - .../providers/google_workspace/signals.py | 10 +- .../providers/google_workspace/tasks.py | 43 +- .../google_workspace/tests/test_groups.py | 2 +- .../google_workspace/tests/test_users.py | 2 +- .../microsoft_entra/api/providers.py | 2 +- .../providers/microsoft_entra/models.py | 7 + .../providers/microsoft_entra/settings.py | 13 - .../providers/microsoft_entra/signals.py | 10 +- .../providers/microsoft_entra/tasks.py | 41 +- .../microsoft_entra/tests/test_groups.py | 18 +- .../microsoft_entra/tests/test_users.py | 2 +- authentik/enterprise/providers/ssf/models.py | 3 +- authentik/enterprise/providers/ssf/signals.py | 10 +- authentik/enterprise/providers/ssf/tasks.py | 84 +- .../enterprise/providers/ssf/views/stream.py | 4 +- authentik/enterprise/settings.py | 12 - authentik/enterprise/signals.py | 3 +- authentik/enterprise/tasks.py | 13 +- authentik/events/api/tasks.py | 104 -- authentik/events/apps.py | 54 +- .../0011_alter_systemtask_options.py | 22 + authentik/events/models.py | 69 +- authentik/events/settings.py | 13 - authentik/events/signals.py | 21 +- authentik/events/system_tasks.py | 156 --- authentik/events/tasks.py | 107 +- authentik/events/tests/test_tasks.py | 103 -- authentik/events/urls.py | 2 - authentik/lib/config.py | 3 +- authentik/lib/debug.py | 6 +- authentik/lib/default.yml | 17 +- authentik/lib/logging.py | 1 - authentik/lib/sentry.py | 11 - authentik/lib/sync/api.py | 12 + authentik/lib/sync/outgoing/__init__.py | 2 +- authentik/lib/sync/outgoing/api.py | 94 +- authentik/lib/sync/outgoing/models.py | 40 +- authentik/lib/sync/outgoing/signals.py | 57 +- authentik/lib/sync/outgoing/tasks.py | 420 ++++---- authentik/lib/tests/test_config.py | 9 +- authentik/outposts/apps.py | 26 + authentik/outposts/controllers/kubernetes.py | 16 +- authentik/outposts/models.py | 40 +- authentik/outposts/settings.py | 28 - authentik/outposts/signals.py | 105 +- authentik/outposts/tasks.py | 163 +--- authentik/outposts/tests/test_sa.py | 1 + authentik/providers/proxy/apps.py | 1 + authentik/providers/proxy/signals.py | 13 + authentik/providers/proxy/tasks.py | 26 + authentik/providers/rac/models.py | 3 +- authentik/providers/scim/api/providers.py | 2 +- .../scim/management/commands/scim_sync.py | 4 +- authentik/providers/scim/models.py | 7 + authentik/providers/scim/settings.py | 13 - authentik/providers/scim/signals.py | 7 +- authentik/providers/scim/tasks.py | 35 +- authentik/providers/scim/tests/test_client.py | 8 +- .../providers/scim/tests/test_membership.py | 38 +- authentik/providers/scim/tests/test_user.py | 17 +- authentik/root/celery.py | 167 ---- authentik/root/settings.py | 119 ++- authentik/root/test_runner.py | 5 +- authentik/sources/kerberos/api/source.py | 71 +- authentik/sources/kerberos/models.py | 26 +- authentik/sources/kerberos/settings.py | 18 - authentik/sources/kerberos/signals.py | 15 +- authentik/sources/kerberos/sync.py | 19 +- authentik/sources/kerberos/tasks.py | 58 +- authentik/sources/kerberos/tests/test_sync.py | 9 +- authentik/sources/ldap/api.py | 60 +- .../ldap/management/commands/ldap_sync.py | 13 +- authentik/sources/ldap/models.py | 26 +- authentik/sources/ldap/settings.py | 18 - authentik/sources/ldap/signals.py | 19 - authentik/sources/ldap/sync/base.py | 20 +- .../ldap/sync/forward_delete_groups.py | 4 +- .../sources/ldap/sync/forward_delete_users.py | 4 +- authentik/sources/ldap/sync/groups.py | 14 +- authentik/sources/ldap/sync/membership.py | 13 +- authentik/sources/ldap/sync/users.py | 19 +- authentik/sources/ldap/sync/vendor/freeipa.py | 2 +- authentik/sources/ldap/sync/vendor/ms_ad.py | 2 +- authentik/sources/ldap/tasks.py | 148 +-- authentik/sources/ldap/tests/test_auth.py | 7 +- authentik/sources/ldap/tests/test_sync.py | 68 +- authentik/sources/oauth/apps.py | 13 + authentik/sources/oauth/settings.py | 13 - authentik/sources/oauth/tasks.py | 25 +- authentik/sources/oauth/tests/test_tasks.py | 2 +- authentik/sources/plex/models.py | 18 +- authentik/sources/plex/settings.py | 13 - authentik/sources/plex/tasks.py | 28 +- authentik/sources/plex/tests.py | 6 +- authentik/stages/authenticator_duo/api.py | 42 +- authentik/stages/authenticator_duo/tasks.py | 47 - authentik/stages/authenticator_email/tests.py | 286 +++--- .../tests/test_webauthn.py | 6 +- .../stages/authenticator_webauthn/apps.py | 13 + .../stages/authenticator_webauthn/settings.py | 17 - .../stages/authenticator_webauthn/tasks.py | 21 +- .../stages/authenticator_webauthn/tests.py | 6 +- authentik/stages/consent/tests.py | 2 +- authentik/stages/email/tasks.py | 136 ++- authentik/stages/email/tests/test_stage.py | 4 + authentik/stages/email/tests/test_tasks.py | 27 +- authentik/tasks/__init__.py | 0 authentik/tasks/api/__init__.py | 0 authentik/tasks/api/tasks.py | 138 +++ authentik/tasks/api/workers.py | 48 + authentik/tasks/apps.py | 21 + authentik/tasks/broker.py | 11 + authentik/tasks/forks.py | 44 + authentik/tasks/middleware.py | 210 ++++ authentik/tasks/migrations/0001_initial.py | 147 +++ authentik/tasks/migrations/__init__.py | 0 authentik/tasks/models.py | 169 ++++ authentik/tasks/schedules/__init__.py | 0 authentik/tasks/schedules/api.py | 133 +++ authentik/tasks/schedules/apps.py | 51 + authentik/tasks/schedules/common.py | 66 ++ .../schedules/migrations/0001_initial.py | 97 ++ .../tasks/schedules/migrations/__init__.py | 0 authentik/tasks/schedules/models.py | 73 ++ authentik/tasks/schedules/scheduler.py | 26 + authentik/tasks/schedules/signals.py | 18 + authentik/tasks/schedules/urls.py | 5 + authentik/tasks/setup.py | 14 + authentik/tasks/signals.py | 45 + authentik/tasks/tasks.py | 10 + authentik/tasks/test.py | 82 ++ authentik/tasks/tests/__init__.py | 0 authentik/tasks/tests/test_actors.py | 10 + authentik/tasks/tests/test_api.py | 26 + authentik/tasks/urls.py | 9 + authentik/tasks/worker.py | 0 authentik/tenants/migrations/0001_initial.py | 5 +- authentik/tenants/models.py | 6 +- authentik/tenants/scheduler.py | 22 - authentik/tenants/tests/test_settings.py | 67 -- authentik/tenants/tests/utils.py | 13 +- blueprints/schema.json | 123 ++- cmd/server/healthcheck.go | 80 +- lifecycle/ak | 18 +- lifecycle/gunicorn.conf.py | 7 +- manage.py | 4 +- packages/django-dramatiq-postgres/README.md | 0 .../django_dramatiq_postgres/__init__.py | 0 .../django_dramatiq_postgres/apps.py | 45 + .../django_dramatiq_postgres/broker.py | 454 +++++++++ .../django_dramatiq_postgres/conf.py | 124 +++ .../django_dramatiq_postgres/forks.py | 18 + .../management/__init__.py | 0 .../management/commands/__init__.py | 0 .../management/commands/worker.py | 100 ++ .../django_dramatiq_postgres/middleware.py | 309 ++++++ .../django_dramatiq_postgres/models.py | 162 ++++ .../django_dramatiq_postgres/results.py | 43 + .../django_dramatiq_postgres/scheduler.py | 57 ++ .../django_dramatiq_postgres/setup.py | 3 + .../django-dramatiq-postgres/pyproject.toml | 50 + pyproject.toml | 19 +- schema.yml | 918 ++++++++++++------ scripts/generate_config.py | 6 + tests/e2e/test_source_ldap_samba.py | 13 +- tests/integration/test_outpost_docker.py | 2 +- tests/integration/test_outpost_kubernetes.py | 2 +- tests/integration/test_proxy_docker.py | 2 +- tests/integration/test_proxy_kubernetes.py | 2 +- uv.lock | 336 +++---- web/src/admin/Routes.ts | 4 +- .../admin/admin-overview/SystemTasksPage.ts | 85 ++ .../admin-overview/cards/WorkerStatusCard.ts | 4 +- .../admin-overview/charts/SyncStatusChart.ts | 39 +- web/src/admin/blueprints/BlueprintListPage.ts | 21 +- web/src/admin/events/RuleListPage.ts | 19 + web/src/admin/events/TransportListPage.ts | 27 + web/src/admin/outposts/OutpostListPage.ts | 38 +- .../outposts/ServiceConnectionListPage.ts | 43 + .../GoogleWorkspaceProviderViewPage.ts | 70 +- .../MicrosoftEntraProviderViewPage.ts | 73 +- .../providers/scim/SCIMProviderViewPage.ts | 206 ++-- .../providers/ssf/SSFProviderViewPage.ts | 11 + .../kerberos/KerberosSourceViewPage.ts | 94 +- .../admin/sources/ldap/LDAPSourceViewPage.ts | 68 +- .../admin/system-tasks/SystemTaskListPage.ts | 162 ---- web/src/elements/Label.ts | 2 + .../elements/sync/SyncStatusCard.stories.ts | 116 +-- web/src/elements/sync/SyncStatusCard.ts | 180 +--- web/src/elements/tasks/ScheduleForm.ts | 68 ++ web/src/elements/tasks/ScheduleList.ts | 179 ++++ web/src/elements/tasks/TaskList.ts | 203 ++++ web/src/elements/tasks/TaskStatus.ts | 66 ++ website/docs/core/architecture.md | 2 +- website/docs/developer-docs/index.md | 4 +- .../docs/developer-docs/setup/debugging.md | 6 +- .../setup/full-dev-environment.mdx | 18 +- .../configuration/configuration.mdx | 132 ++- website/docs/releases/2025/v2025.8.md | 93 ++ website/docs/sidebar.mjs | 2 + website/docs/sys-mgmt/background-tasks.md | 78 ++ website/docs/sys-mgmt/ops/monitoring.md | 8 +- website/docs/sys-mgmt/ops/worker.md | 50 + 242 files changed, 7227 insertions(+), 3897 deletions(-) delete mode 100644 authentik/admin/api/workers.py delete mode 100644 authentik/admin/settings.py delete mode 100644 authentik/admin/signals.py delete mode 100644 authentik/blueprints/settings.py create mode 100644 authentik/blueprints/tasks.py delete mode 100644 authentik/core/management/commands/bootstrap_tasks.py delete mode 100644 authentik/core/management/commands/worker.py delete mode 100644 authentik/crypto/settings.py delete mode 100644 authentik/enterprise/policies/unique_password/settings.py delete mode 100644 authentik/enterprise/providers/google_workspace/settings.py delete mode 100644 authentik/enterprise/providers/microsoft_entra/settings.py delete mode 100644 authentik/events/api/tasks.py create mode 100644 authentik/events/migrations/0011_alter_systemtask_options.py delete mode 100644 authentik/events/settings.py delete mode 100644 authentik/events/system_tasks.py delete mode 100644 authentik/events/tests/test_tasks.py create mode 100644 authentik/lib/sync/api.py delete mode 100644 authentik/outposts/settings.py create mode 100644 authentik/providers/proxy/signals.py create mode 100644 authentik/providers/proxy/tasks.py delete mode 100644 authentik/providers/scim/settings.py delete mode 100644 authentik/root/celery.py delete mode 100644 authentik/sources/kerberos/settings.py delete mode 100644 authentik/sources/ldap/settings.py delete mode 100644 authentik/sources/oauth/settings.py delete mode 100644 authentik/sources/plex/settings.py delete mode 100644 authentik/stages/authenticator_duo/tasks.py delete mode 100644 authentik/stages/authenticator_webauthn/settings.py create mode 100644 authentik/tasks/__init__.py create mode 100644 authentik/tasks/api/__init__.py create mode 100644 authentik/tasks/api/tasks.py create mode 100644 authentik/tasks/api/workers.py create mode 100644 authentik/tasks/apps.py create mode 100644 authentik/tasks/broker.py create mode 100644 authentik/tasks/forks.py create mode 100644 authentik/tasks/middleware.py create mode 100644 authentik/tasks/migrations/0001_initial.py create mode 100644 authentik/tasks/migrations/__init__.py create mode 100644 authentik/tasks/models.py create mode 100644 authentik/tasks/schedules/__init__.py create mode 100644 authentik/tasks/schedules/api.py create mode 100644 authentik/tasks/schedules/apps.py create mode 100644 authentik/tasks/schedules/common.py create mode 100644 authentik/tasks/schedules/migrations/0001_initial.py create mode 100644 authentik/tasks/schedules/migrations/__init__.py create mode 100644 authentik/tasks/schedules/models.py create mode 100644 authentik/tasks/schedules/scheduler.py create mode 100644 authentik/tasks/schedules/signals.py create mode 100644 authentik/tasks/schedules/urls.py create mode 100644 authentik/tasks/setup.py create mode 100644 authentik/tasks/signals.py create mode 100644 authentik/tasks/tasks.py create mode 100644 authentik/tasks/test.py create mode 100644 authentik/tasks/tests/__init__.py create mode 100644 authentik/tasks/tests/test_actors.py create mode 100644 authentik/tasks/tests/test_api.py create mode 100644 authentik/tasks/urls.py create mode 100644 authentik/tasks/worker.py delete mode 100644 authentik/tenants/scheduler.py delete mode 100644 authentik/tenants/tests/test_settings.py create mode 100644 packages/django-dramatiq-postgres/README.md create mode 100644 packages/django-dramatiq-postgres/django_dramatiq_postgres/__init__.py create mode 100644 packages/django-dramatiq-postgres/django_dramatiq_postgres/apps.py create mode 100644 packages/django-dramatiq-postgres/django_dramatiq_postgres/broker.py create mode 100644 packages/django-dramatiq-postgres/django_dramatiq_postgres/conf.py create mode 100644 packages/django-dramatiq-postgres/django_dramatiq_postgres/forks.py create mode 100644 packages/django-dramatiq-postgres/django_dramatiq_postgres/management/__init__.py create mode 100644 packages/django-dramatiq-postgres/django_dramatiq_postgres/management/commands/__init__.py create mode 100644 packages/django-dramatiq-postgres/django_dramatiq_postgres/management/commands/worker.py create mode 100644 packages/django-dramatiq-postgres/django_dramatiq_postgres/middleware.py create mode 100644 packages/django-dramatiq-postgres/django_dramatiq_postgres/models.py create mode 100644 packages/django-dramatiq-postgres/django_dramatiq_postgres/results.py create mode 100644 packages/django-dramatiq-postgres/django_dramatiq_postgres/scheduler.py create mode 100644 packages/django-dramatiq-postgres/django_dramatiq_postgres/setup.py create mode 100644 packages/django-dramatiq-postgres/pyproject.toml create mode 100644 web/src/admin/admin-overview/SystemTasksPage.ts delete mode 100644 web/src/admin/system-tasks/SystemTaskListPage.ts create mode 100644 web/src/elements/tasks/ScheduleForm.ts create mode 100644 web/src/elements/tasks/ScheduleList.ts create mode 100644 web/src/elements/tasks/TaskList.ts create mode 100644 web/src/elements/tasks/TaskStatus.ts create mode 100644 website/docs/releases/2025/v2025.8.md create mode 100644 website/docs/sys-mgmt/background-tasks.md create mode 100644 website/docs/sys-mgmt/ops/worker.md diff --git a/.gitignore b/.gitignore index d79d662b16..6062eb1bd1 100644 --- a/.gitignore +++ b/.gitignore @@ -100,9 +100,6 @@ ipython_config.py # pyenv .python-version -# celery beat schedule file -celerybeat-schedule - # SageMath parsed files *.sage.py @@ -166,8 +163,6 @@ dmypy.json # pyenv -# celery beat schedule file - # SageMath parsed files # Environments diff --git a/Dockerfile b/Dockerfile index ad3e51b324..2027a19ec1 100644 --- a/Dockerfile +++ b/Dockerfile @@ -123,6 +123,7 @@ ENV UV_NO_BINARY_PACKAGE="cryptography lxml python-kadmin-rs xmlsec" RUN --mount=type=bind,target=pyproject.toml,src=pyproject.toml \ --mount=type=bind,target=uv.lock,src=uv.lock \ + --mount=type=bind,target=packages,src=packages \ --mount=type=cache,target=/root/.cache/uv \ uv sync --frozen --no-install-project --no-dev @@ -168,6 +169,7 @@ COPY ./blueprints /blueprints COPY ./lifecycle/ /lifecycle COPY ./authentik/sources/kerberos/krb5.conf /etc/krb5.conf COPY --from=go-builder /go/authentik /bin/authentik +COPY ./packages/ /ak-root/packages COPY --from=python-deps /ak-root/.venv /ak-root/.venv COPY --from=node-builder /work/web/dist/ /web/dist/ COPY --from=node-builder /work/web/authentik/ /web/authentik/ diff --git a/Makefile b/Makefile index 0495f541ec..f78e2bc496 100644 --- a/Makefile +++ b/Makefile @@ -6,7 +6,7 @@ PWD = $(shell pwd) UID = $(shell id -u) GID = $(shell id -g) NPM_VERSION = $(shell python -m scripts.generate_semver) -PY_SOURCES = authentik tests scripts lifecycle .github +PY_SOURCES = authentik packages tests scripts lifecycle .github DOCKER_IMAGE ?= "authentik:test" GEN_API_TS = gen-ts-api @@ -59,9 +59,12 @@ i18n-extract: core-i18n-extract web-i18n-extract ## Extract strings that requir aws-cfn: cd lifecycle/aws && npm run aws-cfn -run: ## Run the main authentik server process +run-server: ## Run the main authentik server process uv run ak server +run-worker: ## Run the main authentik worker process + uv run ak worker + core-i18n-extract: uv run ak makemessages \ --add-location file \ diff --git a/authentik/admin/api/version.py b/authentik/admin/api/version.py index 48ec8bd0f5..f7eecdd53e 100644 --- a/authentik/admin/api/version.py +++ b/authentik/admin/api/version.py @@ -41,7 +41,7 @@ class VersionSerializer(PassiveSerializer): return __version__ version_in_cache = cache.get(VERSION_CACHE_KEY) if not version_in_cache: # pragma: no cover - update_latest_version.delay() + update_latest_version.send() return __version__ return version_in_cache diff --git a/authentik/admin/api/workers.py b/authentik/admin/api/workers.py deleted file mode 100644 index b7c2f08b2c..0000000000 --- a/authentik/admin/api/workers.py +++ /dev/null @@ -1,57 +0,0 @@ -"""authentik administration overview""" - -from socket import gethostname - -from django.conf import settings -from drf_spectacular.utils import extend_schema, inline_serializer -from packaging.version import parse -from rest_framework.fields import BooleanField, CharField -from rest_framework.request import Request -from rest_framework.response import Response -from rest_framework.views import APIView - -from authentik import get_full_version -from authentik.rbac.permissions import HasPermission -from authentik.root.celery import CELERY_APP - - -class WorkerView(APIView): - """Get currently connected worker count.""" - - permission_classes = [HasPermission("authentik_rbac.view_system_info")] - - @extend_schema( - responses=inline_serializer( - "Worker", - fields={ - "worker_id": CharField(), - "version": CharField(), - "version_matching": BooleanField(), - }, - many=True, - ) - ) - def get(self, request: Request) -> Response: - """Get currently connected worker count.""" - raw: list[dict[str, dict]] = CELERY_APP.control.ping(timeout=0.5) - our_version = parse(get_full_version()) - response = [] - for worker in raw: - key = list(worker.keys())[0] - version = worker[key].get("version") - version_matching = False - if version: - version_matching = parse(version) == our_version - response.append( - {"worker_id": key, "version": version, "version_matching": version_matching} - ) - # In debug we run with `task_always_eager`, so tasks are ran on the main process - if settings.DEBUG: # pragma: no cover - response.append( - { - "worker_id": f"authentik-debug@{gethostname()}", - "version": get_full_version(), - "version_matching": True, - } - ) - return Response(response) diff --git a/authentik/admin/apps.py b/authentik/admin/apps.py index 31e55797d9..963cee9149 100644 --- a/authentik/admin/apps.py +++ b/authentik/admin/apps.py @@ -3,6 +3,9 @@ from prometheus_client import Info from authentik.blueprints.apps import ManagedAppConfig +from authentik.lib.config import CONFIG +from authentik.lib.utils.time import fqdn_rand +from authentik.tasks.schedules.common import ScheduleSpec PROM_INFO = Info("authentik_version", "Currently running authentik version") @@ -30,3 +33,15 @@ class AuthentikAdminConfig(ManagedAppConfig): notification_version = notification.event.context["new_version"] if LOCAL_VERSION >= parse(notification_version): notification.delete() + + @property + def global_schedule_specs(self) -> list[ScheduleSpec]: + from authentik.admin.tasks import update_latest_version + + return [ + ScheduleSpec( + actor=update_latest_version, + crontab=f"{fqdn_rand('admin_latest_version')} * * * *", + paused=CONFIG.get_bool("disable_update_check"), + ), + ] diff --git a/authentik/admin/settings.py b/authentik/admin/settings.py deleted file mode 100644 index 4db9a57a97..0000000000 --- a/authentik/admin/settings.py +++ /dev/null @@ -1,15 +0,0 @@ -"""authentik admin settings""" - -from celery.schedules import crontab -from django_tenants.utils import get_public_schema_name - -from authentik.lib.utils.time import fqdn_rand - -CELERY_BEAT_SCHEDULE = { - "admin_latest_version": { - "task": "authentik.admin.tasks.update_latest_version", - "schedule": crontab(minute=fqdn_rand("admin_latest_version"), hour="*"), - "tenant_schemas": [get_public_schema_name()], - "options": {"queue": "authentik_scheduled"}, - } -} diff --git a/authentik/admin/signals.py b/authentik/admin/signals.py deleted file mode 100644 index d6856b0fa9..0000000000 --- a/authentik/admin/signals.py +++ /dev/null @@ -1,35 +0,0 @@ -"""admin signals""" - -from django.dispatch import receiver -from packaging.version import parse -from prometheus_client import Gauge - -from authentik import get_full_version -from authentik.root.celery import CELERY_APP -from authentik.root.monitoring import monitoring_set - -GAUGE_WORKERS = Gauge( - "authentik_admin_workers", - "Currently connected workers, their versions and if they are the same version as authentik", - ["version", "version_matched"], -) - - -_version = parse(get_full_version()) - - -@receiver(monitoring_set) -def monitoring_set_workers(sender, **kwargs): - """Set worker gauge""" - raw: list[dict[str, dict]] = CELERY_APP.control.ping(timeout=0.5) - worker_version_count = {} - for worker in raw: - key = list(worker.keys())[0] - version = worker[key].get("version") - version_matching = False - if version: - version_matching = parse(version) == _version - worker_version_count.setdefault(version, {"count": 0, "matching": version_matching}) - worker_version_count[version]["count"] += 1 - for version, stats in worker_version_count.items(): - GAUGE_WORKERS.labels(version, stats["matching"]).set(stats["count"]) diff --git a/authentik/admin/tasks.py b/authentik/admin/tasks.py index 30b266ba65..68fac20234 100644 --- a/authentik/admin/tasks.py +++ b/authentik/admin/tasks.py @@ -2,6 +2,8 @@ from django.core.cache import cache from django.utils.translation import gettext_lazy as _ +from django_dramatiq_postgres.middleware import CurrentTask +from dramatiq import actor from packaging.version import parse from requests import RequestException from structlog.stdlib import get_logger @@ -9,10 +11,9 @@ from structlog.stdlib import get_logger from authentik import __version__, get_build_hash from authentik.admin.apps import PROM_INFO from authentik.events.models import Event, EventAction -from authentik.events.system_tasks import SystemTask, TaskStatus, prefill_task from authentik.lib.config import CONFIG from authentik.lib.utils.http import get_http_session -from authentik.root.celery import CELERY_APP +from authentik.tasks.models import Task LOGGER = get_logger() VERSION_NULL = "0.0.0" @@ -32,13 +33,12 @@ def _set_prom_info(): ) -@CELERY_APP.task(bind=True, base=SystemTask) -@prefill_task -def update_latest_version(self: SystemTask): - """Update latest version info""" +@actor(description=_("Update latest version info.")) +def update_latest_version(): + self: Task = CurrentTask.get_task() if CONFIG.get_bool("disable_update_check"): cache.set(VERSION_CACHE_KEY, VERSION_NULL, VERSION_CACHE_TIMEOUT) - self.set_status(TaskStatus.WARNING, "Version check disabled.") + self.info("Version check disabled.") return try: response = get_http_session().get( @@ -48,7 +48,7 @@ def update_latest_version(self: SystemTask): data = response.json() upstream_version = data.get("stable", {}).get("version") cache.set(VERSION_CACHE_KEY, upstream_version, VERSION_CACHE_TIMEOUT) - self.set_status(TaskStatus.SUCCESSFUL, "Successfully updated latest Version") + self.info("Successfully updated latest Version") _set_prom_info() # Check if upstream version is newer than what we're running, # and if no event exists yet, create one. @@ -71,7 +71,7 @@ def update_latest_version(self: SystemTask): ).save() except (RequestException, IndexError) as exc: cache.set(VERSION_CACHE_KEY, VERSION_NULL, VERSION_CACHE_TIMEOUT) - self.set_error(exc) + raise exc _set_prom_info() diff --git a/authentik/admin/tests/test_api.py b/authentik/admin/tests/test_api.py index 1268812e7d..1f38ee8d15 100644 --- a/authentik/admin/tests/test_api.py +++ b/authentik/admin/tests/test_api.py @@ -29,13 +29,6 @@ class TestAdminAPI(TestCase): body = loads(response.content) self.assertEqual(body["version_current"], __version__) - def test_workers(self): - """Test Workers API""" - response = self.client.get(reverse("authentik_api:admin_workers")) - self.assertEqual(response.status_code, 200) - body = loads(response.content) - self.assertEqual(len(body), 0) - def test_apps(self): """Test apps API""" response = self.client.get(reverse("authentik_api:apps-list")) diff --git a/authentik/admin/tests/test_tasks.py b/authentik/admin/tests/test_tasks.py index 22926b5948..d81c79175b 100644 --- a/authentik/admin/tests/test_tasks.py +++ b/authentik/admin/tests/test_tasks.py @@ -30,7 +30,7 @@ class TestAdminTasks(TestCase): """Test Update checker with valid response""" with Mocker() as mocker, CONFIG.patch("disable_update_check", False): mocker.get("https://version.goauthentik.io/version.json", json=RESPONSE_VALID) - update_latest_version.delay().get() + update_latest_version.send() self.assertEqual(cache.get(VERSION_CACHE_KEY), "99999999.9999999") self.assertTrue( Event.objects.filter( @@ -40,7 +40,7 @@ class TestAdminTasks(TestCase): ).exists() ) # test that a consecutive check doesn't create a duplicate event - update_latest_version.delay().get() + update_latest_version.send() self.assertEqual( len( Event.objects.filter( @@ -56,7 +56,7 @@ class TestAdminTasks(TestCase): """Test Update checker with invalid response""" with Mocker() as mocker: mocker.get("https://version.goauthentik.io/version.json", status_code=400) - update_latest_version.delay().get() + update_latest_version.send() self.assertEqual(cache.get(VERSION_CACHE_KEY), "0.0.0") self.assertFalse( Event.objects.filter( @@ -67,14 +67,15 @@ class TestAdminTasks(TestCase): def test_version_disabled(self): """Test Update checker while its disabled""" with CONFIG.patch("disable_update_check", True): - update_latest_version.delay().get() + update_latest_version.send() self.assertEqual(cache.get(VERSION_CACHE_KEY), "0.0.0") def test_clear_update_notifications(self): """Test clear of previous notification""" admin_config = apps.get_app_config("authentik_admin") Event.objects.create( - action=EventAction.UPDATE_AVAILABLE, context={"new_version": "99999999.9999999.9999999"} + action=EventAction.UPDATE_AVAILABLE, + context={"new_version": "99999999.9999999.9999999"}, ) Event.objects.create(action=EventAction.UPDATE_AVAILABLE, context={"new_version": "1.1.1"}) Event.objects.create(action=EventAction.UPDATE_AVAILABLE, context={}) diff --git a/authentik/admin/urls.py b/authentik/admin/urls.py index 0dd6fc02f2..9a1dd04e11 100644 --- a/authentik/admin/urls.py +++ b/authentik/admin/urls.py @@ -6,13 +6,11 @@ from authentik.admin.api.meta import AppsViewSet, ModelViewSet from authentik.admin.api.system import SystemView from authentik.admin.api.version import VersionView from authentik.admin.api.version_history import VersionHistoryViewSet -from authentik.admin.api.workers import WorkerView api_urlpatterns = [ ("admin/apps", AppsViewSet, "apps"), ("admin/models", ModelViewSet, "models"), path("admin/version/", VersionView.as_view(), name="admin_version"), ("admin/version/history", VersionHistoryViewSet, "version_history"), - path("admin/workers/", WorkerView.as_view(), name="admin_workers"), path("admin/system/", SystemView.as_view(), name="admin_system"), ] diff --git a/authentik/blueprints/api.py b/authentik/blueprints/api.py index 686f194061..33578e2043 100644 --- a/authentik/blueprints/api.py +++ b/authentik/blueprints/api.py @@ -39,7 +39,7 @@ class BlueprintInstanceSerializer(ModelSerializer): """Ensure the path (if set) specified is retrievable""" if path == "" or path.startswith(OCI_PREFIX): return path - files: list[dict] = blueprints_find_dict.delay().get() + files: list[dict] = blueprints_find_dict.send().get_result(block=True) if path not in [file["path"] for file in files]: raise ValidationError(_("Blueprint file does not exist")) return path @@ -115,7 +115,7 @@ class BlueprintInstanceViewSet(UsedByMixin, ModelViewSet): @action(detail=False, pagination_class=None, filter_backends=[]) def available(self, request: Request) -> Response: """Get blueprints""" - files: list[dict] = blueprints_find_dict.delay().get() + files: list[dict] = blueprints_find_dict.send().get_result(block=True) return Response(files) @permission_required("authentik_blueprints.view_blueprintinstance") @@ -129,5 +129,5 @@ class BlueprintInstanceViewSet(UsedByMixin, ModelViewSet): def apply(self, request: Request, *args, **kwargs) -> Response: """Apply a blueprint""" blueprint = self.get_object() - apply_blueprint.delay(str(blueprint.pk)).get() + apply_blueprint.send_with_options(args=(blueprint.pk,), rel_obj=blueprint) return self.retrieve(request, *args, **kwargs) diff --git a/authentik/blueprints/apps.py b/authentik/blueprints/apps.py index eeef482764..9b56558bad 100644 --- a/authentik/blueprints/apps.py +++ b/authentik/blueprints/apps.py @@ -6,9 +6,12 @@ from inspect import ismethod from django.apps import AppConfig from django.db import DatabaseError, InternalError, ProgrammingError +from dramatiq.broker import get_broker from structlog.stdlib import BoundLogger, get_logger +from authentik.lib.utils.time import fqdn_rand from authentik.root.signals import startup +from authentik.tasks.schedules.common import ScheduleSpec class ManagedAppConfig(AppConfig): @@ -34,7 +37,7 @@ class ManagedAppConfig(AppConfig): def import_related(self): """Automatically import related modules which rely on just being imported - to register themselves (mainly django signals and celery tasks)""" + to register themselves (mainly django signals and tasks)""" def import_relative(rel_module: str): try: @@ -80,6 +83,16 @@ class ManagedAppConfig(AppConfig): func._authentik_managed_reconcile = ManagedAppConfig.RECONCILE_GLOBAL_CATEGORY return func + @property + def tenant_schedule_specs(self) -> list[ScheduleSpec]: + """Get a list of schedule specs that must exist in each tenant""" + return [] + + @property + def global_schedule_specs(self) -> list[ScheduleSpec]: + """Get a list of schedule specs that must exist in the default tenant""" + return [] + def _reconcile_tenant(self) -> None: """reconcile ourselves for tenanted methods""" from authentik.tenants.models import Tenant @@ -100,8 +113,12 @@ class ManagedAppConfig(AppConfig): """ from django_tenants.utils import get_public_schema_name, schema_context - with schema_context(get_public_schema_name()): - self._reconcile(self.RECONCILE_GLOBAL_CATEGORY) + try: + with schema_context(get_public_schema_name()): + self._reconcile(self.RECONCILE_GLOBAL_CATEGORY) + except (DatabaseError, ProgrammingError, InternalError) as exc: + self.logger.debug("Failed to access database to run reconcile", exc=exc) + return class AuthentikBlueprintsConfig(ManagedAppConfig): @@ -112,19 +129,29 @@ class AuthentikBlueprintsConfig(ManagedAppConfig): verbose_name = "authentik Blueprints" default = True - @ManagedAppConfig.reconcile_global - def load_blueprints_v1_tasks(self): - """Load v1 tasks""" - self.import_module("authentik.blueprints.v1.tasks") - - @ManagedAppConfig.reconcile_tenant - def blueprints_discovery(self): - """Run blueprint discovery""" - from authentik.blueprints.v1.tasks import blueprints_discovery, clear_failed_blueprints - - blueprints_discovery.delay() - clear_failed_blueprints.delay() - def import_models(self): super().import_models() self.import_module("authentik.blueprints.v1.meta.apply_blueprint") + + @ManagedAppConfig.reconcile_global + def tasks_middlewares(self): + from authentik.blueprints.v1.tasks import BlueprintWatcherMiddleware + + get_broker().add_middleware(BlueprintWatcherMiddleware()) + + @property + def tenant_schedule_specs(self) -> list[ScheduleSpec]: + from authentik.blueprints.v1.tasks import blueprints_discovery, clear_failed_blueprints + + return [ + ScheduleSpec( + actor=blueprints_discovery, + crontab=f"{fqdn_rand('blueprints_v1_discover')} * * * *", + send_on_startup=True, + ), + ScheduleSpec( + actor=clear_failed_blueprints, + crontab=f"{fqdn_rand('blueprints_v1_cleanup')} * * * *", + send_on_startup=True, + ), + ] diff --git a/authentik/blueprints/models.py b/authentik/blueprints/models.py index a3abcba59f..7ee87a2d8b 100644 --- a/authentik/blueprints/models.py +++ b/authentik/blueprints/models.py @@ -3,6 +3,7 @@ from pathlib import Path from uuid import uuid4 +from django.contrib.contenttypes.fields import GenericRelation from django.contrib.postgres.fields import ArrayField from django.db import models from django.utils.translation import gettext_lazy as _ @@ -71,6 +72,13 @@ class BlueprintInstance(SerializerModel, ManagedModel, CreatedUpdatedModel): enabled = models.BooleanField(default=True) managed_models = ArrayField(models.TextField(), default=list) + # Manual link to tasks instead of using TasksModel because of loop imports + tasks = GenericRelation( + "authentik_tasks.Task", + content_type_field="rel_obj_content_type", + object_id_field="rel_obj_id", + ) + class Meta: verbose_name = _("Blueprint Instance") verbose_name_plural = _("Blueprint Instances") diff --git a/authentik/blueprints/settings.py b/authentik/blueprints/settings.py deleted file mode 100644 index da0029d678..0000000000 --- a/authentik/blueprints/settings.py +++ /dev/null @@ -1,18 +0,0 @@ -"""blueprint Settings""" - -from celery.schedules import crontab - -from authentik.lib.utils.time import fqdn_rand - -CELERY_BEAT_SCHEDULE = { - "blueprints_v1_discover": { - "task": "authentik.blueprints.v1.tasks.blueprints_discovery", - "schedule": crontab(minute=fqdn_rand("blueprints_v1_discover"), hour="*"), - "options": {"queue": "authentik_scheduled"}, - }, - "blueprints_v1_cleanup": { - "task": "authentik.blueprints.v1.tasks.clear_failed_blueprints", - "schedule": crontab(minute=fqdn_rand("blueprints_v1_cleanup"), hour="*"), - "options": {"queue": "authentik_scheduled"}, - }, -} diff --git a/authentik/blueprints/tasks.py b/authentik/blueprints/tasks.py new file mode 100644 index 0000000000..6a90f77db5 --- /dev/null +++ b/authentik/blueprints/tasks.py @@ -0,0 +1,2 @@ +# Import all v1 tasks for auto task discovery +from authentik.blueprints.v1.tasks import * # noqa: F403 diff --git a/authentik/blueprints/tests/test_v1_tasks.py b/authentik/blueprints/tests/test_v1_tasks.py index b1d201419d..bf393e5a1c 100644 --- a/authentik/blueprints/tests/test_v1_tasks.py +++ b/authentik/blueprints/tests/test_v1_tasks.py @@ -54,7 +54,7 @@ class TestBlueprintsV1Tasks(TransactionTestCase): file.seek(0) file_hash = sha512(file.read().encode()).hexdigest() file.flush() - blueprints_discovery() + blueprints_discovery.send() instance = BlueprintInstance.objects.filter(name=blueprint_id).first() self.assertEqual(instance.last_applied_hash, file_hash) self.assertEqual( @@ -82,7 +82,7 @@ class TestBlueprintsV1Tasks(TransactionTestCase): ) ) file.flush() - blueprints_discovery() + blueprints_discovery.send() blueprint = BlueprintInstance.objects.filter(name="foo").first() self.assertEqual( blueprint.last_applied_hash, @@ -107,7 +107,7 @@ class TestBlueprintsV1Tasks(TransactionTestCase): ) ) file.flush() - blueprints_discovery() + blueprints_discovery.send() blueprint.refresh_from_db() self.assertEqual( blueprint.last_applied_hash, diff --git a/authentik/blueprints/v1/importer.py b/authentik/blueprints/v1/importer.py index 56d6ab47a0..d4e3738bab 100644 --- a/authentik/blueprints/v1/importer.py +++ b/authentik/blueprints/v1/importer.py @@ -57,7 +57,6 @@ from authentik.enterprise.stages.authenticator_endpoint_gdtc.models import ( EndpointDeviceConnection, ) from authentik.events.logs import LogEvent, capture_logs -from authentik.events.models import SystemTask from authentik.events.utils import cleanse_dict from authentik.flows.models import FlowToken, Stage from authentik.lib.models import SerializerModel @@ -77,6 +76,7 @@ from authentik.providers.scim.models import SCIMProviderGroup, SCIMProviderUser from authentik.rbac.models import Role from authentik.sources.scim.models import SCIMSourceGroup, SCIMSourceUser from authentik.stages.authenticator_webauthn.models import WebAuthnDeviceType +from authentik.tasks.models import Task from authentik.tenants.models import Tenant # Context set when the serializer is created in a blueprint context @@ -118,7 +118,7 @@ def excluded_models() -> list[type[Model]]: SCIMProviderGroup, SCIMProviderUser, Tenant, - SystemTask, + Task, ConnectionToken, AuthorizationCode, AccessToken, diff --git a/authentik/blueprints/v1/meta/apply_blueprint.py b/authentik/blueprints/v1/meta/apply_blueprint.py index abd593c045..de518e4838 100644 --- a/authentik/blueprints/v1/meta/apply_blueprint.py +++ b/authentik/blueprints/v1/meta/apply_blueprint.py @@ -44,7 +44,7 @@ class ApplyBlueprintMetaSerializer(PassiveSerializer): return MetaResult() LOGGER.debug("Applying blueprint from meta model", blueprint=self.blueprint_instance) - apply_blueprint(str(self.blueprint_instance.pk)) + apply_blueprint(self.blueprint_instance.pk) return MetaResult() diff --git a/authentik/blueprints/v1/tasks.py b/authentik/blueprints/v1/tasks.py index 73e712b8f8..5da8a5090d 100644 --- a/authentik/blueprints/v1/tasks.py +++ b/authentik/blueprints/v1/tasks.py @@ -4,12 +4,17 @@ from dataclasses import asdict, dataclass, field from hashlib import sha512 from pathlib import Path from sys import platform +from uuid import UUID from dacite.core import from_dict +from django.conf import settings from django.db import DatabaseError, InternalError, ProgrammingError from django.utils.text import slugify from django.utils.timezone import now from django.utils.translation import gettext_lazy as _ +from django_dramatiq_postgres.middleware import CurrentTask, CurrentTaskNotFound +from dramatiq.actor import actor +from dramatiq.middleware import Middleware from structlog.stdlib import get_logger from watchdog.events import ( FileCreatedEvent, @@ -31,15 +36,13 @@ from authentik.blueprints.v1.importer import Importer from authentik.blueprints.v1.labels import LABEL_AUTHENTIK_INSTANTIATE from authentik.blueprints.v1.oci import OCI_PREFIX from authentik.events.logs import capture_logs -from authentik.events.models import TaskStatus -from authentik.events.system_tasks import SystemTask, prefill_task from authentik.events.utils import sanitize_dict from authentik.lib.config import CONFIG -from authentik.root.celery import CELERY_APP +from authentik.tasks.models import Task +from authentik.tasks.schedules.models import Schedule from authentik.tenants.models import Tenant LOGGER = get_logger() -_file_watcher_started = False @dataclass @@ -53,22 +56,21 @@ class BlueprintFile: meta: BlueprintMetadata | None = field(default=None) -def start_blueprint_watcher(): - """Start blueprint watcher, if it's not running already.""" - # This function might be called twice since it's called on celery startup +class BlueprintWatcherMiddleware(Middleware): + def start_blueprint_watcher(self): + """Start blueprint watcher""" + observer = Observer() + kwargs = {} + if platform.startswith("linux"): + kwargs["event_filter"] = (FileCreatedEvent, FileModifiedEvent) + observer.schedule( + BlueprintEventHandler(), CONFIG.get("blueprints_dir"), recursive=True, **kwargs + ) + observer.start() - global _file_watcher_started # noqa: PLW0603 - if _file_watcher_started: - return - observer = Observer() - kwargs = {} - if platform.startswith("linux"): - kwargs["event_filter"] = (FileCreatedEvent, FileModifiedEvent) - observer.schedule( - BlueprintEventHandler(), CONFIG.get("blueprints_dir"), recursive=True, **kwargs - ) - observer.start() - _file_watcher_started = True + def after_worker_boot(self, broker, worker): + if not settings.TEST: + self.start_blueprint_watcher() class BlueprintEventHandler(FileSystemEventHandler): @@ -92,7 +94,7 @@ class BlueprintEventHandler(FileSystemEventHandler): LOGGER.debug("new blueprint file created, starting discovery") for tenant in Tenant.objects.filter(ready=True): with tenant: - blueprints_discovery.delay() + Schedule.dispatch_by_actor(blueprints_discovery) def on_modified(self, event: FileSystemEvent): """Process file modification""" @@ -103,14 +105,14 @@ class BlueprintEventHandler(FileSystemEventHandler): with tenant: for instance in BlueprintInstance.objects.filter(path=rel_path, enabled=True): LOGGER.debug("modified blueprint file, starting apply", instance=instance) - apply_blueprint.delay(instance.pk.hex) + apply_blueprint.send_with_options(args=(instance.pk,), rel_obj=instance) -@CELERY_APP.task( +@actor( + description=_("Find blueprints as `blueprints_find` does, but return a safe dict."), throws=(DatabaseError, ProgrammingError, InternalError), ) def blueprints_find_dict(): - """Find blueprints as `blueprints_find` does, but return a safe dict""" blueprints = [] for blueprint in blueprints_find(): blueprints.append(sanitize_dict(asdict(blueprint))) @@ -146,21 +148,19 @@ def blueprints_find() -> list[BlueprintFile]: return blueprints -@CELERY_APP.task( - throws=(DatabaseError, ProgrammingError, InternalError), base=SystemTask, bind=True +@actor( + description=_("Find blueprints and check if they need to be created in the database."), + throws=(DatabaseError, ProgrammingError, InternalError), ) -@prefill_task -def blueprints_discovery(self: SystemTask, path: str | None = None): - """Find blueprints and check if they need to be created in the database""" +def blueprints_discovery(path: str | None = None): + self: Task = CurrentTask.get_task() count = 0 for blueprint in blueprints_find(): if path and blueprint.path != path: continue check_blueprint_v1_file(blueprint) count += 1 - self.set_status( - TaskStatus.SUCCESSFUL, _("Successfully imported {count} files.".format(count=count)) - ) + self.info(f"Successfully imported {count} files.") def check_blueprint_v1_file(blueprint: BlueprintFile): @@ -187,22 +187,26 @@ def check_blueprint_v1_file(blueprint: BlueprintFile): ) if instance.last_applied_hash != blueprint.hash: LOGGER.info("Applying blueprint due to changed file", instance=instance, path=instance.path) - apply_blueprint.delay(str(instance.pk)) + apply_blueprint.send_with_options(args=(instance.pk,), rel_obj=instance) -@CELERY_APP.task( - bind=True, - base=SystemTask, -) -def apply_blueprint(self: SystemTask, instance_pk: str): - """Apply single blueprint""" - self.save_on_success = False +@actor(description=_("Apply single blueprint.")) +def apply_blueprint(instance_pk: UUID): + try: + self: Task = CurrentTask.get_task() + except CurrentTaskNotFound: + self = Task() + self.set_uid(str(instance_pk)) instance: BlueprintInstance | None = None try: instance: BlueprintInstance = BlueprintInstance.objects.filter(pk=instance_pk).first() - if not instance or not instance.enabled: + if not instance: + self.warning(f"Could not find blueprint {instance_pk}, skipping") return self.set_uid(slugify(instance.name)) + if not instance.enabled: + self.info(f"Blueprint {instance.name} is disabled, skipping") + return blueprint_content = instance.retrieve() file_hash = sha512(blueprint_content.encode()).hexdigest() importer = Importer.from_string(blueprint_content, instance.context) @@ -212,19 +216,18 @@ def apply_blueprint(self: SystemTask, instance_pk: str): if not valid: instance.status = BlueprintInstanceStatus.ERROR instance.save() - self.set_status(TaskStatus.ERROR, *logs) + self.logs(logs) return with capture_logs() as logs: applied = importer.apply() if not applied: instance.status = BlueprintInstanceStatus.ERROR instance.save() - self.set_status(TaskStatus.ERROR, *logs) + self.logs(logs) return instance.status = BlueprintInstanceStatus.SUCCESSFUL instance.last_applied_hash = file_hash instance.last_applied = now() - self.set_status(TaskStatus.SUCCESSFUL) except ( OSError, DatabaseError, @@ -235,15 +238,14 @@ def apply_blueprint(self: SystemTask, instance_pk: str): ) as exc: if instance: instance.status = BlueprintInstanceStatus.ERROR - self.set_error(exc) + self.error(exc) finally: if instance: instance.save() -@CELERY_APP.task() +@actor(description=_("Remove blueprints which couldn't be fetched.")) def clear_failed_blueprints(): - """Remove blueprints which couldn't be fetched""" # Exclude OCI blueprints as those might be temporarily unavailable for blueprint in BlueprintInstance.objects.exclude(path__startswith=OCI_PREFIX): try: diff --git a/authentik/brands/apps.py b/authentik/brands/apps.py index 8daa0525a9..f1d37473b3 100644 --- a/authentik/brands/apps.py +++ b/authentik/brands/apps.py @@ -9,6 +9,7 @@ class AuthentikBrandsConfig(ManagedAppConfig): name = "authentik.brands" label = "authentik_brands" verbose_name = "authentik Brands" + default = True mountpoints = { "authentik.brands.urls_root": "", } diff --git a/authentik/core/apps.py b/authentik/core/apps.py index 87f7992682..392980a035 100644 --- a/authentik/core/apps.py +++ b/authentik/core/apps.py @@ -1,8 +1,7 @@ """authentik core app config""" -from django.conf import settings - from authentik.blueprints.apps import ManagedAppConfig +from authentik.tasks.schedules.common import ScheduleSpec class AuthentikCoreConfig(ManagedAppConfig): @@ -14,14 +13,6 @@ class AuthentikCoreConfig(ManagedAppConfig): mountpoint = "" default = True - @ManagedAppConfig.reconcile_global - def debug_worker_hook(self): - """Dispatch startup tasks inline when debugging""" - if settings.DEBUG: - from authentik.root.celery import worker_ready_hook - - worker_ready_hook() - @ManagedAppConfig.reconcile_tenant def source_inbuilt(self): """Reconcile inbuilt source""" @@ -34,3 +25,18 @@ class AuthentikCoreConfig(ManagedAppConfig): }, managed=Source.MANAGED_INBUILT, ) + + @property + def tenant_schedule_specs(self) -> list[ScheduleSpec]: + from authentik.core.tasks import clean_expired_models, clean_temporary_users + + return [ + ScheduleSpec( + actor=clean_expired_models, + crontab="2-59/5 * * * *", + ), + ScheduleSpec( + actor=clean_temporary_users, + crontab="9-59/5 * * * *", + ), + ] diff --git a/authentik/core/management/commands/bootstrap_tasks.py b/authentik/core/management/commands/bootstrap_tasks.py deleted file mode 100644 index 1038796b9a..0000000000 --- a/authentik/core/management/commands/bootstrap_tasks.py +++ /dev/null @@ -1,21 +0,0 @@ -"""Run bootstrap tasks""" - -from django.core.management.base import BaseCommand -from django_tenants.utils import get_public_schema_name - -from authentik.root.celery import _get_startup_tasks_all_tenants, _get_startup_tasks_default_tenant -from authentik.tenants.models import Tenant - - -class Command(BaseCommand): - """Run bootstrap tasks to ensure certain objects are created""" - - def handle(self, **options): - for task in _get_startup_tasks_default_tenant(): - with Tenant.objects.get(schema_name=get_public_schema_name()): - task() - - for task in _get_startup_tasks_all_tenants(): - for tenant in Tenant.objects.filter(ready=True): - with tenant: - task() diff --git a/authentik/core/management/commands/worker.py b/authentik/core/management/commands/worker.py deleted file mode 100644 index 8b3ed9346c..0000000000 --- a/authentik/core/management/commands/worker.py +++ /dev/null @@ -1,47 +0,0 @@ -"""Run worker""" - -from sys import exit as sysexit -from tempfile import tempdir - -from celery.apps.worker import Worker -from django.core.management.base import BaseCommand -from django.db import close_old_connections -from structlog.stdlib import get_logger - -from authentik.lib.config import CONFIG -from authentik.lib.debug import start_debug_server -from authentik.root.celery import CELERY_APP - -LOGGER = get_logger() - - -class Command(BaseCommand): - """Run worker""" - - def add_arguments(self, parser): - parser.add_argument( - "-b", - "--beat", - action="store_false", - help="When set, this worker will _not_ run Beat (scheduled) tasks", - ) - - def handle(self, **options): - LOGGER.debug("Celery options", **options) - close_old_connections() - start_debug_server() - worker: Worker = CELERY_APP.Worker( - no_color=False, - quiet=True, - optimization="fair", - autoscale=(CONFIG.get_int("worker.concurrency"), 1), - task_events=True, - beat=options.get("beat", True), - schedule_filename=f"{tempdir}/celerybeat-schedule", - queues=["authentik", "authentik_scheduled", "authentik_events"], - ) - for task in CELERY_APP.tasks: - LOGGER.debug("Registered task", task=task) - - worker.start() - sysexit(worker.exitcode) diff --git a/authentik/core/tasks.py b/authentik/core/tasks.py index 80ca278960..1f398dc8b0 100644 --- a/authentik/core/tasks.py +++ b/authentik/core/tasks.py @@ -3,6 +3,9 @@ from datetime import datetime, timedelta from django.utils.timezone import now +from django.utils.translation import gettext_lazy as _ +from django_dramatiq_postgres.middleware import CurrentTask +from dramatiq.actor import actor from structlog.stdlib import get_logger from authentik.core.models import ( @@ -11,17 +14,14 @@ from authentik.core.models import ( ExpiringModel, User, ) -from authentik.events.system_tasks import SystemTask, TaskStatus, prefill_task -from authentik.root.celery import CELERY_APP +from authentik.tasks.models import Task LOGGER = get_logger() -@CELERY_APP.task(bind=True, base=SystemTask) -@prefill_task -def clean_expired_models(self: SystemTask): - """Remove expired objects""" - messages = [] +@actor(description=_("Remove expired objects.")) +def clean_expired_models(): + self: Task = CurrentTask.get_task() for cls in ExpiringModel.__subclasses__(): cls: ExpiringModel objects = ( @@ -31,16 +31,13 @@ def clean_expired_models(self: SystemTask): for obj in objects: obj.expire_action() LOGGER.debug("Expired models", model=cls, amount=amount) - messages.append(f"Expired {amount} {cls._meta.verbose_name_plural}") - self.set_status(TaskStatus.SUCCESSFUL, *messages) + self.info(f"Expired {amount} {cls._meta.verbose_name_plural}") -@CELERY_APP.task(bind=True, base=SystemTask) -@prefill_task -def clean_temporary_users(self: SystemTask): - """Remove temporary users created by SAML Sources""" +@actor(description=_("Remove temporary users created by SAML Sources.")) +def clean_temporary_users(): + self: Task = CurrentTask.get_task() _now = datetime.now() - messages = [] deleted_users = 0 for user in User.objects.filter(**{f"attributes__{USER_ATTRIBUTE_GENERATED}": True}): if not user.attributes.get(USER_ATTRIBUTE_EXPIRES): @@ -52,5 +49,4 @@ def clean_temporary_users(self: SystemTask): LOGGER.debug("User is expired and will be deleted.", user=user, delta=delta) user.delete() deleted_users += 1 - messages.append(f"Successfully deleted {deleted_users} users.") - self.set_status(TaskStatus.SUCCESSFUL, *messages) + self.info(f"Successfully deleted {deleted_users} users.") diff --git a/authentik/core/tests/test_tasks.py b/authentik/core/tests/test_tasks.py index 92e83066e4..89def48506 100644 --- a/authentik/core/tests/test_tasks.py +++ b/authentik/core/tests/test_tasks.py @@ -36,7 +36,7 @@ class TestTasks(APITestCase): expires=now(), user=get_anonymous_user(), intent=TokenIntents.INTENT_API ) key = token.key - clean_expired_models.delay().get() + clean_expired_models.send() token.refresh_from_db() self.assertNotEqual(key, token.key) @@ -50,5 +50,5 @@ class TestTasks(APITestCase): USER_ATTRIBUTE_EXPIRES: mktime(now().timetuple()), }, ) - clean_temporary_users.delay().get() + clean_temporary_users.send() self.assertFalse(User.objects.filter(username=username)) diff --git a/authentik/crypto/apps.py b/authentik/crypto/apps.py index cdb01b3a1b..af96006d06 100644 --- a/authentik/crypto/apps.py +++ b/authentik/crypto/apps.py @@ -4,6 +4,8 @@ from datetime import UTC, datetime from authentik.blueprints.apps import ManagedAppConfig from authentik.lib.generators import generate_id +from authentik.lib.utils.time import fqdn_rand +from authentik.tasks.schedules.common import ScheduleSpec MANAGED_KEY = "goauthentik.io/crypto/jwt-managed" @@ -67,3 +69,14 @@ class AuthentikCryptoConfig(ManagedAppConfig): "key_data": builder.private_key, }, ) + + @property + def tenant_schedule_specs(self) -> list[ScheduleSpec]: + from authentik.crypto.tasks import certificate_discovery + + return [ + ScheduleSpec( + actor=certificate_discovery, + crontab=f"{fqdn_rand('crypto_certificate_discovery')} * * * *", + ), + ] diff --git a/authentik/crypto/settings.py b/authentik/crypto/settings.py deleted file mode 100644 index 8316e9b84d..0000000000 --- a/authentik/crypto/settings.py +++ /dev/null @@ -1,13 +0,0 @@ -"""Crypto task Settings""" - -from celery.schedules import crontab - -from authentik.lib.utils.time import fqdn_rand - -CELERY_BEAT_SCHEDULE = { - "crypto_certificate_discovery": { - "task": "authentik.crypto.tasks.certificate_discovery", - "schedule": crontab(minute=fqdn_rand("crypto_certificate_discovery"), hour="*"), - "options": {"queue": "authentik_scheduled"}, - }, -} diff --git a/authentik/crypto/tasks.py b/authentik/crypto/tasks.py index bce3f998c7..5012aac1a5 100644 --- a/authentik/crypto/tasks.py +++ b/authentik/crypto/tasks.py @@ -7,13 +7,13 @@ from cryptography.hazmat.backends import default_backend from cryptography.hazmat.primitives.serialization import load_pem_private_key from cryptography.x509.base import load_pem_x509_certificate from django.utils.translation import gettext_lazy as _ +from django_dramatiq_postgres.middleware import CurrentTask +from dramatiq.actor import actor from structlog.stdlib import get_logger from authentik.crypto.models import CertificateKeyPair -from authentik.events.models import TaskStatus -from authentik.events.system_tasks import SystemTask, prefill_task from authentik.lib.config import CONFIG -from authentik.root.celery import CELERY_APP +from authentik.tasks.models import Task LOGGER = get_logger() @@ -36,10 +36,9 @@ def ensure_certificate_valid(body: str): return body -@CELERY_APP.task(bind=True, base=SystemTask) -@prefill_task -def certificate_discovery(self: SystemTask): - """Discover, import and update certificates from the filesystem""" +@actor(description=_("Discover, import and update certificates from the filesystem.")) +def certificate_discovery(): + self: Task = CurrentTask.get_task() certs = {} private_keys = {} discovered = 0 @@ -84,6 +83,4 @@ def certificate_discovery(self: SystemTask): dirty = True if dirty: cert.save() - self.set_status( - TaskStatus.SUCCESSFUL, _("Successfully imported {count} files.".format(count=discovered)) - ) + self.info(f"Successfully imported {discovered} files.") diff --git a/authentik/crypto/tests.py b/authentik/crypto/tests.py index 0e3c886d11..8b3b24a68e 100644 --- a/authentik/crypto/tests.py +++ b/authentik/crypto/tests.py @@ -338,7 +338,7 @@ class TestCrypto(APITestCase): with open(f"{temp_dir}/foo.bar/privkey.pem", "w+", encoding="utf-8") as _key: _key.write(builder.private_key) with CONFIG.patch("cert_discovery_dir", temp_dir): - certificate_discovery() + certificate_discovery.send() keypair: CertificateKeyPair = CertificateKeyPair.objects.filter( managed=MANAGED_DISCOVERED % "foo" ).first() diff --git a/authentik/enterprise/apps.py b/authentik/enterprise/apps.py index 31d6eba9a1..94afc27bfb 100644 --- a/authentik/enterprise/apps.py +++ b/authentik/enterprise/apps.py @@ -3,6 +3,8 @@ from django.conf import settings from authentik.blueprints.apps import ManagedAppConfig +from authentik.lib.utils.time import fqdn_rand +from authentik.tasks.schedules.common import ScheduleSpec class EnterpriseConfig(ManagedAppConfig): @@ -26,3 +28,14 @@ class AuthentikEnterpriseConfig(EnterpriseConfig): from authentik.enterprise.license import LicenseKey return LicenseKey.cached_summary().status.is_valid + + @property + def tenant_schedule_specs(self) -> list[ScheduleSpec]: + from authentik.enterprise.tasks import enterprise_update_usage + + return [ + ScheduleSpec( + actor=enterprise_update_usage, + crontab=f"{fqdn_rand('enterprise_update_usage')} */2 * * *", + ), + ] diff --git a/authentik/enterprise/policies/unique_password/apps.py b/authentik/enterprise/policies/unique_password/apps.py index 6ed4734cae..d58e58b9a9 100644 --- a/authentik/enterprise/policies/unique_password/apps.py +++ b/authentik/enterprise/policies/unique_password/apps.py @@ -1,6 +1,8 @@ """authentik Unique Password policy app config""" from authentik.enterprise.apps import EnterpriseConfig +from authentik.lib.utils.time import fqdn_rand +from authentik.tasks.schedules.common import ScheduleSpec class AuthentikEnterprisePoliciesUniquePasswordConfig(EnterpriseConfig): @@ -8,3 +10,21 @@ class AuthentikEnterprisePoliciesUniquePasswordConfig(EnterpriseConfig): label = "authentik_policies_unique_password" verbose_name = "authentik Enterprise.Policies.Unique Password" default = True + + @property + def tenant_schedule_specs(self) -> list[ScheduleSpec]: + from authentik.enterprise.policies.unique_password.tasks import ( + check_and_purge_password_history, + trim_password_histories, + ) + + return [ + ScheduleSpec( + actor=trim_password_histories, + crontab=f"{fqdn_rand('policies_unique_password_trim')} */12 * * *", + ), + ScheduleSpec( + actor=check_and_purge_password_history, + crontab=f"{fqdn_rand('policies_unique_password_purge')} */24 * * *", + ), + ] diff --git a/authentik/enterprise/policies/unique_password/settings.py b/authentik/enterprise/policies/unique_password/settings.py deleted file mode 100644 index 2d83afbeb1..0000000000 --- a/authentik/enterprise/policies/unique_password/settings.py +++ /dev/null @@ -1,20 +0,0 @@ -"""Unique Password Policy settings""" - -from celery.schedules import crontab - -from authentik.lib.utils.time import fqdn_rand - -CELERY_BEAT_SCHEDULE = { - "policies_unique_password_trim_history": { - "task": "authentik.enterprise.policies.unique_password.tasks.trim_password_histories", - "schedule": crontab(minute=fqdn_rand("policies_unique_password_trim"), hour="*/12"), - "options": {"queue": "authentik_scheduled"}, - }, - "policies_unique_password_check_purge": { - "task": ( - "authentik.enterprise.policies.unique_password.tasks.check_and_purge_password_history" - ), - "schedule": crontab(minute=fqdn_rand("policies_unique_password_purge"), hour="*/24"), - "options": {"queue": "authentik_scheduled"}, - }, -} diff --git a/authentik/enterprise/policies/unique_password/tasks.py b/authentik/enterprise/policies/unique_password/tasks.py index d871f3cf31..43fcd40541 100644 --- a/authentik/enterprise/policies/unique_password/tasks.py +++ b/authentik/enterprise/policies/unique_password/tasks.py @@ -1,35 +1,37 @@ from django.db.models.aggregates import Count +from django.utils.translation import gettext_lazy as _ +from django_dramatiq_postgres.middleware import CurrentTask +from dramatiq.actor import actor from structlog import get_logger from authentik.enterprise.policies.unique_password.models import ( UniquePasswordPolicy, UserPasswordHistory, ) -from authentik.events.system_tasks import SystemTask, TaskStatus, prefill_task -from authentik.root.celery import CELERY_APP +from authentik.tasks.models import Task LOGGER = get_logger() -@CELERY_APP.task(bind=True, base=SystemTask) -@prefill_task -def check_and_purge_password_history(self: SystemTask): - """Check if any UniquePasswordPolicy exists, and if not, purge the password history table. - This is run on a schedule instead of being triggered by policy binding deletion. - """ +@actor( + description=_( + "Check if any UniquePasswordPolicy exists, and if not, purge the password history table." + ) +) +def check_and_purge_password_history(): + self: Task = CurrentTask.get_task() + if not UniquePasswordPolicy.objects.exists(): UserPasswordHistory.objects.all().delete() LOGGER.debug("Purged UserPasswordHistory table as no policies are in use") - self.set_status(TaskStatus.SUCCESSFUL, "Successfully purged UserPasswordHistory") + self.info("Successfully purged UserPasswordHistory") return - self.set_status( - TaskStatus.SUCCESSFUL, "Not purging password histories, a unique password policy exists" - ) + self.info("Not purging password histories, a unique password policy exists") -@CELERY_APP.task(bind=True, base=SystemTask) -def trim_password_histories(self: SystemTask): +@actor(description=_("Remove user password history that are too old.")) +def trim_password_histories(): """Removes rows from UserPasswordHistory older than the `n` most recent entries. @@ -37,6 +39,8 @@ def trim_password_histories(self: SystemTask): UniquePasswordPolicy policies. """ + self: Task = CurrentTask.get_task() + # No policy, we'll let the cleanup above do its thing if not UniquePasswordPolicy.objects.exists(): return @@ -63,4 +67,4 @@ def trim_password_histories(self: SystemTask): num_deleted, _ = UserPasswordHistory.objects.exclude(pk__in=all_pks_to_keep).delete() LOGGER.debug("Deleted stale password history records", count=num_deleted) - self.set_status(TaskStatus.SUCCESSFUL, f"Delete {num_deleted} stale password history records") + self.info(f"Delete {num_deleted} stale password history records") diff --git a/authentik/enterprise/policies/unique_password/tests/test_tasks.py b/authentik/enterprise/policies/unique_password/tests/test_tasks.py index 3e46c67b9d..3ecd5e9b45 100644 --- a/authentik/enterprise/policies/unique_password/tests/test_tasks.py +++ b/authentik/enterprise/policies/unique_password/tests/test_tasks.py @@ -76,7 +76,7 @@ class TestCheckAndPurgePasswordHistory(TestCase): self.assertTrue(UserPasswordHistory.objects.exists()) # Run the task - should purge since no policy is in use - check_and_purge_password_history() + check_and_purge_password_history.send() # Verify the table is empty self.assertFalse(UserPasswordHistory.objects.exists()) @@ -99,7 +99,7 @@ class TestCheckAndPurgePasswordHistory(TestCase): self.assertTrue(UserPasswordHistory.objects.exists()) # Run the task - should NOT purge since a policy is in use - check_and_purge_password_history() + check_and_purge_password_history.send() # Verify the entries still exist self.assertTrue(UserPasswordHistory.objects.exists()) @@ -142,7 +142,7 @@ class TestTrimPasswordHistory(TestCase): enabled=True, order=0, ) - trim_password_histories.delay() + trim_password_histories.send() user_pwd_history_qs = UserPasswordHistory.objects.filter(user=self.user) self.assertEqual(len(user_pwd_history_qs), 1) @@ -159,7 +159,7 @@ class TestTrimPasswordHistory(TestCase): enabled=False, order=0, ) - trim_password_histories.delay() + trim_password_histories.send() self.assertTrue(UserPasswordHistory.objects.filter(user=self.user).exists()) def test_trim_password_history_fewer_records_than_maximum_is_no_op(self): @@ -174,5 +174,5 @@ class TestTrimPasswordHistory(TestCase): enabled=True, order=0, ) - trim_password_histories.delay() + trim_password_histories.send() self.assertTrue(UserPasswordHistory.objects.filter(user=self.user).exists()) diff --git a/authentik/enterprise/providers/google_workspace/api/providers.py b/authentik/enterprise/providers/google_workspace/api/providers.py index 772789ce05..ab162cefd9 100644 --- a/authentik/enterprise/providers/google_workspace/api/providers.py +++ b/authentik/enterprise/providers/google_workspace/api/providers.py @@ -55,5 +55,5 @@ class GoogleWorkspaceProviderViewSet(OutgoingSyncProviderStatusMixin, UsedByMixi ] search_fields = ["name"] ordering = ["name"] - sync_single_task = google_workspace_sync + sync_task = google_workspace_sync sync_objects_task = google_workspace_sync_objects diff --git a/authentik/enterprise/providers/google_workspace/models.py b/authentik/enterprise/providers/google_workspace/models.py index 188f061fc6..e3fde7de74 100644 --- a/authentik/enterprise/providers/google_workspace/models.py +++ b/authentik/enterprise/providers/google_workspace/models.py @@ -7,6 +7,7 @@ from django.db import models from django.db.models import QuerySet from django.templatetags.static import static from django.utils.translation import gettext_lazy as _ +from dramatiq.actor import Actor from google.oauth2.service_account import Credentials from rest_framework.serializers import Serializer @@ -110,6 +111,12 @@ class GoogleWorkspaceProvider(OutgoingSyncProvider, BackchannelProvider): help_text=_("Property mappings used for group creation/updating."), ) + @property + def sync_actor(self) -> Actor: + from authentik.enterprise.providers.google_workspace.tasks import google_workspace_sync + + return google_workspace_sync + def client_for_model( self, model: type[User | Group | GoogleWorkspaceProviderUser | GoogleWorkspaceProviderGroup], diff --git a/authentik/enterprise/providers/google_workspace/settings.py b/authentik/enterprise/providers/google_workspace/settings.py deleted file mode 100644 index 443a1a1884..0000000000 --- a/authentik/enterprise/providers/google_workspace/settings.py +++ /dev/null @@ -1,13 +0,0 @@ -"""Google workspace provider task Settings""" - -from celery.schedules import crontab - -from authentik.lib.utils.time import fqdn_rand - -CELERY_BEAT_SCHEDULE = { - "providers_google_workspace_sync": { - "task": "authentik.enterprise.providers.google_workspace.tasks.google_workspace_sync_all", - "schedule": crontab(minute=fqdn_rand("google_workspace_sync_all"), hour="*/4"), - "options": {"queue": "authentik_scheduled"}, - }, -} diff --git a/authentik/enterprise/providers/google_workspace/signals.py b/authentik/enterprise/providers/google_workspace/signals.py index 2e7eb70a94..ce8b5241b5 100644 --- a/authentik/enterprise/providers/google_workspace/signals.py +++ b/authentik/enterprise/providers/google_workspace/signals.py @@ -2,15 +2,13 @@ from authentik.enterprise.providers.google_workspace.models import GoogleWorkspaceProvider from authentik.enterprise.providers.google_workspace.tasks import ( - google_workspace_sync, - google_workspace_sync_direct, - google_workspace_sync_m2m, + google_workspace_sync_direct_dispatch, + google_workspace_sync_m2m_dispatch, ) from authentik.lib.sync.outgoing.signals import register_signals register_signals( GoogleWorkspaceProvider, - task_sync_single=google_workspace_sync, - task_sync_direct=google_workspace_sync_direct, - task_sync_m2m=google_workspace_sync_m2m, + task_sync_direct_dispatch=google_workspace_sync_direct_dispatch, + task_sync_m2m_dispatch=google_workspace_sync_m2m_dispatch, ) diff --git a/authentik/enterprise/providers/google_workspace/tasks.py b/authentik/enterprise/providers/google_workspace/tasks.py index 237076411a..00e9118513 100644 --- a/authentik/enterprise/providers/google_workspace/tasks.py +++ b/authentik/enterprise/providers/google_workspace/tasks.py @@ -1,37 +1,48 @@ """Google Provider tasks""" +from django.utils.translation import gettext_lazy as _ +from dramatiq.actor import actor + from authentik.enterprise.providers.google_workspace.models import GoogleWorkspaceProvider -from authentik.events.system_tasks import SystemTask -from authentik.lib.sync.outgoing.exceptions import TransientSyncException from authentik.lib.sync.outgoing.tasks import SyncTasks -from authentik.root.celery import CELERY_APP sync_tasks = SyncTasks(GoogleWorkspaceProvider) -@CELERY_APP.task(autoretry_for=(TransientSyncException,), retry_backoff=True) +@actor(description=_("Sync Google Workspace provider objects.")) def google_workspace_sync_objects(*args, **kwargs): return sync_tasks.sync_objects(*args, **kwargs) -@CELERY_APP.task( - base=SystemTask, bind=True, autoretry_for=(TransientSyncException,), retry_backoff=True -) -def google_workspace_sync(self, provider_pk: int, *args, **kwargs): +@actor(description=_("Full sync for Google Workspace provider.")) +def google_workspace_sync(provider_pk: int, *args, **kwargs): """Run full sync for Google Workspace provider""" - return sync_tasks.sync_single(self, provider_pk, google_workspace_sync_objects) + return sync_tasks.sync(provider_pk, google_workspace_sync_objects) -@CELERY_APP.task() -def google_workspace_sync_all(): - return sync_tasks.sync_all(google_workspace_sync) - - -@CELERY_APP.task(autoretry_for=(TransientSyncException,), retry_backoff=True) +@actor(description=_("Sync a direct object (user, group) for Google Workspace provider.")) def google_workspace_sync_direct(*args, **kwargs): return sync_tasks.sync_signal_direct(*args, **kwargs) -@CELERY_APP.task(autoretry_for=(TransientSyncException,), retry_backoff=True) +@actor( + description=_( + "Dispatch syncs for a direct object (user, group) for Google Workspace providers." + ) +) +def google_workspace_sync_direct_dispatch(*args, **kwargs): + return sync_tasks.sync_signal_direct_dispatch(google_workspace_sync_direct, *args, **kwargs) + + +@actor(description=_("Sync a related object (memberships) for Google Workspace provider.")) def google_workspace_sync_m2m(*args, **kwargs): return sync_tasks.sync_signal_m2m(*args, **kwargs) + + +@actor( + description=_( + "Dispatch syncs for a related object (memberships) for Google Workspace providers." + ) +) +def google_workspace_sync_m2m_dispatch(*args, **kwargs): + return sync_tasks.sync_signal_m2m_dispatch(google_workspace_sync_m2m, *args, **kwargs) diff --git a/authentik/enterprise/providers/google_workspace/tests/test_groups.py b/authentik/enterprise/providers/google_workspace/tests/test_groups.py index fbc52a94aa..c36509ddd2 100644 --- a/authentik/enterprise/providers/google_workspace/tests/test_groups.py +++ b/authentik/enterprise/providers/google_workspace/tests/test_groups.py @@ -324,7 +324,7 @@ class GoogleWorkspaceGroupTests(TestCase): "authentik.enterprise.providers.google_workspace.models.GoogleWorkspaceProvider.google_credentials", MagicMock(return_value={"developerKey": self.api_key, "http": http}), ): - google_workspace_sync.delay(self.provider.pk).get() + google_workspace_sync.send(self.provider.pk).get_result() self.assertTrue( GoogleWorkspaceProviderGroup.objects.filter( group=different_group, provider=self.provider diff --git a/authentik/enterprise/providers/google_workspace/tests/test_users.py b/authentik/enterprise/providers/google_workspace/tests/test_users.py index 0f7ec0d14e..8452976e67 100644 --- a/authentik/enterprise/providers/google_workspace/tests/test_users.py +++ b/authentik/enterprise/providers/google_workspace/tests/test_users.py @@ -302,7 +302,7 @@ class GoogleWorkspaceUserTests(TestCase): "authentik.enterprise.providers.google_workspace.models.GoogleWorkspaceProvider.google_credentials", MagicMock(return_value={"developerKey": self.api_key, "http": http}), ): - google_workspace_sync.delay(self.provider.pk).get() + google_workspace_sync.send(self.provider.pk).get_result() self.assertTrue( GoogleWorkspaceProviderUser.objects.filter( user=different_user, provider=self.provider diff --git a/authentik/enterprise/providers/microsoft_entra/api/providers.py b/authentik/enterprise/providers/microsoft_entra/api/providers.py index f1a9396e9d..20ddfa2d1f 100644 --- a/authentik/enterprise/providers/microsoft_entra/api/providers.py +++ b/authentik/enterprise/providers/microsoft_entra/api/providers.py @@ -53,5 +53,5 @@ class MicrosoftEntraProviderViewSet(OutgoingSyncProviderStatusMixin, UsedByMixin ] search_fields = ["name"] ordering = ["name"] - sync_single_task = microsoft_entra_sync + sync_task = microsoft_entra_sync sync_objects_task = microsoft_entra_sync_objects diff --git a/authentik/enterprise/providers/microsoft_entra/models.py b/authentik/enterprise/providers/microsoft_entra/models.py index 56bb346bcc..37c6512d58 100644 --- a/authentik/enterprise/providers/microsoft_entra/models.py +++ b/authentik/enterprise/providers/microsoft_entra/models.py @@ -8,6 +8,7 @@ from django.db import models from django.db.models import QuerySet from django.templatetags.static import static from django.utils.translation import gettext_lazy as _ +from dramatiq.actor import Actor from rest_framework.serializers import Serializer from authentik.core.models import ( @@ -99,6 +100,12 @@ class MicrosoftEntraProvider(OutgoingSyncProvider, BackchannelProvider): help_text=_("Property mappings used for group creation/updating."), ) + @property + def sync_actor(self) -> Actor: + from authentik.enterprise.providers.microsoft_entra.tasks import microsoft_entra_sync + + return microsoft_entra_sync + def client_for_model( self, model: type[User | Group | MicrosoftEntraProviderUser | MicrosoftEntraProviderGroup], diff --git a/authentik/enterprise/providers/microsoft_entra/settings.py b/authentik/enterprise/providers/microsoft_entra/settings.py deleted file mode 100644 index 08ef592de8..0000000000 --- a/authentik/enterprise/providers/microsoft_entra/settings.py +++ /dev/null @@ -1,13 +0,0 @@ -"""Microsoft Entra provider task Settings""" - -from celery.schedules import crontab - -from authentik.lib.utils.time import fqdn_rand - -CELERY_BEAT_SCHEDULE = { - "providers_microsoft_entra_sync": { - "task": "authentik.enterprise.providers.microsoft_entra.tasks.microsoft_entra_sync_all", - "schedule": crontab(minute=fqdn_rand("microsoft_entra_sync_all"), hour="*/4"), - "options": {"queue": "authentik_scheduled"}, - }, -} diff --git a/authentik/enterprise/providers/microsoft_entra/signals.py b/authentik/enterprise/providers/microsoft_entra/signals.py index b9063ccb8b..75915d7117 100644 --- a/authentik/enterprise/providers/microsoft_entra/signals.py +++ b/authentik/enterprise/providers/microsoft_entra/signals.py @@ -2,15 +2,13 @@ from authentik.enterprise.providers.microsoft_entra.models import MicrosoftEntraProvider from authentik.enterprise.providers.microsoft_entra.tasks import ( - microsoft_entra_sync, - microsoft_entra_sync_direct, - microsoft_entra_sync_m2m, + microsoft_entra_sync_direct_dispatch, + microsoft_entra_sync_m2m_dispatch, ) from authentik.lib.sync.outgoing.signals import register_signals register_signals( MicrosoftEntraProvider, - task_sync_single=microsoft_entra_sync, - task_sync_direct=microsoft_entra_sync_direct, - task_sync_m2m=microsoft_entra_sync_m2m, + task_sync_direct_dispatch=microsoft_entra_sync_direct_dispatch, + task_sync_m2m_dispatch=microsoft_entra_sync_m2m_dispatch, ) diff --git a/authentik/enterprise/providers/microsoft_entra/tasks.py b/authentik/enterprise/providers/microsoft_entra/tasks.py index 6985b8acfa..0c0f922f9f 100644 --- a/authentik/enterprise/providers/microsoft_entra/tasks.py +++ b/authentik/enterprise/providers/microsoft_entra/tasks.py @@ -1,37 +1,46 @@ """Microsoft Entra Provider tasks""" +from django.utils.translation import gettext_lazy as _ +from dramatiq.actor import actor + from authentik.enterprise.providers.microsoft_entra.models import MicrosoftEntraProvider -from authentik.events.system_tasks import SystemTask -from authentik.lib.sync.outgoing.exceptions import TransientSyncException from authentik.lib.sync.outgoing.tasks import SyncTasks -from authentik.root.celery import CELERY_APP sync_tasks = SyncTasks(MicrosoftEntraProvider) -@CELERY_APP.task(autoretry_for=(TransientSyncException,), retry_backoff=True) +@actor(description=_("Sync Microsoft Entra provider objects.")) def microsoft_entra_sync_objects(*args, **kwargs): return sync_tasks.sync_objects(*args, **kwargs) -@CELERY_APP.task( - base=SystemTask, bind=True, autoretry_for=(TransientSyncException,), retry_backoff=True -) -def microsoft_entra_sync(self, provider_pk: int, *args, **kwargs): +@actor(description=_("Full sync for Microsoft Entra provider.")) +def microsoft_entra_sync(provider_pk: int, *args, **kwargs): """Run full sync for Microsoft Entra provider""" - return sync_tasks.sync_single(self, provider_pk, microsoft_entra_sync_objects) + return sync_tasks.sync(provider_pk, microsoft_entra_sync_objects) -@CELERY_APP.task() -def microsoft_entra_sync_all(): - return sync_tasks.sync_all(microsoft_entra_sync) - - -@CELERY_APP.task(autoretry_for=(TransientSyncException,), retry_backoff=True) +@actor(description=_("Sync a direct object (user, group) for Microsoft Entra provider.")) def microsoft_entra_sync_direct(*args, **kwargs): return sync_tasks.sync_signal_direct(*args, **kwargs) -@CELERY_APP.task(autoretry_for=(TransientSyncException,), retry_backoff=True) +@actor( + description=_("Dispatch syncs for a direct object (user, group) for Microsoft Entra providers.") +) +def microsoft_entra_sync_direct_dispatch(*args, **kwargs): + return sync_tasks.sync_signal_direct_dispatch(microsoft_entra_sync_direct, *args, **kwargs) + + +@actor(description=_("Sync a related object (memberships) for Microsoft Entra provider.")) def microsoft_entra_sync_m2m(*args, **kwargs): return sync_tasks.sync_signal_m2m(*args, **kwargs) + + +@actor( + description=_( + "Dispatch syncs for a related object (memberships) for Microsoft Entra providers." + ) +) +def microsoft_entra_sync_m2m_dispatch(*args, **kwargs): + return sync_tasks.sync_signal_m2m_dispatch(microsoft_entra_sync_m2m, *args, **kwargs) diff --git a/authentik/enterprise/providers/microsoft_entra/tests/test_groups.py b/authentik/enterprise/providers/microsoft_entra/tests/test_groups.py index c39d0ca206..d73f7149ea 100644 --- a/authentik/enterprise/providers/microsoft_entra/tests/test_groups.py +++ b/authentik/enterprise/providers/microsoft_entra/tests/test_groups.py @@ -252,9 +252,13 @@ class MicrosoftEntraGroupTests(TestCase): member_add.assert_called_once() self.assertEqual( member_add.call_args[0][0].odata_id, - f"https://graph.microsoft.com/v1.0/directoryObjects/{MicrosoftEntraProviderUser.objects.filter( + f"https://graph.microsoft.com/v1.0/directoryObjects/{ + MicrosoftEntraProviderUser.objects.filter( provider=self.provider, - ).first().microsoft_id}", + ) + .first() + .microsoft_id + }", ) def test_group_create_member_remove(self): @@ -311,9 +315,13 @@ class MicrosoftEntraGroupTests(TestCase): member_add.assert_called_once() self.assertEqual( member_add.call_args[0][0].odata_id, - f"https://graph.microsoft.com/v1.0/directoryObjects/{MicrosoftEntraProviderUser.objects.filter( + f"https://graph.microsoft.com/v1.0/directoryObjects/{ + MicrosoftEntraProviderUser.objects.filter( provider=self.provider, - ).first().microsoft_id}", + ) + .first() + .microsoft_id + }", ) member_remove.assert_called_once() @@ -413,7 +421,7 @@ class MicrosoftEntraGroupTests(TestCase): ), ) as group_list, ): - microsoft_entra_sync.delay(self.provider.pk).get() + microsoft_entra_sync.send(self.provider.pk).get_result() self.assertTrue( MicrosoftEntraProviderGroup.objects.filter( group=different_group, provider=self.provider diff --git a/authentik/enterprise/providers/microsoft_entra/tests/test_users.py b/authentik/enterprise/providers/microsoft_entra/tests/test_users.py index 8c46e998aa..403d0e0ef9 100644 --- a/authentik/enterprise/providers/microsoft_entra/tests/test_users.py +++ b/authentik/enterprise/providers/microsoft_entra/tests/test_users.py @@ -397,7 +397,7 @@ class MicrosoftEntraUserTests(APITestCase): AsyncMock(return_value=GroupCollectionResponse(value=[])), ), ): - microsoft_entra_sync.delay(self.provider.pk).get() + microsoft_entra_sync.send(self.provider.pk).get_result() self.assertTrue( MicrosoftEntraProviderUser.objects.filter( user=different_user, provider=self.provider diff --git a/authentik/enterprise/providers/ssf/models.py b/authentik/enterprise/providers/ssf/models.py index 9e34031c58..78b3129285 100644 --- a/authentik/enterprise/providers/ssf/models.py +++ b/authentik/enterprise/providers/ssf/models.py @@ -17,6 +17,7 @@ from authentik.crypto.models import CertificateKeyPair from authentik.lib.models import CreatedUpdatedModel from authentik.lib.utils.time import timedelta_from_string, timedelta_string_validator from authentik.providers.oauth2.models import JWTAlgorithms, OAuth2Provider +from authentik.tasks.models import TasksModel class EventTypes(models.TextChoices): @@ -42,7 +43,7 @@ class SSFEventStatus(models.TextChoices): SENT = "sent" -class SSFProvider(BackchannelProvider): +class SSFProvider(TasksModel, BackchannelProvider): """Shared Signals Framework provider to allow applications to receive user events from authentik.""" diff --git a/authentik/enterprise/providers/ssf/signals.py b/authentik/enterprise/providers/ssf/signals.py index 10b3c40e10..1d4e657790 100644 --- a/authentik/enterprise/providers/ssf/signals.py +++ b/authentik/enterprise/providers/ssf/signals.py @@ -18,7 +18,7 @@ from authentik.enterprise.providers.ssf.models import ( EventTypes, SSFProvider, ) -from authentik.enterprise.providers.ssf.tasks import send_ssf_event +from authentik.enterprise.providers.ssf.tasks import send_ssf_events from authentik.events.middleware import audit_ignore from authentik.stages.authenticator.models import Device from authentik.stages.authenticator_duo.models import DuoDevice @@ -66,7 +66,7 @@ def ssf_user_session_delete_session_revoked(sender, instance: AuthenticatedSessi As this signal is also triggered with a regular logout, we can't be sure if the session has been deleted by an admin or by the user themselves.""" - send_ssf_event( + send_ssf_events( EventTypes.CAEP_SESSION_REVOKED, { "initiating_entity": "user", @@ -88,7 +88,7 @@ def ssf_user_session_delete_session_revoked(sender, instance: AuthenticatedSessi @receiver(password_changed) def ssf_password_changed_cred_change(sender, user: User, password: str | None, **_): """Credential change trigger (password changed)""" - send_ssf_event( + send_ssf_events( EventTypes.CAEP_CREDENTIAL_CHANGE, { "credential_type": "password", @@ -126,7 +126,7 @@ def ssf_device_post_save(sender: type[Model], instance: Device, created: bool, * } if isinstance(instance, WebAuthnDevice) and instance.aaguid != UNKNOWN_DEVICE_TYPE_AAGUID: data["fido2_aaguid"] = instance.aaguid - send_ssf_event( + send_ssf_events( EventTypes.CAEP_CREDENTIAL_CHANGE, data, sub_id={ @@ -153,7 +153,7 @@ def ssf_device_post_delete(sender: type[Model], instance: Device, **_): } if isinstance(instance, WebAuthnDevice) and instance.aaguid != UNKNOWN_DEVICE_TYPE_AAGUID: data["fido2_aaguid"] = instance.aaguid - send_ssf_event( + send_ssf_events( EventTypes.CAEP_CREDENTIAL_CHANGE, data, sub_id={ diff --git a/authentik/enterprise/providers/ssf/tasks.py b/authentik/enterprise/providers/ssf/tasks.py index e3779d6d39..d842cc287a 100644 --- a/authentik/enterprise/providers/ssf/tasks.py +++ b/authentik/enterprise/providers/ssf/tasks.py @@ -1,7 +1,11 @@ -from celery import group +from typing import Any +from uuid import UUID + from django.http import HttpRequest from django.utils.timezone import now from django.utils.translation import gettext_lazy as _ +from django_dramatiq_postgres.middleware import CurrentTask +from dramatiq.actor import actor from requests.exceptions import RequestException from structlog.stdlib import get_logger @@ -13,19 +17,16 @@ from authentik.enterprise.providers.ssf.models import ( Stream, StreamEvent, ) -from authentik.events.logs import LogEvent -from authentik.events.models import TaskStatus -from authentik.events.system_tasks import SystemTask from authentik.lib.utils.http import get_http_session from authentik.lib.utils.time import timedelta_from_string from authentik.policies.engine import PolicyEngine -from authentik.root.celery import CELERY_APP +from authentik.tasks.models import Task session = get_http_session() LOGGER = get_logger() -def send_ssf_event( +def send_ssf_events( event_type: EventTypes, data: dict, stream_filter: dict | None = None, @@ -33,7 +34,7 @@ def send_ssf_event( **extra_data, ): """Wrapper to send an SSF event to multiple streams""" - payload = [] + events_data = {} if not stream_filter: stream_filter = {} stream_filter["events_requested__contains"] = [event_type] @@ -41,16 +42,22 @@ def send_ssf_event( extra_data.setdefault("txn", request.request_id) for stream in Stream.objects.filter(**stream_filter): event_data = stream.prepare_event_payload(event_type, data, **extra_data) - payload.append((str(stream.uuid), event_data)) - return _send_ssf_event.delay(payload) + events_data[stream.uuid] = event_data + ssf_events_dispatch.send(events_data) -def _check_app_access(stream_uuid: str, event_data: dict) -> bool: +@actor(description=_("Dispatch SSF events.")) +def ssf_events_dispatch(events_data: dict[str, dict[str, Any]]): + for stream_uuid, event_data in events_data.items(): + stream = Stream.objects.filter(pk=stream_uuid).first() + if not stream: + continue + send_ssf_event.send_with_options(args=(stream_uuid, event_data), rel_obj=stream.provider) + + +def _check_app_access(stream: Stream, event_data: dict) -> bool: """Check if event is related to user and if so, check if the user has access to the application""" - stream = Stream.objects.filter(pk=stream_uuid).first() - if not stream: - return False # `event_data` is a dict version of a StreamEvent sub_id = event_data.get("payload", {}).get("sub_id", {}) email = sub_id.get("user", {}).get("email", None) @@ -65,42 +72,22 @@ def _check_app_access(stream_uuid: str, event_data: dict) -> bool: return engine.passing -@CELERY_APP.task() -def _send_ssf_event(event_data: list[tuple[str, dict]]): - tasks = [] - for stream, data in event_data: - if not _check_app_access(stream, data): - continue - event = StreamEvent.objects.create(**data) - tasks.extend(send_single_ssf_event(stream, str(event.uuid))) - main_task = group(*tasks) - main_task() +@actor(description=_("Send an SSF event.")) +def send_ssf_event(stream_uuid: UUID, event_data: dict[str, Any]): + self: Task = CurrentTask.get_task() - -def send_single_ssf_event(stream_id: str, evt_id: str): - stream = Stream.objects.filter(pk=stream_id).first() + stream = Stream.objects.filter(pk=stream_uuid).first() if not stream: return - event = StreamEvent.objects.filter(pk=evt_id).first() - if not event: + if not _check_app_access(stream, event_data): return + event = StreamEvent.objects.create(**event_data) + self.set_uid(event.pk) if event.status == SSFEventStatus.SENT: return - if stream.delivery_method == DeliveryMethods.RISC_PUSH: - return [ssf_push_event.si(str(event.pk))] - return [] - - -@CELERY_APP.task(bind=True, base=SystemTask) -def ssf_push_event(self: SystemTask, event_id: str): - self.save_on_success = False - event = StreamEvent.objects.filter(pk=event_id).first() - if not event: - return - self.set_uid(event_id) - if event.status == SSFEventStatus.SENT: - self.set_status(TaskStatus.SUCCESSFUL) + if stream.delivery_method != DeliveryMethods.RISC_PUSH: return + try: response = session.post( event.stream.endpoint_url, @@ -110,26 +97,17 @@ def ssf_push_event(self: SystemTask, event_id: str): response.raise_for_status() event.status = SSFEventStatus.SENT event.save() - self.set_status(TaskStatus.SUCCESSFUL) return except RequestException as exc: LOGGER.warning("Failed to send SSF event", exc=exc) - self.set_status(TaskStatus.ERROR) attrs = {} if exc.response: attrs["response"] = { "content": exc.response.text, "status": exc.response.status_code, } - self.set_error( - exc, - LogEvent( - _("Failed to send request"), - log_level="warning", - logger=self.__name__, - attributes=attrs, - ), - ) + self.warning(exc) + self.warning("Failed to send request", **attrs) # Re-up the expiry of the stream event event.expires = now() + timedelta_from_string(event.stream.provider.event_retention) event.status = SSFEventStatus.PENDING_FAILED diff --git a/authentik/enterprise/providers/ssf/views/stream.py b/authentik/enterprise/providers/ssf/views/stream.py index 0b06ec0efb..2b36271739 100644 --- a/authentik/enterprise/providers/ssf/views/stream.py +++ b/authentik/enterprise/providers/ssf/views/stream.py @@ -13,7 +13,7 @@ from authentik.enterprise.providers.ssf.models import ( SSFProvider, Stream, ) -from authentik.enterprise.providers.ssf.tasks import send_ssf_event +from authentik.enterprise.providers.ssf.tasks import send_ssf_events from authentik.enterprise.providers.ssf.views.base import SSFView LOGGER = get_logger() @@ -109,7 +109,7 @@ class StreamView(SSFView): "User does not have permission to create stream for this provider." ) instance: Stream = stream.save(provider=self.provider) - send_ssf_event( + send_ssf_events( EventTypes.SET_VERIFICATION, { "state": None, diff --git a/authentik/enterprise/settings.py b/authentik/enterprise/settings.py index 59b8a0e8ca..97b988d605 100644 --- a/authentik/enterprise/settings.py +++ b/authentik/enterprise/settings.py @@ -1,17 +1,5 @@ """Enterprise additional settings""" -from celery.schedules import crontab - -from authentik.lib.utils.time import fqdn_rand - -CELERY_BEAT_SCHEDULE = { - "enterprise_update_usage": { - "task": "authentik.enterprise.tasks.enterprise_update_usage", - "schedule": crontab(minute=fqdn_rand("enterprise_update_usage"), hour="*/2"), - "options": {"queue": "authentik_scheduled"}, - } -} - TENANT_APPS = [ "authentik.enterprise.audit", "authentik.enterprise.policies.unique_password", diff --git a/authentik/enterprise/signals.py b/authentik/enterprise/signals.py index 4fb11d4841..78eedf0b08 100644 --- a/authentik/enterprise/signals.py +++ b/authentik/enterprise/signals.py @@ -10,6 +10,7 @@ from django.utils.timezone import get_current_timezone from authentik.enterprise.license import CACHE_KEY_ENTERPRISE_LICENSE from authentik.enterprise.models import License from authentik.enterprise.tasks import enterprise_update_usage +from authentik.tasks.schedules.models import Schedule @receiver(pre_save, sender=License) @@ -26,7 +27,7 @@ def pre_save_license(sender: type[License], instance: License, **_): def post_save_license(sender: type[License], instance: License, **_): """Trigger license usage calculation when license is saved""" cache.delete(CACHE_KEY_ENTERPRISE_LICENSE) - enterprise_update_usage.delay() + Schedule.dispatch_by_actor(enterprise_update_usage) @receiver(post_delete, sender=License) diff --git a/authentik/enterprise/tasks.py b/authentik/enterprise/tasks.py index a55ab5e13d..7c5a3bbea0 100644 --- a/authentik/enterprise/tasks.py +++ b/authentik/enterprise/tasks.py @@ -1,14 +1,11 @@ """Enterprise tasks""" +from django.utils.translation import gettext_lazy as _ +from dramatiq.actor import actor + from authentik.enterprise.license import LicenseKey -from authentik.events.models import TaskStatus -from authentik.events.system_tasks import SystemTask, prefill_task -from authentik.root.celery import CELERY_APP -@CELERY_APP.task(bind=True, base=SystemTask) -@prefill_task -def enterprise_update_usage(self: SystemTask): - """Update enterprise license status""" +@actor(description=_("Update enterprise license status.")) +def enterprise_update_usage(): LicenseKey.get_total().record_usage() - self.set_status(TaskStatus.SUCCESSFUL) diff --git a/authentik/events/api/tasks.py b/authentik/events/api/tasks.py deleted file mode 100644 index 51c0b611c4..0000000000 --- a/authentik/events/api/tasks.py +++ /dev/null @@ -1,104 +0,0 @@ -"""Tasks API""" - -from importlib import import_module - -from django.contrib import messages -from django.utils.translation import gettext_lazy as _ -from drf_spectacular.types import OpenApiTypes -from drf_spectacular.utils import OpenApiResponse, extend_schema -from rest_framework.decorators import action -from rest_framework.fields import ( - CharField, - ChoiceField, - DateTimeField, - FloatField, - SerializerMethodField, -) -from rest_framework.request import Request -from rest_framework.response import Response -from rest_framework.viewsets import ReadOnlyModelViewSet -from structlog.stdlib import get_logger - -from authentik.core.api.utils import ModelSerializer -from authentik.events.logs import LogEventSerializer -from authentik.events.models import SystemTask, TaskStatus -from authentik.rbac.decorators import permission_required - -LOGGER = get_logger() - - -class SystemTaskSerializer(ModelSerializer): - """Serialize TaskInfo and TaskResult""" - - name = CharField() - full_name = SerializerMethodField() - uid = CharField(required=False) - description = CharField() - start_timestamp = DateTimeField(read_only=True) - finish_timestamp = DateTimeField(read_only=True) - duration = FloatField(read_only=True) - - status = ChoiceField(choices=[(x.value, x.name) for x in TaskStatus]) - messages = LogEventSerializer(many=True) - - def get_full_name(self, instance: SystemTask) -> str: - """Get full name with UID""" - if instance.uid: - return f"{instance.name}:{instance.uid}" - return instance.name - - class Meta: - model = SystemTask - fields = [ - "uuid", - "name", - "full_name", - "uid", - "description", - "start_timestamp", - "finish_timestamp", - "duration", - "status", - "messages", - "expires", - "expiring", - ] - - -class SystemTaskViewSet(ReadOnlyModelViewSet): - """Read-only view set that returns all background tasks""" - - queryset = SystemTask.objects.all() - serializer_class = SystemTaskSerializer - filterset_fields = ["name", "uid", "status"] - ordering = ["name", "uid", "status"] - search_fields = ["name", "description", "uid", "status"] - - @permission_required(None, ["authentik_events.run_task"]) - @extend_schema( - request=OpenApiTypes.NONE, - responses={ - 204: OpenApiResponse(description="Task retried successfully"), - 404: OpenApiResponse(description="Task not found"), - 500: OpenApiResponse(description="Failed to retry task"), - }, - ) - @action(detail=True, methods=["POST"], permission_classes=[]) - def run(self, request: Request, pk=None) -> Response: - """Run task""" - task: SystemTask = self.get_object() - try: - task_module = import_module(task.task_call_module) - task_func = getattr(task_module, task.task_call_func) - LOGGER.info("Running task", task=task_func) - task_func.delay(*task.task_call_args, **task.task_call_kwargs) - messages.success( - self.request, - _("Successfully started task {name}.".format_map({"name": task.name})), - ) - return Response(status=204) - except (ImportError, AttributeError) as exc: # pragma: no cover - LOGGER.warning("Failed to run task, remove state", task=task.name, exc=exc) - # if we get an import error, the module path has probably changed - task.delete() - return Response(status=500) diff --git a/authentik/events/apps.py b/authentik/events/apps.py index 915b923afc..ccca6df76b 100644 --- a/authentik/events/apps.py +++ b/authentik/events/apps.py @@ -1,12 +1,11 @@ """authentik events app""" -from celery.schedules import crontab from prometheus_client import Gauge, Histogram from authentik.blueprints.apps import ManagedAppConfig from authentik.lib.config import CONFIG, ENV_PREFIX -from authentik.lib.utils.reflection import path_to_class -from authentik.root.celery import CELERY_APP +from authentik.lib.utils.time import fqdn_rand +from authentik.tasks.schedules.common import ScheduleSpec # TODO: Deprecated metric - remove in 2024.2 or later GAUGE_TASKS = Gauge( @@ -35,6 +34,17 @@ class AuthentikEventsConfig(ManagedAppConfig): verbose_name = "authentik Events" default = True + @property + def tenant_schedule_specs(self) -> list[ScheduleSpec]: + from authentik.events.tasks import notification_cleanup + + return [ + ScheduleSpec( + actor=notification_cleanup, + crontab=f"{fqdn_rand('notification_cleanup')} */8 * * *", + ), + ] + @ManagedAppConfig.reconcile_global def check_deprecations(self): """Check for config deprecations""" @@ -56,41 +66,3 @@ class AuthentikEventsConfig(ManagedAppConfig): replacement_env=replace_env, message=msg, ).save() - - @ManagedAppConfig.reconcile_tenant - def prefill_tasks(self): - """Prefill tasks""" - from authentik.events.models import SystemTask - from authentik.events.system_tasks import _prefill_tasks - - for task in _prefill_tasks: - if SystemTask.objects.filter(name=task.name).exists(): - continue - task.save() - self.logger.debug("prefilled task", task_name=task.name) - - @ManagedAppConfig.reconcile_tenant - def run_scheduled_tasks(self): - """Run schedule tasks which are behind schedule (only applies - to tasks of which we keep metrics)""" - from authentik.events.models import TaskStatus - from authentik.events.system_tasks import SystemTask as CelerySystemTask - - for task in CELERY_APP.conf["beat_schedule"].values(): - schedule = task["schedule"] - if not isinstance(schedule, crontab): - continue - task_class: CelerySystemTask = path_to_class(task["task"]) - if not isinstance(task_class, CelerySystemTask): - continue - db_task = task_class.db() - if not db_task: - continue - due, _ = schedule.is_due(db_task.finish_timestamp) - if due or db_task.status == TaskStatus.UNKNOWN: - self.logger.debug("Running past-due scheduled task", task=task["task"]) - task_class.apply_async( - args=task.get("args", None), - kwargs=task.get("kwargs", None), - **task.get("options", {}), - ) diff --git a/authentik/events/migrations/0011_alter_systemtask_options.py b/authentik/events/migrations/0011_alter_systemtask_options.py new file mode 100644 index 0000000000..1784ef7ea5 --- /dev/null +++ b/authentik/events/migrations/0011_alter_systemtask_options.py @@ -0,0 +1,22 @@ +# Generated by Django 5.1.11 on 2025-06-24 15:36 + +from django.db import migrations + + +class Migration(migrations.Migration): + + dependencies = [ + ("authentik_events", "0010_rename_group_notificationrule_destination_group_and_more"), + ] + + operations = [ + migrations.AlterModelOptions( + name="systemtask", + options={ + "default_permissions": (), + "permissions": (), + "verbose_name": "System Task", + "verbose_name_plural": "System Tasks", + }, + ), + ] diff --git a/authentik/events/models.py b/authentik/events/models.py index b2b00a048e..8b593b2262 100644 --- a/authentik/events/models.py +++ b/authentik/events/models.py @@ -5,12 +5,11 @@ from datetime import timedelta from difflib import get_close_matches from functools import lru_cache from inspect import currentframe -from smtplib import SMTPException from typing import Any from uuid import uuid4 from django.apps import apps -from django.db import connection, models +from django.db import models from django.http import HttpRequest from django.http.request import QueryDict from django.utils.timezone import now @@ -27,7 +26,6 @@ from authentik.core.middleware import ( SESSION_KEY_IMPERSONATE_USER, ) from authentik.core.models import ExpiringModel, Group, PropertyMapping, User -from authentik.events.apps import GAUGE_TASKS, SYSTEM_TASK_STATUS, SYSTEM_TASK_TIME from authentik.events.context_processors.base import get_context_processors from authentik.events.utils import ( cleanse_dict, @@ -44,6 +42,7 @@ from authentik.lib.utils.time import timedelta_from_string from authentik.policies.models import PolicyBindingModel from authentik.root.middleware import ClientIPMiddleware from authentik.stages.email.utils import TemplateEmailMessage +from authentik.tasks.models import TasksModel from authentik.tenants.models import Tenant from authentik.tenants.utils import get_current_tenant @@ -274,7 +273,8 @@ class Event(SerializerModel, ExpiringModel): models.Index(fields=["created"]), models.Index(fields=["client_ip"]), models.Index( - models.F("context__authorized_application"), name="authentik_e_ctx_app__idx" + models.F("context__authorized_application"), + name="authentik_e_ctx_app__idx", ), ] @@ -288,7 +288,7 @@ class TransportMode(models.TextChoices): EMAIL = "email", _("Email") -class NotificationTransport(SerializerModel): +class NotificationTransport(TasksModel, SerializerModel): """Action which is executed when a Rule matches""" uuid = models.UUIDField(primary_key=True, editable=False, default=uuid4) @@ -453,6 +453,8 @@ class NotificationTransport(SerializerModel): def send_email(self, notification: "Notification") -> list[str]: """Send notification via global email configuration""" + from authentik.stages.email.tasks import send_mail + if notification.user.email.strip() == "": LOGGER.info( "Discarding notification as user has no email address", @@ -494,17 +496,14 @@ class NotificationTransport(SerializerModel): template_name="email/event_notification.html", template_context=context, ) - # Email is sent directly here, as the call to send() should have been from a task. - try: - from authentik.stages.email.tasks import send_mail - - return send_mail(mail.__dict__) - except (SMTPException, ConnectionError, OSError) as exc: - raise NotificationTransportError(exc) from exc + send_mail.send_with_options(args=(mail.__dict__,), rel_obj=self) + return [] @property def serializer(self) -> type[Serializer]: - from authentik.events.api.notification_transports import NotificationTransportSerializer + from authentik.events.api.notification_transports import ( + NotificationTransportSerializer, + ) return NotificationTransportSerializer @@ -554,7 +553,7 @@ class Notification(SerializerModel): verbose_name_plural = _("Notifications") -class NotificationRule(SerializerModel, PolicyBindingModel): +class NotificationRule(TasksModel, SerializerModel, PolicyBindingModel): """Decide when to create a Notification based on policies attached to this object.""" name = models.TextField(unique=True) @@ -618,7 +617,9 @@ class NotificationWebhookMapping(PropertyMapping): @property def serializer(self) -> type[type[Serializer]]: - from authentik.events.api.notification_mappings import NotificationWebhookMappingSerializer + from authentik.events.api.notification_mappings import ( + NotificationWebhookMappingSerializer, + ) return NotificationWebhookMappingSerializer @@ -631,7 +632,7 @@ class NotificationWebhookMapping(PropertyMapping): class TaskStatus(models.TextChoices): - """Possible states of tasks""" + """DEPRECATED do not use""" UNKNOWN = "unknown" SUCCESSFUL = "successful" @@ -639,8 +640,8 @@ class TaskStatus(models.TextChoices): ERROR = "error" -class SystemTask(SerializerModel, ExpiringModel): - """Info about a system task running in the background along with details to restart the task""" +class SystemTask(ExpiringModel): + """DEPRECATED do not use""" uuid = models.UUIDField(primary_key=True, editable=False, default=uuid4) name = models.TextField() @@ -660,41 +661,13 @@ class SystemTask(SerializerModel, ExpiringModel): task_call_args = models.JSONField(default=list) task_call_kwargs = models.JSONField(default=dict) - @property - def serializer(self) -> type[Serializer]: - from authentik.events.api.tasks import SystemTaskSerializer - - return SystemTaskSerializer - - def update_metrics(self): - """Update prometheus metrics""" - # TODO: Deprecated metric - remove in 2024.2 or later - GAUGE_TASKS.labels( - tenant=connection.schema_name, - task_name=self.name, - task_uid=self.uid or "", - status=self.status.lower(), - ).set(self.duration) - SYSTEM_TASK_TIME.labels( - tenant=connection.schema_name, - task_name=self.name, - task_uid=self.uid or "", - ).observe(self.duration) - SYSTEM_TASK_STATUS.labels( - tenant=connection.schema_name, - task_name=self.name, - task_uid=self.uid or "", - status=self.status.lower(), - ).inc() - def __str__(self) -> str: return f"System Task {self.name}" class Meta: unique_together = (("name", "uid"),) - # Remove "add", "change" and "delete" permissions as those are not used - default_permissions = ["view"] - permissions = [("run_task", _("Run task"))] + default_permissions = () + permissions = () verbose_name = _("System Task") verbose_name_plural = _("System Tasks") indexes = ExpiringModel.Meta.indexes diff --git a/authentik/events/settings.py b/authentik/events/settings.py deleted file mode 100644 index 5fc978ef27..0000000000 --- a/authentik/events/settings.py +++ /dev/null @@ -1,13 +0,0 @@ -"""Event Settings""" - -from celery.schedules import crontab - -from authentik.lib.utils.time import fqdn_rand - -CELERY_BEAT_SCHEDULE = { - "events_notification_cleanup": { - "task": "authentik.events.tasks.notification_cleanup", - "schedule": crontab(minute=fqdn_rand("notification_cleanup"), hour="*/8"), - "options": {"queue": "authentik_scheduled"}, - }, -} diff --git a/authentik/events/signals.py b/authentik/events/signals.py index 232ef605a8..0ae2ca7857 100644 --- a/authentik/events/signals.py +++ b/authentik/events/signals.py @@ -12,13 +12,10 @@ from rest_framework.request import Request from authentik.core.models import AuthenticatedSession, User from authentik.core.signals import login_failed, password_changed -from authentik.events.apps import SYSTEM_TASK_STATUS -from authentik.events.models import Event, EventAction, SystemTask -from authentik.events.tasks import event_notification_handler, gdpr_cleanup +from authentik.events.models import Event, EventAction from authentik.flows.models import Stage from authentik.flows.planner import PLAN_CONTEXT_OUTPOST, PLAN_CONTEXT_SOURCE, FlowPlan from authentik.flows.views.executor import SESSION_KEY_PLAN -from authentik.root.monitoring import monitoring_set from authentik.stages.invitation.models import Invitation from authentik.stages.invitation.signals import invitation_used from authentik.stages.password.stage import PLAN_CONTEXT_METHOD, PLAN_CONTEXT_METHOD_ARGS @@ -114,19 +111,15 @@ def on_password_changed(sender, user: User, password: str, request: HttpRequest @receiver(post_save, sender=Event) def event_post_save_notification(sender, instance: Event, **_): """Start task to check if any policies trigger an notification on this event""" - event_notification_handler.delay(instance.event_uuid.hex) + from authentik.events.tasks import event_trigger_dispatch + + event_trigger_dispatch.send(instance.event_uuid) @receiver(pre_delete, sender=User) def event_user_pre_delete_cleanup(sender, instance: User, **_): """If gdpr_compliance is enabled, remove all the user's events""" + from authentik.events.tasks import gdpr_cleanup + if get_current_tenant().gdpr_compliance: - gdpr_cleanup.delay(instance.pk) - - -@receiver(monitoring_set) -def monitoring_system_task(sender, **_): - """Update metrics when task is saved""" - SYSTEM_TASK_STATUS.clear() - for task in SystemTask.objects.all(): - task.update_metrics() + gdpr_cleanup.send(instance.pk) diff --git a/authentik/events/system_tasks.py b/authentik/events/system_tasks.py deleted file mode 100644 index 0438b2887e..0000000000 --- a/authentik/events/system_tasks.py +++ /dev/null @@ -1,156 +0,0 @@ -"""Monitored tasks""" - -from datetime import datetime, timedelta -from time import perf_counter -from typing import Any - -from django.utils.timezone import now -from django.utils.translation import gettext_lazy as _ -from structlog.stdlib import BoundLogger, get_logger -from tenant_schemas_celery.task import TenantTask - -from authentik.events.logs import LogEvent -from authentik.events.models import Event, EventAction, TaskStatus -from authentik.events.models import SystemTask as DBSystemTask -from authentik.events.utils import sanitize_item -from authentik.lib.utils.errors import exception_to_string - - -class SystemTask(TenantTask): - """Task which can save its state to the cache""" - - logger: BoundLogger - - # For tasks that should only be listed if they failed, set this to False - save_on_success: bool - - _status: TaskStatus - _messages: list[LogEvent] - - _uid: str | None - # Precise start time from perf_counter - _start_precise: float | None = None - _start: datetime | None = None - - def __init__(self, *args, **kwargs) -> None: - super().__init__(*args, **kwargs) - self._status = TaskStatus.SUCCESSFUL - self.save_on_success = True - self._uid = None - self._status = None - self._messages = [] - self.result_timeout_hours = 6 - - def set_uid(self, uid: str): - """Set UID, so in the case of an unexpected error its saved correctly""" - self._uid = uid - - def set_status(self, status: TaskStatus, *messages: LogEvent): - """Set result for current run, will overwrite previous result.""" - self._status = status - self._messages = list(messages) - for idx, msg in enumerate(self._messages): - if not isinstance(msg, LogEvent): - self._messages[idx] = LogEvent(msg, logger=self.__name__, log_level="info") - - def set_error(self, exception: Exception, *messages: LogEvent): - """Set result to error and save exception""" - self._status = TaskStatus.ERROR - self._messages = list(messages) - self._messages.extend( - [LogEvent(exception_to_string(exception), logger=self.__name__, log_level="error")] - ) - - def before_start(self, task_id, args, kwargs): - self._start_precise = perf_counter() - self._start = now() - self.logger = get_logger().bind(task_id=task_id) - return super().before_start(task_id, args, kwargs) - - def db(self) -> DBSystemTask | None: - """Get DB object for latest task""" - return DBSystemTask.objects.filter( - name=self.__name__, - uid=self._uid, - ).first() - - def after_return(self, status, retval, task_id, args: list[Any], kwargs: dict[str, Any], einfo): - super().after_return(status, retval, task_id, args, kwargs, einfo=einfo) - if not self._status: - return - if self._status == TaskStatus.SUCCESSFUL and not self.save_on_success: - DBSystemTask.objects.filter( - name=self.__name__, - uid=self._uid, - ).delete() - return - DBSystemTask.objects.update_or_create( - name=self.__name__, - uid=self._uid, - defaults={ - "description": self.__doc__, - "start_timestamp": self._start or now(), - "finish_timestamp": now(), - "duration": max(perf_counter() - self._start_precise, 0), - "task_call_module": self.__module__, - "task_call_func": self.__name__, - "task_call_args": sanitize_item(args), - "task_call_kwargs": sanitize_item(kwargs), - "status": self._status, - "messages": sanitize_item(self._messages), - "expires": now() + timedelta(hours=self.result_timeout_hours), - "expiring": True, - }, - ) - - def on_failure(self, exc, task_id, args, kwargs, einfo): - super().on_failure(exc, task_id, args, kwargs, einfo=einfo) - if not self._status: - self.set_error(exc) - DBSystemTask.objects.update_or_create( - name=self.__name__, - uid=self._uid, - defaults={ - "description": self.__doc__, - "start_timestamp": self._start or now(), - "finish_timestamp": now(), - "duration": max(perf_counter() - self._start_precise, 0), - "task_call_module": self.__module__, - "task_call_func": self.__name__, - "task_call_args": sanitize_item(args), - "task_call_kwargs": sanitize_item(kwargs), - "status": self._status, - "messages": sanitize_item(self._messages), - "expires": now() + timedelta(hours=self.result_timeout_hours + 3), - "expiring": True, - }, - ) - Event.new( - EventAction.SYSTEM_TASK_EXCEPTION, - message=f"Task {self.__name__} encountered an error", - ).with_exception(exc).save() - - def run(self, *args, **kwargs): - raise NotImplementedError - - -def prefill_task(func): - """Ensure a task's details are always in cache, so it can always be triggered via API""" - _prefill_tasks.append( - DBSystemTask( - name=func.__name__, - description=func.__doc__, - start_timestamp=now(), - finish_timestamp=now(), - status=TaskStatus.UNKNOWN, - messages=sanitize_item([_("Task has not been run yet.")]), - task_call_module=func.__module__, - task_call_func=func.__name__, - expiring=False, - duration=0, - ) - ) - return func - - -_prefill_tasks = [] diff --git a/authentik/events/tasks.py b/authentik/events/tasks.py index a53fbbaebf..162a5d5280 100644 --- a/authentik/events/tasks.py +++ b/authentik/events/tasks.py @@ -1,41 +1,49 @@ """Event notification tasks""" +from uuid import UUID + from django.db.models.query_utils import Q +from django.utils.translation import gettext_lazy as _ +from django_dramatiq_postgres.middleware import CurrentTask +from dramatiq.actor import actor from guardian.shortcuts import get_anonymous_user from structlog.stdlib import get_logger -from authentik.core.expression.exceptions import PropertyMappingExpressionException from authentik.core.models import User from authentik.events.models import ( Event, Notification, NotificationRule, NotificationTransport, - NotificationTransportError, - TaskStatus, ) -from authentik.events.system_tasks import SystemTask, prefill_task from authentik.policies.engine import PolicyEngine from authentik.policies.models import PolicyBinding, PolicyEngineMode -from authentik.root.celery import CELERY_APP +from authentik.tasks.models import Task LOGGER = get_logger() -@CELERY_APP.task() -def event_notification_handler(event_uuid: str): - """Start task for each trigger definition""" +@actor(description=_("Dispatch new event notifications.")) +def event_trigger_dispatch(event_uuid: UUID): for trigger in NotificationRule.objects.all(): - event_trigger_handler.apply_async(args=[event_uuid, trigger.name], queue="authentik_events") + event_trigger_handler.send_with_options(args=(event_uuid, trigger.name), rel_obj=trigger) -@CELERY_APP.task() -def event_trigger_handler(event_uuid: str, trigger_name: str): +@actor( + description=_( + "Check if policies attached to NotificationRule match event " + "and dispatch notification tasks." + ) +) +def event_trigger_handler(event_uuid: UUID, trigger_name: str): """Check if policies attached to NotificationRule match event""" + self: Task = CurrentTask.get_task() + event: Event = Event.objects.filter(event_uuid=event_uuid).first() if not event: - LOGGER.warning("event doesn't exist yet or anymore", event_uuid=event_uuid) + self.warning("event doesn't exist yet or anymore", event_uuid=event_uuid) return + trigger: NotificationRule | None = NotificationRule.objects.filter(name=trigger_name).first() if not trigger: return @@ -71,57 +79,46 @@ def event_trigger_handler(event_uuid: str, trigger_name: str): LOGGER.debug("e(trigger): event trigger matched", trigger=trigger) # Create the notification objects + count = 0 for transport in trigger.transports.all(): for user in trigger.destination_users(event): - LOGGER.debug("created notification") - notification_transport.apply_async( - args=[ + notification_transport.send_with_options( + args=( transport.pk, - str(event.pk), + event.pk, user.pk, - str(trigger.pk), - ], - queue="authentik_events", + trigger.pk, + ), + rel_obj=transport, ) + count += 1 if transport.send_once: break + self.info(f"Created {count} notification tasks") -@CELERY_APP.task( - bind=True, - autoretry_for=(NotificationTransportError,), - retry_backoff=True, - base=SystemTask, -) -def notification_transport( - self: SystemTask, transport_pk: int, event_pk: str, user_pk: int, trigger_pk: str -): +@actor(description=_("Send notification.")) +def notification_transport(transport_pk: int, event_pk: str, user_pk: int, trigger_pk: str): """Send notification over specified transport""" - self.save_on_success = False - try: - event = Event.objects.filter(pk=event_pk).first() - if not event: - return - user = User.objects.filter(pk=user_pk).first() - if not user: - return - trigger = NotificationRule.objects.filter(pk=trigger_pk).first() - if not trigger: - return - notification = Notification( - severity=trigger.severity, body=event.summary, event=event, user=user - ) - transport = NotificationTransport.objects.filter(pk=transport_pk).first() - if not transport: - return - transport.send(notification) - self.set_status(TaskStatus.SUCCESSFUL) - except (NotificationTransportError, PropertyMappingExpressionException) as exc: - self.set_error(exc) - raise exc + event = Event.objects.filter(pk=event_pk).first() + if not event: + return + user = User.objects.filter(pk=user_pk).first() + if not user: + return + trigger = NotificationRule.objects.filter(pk=trigger_pk).first() + if not trigger: + return + notification = Notification( + severity=trigger.severity, body=event.summary, event=event, user=user + ) + transport: NotificationTransport = NotificationTransport.objects.filter(pk=transport_pk).first() + if not transport: + return + transport.send(notification) -@CELERY_APP.task() +@actor(description=_("Cleanup events for GDPR compliance.")) def gdpr_cleanup(user_pk: int): """cleanup events from gdpr_compliance""" events = Event.objects.filter(user__pk=user_pk) @@ -129,12 +126,12 @@ def gdpr_cleanup(user_pk: int): events.delete() -@CELERY_APP.task(bind=True, base=SystemTask) -@prefill_task -def notification_cleanup(self: SystemTask): +@actor(description=_("Cleanup seen notifications and notifications whose event expired.")) +def notification_cleanup(): """Cleanup seen notifications and notifications whose event expired.""" + self: Task = CurrentTask.get_task() notifications = Notification.objects.filter(Q(event=None) | Q(seen=True)) amount = notifications.count() notifications.delete() LOGGER.debug("Expired notifications", amount=amount) - self.set_status(TaskStatus.SUCCESSFUL, f"Expired {amount} Notifications") + self.info(f"Expired {amount} Notifications") diff --git a/authentik/events/tests/test_tasks.py b/authentik/events/tests/test_tasks.py deleted file mode 100644 index 353c11b07f..0000000000 --- a/authentik/events/tests/test_tasks.py +++ /dev/null @@ -1,103 +0,0 @@ -"""Test Monitored tasks""" - -from json import loads - -from django.urls import reverse -from rest_framework.test import APITestCase - -from authentik.core.tasks import clean_expired_models -from authentik.core.tests.utils import create_test_admin_user -from authentik.events.models import SystemTask as DBSystemTask -from authentik.events.models import TaskStatus -from authentik.events.system_tasks import SystemTask -from authentik.lib.generators import generate_id -from authentik.root.celery import CELERY_APP - - -class TestSystemTasks(APITestCase): - """Test Monitored tasks""" - - def setUp(self): - super().setUp() - self.user = create_test_admin_user() - self.client.force_login(self.user) - - def test_failed_successful_remove_state(self): - """Test that a task with `save_on_success` set to `False` that failed saves - a state, and upon successful completion will delete the state""" - should_fail = True - uid = generate_id() - - @CELERY_APP.task( - bind=True, - base=SystemTask, - ) - def test_task(self: SystemTask): - self.save_on_success = False - self.set_uid(uid) - self.set_status(TaskStatus.ERROR if should_fail else TaskStatus.SUCCESSFUL) - - # First test successful run - should_fail = False - test_task.delay().get() - self.assertIsNone(DBSystemTask.objects.filter(name="test_task", uid=uid).first()) - - # Then test failed - should_fail = True - test_task.delay().get() - task = DBSystemTask.objects.filter(name="test_task", uid=uid).first() - self.assertEqual(task.status, TaskStatus.ERROR) - - # Then after that, the state should be removed - should_fail = False - test_task.delay().get() - self.assertIsNone(DBSystemTask.objects.filter(name="test_task", uid=uid).first()) - - def test_tasks(self): - """Test Task API""" - clean_expired_models.delay().get() - response = self.client.get(reverse("authentik_api:systemtask-list")) - self.assertEqual(response.status_code, 200) - body = loads(response.content) - self.assertTrue(any(task["name"] == "clean_expired_models" for task in body["results"])) - - def test_tasks_single(self): - """Test Task API (read single)""" - clean_expired_models.delay().get() - task = DBSystemTask.objects.filter(name="clean_expired_models").first() - response = self.client.get( - reverse( - "authentik_api:systemtask-detail", - kwargs={"pk": str(task.pk)}, - ) - ) - self.assertEqual(response.status_code, 200) - body = loads(response.content) - self.assertEqual(body["status"], TaskStatus.SUCCESSFUL.value) - self.assertEqual(body["name"], "clean_expired_models") - response = self.client.get( - reverse("authentik_api:systemtask-detail", kwargs={"pk": "qwerqwer"}) - ) - self.assertEqual(response.status_code, 404) - - def test_tasks_run(self): - """Test Task API (run)""" - clean_expired_models.delay().get() - task = DBSystemTask.objects.filter(name="clean_expired_models").first() - response = self.client.post( - reverse( - "authentik_api:systemtask-run", - kwargs={"pk": str(task.pk)}, - ) - ) - self.assertEqual(response.status_code, 204) - - def test_tasks_run_404(self): - """Test Task API (run, 404)""" - response = self.client.post( - reverse( - "authentik_api:systemtask-run", - kwargs={"pk": "qwerqewrqrqewrqewr"}, - ) - ) - self.assertEqual(response.status_code, 404) diff --git a/authentik/events/urls.py b/authentik/events/urls.py index 0beea59913..b961b7dbb3 100644 --- a/authentik/events/urls.py +++ b/authentik/events/urls.py @@ -5,13 +5,11 @@ from authentik.events.api.notification_mappings import NotificationWebhookMappin from authentik.events.api.notification_rules import NotificationRuleViewSet from authentik.events.api.notification_transports import NotificationTransportViewSet from authentik.events.api.notifications import NotificationViewSet -from authentik.events.api.tasks import SystemTaskViewSet api_urlpatterns = [ ("events/events", EventViewSet), ("events/notifications", NotificationViewSet), ("events/transports", NotificationTransportViewSet), ("events/rules", NotificationRuleViewSet), - ("events/system_tasks", SystemTaskViewSet), ("propertymappings/notification", NotificationWebhookMappingViewSet), ] diff --git a/authentik/lib/config.py b/authentik/lib/config.py index 5dae033c6a..6146db2c2a 100644 --- a/authentik/lib/config.py +++ b/authentik/lib/config.py @@ -41,8 +41,7 @@ REDIS_ENV_KEYS = [ # Old key -> new key DEPRECATIONS = { "geoip": "events.context_processors.geoip", - "redis.broker_url": "broker.url", - "redis.broker_transport_options": "broker.transport_options", + "worker.concurrency": "worker.threads", "redis.cache_timeout": "cache.timeout", "redis.cache_timeout_flows": "cache.timeout_flows", "redis.cache_timeout_policies": "cache.timeout_policies", diff --git a/authentik/lib/debug.py b/authentik/lib/debug.py index 76d7422b6a..827e73b6a1 100644 --- a/authentik/lib/debug.py +++ b/authentik/lib/debug.py @@ -21,6 +21,10 @@ def start_debug_server(**kwargs) -> bool: listen: str = CONFIG.get("listen.listen_debug_py", "127.0.0.1:9901") host, _, port = listen.rpartition(":") - debugpy.listen((host, int(port)), **kwargs) # nosec + try: + debugpy.listen((host, int(port)), **kwargs) # nosec + except RuntimeError: + LOGGER.warning("Could not start debug server. Continuing without") + return False LOGGER.debug("Starting debug server", host=host, port=port) return True diff --git a/authentik/lib/default.yml b/authentik/lib/default.yml index f587da2fd8..31389d8dea 100644 --- a/authentik/lib/default.yml +++ b/authentik/lib/default.yml @@ -57,10 +57,6 @@ redis: tls_reqs: "none" tls_ca_cert: null -# broker: -# url: "" -# transport_options: "" - http_timeout: 30 cache: @@ -72,10 +68,6 @@ cache: # channel: # url: "" -# result_backend: -# url: "" -# transport_options: "" - debug: false debugger: false @@ -157,7 +149,14 @@ web: path: / worker: - concurrency: 2 + processes: 1 + threads: 2 + consumer_listen_timeout: "seconds=30" + task_max_retries: 20 + task_default_time_limit: "minutes=10" + task_purge_interval: "days=1" + task_expiration: "days=30" + scheduler_interval: "seconds=60" storage: media: diff --git a/authentik/lib/logging.py b/authentik/lib/logging.py index 0ffb74530e..1b0b8aba7c 100644 --- a/authentik/lib/logging.py +++ b/authentik/lib/logging.py @@ -88,7 +88,6 @@ def get_logger_config(): "authentik": global_level, "django": "WARNING", "django.request": "ERROR", - "celery": "WARNING", "selenium": "WARNING", "docker": "WARNING", "urllib3": "WARNING", diff --git a/authentik/lib/sentry.py b/authentik/lib/sentry.py index 23aba5dc6f..0a59ef1537 100644 --- a/authentik/lib/sentry.py +++ b/authentik/lib/sentry.py @@ -3,8 +3,6 @@ from asyncio.exceptions import CancelledError from typing import Any -from billiard.exceptions import SoftTimeLimitExceeded, WorkerLostError -from celery.exceptions import CeleryError from channels_redis.core import ChannelFull from django.conf import settings from django.core.exceptions import ImproperlyConfigured, SuspiciousOperation, ValidationError @@ -22,7 +20,6 @@ from sentry_sdk import HttpTransport, get_current_scope from sentry_sdk import init as sentry_sdk_init from sentry_sdk.api import set_tag from sentry_sdk.integrations.argv import ArgvIntegration -from sentry_sdk.integrations.celery import CeleryIntegration from sentry_sdk.integrations.django import DjangoIntegration from sentry_sdk.integrations.redis import RedisIntegration from sentry_sdk.integrations.socket import SocketIntegration @@ -71,10 +68,6 @@ ignored_classes = ( LocalProtocolError, # rest_framework error APIException, - # celery errors - WorkerLostError, - CeleryError, - SoftTimeLimitExceeded, # custom baseclass SentryIgnoredException, # ldap errors @@ -115,7 +108,6 @@ def sentry_init(**sentry_init_kwargs): ArgvIntegration(), StdlibIntegration(), DjangoIntegration(transaction_style="function_name", cache_spans=True), - CeleryIntegration(), RedisIntegration(), ThreadingIntegration(propagate_hub=True), SocketIntegration(), @@ -160,14 +152,11 @@ def before_send(event: dict, hint: dict) -> dict | None: return None if "logger" in event: if event["logger"] in [ - "kombu", "asyncio", "multiprocessing", "django_redis", "django.security.DisallowedHost", "django_redis.cache", - "celery.backends.redis", - "celery.worker", "paramiko.transport", ]: return None diff --git a/authentik/lib/sync/api.py b/authentik/lib/sync/api.py new file mode 100644 index 0000000000..b4343bc64d --- /dev/null +++ b/authentik/lib/sync/api.py @@ -0,0 +1,12 @@ +from rest_framework.fields import BooleanField, ChoiceField, DateTimeField + +from authentik.core.api.utils import PassiveSerializer +from authentik.tasks.models import TaskStatus + + +class SyncStatusSerializer(PassiveSerializer): + """Provider/source sync status""" + + is_running = BooleanField() + last_successful_sync = DateTimeField(required=False) + last_sync_status = ChoiceField(required=False, choices=TaskStatus.choices) diff --git a/authentik/lib/sync/outgoing/__init__.py b/authentik/lib/sync/outgoing/__init__.py index 39d28cfc22..76b6d9232a 100644 --- a/authentik/lib/sync/outgoing/__init__.py +++ b/authentik/lib/sync/outgoing/__init__.py @@ -1,7 +1,7 @@ """Sync constants""" PAGE_SIZE = 100 -PAGE_TIMEOUT = 60 * 60 * 0.5 # Half an hour +PAGE_TIMEOUT_MS = 60 * 60 * 0.5 * 1000 # Half an hour HTTP_CONFLICT = 409 HTTP_NO_CONTENT = 204 HTTP_SERVICE_UNAVAILABLE = 503 diff --git a/authentik/lib/sync/outgoing/api.py b/authentik/lib/sync/outgoing/api.py index ee6a3c8e03..97d7882f70 100644 --- a/authentik/lib/sync/outgoing/api.py +++ b/authentik/lib/sync/outgoing/api.py @@ -1,7 +1,5 @@ -from celery import Task -from django.utils.text import slugify -from drf_spectacular.utils import OpenApiResponse, extend_schema -from guardian.shortcuts import get_objects_for_user +from dramatiq.actor import Actor +from drf_spectacular.utils import extend_schema from rest_framework.decorators import action from rest_framework.fields import BooleanField, CharField, ChoiceField from rest_framework.request import Request @@ -9,18 +7,12 @@ from rest_framework.response import Response from authentik.core.api.utils import ModelSerializer, PassiveSerializer from authentik.core.models import Group, User -from authentik.events.api.tasks import SystemTaskSerializer -from authentik.events.logs import LogEvent, LogEventSerializer +from authentik.events.logs import LogEventSerializer +from authentik.lib.sync.api import SyncStatusSerializer from authentik.lib.sync.outgoing.models import OutgoingSyncProvider from authentik.lib.utils.reflection import class_to_path from authentik.rbac.filters import ObjectFilter - - -class SyncStatusSerializer(PassiveSerializer): - """Provider sync status""" - - is_running = BooleanField(read_only=True) - tasks = SystemTaskSerializer(many=True, read_only=True) +from authentik.tasks.models import Task, TaskStatus class SyncObjectSerializer(PassiveSerializer): @@ -45,15 +37,10 @@ class SyncObjectResultSerializer(PassiveSerializer): class OutgoingSyncProviderStatusMixin: """Common API Endpoints for Outgoing sync providers""" - sync_single_task: type[Task] = None - sync_objects_task: type[Task] = None + sync_task: Actor + sync_objects_task: Actor - @extend_schema( - responses={ - 200: SyncStatusSerializer(), - 404: OpenApiResponse(description="Task not found"), - } - ) + @extend_schema(responses={200: SyncStatusSerializer()}) @action( methods=["GET"], detail=True, @@ -64,18 +51,39 @@ class OutgoingSyncProviderStatusMixin: def sync_status(self, request: Request, pk: int) -> Response: """Get provider's sync status""" provider: OutgoingSyncProvider = self.get_object() - tasks = list( - get_objects_for_user(request.user, "authentik_events.view_systemtask").filter( - name=self.sync_single_task.__name__, - uid=slugify(provider.name), - ) - ) + + status = {} + with provider.sync_lock as lock_acquired: - status = { - "tasks": tasks, - # If we could not acquire the lock, it means a task is using it, and thus is running - "is_running": not lock_acquired, - } + # If we could not acquire the lock, it means a task is using it, and thus is running + status["is_running"] = not lock_acquired + + sync_schedule = None + for schedule in provider.schedules.all(): + if schedule.actor_name == self.sync_task.actor_name: + sync_schedule = schedule + + if not sync_schedule: + return Response(SyncStatusSerializer(status).data) + + last_task: Task = ( + sync_schedule.tasks.exclude( + aggregated_status__in=(TaskStatus.CONSUMED, TaskStatus.QUEUED) + ) + .order_by("-mtime") + .first() + ) + last_successful_task: Task = ( + sync_schedule.tasks.filter(aggregated_status__in=(TaskStatus.DONE, TaskStatus.INFO)) + .order_by("-mtime") + .first() + ) + + if last_task: + status["last_sync_status"] = last_task.aggregated_status + if last_successful_task: + status["last_successful_sync"] = last_successful_task.mtime + return Response(SyncStatusSerializer(status).data) @extend_schema( @@ -94,14 +102,20 @@ class OutgoingSyncProviderStatusMixin: provider: OutgoingSyncProvider = self.get_object() params = SyncObjectSerializer(data=request.data) params.is_valid(raise_exception=True) - res: list[LogEvent] = self.sync_objects_task.delay( - params.validated_data["sync_object_model"], - page=1, - provider_pk=provider.pk, - pk=params.validated_data["sync_object_id"], - override_dry_run=params.validated_data["override_dry_run"], - ).get() - return Response(SyncObjectResultSerializer(instance={"messages": res}).data) + msg = self.sync_objects_task.send_with_options( + kwargs={ + "object_type": params.validated_data["sync_object_model"], + "page": 1, + "provider_pk": provider.pk, + "override_dry_run": params.validated_data["override_dry_run"], + "pk": params.validated_data["sync_object_id"], + }, + rel_obj=provider, + ) + msg.get_result(block=True) + task: Task = msg.options["task"] + task.refresh_from_db() + return Response(SyncObjectResultSerializer(instance={"messages": task._messages}).data) class OutgoingSyncConnectionCreateMixin: diff --git a/authentik/lib/sync/outgoing/models.py b/authentik/lib/sync/outgoing/models.py index f1dad9de7b..02a5e71450 100644 --- a/authentik/lib/sync/outgoing/models.py +++ b/authentik/lib/sync/outgoing/models.py @@ -1,12 +1,18 @@ from typing import Any, Self import pglock +from django.core.paginator import Paginator from django.db import connection, models from django.db.models import Model, QuerySet, TextChoices from django.utils.translation import gettext_lazy as _ +from dramatiq.actor import Actor from authentik.core.models import Group, User +from authentik.lib.sync.outgoing import PAGE_SIZE, PAGE_TIMEOUT_MS from authentik.lib.sync.outgoing.base import BaseOutgoingSyncClient +from authentik.lib.utils.time import fqdn_rand +from authentik.tasks.schedules.common import ScheduleSpec +from authentik.tasks.schedules.models import ScheduledModel class OutgoingSyncDeleteAction(TextChoices): @@ -18,7 +24,7 @@ class OutgoingSyncDeleteAction(TextChoices): SUSPEND = "suspend" -class OutgoingSyncProvider(Model): +class OutgoingSyncProvider(ScheduledModel, Model): """Base abstract models for providers implementing outgoing sync""" dry_run = models.BooleanField( @@ -39,6 +45,19 @@ class OutgoingSyncProvider(Model): def get_object_qs[T: User | Group](self, type: type[T]) -> QuerySet[T]: raise NotImplementedError + def get_paginator[T: User | Group](self, type: type[T]) -> Paginator: + return Paginator(self.get_object_qs(type), PAGE_SIZE) + + def get_object_sync_time_limit_ms[T: User | Group](self, type: type[T]) -> int: + num_pages: int = self.get_paginator(type).num_pages + return int(num_pages * PAGE_TIMEOUT_MS * 1.5) + + def get_sync_time_limit_ms(self) -> int: + return int( + (self.get_object_sync_time_limit_ms(User) + self.get_object_sync_time_limit_ms(Group)) + * 1.5 + ) + @property def sync_lock(self) -> pglock.advisory: """Postgres lock for syncing to prevent multiple parallel syncs happening""" @@ -47,3 +66,22 @@ class OutgoingSyncProvider(Model): timeout=0, side_effect=pglock.Return, ) + + @property + def sync_actor(self) -> Actor: + raise NotImplementedError + + @property + def schedule_specs(self) -> list[ScheduleSpec]: + return [ + ScheduleSpec( + actor=self.sync_actor, + uid=self.name, + args=(self.pk,), + options={ + "time_limit": self.get_sync_time_limit_ms(), + }, + send_on_save=True, + crontab=f"{fqdn_rand(self.pk)} */4 * * *", + ), + ] diff --git a/authentik/lib/sync/outgoing/signals.py b/authentik/lib/sync/outgoing/signals.py index 0fd8c63ff2..af7edb2d73 100644 --- a/authentik/lib/sync/outgoing/signals.py +++ b/authentik/lib/sync/outgoing/signals.py @@ -1,12 +1,8 @@ -from collections.abc import Callable - -from django.core.paginator import Paginator from django.db.models import Model -from django.db.models.query import Q from django.db.models.signals import m2m_changed, post_save, pre_delete +from dramatiq.actor import Actor from authentik.core.models import Group, User -from authentik.lib.sync.outgoing import PAGE_SIZE, PAGE_TIMEOUT from authentik.lib.sync.outgoing.base import Direction from authentik.lib.sync.outgoing.models import OutgoingSyncProvider from authentik.lib.utils.reflection import class_to_path @@ -14,45 +10,30 @@ from authentik.lib.utils.reflection import class_to_path def register_signals( provider_type: type[OutgoingSyncProvider], - task_sync_single: Callable[[int], None], - task_sync_direct: Callable[[int], None], - task_sync_m2m: Callable[[int], None], + task_sync_direct_dispatch: Actor[[str, str | int, str], None], + task_sync_m2m_dispatch: Actor[[str, str, list[str], bool], None], ): """Register sync signals""" uid = class_to_path(provider_type) - def post_save_provider(sender: type[Model], instance: OutgoingSyncProvider, created: bool, **_): - """Trigger sync when Provider is saved""" - users_paginator = Paginator(instance.get_object_qs(User), PAGE_SIZE) - groups_paginator = Paginator(instance.get_object_qs(Group), PAGE_SIZE) - soft_time_limit = (users_paginator.num_pages + groups_paginator.num_pages) * PAGE_TIMEOUT - time_limit = soft_time_limit * 1.5 - task_sync_single.apply_async( - (instance.pk,), time_limit=int(time_limit), soft_time_limit=int(soft_time_limit) - ) - - post_save.connect(post_save_provider, provider_type, dispatch_uid=uid, weak=False) - def model_post_save(sender: type[Model], instance: User | Group, created: bool, **_): """Post save handler""" - if not provider_type.objects.filter( - Q(backchannel_application__isnull=False) | Q(application__isnull=False) - ).exists(): - return - task_sync_direct.delay(class_to_path(instance.__class__), instance.pk, Direction.add.value) + task_sync_direct_dispatch.send( + class_to_path(instance.__class__), + instance.pk, + Direction.add.value, + ) post_save.connect(model_post_save, User, dispatch_uid=uid, weak=False) post_save.connect(model_post_save, Group, dispatch_uid=uid, weak=False) def model_pre_delete(sender: type[Model], instance: User | Group, **_): """Pre-delete handler""" - if not provider_type.objects.filter( - Q(backchannel_application__isnull=False) | Q(application__isnull=False) - ).exists(): - return - task_sync_direct.delay( - class_to_path(instance.__class__), instance.pk, Direction.remove.value - ).get(propagate=False) + task_sync_direct_dispatch.send( + class_to_path(instance.__class__), + instance.pk, + Direction.remove.value, + ) pre_delete.connect(model_pre_delete, User, dispatch_uid=uid, weak=False) pre_delete.connect(model_pre_delete, Group, dispatch_uid=uid, weak=False) @@ -63,16 +44,6 @@ def register_signals( """Sync group membership""" if action not in ["post_add", "post_remove"]: return - if not provider_type.objects.filter( - Q(backchannel_application__isnull=False) | Q(application__isnull=False) - ).exists(): - return - # reverse: instance is a Group, pk_set is a list of user pks - # non-reverse: instance is a User, pk_set is a list of groups - if reverse: - task_sync_m2m.delay(str(instance.pk), action, list(pk_set)) - else: - for group_pk in pk_set: - task_sync_m2m.delay(group_pk, action, [instance.pk]) + task_sync_m2m_dispatch.send(instance.pk, action, list(pk_set), reverse) m2m_changed.connect(model_m2m_changed, User.ak_groups.through, dispatch_uid=uid, weak=False) diff --git a/authentik/lib/sync/outgoing/tasks.py b/authentik/lib/sync/outgoing/tasks.py index f78acb3429..1437480f0f 100644 --- a/authentik/lib/sync/outgoing/tasks.py +++ b/authentik/lib/sync/outgoing/tasks.py @@ -1,23 +1,17 @@ -from collections.abc import Callable -from dataclasses import asdict - -from celery import group -from celery.exceptions import Retry -from celery.result import allow_join_result from django.core.paginator import Paginator from django.db.models import Model, QuerySet from django.db.models.query import Q from django.utils.text import slugify -from django.utils.translation import gettext_lazy as _ +from django_dramatiq_postgres.middleware import CurrentTask +from dramatiq.actor import Actor +from dramatiq.composition import group +from dramatiq.errors import Retry from structlog.stdlib import BoundLogger, get_logger from authentik.core.expression.exceptions import SkipObjectException from authentik.core.models import Group, User -from authentik.events.logs import LogEvent -from authentik.events.models import TaskStatus -from authentik.events.system_tasks import SystemTask from authentik.events.utils import sanitize_item -from authentik.lib.sync.outgoing import PAGE_SIZE, PAGE_TIMEOUT +from authentik.lib.sync.outgoing import PAGE_SIZE, PAGE_TIMEOUT_MS from authentik.lib.sync.outgoing.base import Direction from authentik.lib.sync.outgoing.exceptions import ( BadRequestSyncException, @@ -27,11 +21,12 @@ from authentik.lib.sync.outgoing.exceptions import ( ) from authentik.lib.sync.outgoing.models import OutgoingSyncProvider from authentik.lib.utils.reflection import class_to_path, path_to_class +from authentik.tasks.models import Task class SyncTasks: - """Container for all sync 'tasks' (this class doesn't actually contain celery - tasks due to celery's magic, however exposes a number of functions to be called from tasks)""" + """Container for all sync 'tasks' (this class doesn't actually contain + tasks due to dramatiq's magic, however exposes a number of functions to be called from tasks)""" logger: BoundLogger @@ -39,107 +34,104 @@ class SyncTasks: super().__init__() self._provider_model = provider_model - def sync_all(self, single_sync: Callable[[int], None]): - for provider in self._provider_model.objects.filter( - Q(backchannel_application__isnull=False) | Q(application__isnull=False) - ): - self.trigger_single_task(provider, single_sync) - - def trigger_single_task(self, provider: OutgoingSyncProvider, sync_task: Callable[[int], None]): - """Wrapper single sync task that correctly sets time limits based - on the amount of objects that will be synced""" - users_paginator = Paginator(provider.get_object_qs(User), PAGE_SIZE) - groups_paginator = Paginator(provider.get_object_qs(Group), PAGE_SIZE) - soft_time_limit = (users_paginator.num_pages + groups_paginator.num_pages) * PAGE_TIMEOUT - time_limit = soft_time_limit * 1.5 - return sync_task.apply_async( - (provider.pk,), time_limit=int(time_limit), soft_time_limit=int(soft_time_limit) - ) - - def sync_single( + def sync_paginator( self, - task: SystemTask, - provider_pk: int, - sync_objects: Callable[[int, int], list[str]], + current_task: Task, + provider: OutgoingSyncProvider, + sync_objects: Actor[[str, int, int, bool], None], + paginator: Paginator, + object_type: type[User | Group], + **options, ): + tasks = [] + for page in paginator.page_range: + page_sync = sync_objects.message_with_options( + args=(class_to_path(object_type), page, provider.pk), + time_limit=PAGE_TIMEOUT_MS, + # Assign tasks to the same schedule as the current one + rel_obj=current_task.rel_obj, + **options, + ) + tasks.append(page_sync) + return tasks + + def sync( + self, + provider_pk: int, + sync_objects: Actor[[str, int, int, bool], None], + ): + task: Task = CurrentTask.get_task() self.logger = get_logger().bind( provider_type=class_to_path(self._provider_model), provider_pk=provider_pk, ) - provider = self._provider_model.objects.filter( + provider: OutgoingSyncProvider = self._provider_model.objects.filter( Q(backchannel_application__isnull=False) | Q(application__isnull=False), pk=provider_pk, ).first() if not provider: + task.warning("No provider found. Is it assigned to an application?") return task.set_uid(slugify(provider.name)) - messages = [] - messages.append(_("Starting full provider sync")) + task.info("Starting full provider sync") self.logger.debug("Starting provider sync") - users_paginator = Paginator(provider.get_object_qs(User), PAGE_SIZE) - groups_paginator = Paginator(provider.get_object_qs(Group), PAGE_SIZE) - with allow_join_result(), provider.sync_lock as lock_acquired: + with provider.sync_lock as lock_acquired: if not lock_acquired: + task.info("Synchronization is already running. Skipping.") self.logger.debug("Failed to acquire sync lock, skipping", provider=provider.name) return try: - messages.append(_("Syncing users")) - user_results = ( - group( - [ - sync_objects.signature( - args=(class_to_path(User), page, provider_pk), - time_limit=PAGE_TIMEOUT, - soft_time_limit=PAGE_TIMEOUT, - ) - for page in users_paginator.page_range - ] + users_tasks = group( + self.sync_paginator( + current_task=task, + provider=provider, + sync_objects=sync_objects, + paginator=provider.get_paginator(User), + object_type=User, ) - .apply_async() - .get() ) - for result in user_results: - for msg in result: - messages.append(LogEvent(**msg)) - messages.append(_("Syncing groups")) - group_results = ( - group( - [ - sync_objects.signature( - args=(class_to_path(Group), page, provider_pk), - time_limit=PAGE_TIMEOUT, - soft_time_limit=PAGE_TIMEOUT, - ) - for page in groups_paginator.page_range - ] + group_tasks = group( + self.sync_paginator( + current_task=task, + provider=provider, + sync_objects=sync_objects, + paginator=provider.get_paginator(Group), + object_type=Group, ) - .apply_async() - .get() ) - for result in group_results: - for msg in result: - messages.append(LogEvent(**msg)) + users_tasks.run().wait(timeout=provider.get_object_sync_time_limit_ms(User)) + group_tasks.run().wait(timeout=provider.get_object_sync_time_limit_ms(Group)) except TransientSyncException as exc: self.logger.warning("transient sync exception", exc=exc) - raise task.retry(exc=exc) from exc + task.warning("Sync encountered a transient exception. Retrying", exc=exc) + raise Retry() from exc except StopSync as exc: - task.set_error(exc) + task.error(exc) return - task.set_status(TaskStatus.SUCCESSFUL, *messages) def sync_objects( - self, object_type: str, page: int, provider_pk: int, override_dry_run=False, **filter + self, + object_type: str, + page: int, + provider_pk: int, + override_dry_run=False, + **filter, ): + task: Task = CurrentTask.get_task() _object_type: type[Model] = path_to_class(object_type) self.logger = get_logger().bind( provider_type=class_to_path(self._provider_model), provider_pk=provider_pk, object_type=object_type, ) - messages = [] - provider = self._provider_model.objects.filter(pk=provider_pk).first() + provider: OutgoingSyncProvider = self._provider_model.objects.filter( + Q(backchannel_application__isnull=False) | Q(application__isnull=False), + pk=provider_pk, + ).first() if not provider: - return messages + task.warning("No provider found. Is it assigned to an application?") + return + task.set_uid(slugify(provider.name)) # Override dry run mode if requested, however don't save the provider # so that scheduled sync tasks still run in dry_run mode if override_dry_run: @@ -147,25 +139,13 @@ class SyncTasks: try: client = provider.client_for_model(_object_type) except TransientSyncException: - return messages + return paginator = Paginator(provider.get_object_qs(_object_type).filter(**filter), PAGE_SIZE) if client.can_discover: self.logger.debug("starting discover") client.discover() self.logger.debug("starting sync for page", page=page) - messages.append( - asdict( - LogEvent( - _( - "Syncing page {page} of {object_type}".format( - page=page, object_type=_object_type._meta.verbose_name_plural - ) - ), - log_level="info", - logger=f"{provider._meta.verbose_name}@{object_type}", - ) - ) - ) + task.info(f"Syncing page {page} or {_object_type._meta.verbose_name_plural}") for obj in paginator.page(page).object_list: obj: Model try: @@ -174,89 +154,58 @@ class SyncTasks: self.logger.debug("skipping object due to SkipObject", obj=obj) continue except DryRunRejected as exc: - messages.append( - asdict( - LogEvent( - _("Dropping mutating request due to dry run"), - log_level="info", - logger=f"{provider._meta.verbose_name}@{object_type}", - attributes={ - "obj": sanitize_item(obj), - "method": exc.method, - "url": exc.url, - "body": exc.body, - }, - ) - ) + task.info( + "Dropping mutating request due to dry run", + obj=sanitize_item(obj), + method=exc.method, + url=exc.url, + body=exc.body, ) except BadRequestSyncException as exc: self.logger.warning("failed to sync object", exc=exc, obj=obj) - messages.append( - asdict( - LogEvent( - _( - ( - "Failed to sync {object_type} {object_name} " - "due to error: {error}" - ).format_map( - { - "object_type": obj._meta.verbose_name, - "object_name": str(obj), - "error": str(exc), - } - ) - ), - log_level="warning", - logger=f"{provider._meta.verbose_name}@{object_type}", - attributes={"arguments": exc.args[1:], "obj": sanitize_item(obj)}, - ) - ) + task.warning( + f"Failed to sync {obj._meta.verbose_name} {str(obj)} due to error: {str(exc)}", + arguments=exc.args[1:], + obj=sanitize_item(obj), ) except TransientSyncException as exc: self.logger.warning("failed to sync object", exc=exc, user=obj) - messages.append( - asdict( - LogEvent( - _( - ( - "Failed to sync {object_type} {object_name} " - "due to transient error: {error}" - ).format_map( - { - "object_type": obj._meta.verbose_name, - "object_name": str(obj), - "error": str(exc), - } - ) - ), - log_level="warning", - logger=f"{provider._meta.verbose_name}@{object_type}", - attributes={"obj": sanitize_item(obj)}, - ) - ) + task.warning( + f"Failed to sync {obj._meta.verbose_name} {str(obj)} due to " + "transient error: {str(exc)}", + obj=sanitize_item(obj), ) except StopSync as exc: self.logger.warning("Stopping sync", exc=exc) - messages.append( - asdict( - LogEvent( - _( - "Stopping sync due to error: {error}".format_map( - { - "error": exc.detail(), - } - ) - ), - log_level="warning", - logger=f"{provider._meta.verbose_name}@{object_type}", - attributes={"obj": sanitize_item(obj)}, - ) - ) + task.warning( + f"Stopping sync due to error: {exc.detail()}", + obj=sanitize_item(obj), ) break - return messages - def sync_signal_direct(self, model: str, pk: str | int, raw_op: str): + def sync_signal_direct_dispatch( + self, + task_sync_signal_direct: Actor[[str, str | int, int, str], None], + model: str, + pk: str | int, + raw_op: str, + ): + for provider in self._provider_model.objects.filter( + Q(backchannel_application__isnull=False) | Q(application__isnull=False) + ): + task_sync_signal_direct.send_with_options( + args=(model, pk, provider.pk, raw_op), + rel_obj=provider, + ) + + def sync_signal_direct( + self, + model: str, + pk: str | int, + provider_pk: int, + raw_op: str, + ): + task: Task = CurrentTask.get_task() self.logger = get_logger().bind( provider_type=class_to_path(self._provider_model), ) @@ -264,65 +213,108 @@ class SyncTasks: instance = model_class.objects.filter(pk=pk).first() if not instance: return + provider: OutgoingSyncProvider = self._provider_model.objects.filter( + Q(backchannel_application__isnull=False) | Q(application__isnull=False), + pk=provider_pk, + ).first() + if not provider: + task.warning("No provider found. Is it assigned to an application?") + return + task.set_uid(slugify(provider.name)) operation = Direction(raw_op) + client = provider.client_for_model(instance.__class__) + # Check if the object is allowed within the provider's restrictions + queryset = provider.get_object_qs(instance.__class__) + if not queryset: + return + + # The queryset we get from the provider must include the instance we've got given + # otherwise ignore this provider + if not queryset.filter(pk=instance.pk).exists(): + return + + try: + if operation == Direction.add: + client.write(instance) + if operation == Direction.remove: + client.delete(instance) + except TransientSyncException as exc: + raise Retry() from exc + except SkipObjectException: + return + except DryRunRejected as exc: + self.logger.info("Rejected dry-run event", exc=exc) + except StopSync as exc: + self.logger.warning("Stopping sync", exc=exc, provider_pk=provider.pk) + + def sync_signal_m2m_dispatch( + self, + task_sync_signal_m2m: Actor[[str, int, str, list[int]], None], + instance_pk: str, + action: str, + pk_set: list[int], + reverse: bool, + ): for provider in self._provider_model.objects.filter( Q(backchannel_application__isnull=False) | Q(application__isnull=False) ): - client = provider.client_for_model(instance.__class__) - # Check if the object is allowed within the provider's restrictions - queryset = provider.get_object_qs(instance.__class__) - if not queryset: - continue + # reverse: instance is a Group, pk_set is a list of user pks + # non-reverse: instance is a User, pk_set is a list of groups + if reverse: + task_sync_signal_m2m.send_with_options( + args=(instance_pk, provider.pk, action, list(pk_set)), + rel_obj=provider, + ) + else: + for pk in pk_set: + task_sync_signal_m2m.send_with_options( + args=(pk, provider.pk, action, [instance_pk]), + rel_obj=provider, + ) - # The queryset we get from the provider must include the instance we've got given - # otherwise ignore this provider - if not queryset.filter(pk=instance.pk).exists(): - continue - - try: - if operation == Direction.add: - client.write(instance) - if operation == Direction.remove: - client.delete(instance) - except TransientSyncException as exc: - raise Retry() from exc - except SkipObjectException: - continue - except DryRunRejected as exc: - self.logger.info("Rejected dry-run event", exc=exc) - except StopSync as exc: - self.logger.warning("Stopping sync", exc=exc, provider_pk=provider.pk) - - def sync_signal_m2m(self, group_pk: str, action: str, pk_set: list[int]): + def sync_signal_m2m( + self, + group_pk: str, + provider_pk: int, + action: str, + pk_set: list[int], + ): + task: Task = CurrentTask.get_task() self.logger = get_logger().bind( provider_type=class_to_path(self._provider_model), ) group = Group.objects.filter(pk=group_pk).first() if not group: return - for provider in self._provider_model.objects.filter( - Q(backchannel_application__isnull=False) | Q(application__isnull=False) - ): - # Check if the object is allowed within the provider's restrictions - queryset: QuerySet = provider.get_object_qs(Group) - # The queryset we get from the provider must include the instance we've got given - # otherwise ignore this provider - if not queryset.filter(pk=group_pk).exists(): - continue + provider: OutgoingSyncProvider = self._provider_model.objects.filter( + Q(backchannel_application__isnull=False) | Q(application__isnull=False), + pk=provider_pk, + ).first() + if not provider: + task.warning("No provider found. Is it assigned to an application?") + return + task.set_uid(slugify(provider.name)) - client = provider.client_for_model(Group) - try: - operation = None - if action == "post_add": - operation = Direction.add - if action == "post_remove": - operation = Direction.remove - client.update_group(group, operation, pk_set) - except TransientSyncException as exc: - raise Retry() from exc - except SkipObjectException: - continue - except DryRunRejected as exc: - self.logger.info("Rejected dry-run event", exc=exc) - except StopSync as exc: - self.logger.warning("Stopping sync", exc=exc, provider_pk=provider.pk) + # Check if the object is allowed within the provider's restrictions + queryset: QuerySet = provider.get_object_qs(Group) + # The queryset we get from the provider must include the instance we've got given + # otherwise ignore this provider + if not queryset.filter(pk=group_pk).exists(): + return + + client = provider.client_for_model(Group) + try: + operation = None + if action == "post_add": + operation = Direction.add + if action == "post_remove": + operation = Direction.remove + client.update_group(group, operation, pk_set) + except TransientSyncException as exc: + raise Retry() from exc + except SkipObjectException: + return + except DryRunRejected as exc: + self.logger.info("Rejected dry-run event", exc=exc) + except StopSync as exc: + self.logger.warning("Stopping sync", exc=exc, provider_pk=provider.pk) diff --git a/authentik/lib/tests/test_config.py b/authentik/lib/tests/test_config.py index 18bc68cee4..3e076b71a8 100644 --- a/authentik/lib/tests/test_config.py +++ b/authentik/lib/tests/test_config.py @@ -23,8 +23,7 @@ class TestConfig(TestCase): """Test config loader""" check_deprecations_env_vars = { - ENV_PREFIX + "_REDIS__BROKER_URL": "redis://myredis:8327/43", - ENV_PREFIX + "_REDIS__BROKER_TRANSPORT_OPTIONS": "bWFzdGVybmFtZT1teW1hc3Rlcg==", + ENV_PREFIX + "_WORKER__CONCURRENCY": "2", ENV_PREFIX + "_REDIS__CACHE_TIMEOUT": "124s", ENV_PREFIX + "_REDIS__CACHE_TIMEOUT_FLOWS": "32m", ENV_PREFIX + "_REDIS__CACHE_TIMEOUT_POLICIES": "3920ns", @@ -176,14 +175,12 @@ class TestConfig(TestCase): config = ConfigLoader() config.update_from_env() config.check_deprecations() - self.assertEqual(config.get("redis.broker_url", UNSET), UNSET) - self.assertEqual(config.get("redis.broker_transport_options", UNSET), UNSET) + self.assertEqual(config.get("worker.concurrency", UNSET), UNSET) self.assertEqual(config.get("redis.cache_timeout", UNSET), UNSET) self.assertEqual(config.get("redis.cache_timeout_flows", UNSET), UNSET) self.assertEqual(config.get("redis.cache_timeout_policies", UNSET), UNSET) self.assertEqual(config.get("redis.cache_timeout_reputation", UNSET), UNSET) - self.assertEqual(config.get("broker.url"), "redis://myredis:8327/43") - self.assertEqual(config.get("broker.transport_options"), "bWFzdGVybmFtZT1teW1hc3Rlcg==") + self.assertEqual(config.get("worker.threads"), 2) self.assertEqual(config.get("cache.timeout"), "124s") self.assertEqual(config.get("cache.timeout_flows"), "32m") self.assertEqual(config.get("cache.timeout_policies"), "3920ns") diff --git a/authentik/outposts/apps.py b/authentik/outposts/apps.py index a7680a9aa5..ff2c55b77a 100644 --- a/authentik/outposts/apps.py +++ b/authentik/outposts/apps.py @@ -5,6 +5,8 @@ from structlog.stdlib import get_logger from authentik.blueprints.apps import ManagedAppConfig from authentik.lib.config import CONFIG +from authentik.lib.utils.time import fqdn_rand +from authentik.tasks.schedules.common import ScheduleSpec LOGGER = get_logger() @@ -60,3 +62,27 @@ class AuthentikOutpostConfig(ManagedAppConfig): outpost.save() else: Outpost.objects.filter(managed=MANAGED_OUTPOST).delete() + + @property + def tenant_schedule_specs(self) -> list[ScheduleSpec]: + from authentik.outposts.tasks import outpost_token_ensurer + + return [ + ScheduleSpec( + actor=outpost_token_ensurer, + crontab=f"{fqdn_rand('outpost_token_ensurer')} */8 * * *", + ), + ] + + @property + def global_schedule_specs(self) -> list[ScheduleSpec]: + from authentik.outposts.tasks import outpost_connection_discovery + + return [ + ScheduleSpec( + actor=outpost_connection_discovery, + crontab=f"{fqdn_rand('outpost_connection_discovery')} */8 * * *", + send_on_startup=True, + paused=not CONFIG.get_bool("outposts.discover"), + ), + ] diff --git a/authentik/outposts/controllers/kubernetes.py b/authentik/outposts/controllers/kubernetes.py index 139fa4d5c0..e0d8f49db0 100644 --- a/authentik/outposts/controllers/kubernetes.py +++ b/authentik/outposts/controllers/kubernetes.py @@ -101,7 +101,13 @@ class KubernetesController(BaseController): all_logs = [] for reconcile_key in self.reconcile_order: if reconcile_key in self.outpost.config.kubernetes_disabled_components: - all_logs += [f"{reconcile_key.title()}: Disabled"] + all_logs.append( + LogEvent( + log_level="info", + event=f"{reconcile_key.title()}: Disabled", + logger=str(type(self)), + ) + ) continue with capture_logs() as logs: reconciler_cls = self.reconcilers.get(reconcile_key) @@ -134,7 +140,13 @@ class KubernetesController(BaseController): all_logs = [] for reconcile_key in self.reconcile_order: if reconcile_key in self.outpost.config.kubernetes_disabled_components: - all_logs += [f"{reconcile_key.title()}: Disabled"] + all_logs.append( + LogEvent( + log_level="info", + event=f"{reconcile_key.title()}: Disabled", + logger=str(type(self)), + ) + ) continue with capture_logs() as logs: reconciler_cls = self.reconcilers.get(reconcile_key) diff --git a/authentik/outposts/models.py b/authentik/outposts/models.py index 2fdb190158..6bafa61f4e 100644 --- a/authentik/outposts/models.py +++ b/authentik/outposts/models.py @@ -35,7 +35,10 @@ from authentik.events.models import Event, EventAction from authentik.lib.config import CONFIG from authentik.lib.models import InheritanceForeignKey, SerializerModel from authentik.lib.sentry import SentryIgnoredException +from authentik.lib.utils.time import fqdn_rand from authentik.outposts.controllers.k8s.utils import get_namespace +from authentik.tasks.schedules.common import ScheduleSpec +from authentik.tasks.schedules.models import ScheduledModel OUR_VERSION = parse(__version__) OUTPOST_HELLO_INTERVAL = 10 @@ -114,7 +117,7 @@ class OutpostServiceConnectionState: healthy: bool -class OutpostServiceConnection(models.Model): +class OutpostServiceConnection(ScheduledModel, models.Model): """Connection details for an Outpost Controller, like Docker or Kubernetes""" uuid = models.UUIDField(default=uuid4, editable=False, primary_key=True) @@ -144,11 +147,11 @@ class OutpostServiceConnection(models.Model): @property def state(self) -> OutpostServiceConnectionState: """Get state of service connection""" - from authentik.outposts.tasks import outpost_service_connection_state + from authentik.outposts.tasks import outpost_service_connection_monitor state = cache.get(self.state_key, None) if not state: - outpost_service_connection_state.delay(self.pk) + outpost_service_connection_monitor.send_with_options(args=(self.pk), rel_obj=self) return OutpostServiceConnectionState("", False) return state @@ -159,6 +162,20 @@ class OutpostServiceConnection(models.Model): # since the response doesn't use the correct inheritance return "" + @property + def schedule_specs(self) -> list[ScheduleSpec]: + from authentik.outposts.tasks import outpost_service_connection_monitor + + return [ + ScheduleSpec( + actor=outpost_service_connection_monitor, + uid=self.name, + args=(self.pk,), + crontab="3-59/15 * * * *", + send_on_save=True, + ), + ] + class DockerServiceConnection(SerializerModel, OutpostServiceConnection): """Service Connection to a Docker endpoint""" @@ -243,7 +260,7 @@ class KubernetesServiceConnection(SerializerModel, OutpostServiceConnection): return "ak-service-connection-kubernetes-form" -class Outpost(SerializerModel, ManagedModel): +class Outpost(ScheduledModel, SerializerModel, ManagedModel): """Outpost instance which manages a service user and token""" uuid = models.UUIDField(default=uuid4, editable=False, primary_key=True) @@ -297,6 +314,21 @@ class Outpost(SerializerModel, ManagedModel): """Username for service user""" return f"ak-outpost-{self.uuid.hex}" + @property + def schedule_specs(self) -> list[ScheduleSpec]: + from authentik.outposts.tasks import outpost_controller + + return [ + ScheduleSpec( + actor=outpost_controller, + uid=self.name, + args=(self.pk,), + kwargs={"action": "up", "from_cache": False}, + crontab=f"{fqdn_rand('outpost_controller')} */4 * * *", + send_on_save=True, + ), + ] + def build_user_permissions(self, user: User): """Create per-object and global permissions for outpost service-account""" # To ensure the user only has the correct permissions, we delete all of them and re-add diff --git a/authentik/outposts/settings.py b/authentik/outposts/settings.py deleted file mode 100644 index c29f9f64ab..0000000000 --- a/authentik/outposts/settings.py +++ /dev/null @@ -1,28 +0,0 @@ -"""Outposts Settings""" - -from celery.schedules import crontab - -from authentik.lib.utils.time import fqdn_rand - -CELERY_BEAT_SCHEDULE = { - "outposts_controller": { - "task": "authentik.outposts.tasks.outpost_controller_all", - "schedule": crontab(minute=fqdn_rand("outposts_controller"), hour="*/4"), - "options": {"queue": "authentik_scheduled"}, - }, - "outposts_service_connection_check": { - "task": "authentik.outposts.tasks.outpost_service_connection_monitor", - "schedule": crontab(minute="3-59/15"), - "options": {"queue": "authentik_scheduled"}, - }, - "outpost_token_ensurer": { - "task": "authentik.outposts.tasks.outpost_token_ensurer", - "schedule": crontab(minute=fqdn_rand("outpost_token_ensurer"), hour="*/8"), - "options": {"queue": "authentik_scheduled"}, - }, - "outpost_connection_discovery": { - "task": "authentik.outposts.tasks.outpost_connection_discovery", - "schedule": crontab(minute=fqdn_rand("outpost_connection_discovery"), hour="*/8"), - "options": {"queue": "authentik_scheduled"}, - }, -} diff --git a/authentik/outposts/signals.py b/authentik/outposts/signals.py index 93d6731871..b08c0ecef2 100644 --- a/authentik/outposts/signals.py +++ b/authentik/outposts/signals.py @@ -1,7 +1,6 @@ """authentik outpost signals""" from django.core.cache import cache -from django.db.models import Model from django.db.models.signals import m2m_changed, post_save, pre_delete, pre_save from django.dispatch import receiver from structlog.stdlib import get_logger @@ -9,27 +8,19 @@ from structlog.stdlib import get_logger from authentik.brands.models import Brand from authentik.core.models import AuthenticatedSession, Provider from authentik.crypto.models import CertificateKeyPair -from authentik.lib.utils.reflection import class_to_path -from authentik.outposts.models import Outpost, OutpostServiceConnection +from authentik.outposts.models import Outpost, OutpostModel, OutpostServiceConnection from authentik.outposts.tasks import ( CACHE_KEY_OUTPOST_DOWN, outpost_controller, - outpost_post_save, + outpost_send_update, outpost_session_end, ) LOGGER = get_logger() -UPDATE_TRIGGERING_MODELS = ( - Outpost, - OutpostServiceConnection, - Provider, - CertificateKeyPair, - Brand, -) @receiver(pre_save, sender=Outpost) -def pre_save_outpost(sender, instance: Outpost, **_): +def outpost_pre_save(sender, instance: Outpost, **_): """Pre-save checks for an outpost, if the name or config.kubernetes_namespace changes, we call down and then wait for the up after save""" old_instances = Outpost.objects.filter(pk=instance.pk) @@ -44,43 +35,89 @@ def pre_save_outpost(sender, instance: Outpost, **_): if bool(dirty): LOGGER.info("Outpost needs re-deployment due to changes", instance=instance) cache.set(CACHE_KEY_OUTPOST_DOWN % instance.pk.hex, old_instance) - outpost_controller.delay(instance.pk.hex, action="down", from_cache=True) + outpost_controller.send_with_options( + args=(instance.pk.hex,), + kwargs={"action": "down", "from_cache": True}, + rel_obj=instance, + ) @receiver(m2m_changed, sender=Outpost.providers.through) -def m2m_changed_update(sender, instance: Model, action: str, **_): +def outpost_m2m_changed(sender, instance: Outpost | Provider, action: str, **_): """Update outpost on m2m change, when providers are added or removed""" - if action in ["post_add", "post_remove", "post_clear"]: - outpost_post_save.delay(class_to_path(instance.__class__), instance.pk) + if action not in ["post_add", "post_remove", "post_clear"]: + return + if isinstance(instance, Outpost): + outpost_controller.send_with_options( + args=(instance.pk,), + rel_obj=instance.service_connection, + ) + outpost_send_update.send_with_options(args=(instance.pk,), rel_obj=instance) + elif isinstance(instance, OutpostModel): + for outpost in instance.outpost_set.all(): + outpost_controller.send_with_options( + args=(instance.pk,), + rel_obj=instance.service_connection, + ) + outpost_send_update.send_with_options(args=(outpost.pk,), rel_obj=outpost) -@receiver(post_save) -def post_save_update(sender, instance: Model, created: bool, **_): - """If an Outpost is saved, Ensure that token is created/updated - - If an OutpostModel, or a model that is somehow connected to an OutpostModel is saved, - we send a message down the relevant OutpostModels WS connection to trigger an update""" - if instance.__module__ == "django.db.migrations.recorder": - return - if instance.__module__ == "__fake__": - return - if not isinstance(instance, UPDATE_TRIGGERING_MODELS): - return - if isinstance(instance, Outpost) and created: +@receiver(post_save, sender=Outpost) +def outpost_post_save(sender, instance: Outpost, created: bool, **_): + if created: LOGGER.info("New outpost saved, ensuring initial token and user are created") _ = instance.token - outpost_post_save.delay(class_to_path(instance.__class__), instance.pk) + outpost_controller.send_with_options(args=(instance.pk,), rel_obj=instance.service_connection) + outpost_send_update.send_with_options(args=(instance.pk,), rel_obj=instance) + + +def outpost_related_post_save(sender, instance: OutpostServiceConnection | OutpostModel, **_): + for outpost in instance.outpost_set.all(): + outpost_send_update.send_with_options(args=(outpost.pk,), rel_obj=outpost) + + +post_save.connect(outpost_related_post_save, sender=OutpostServiceConnection, weak=False) +for subclass in OutpostModel.__subclasses__(): + post_save.connect(outpost_related_post_save, sender=subclass, weak=False) + + +def outpost_reverse_related_post_save(sender, instance: CertificateKeyPair | Brand, **_): + for field in instance._meta.get_fields(): + # Each field is checked if it has a `related_model` attribute (when ForeginKeys or M2Ms) + # are used, and if it has a value + if not hasattr(field, "related_model"): + continue + if not field.related_model: + continue + if not issubclass(field.related_model, OutpostModel): + continue + + field_name = f"{field.name}_set" + if not hasattr(instance, field_name): + continue + + LOGGER.debug("triggering outpost update from field", field=field.name) + # Because the Outpost Model has an M2M to Provider, + # we have to iterate over the entire QS + for reverse in getattr(instance, field_name).all(): + if isinstance(reverse, OutpostModel): + for outpost in reverse.outpost_set.all(): + outpost_send_update.send_with_options(args=(outpost.pk,), rel_obj=outpost) + + +post_save.connect(outpost_reverse_related_post_save, sender=Brand, weak=False) +post_save.connect(outpost_reverse_related_post_save, sender=CertificateKeyPair, weak=False) @receiver(pre_delete, sender=Outpost) -def pre_delete_cleanup(sender, instance: Outpost, **_): +def outpost_pre_delete_cleanup(sender, instance: Outpost, **_): """Ensure that Outpost's user is deleted (which will delete the token through cascade)""" instance.user.delete() cache.set(CACHE_KEY_OUTPOST_DOWN % instance.pk.hex, instance) - outpost_controller.delay(instance.pk.hex, action="down", from_cache=True) + outpost_controller.send(instance.pk.hex, action="down", from_cache=True) @receiver(pre_delete, sender=AuthenticatedSession) -def logout_revoke(sender: type[AuthenticatedSession], instance: AuthenticatedSession, **_): +def outpost_logout_revoke(sender: type[AuthenticatedSession], instance: AuthenticatedSession, **_): """Catch logout by expiring sessions being deleted""" - outpost_session_end.delay(instance.session.session_key) + outpost_session_end.send(instance.session.session_key) diff --git a/authentik/outposts/tasks.py b/authentik/outposts/tasks.py index fe716ff455..531401b4c1 100644 --- a/authentik/outposts/tasks.py +++ b/authentik/outposts/tasks.py @@ -10,19 +10,17 @@ from urllib.parse import urlparse from asgiref.sync import async_to_sync from channels.layers import get_channel_layer from django.core.cache import cache -from django.db import DatabaseError, InternalError, ProgrammingError -from django.db.models.base import Model from django.utils.text import slugify +from django.utils.translation import gettext_lazy as _ +from django_dramatiq_postgres.middleware import CurrentTask from docker.constants import DEFAULT_UNIX_SOCKET +from dramatiq.actor import actor from kubernetes.config.incluster_config import SERVICE_TOKEN_FILENAME from kubernetes.config.kube_config import KUBE_CONFIG_DEFAULT_LOCATION from structlog.stdlib import get_logger from yaml import safe_load -from authentik.events.models import TaskStatus -from authentik.events.system_tasks import SystemTask, prefill_task from authentik.lib.config import CONFIG -from authentik.lib.utils.reflection import path_to_class from authentik.outposts.consumer import OUTPOST_GROUP from authentik.outposts.controllers.base import BaseController, ControllerException from authentik.outposts.controllers.docker import DockerClient @@ -31,7 +29,6 @@ from authentik.outposts.models import ( DockerServiceConnection, KubernetesServiceConnection, Outpost, - OutpostModel, OutpostServiceConnection, OutpostType, ServiceConnectionInvalid, @@ -44,7 +41,7 @@ from authentik.providers.rac.controllers.docker import RACDockerController from authentik.providers.rac.controllers.kubernetes import RACKubernetesController from authentik.providers.radius.controllers.docker import RadiusDockerController from authentik.providers.radius.controllers.kubernetes import RadiusKubernetesController -from authentik.root.celery import CELERY_APP +from authentik.tasks.models import Task LOGGER = get_logger() CACHE_KEY_OUTPOST_DOWN = "goauthentik.io/outposts/teardown/%s" @@ -83,8 +80,8 @@ def controller_for_outpost(outpost: Outpost) -> type[BaseController] | None: return None -@CELERY_APP.task() -def outpost_service_connection_state(connection_pk: Any): +@actor(description=_("Update cached state of service connection.")) +def outpost_service_connection_monitor(connection_pk: Any): """Update cached state of a service connection""" connection: OutpostServiceConnection = ( OutpostServiceConnection.objects.filter(pk=connection_pk).select_subclasses().first() @@ -108,37 +105,11 @@ def outpost_service_connection_state(connection_pk: Any): cache.set(connection.state_key, state, timeout=None) -@CELERY_APP.task( - bind=True, - base=SystemTask, - throws=(DatabaseError, ProgrammingError, InternalError), -) -@prefill_task -def outpost_service_connection_monitor(self: SystemTask): - """Regularly check the state of Outpost Service Connections""" - connections = OutpostServiceConnection.objects.all() - for connection in connections.iterator(): - outpost_service_connection_state.delay(connection.pk) - self.set_status( - TaskStatus.SUCCESSFUL, - f"Successfully updated {len(connections)} connections.", - ) - - -@CELERY_APP.task( - throws=(DatabaseError, ProgrammingError, InternalError), -) -def outpost_controller_all(): - """Launch Controller for all Outposts which support it""" - for outpost in Outpost.objects.exclude(service_connection=None): - outpost_controller.delay(outpost.pk.hex, "up", from_cache=False) - - -@CELERY_APP.task(bind=True, base=SystemTask) -def outpost_controller( - self: SystemTask, outpost_pk: str, action: str = "up", from_cache: bool = False -): +@actor(description=_("Create/update/monitor/delete the deployment of an Outpost.")) +def outpost_controller(outpost_pk: str, action: str = "up", from_cache: bool = False): """Create/update/monitor/delete the deployment of an Outpost""" + self: Task = CurrentTask.get_task() + self.set_uid(outpost_pk) logs = [] if from_cache: outpost: Outpost = cache.get(CACHE_KEY_OUTPOST_DOWN % outpost_pk) @@ -159,125 +130,65 @@ def outpost_controller( logs = getattr(controller, f"{action}_with_logs")() LOGGER.debug("-----------------Outpost Controller logs end-------------------") except (ControllerException, ServiceConnectionInvalid) as exc: - self.set_error(exc) + self.error(exc) else: if from_cache: cache.delete(CACHE_KEY_OUTPOST_DOWN % outpost_pk) - self.set_status(TaskStatus.SUCCESSFUL, *logs) + self.logs(logs) -@CELERY_APP.task(bind=True, base=SystemTask) -@prefill_task -def outpost_token_ensurer(self: SystemTask): - """Periodically ensure that all Outposts have valid Service Accounts - and Tokens""" +@actor(description=_("Ensure that all Outposts have valid Service Accounts and Tokens.")) +def outpost_token_ensurer(): + """ + Periodically ensure that all Outposts have valid Service Accounts and Tokens + """ + self: Task = CurrentTask.get_task() all_outposts = Outpost.objects.all() for outpost in all_outposts: _ = outpost.token outpost.build_user_permissions(outpost.user) - self.set_status( - TaskStatus.SUCCESSFUL, - f"Successfully checked {len(all_outposts)} Outposts.", - ) + self.info(f"Successfully checked {len(all_outposts)} Outposts.") -@CELERY_APP.task() -def outpost_post_save(model_class: str, model_pk: Any): - """If an Outpost is saved, Ensure that token is created/updated - - If an OutpostModel, or a model that is somehow connected to an OutpostModel is saved, - we send a message down the relevant OutpostModels WS connection to trigger an update""" - model: Model = path_to_class(model_class) - try: - instance = model.objects.get(pk=model_pk) - except model.DoesNotExist: - LOGGER.warning("Model does not exist", model=model, pk=model_pk) +@actor(description=_("Send update to outpost")) +def outpost_send_update(pk: Any): + """Update outpost instance""" + outpost = Outpost.objects.filter(pk=pk).first() + if not outpost: return - - if isinstance(instance, Outpost): - LOGGER.debug("Trigger reconcile for outpost", instance=instance) - outpost_controller.delay(str(instance.pk)) - - if isinstance(instance, OutpostModel | Outpost): - LOGGER.debug("triggering outpost update from outpostmodel/outpost", instance=instance) - outpost_send_update(instance) - - if isinstance(instance, OutpostServiceConnection): - LOGGER.debug("triggering ServiceConnection state update", instance=instance) - outpost_service_connection_state.delay(str(instance.pk)) - - for field in instance._meta.get_fields(): - # Each field is checked if it has a `related_model` attribute (when ForeginKeys or M2Ms) - # are used, and if it has a value - if not hasattr(field, "related_model"): - continue - if not field.related_model: - continue - if not issubclass(field.related_model, OutpostModel): - continue - - field_name = f"{field.name}_set" - if not hasattr(instance, field_name): - continue - - LOGGER.debug("triggering outpost update from field", field=field.name) - # Because the Outpost Model has an M2M to Provider, - # we have to iterate over the entire QS - for reverse in getattr(instance, field_name).all(): - outpost_send_update(reverse) - - -def outpost_send_update(model_instance: Model): - """Send outpost update to all registered outposts, regardless to which authentik - instance they are connected""" - channel_layer = get_channel_layer() - if isinstance(model_instance, OutpostModel): - for outpost in model_instance.outpost_set.all(): - _outpost_single_update(outpost, channel_layer) - elif isinstance(model_instance, Outpost): - _outpost_single_update(model_instance, channel_layer) - - -def _outpost_single_update(outpost: Outpost, layer=None): - """Update outpost instances connected to a single outpost""" # Ensure token again, because this function is called when anything related to an # OutpostModel is saved, so we can be sure permissions are right _ = outpost.token outpost.build_user_permissions(outpost.user) - if not layer: # pragma: no cover - layer = get_channel_layer() + layer = get_channel_layer() group = OUTPOST_GROUP % {"outpost_pk": str(outpost.pk)} LOGGER.debug("sending update", channel=group, outpost=outpost) async_to_sync(layer.group_send)(group, {"type": "event.update"}) -@CELERY_APP.task( - base=SystemTask, - bind=True, -) -def outpost_connection_discovery(self: SystemTask): +@actor(description=_("Checks the local environment and create Service connections.")) +def outpost_connection_discovery(): """Checks the local environment and create Service connections.""" - messages = [] + self: Task = CurrentTask.get_task() if not CONFIG.get_bool("outposts.discover"): - messages.append("Outpost integration discovery is disabled") - self.set_status(TaskStatus.SUCCESSFUL, *messages) + self.info("Outpost integration discovery is disabled") return # Explicitly check against token filename, as that's # only present when the integration is enabled if Path(SERVICE_TOKEN_FILENAME).exists(): - messages.append("Detected in-cluster Kubernetes Config") + self.info("Detected in-cluster Kubernetes Config") if not KubernetesServiceConnection.objects.filter(local=True).exists(): - messages.append("Created Service Connection for in-cluster") + self.info("Created Service Connection for in-cluster") KubernetesServiceConnection.objects.create( name="Local Kubernetes Cluster", local=True, kubeconfig={} ) # For development, check for the existence of a kubeconfig file kubeconfig_path = Path(KUBE_CONFIG_DEFAULT_LOCATION).expanduser() if kubeconfig_path.exists(): - messages.append("Detected kubeconfig") + self.info("Detected kubeconfig") kubeconfig_local_name = f"k8s-{gethostname()}" if not KubernetesServiceConnection.objects.filter(name=kubeconfig_local_name).exists(): - messages.append("Creating kubeconfig Service Connection") + self.info("Creating kubeconfig Service Connection") with kubeconfig_path.open("r", encoding="utf8") as _kubeconfig: KubernetesServiceConnection.objects.create( name=kubeconfig_local_name, @@ -286,20 +197,18 @@ def outpost_connection_discovery(self: SystemTask): unix_socket_path = urlparse(DEFAULT_UNIX_SOCKET).path socket = Path(unix_socket_path) if socket.exists() and access(socket, R_OK): - messages.append("Detected local docker socket") + self.info("Detected local docker socket") if len(DockerServiceConnection.objects.filter(local=True)) == 0: - messages.append("Created Service Connection for docker") + self.info("Created Service Connection for docker") DockerServiceConnection.objects.create( name="Local Docker connection", local=True, url=unix_socket_path, ) - self.set_status(TaskStatus.SUCCESSFUL, *messages) -@CELERY_APP.task() +@actor(description=_("Terminate session on all outposts.")) def outpost_session_end(session_id: str): - """Update outpost instances connected to a single outpost""" layer = get_channel_layer() hashed_session_id = hash_session_key(session_id) for outpost in Outpost.objects.all(): diff --git a/authentik/outposts/tests/test_sa.py b/authentik/outposts/tests/test_sa.py index 59238a2cf8..92ac083f6e 100644 --- a/authentik/outposts/tests/test_sa.py +++ b/authentik/outposts/tests/test_sa.py @@ -37,6 +37,7 @@ class OutpostTests(TestCase): # We add a provider, user should only have access to outpost and provider outpost.providers.add(provider) + provider.refresh_from_db() permissions = UserObjectPermission.objects.filter(user=outpost.user).order_by( "content_type__model" ) diff --git a/authentik/providers/proxy/apps.py b/authentik/providers/proxy/apps.py index 50b3174b3c..4e9055b586 100644 --- a/authentik/providers/proxy/apps.py +++ b/authentik/providers/proxy/apps.py @@ -15,6 +15,7 @@ class AuthentikProviderProxyConfig(ManagedAppConfig): def proxy_set_defaults(self): from authentik.providers.proxy.models import ProxyProvider + # TODO: figure out if this can be in pre_save + post_save signals for provider in ProxyProvider.objects.all(): provider.set_oauth_defaults() provider.save() diff --git a/authentik/providers/proxy/signals.py b/authentik/providers/proxy/signals.py new file mode 100644 index 0000000000..80dea9fb6c --- /dev/null +++ b/authentik/providers/proxy/signals.py @@ -0,0 +1,13 @@ +"""Proxy provider signals""" + +from django.db.models.signals import pre_delete +from django.dispatch import receiver + +from authentik.core.models import AuthenticatedSession +from authentik.providers.proxy.tasks import proxy_on_logout + + +@receiver(pre_delete, sender=AuthenticatedSession) +def logout_proxy_revoke(sender: type[AuthenticatedSession], instance: AuthenticatedSession, **_): + """Catch logout by expiring sessions being deleted""" + proxy_on_logout.send(instance.session.session_key) diff --git a/authentik/providers/proxy/tasks.py b/authentik/providers/proxy/tasks.py new file mode 100644 index 0000000000..d3bf2674c2 --- /dev/null +++ b/authentik/providers/proxy/tasks.py @@ -0,0 +1,26 @@ +"""proxy provider tasks""" + +from asgiref.sync import async_to_sync +from channels.layers import get_channel_layer +from django.utils.translation import gettext_lazy as _ +from dramatiq.actor import actor + +from authentik.outposts.consumer import OUTPOST_GROUP +from authentik.outposts.models import Outpost, OutpostType +from authentik.providers.oauth2.id_token import hash_session_key + + +@actor(description=_("Terminate session on Proxy outpost.")) +def proxy_on_logout(session_id: str): + layer = get_channel_layer() + hashed_session_id = hash_session_key(session_id) + for outpost in Outpost.objects.filter(type=OutpostType.PROXY): + group = OUTPOST_GROUP % {"outpost_pk": str(outpost.pk)} + async_to_sync(layer.group_send)( + group, + { + "type": "event.provider.specific", + "sub_type": "logout", + "session_id": hashed_session_id, + }, + ) diff --git a/authentik/providers/rac/models.py b/authentik/providers/rac/models.py index f9f28c7cb0..b7d898cff4 100644 --- a/authentik/providers/rac/models.py +++ b/authentik/providers/rac/models.py @@ -17,6 +17,7 @@ from authentik.core.models import ExpiringModel, PropertyMapping, Provider, User from authentik.events.models import Event, EventAction from authentik.lib.models import SerializerModel from authentik.lib.utils.time import timedelta_string_validator +from authentik.outposts.models import OutpostModel from authentik.policies.models import PolicyBindingModel LOGGER = get_logger() @@ -37,7 +38,7 @@ class AuthenticationMode(models.TextChoices): PROMPT = "prompt" -class RACProvider(Provider): +class RACProvider(OutpostModel, Provider): """Remotely access computers/servers via RDP/SSH/VNC.""" settings = models.JSONField(default=dict) diff --git a/authentik/providers/scim/api/providers.py b/authentik/providers/scim/api/providers.py index 350473d646..1ce65953c1 100644 --- a/authentik/providers/scim/api/providers.py +++ b/authentik/providers/scim/api/providers.py @@ -44,5 +44,5 @@ class SCIMProviderViewSet(OutgoingSyncProviderStatusMixin, UsedByMixin, ModelVie filterset_fields = ["name", "exclude_users_service_account", "url", "filter_group"] search_fields = ["name", "url"] ordering = ["name", "url"] - sync_single_task = scim_sync + sync_task = scim_sync sync_objects_task = scim_sync_objects diff --git a/authentik/providers/scim/management/commands/scim_sync.py b/authentik/providers/scim/management/commands/scim_sync.py index 2458c1f826..7eca362755 100644 --- a/authentik/providers/scim/management/commands/scim_sync.py +++ b/authentik/providers/scim/management/commands/scim_sync.py @@ -3,7 +3,6 @@ from structlog.stdlib import get_logger from authentik.providers.scim.models import SCIMProvider -from authentik.providers.scim.tasks import scim_sync, sync_tasks from authentik.tenants.management import TenantCommand LOGGER = get_logger() @@ -21,4 +20,5 @@ class Command(TenantCommand): if not provider: LOGGER.warning("Provider does not exist", name=provider_name) continue - sync_tasks.trigger_single_task(provider, scim_sync).get() + for schedule in provider.schedules.all(): + schedule.send().get_result() diff --git a/authentik/providers/scim/models.py b/authentik/providers/scim/models.py index 606cec31cd..4a8f047a34 100644 --- a/authentik/providers/scim/models.py +++ b/authentik/providers/scim/models.py @@ -7,6 +7,7 @@ from django.db import models from django.db.models import QuerySet from django.templatetags.static import static from django.utils.translation import gettext_lazy as _ +from dramatiq.actor import Actor from rest_framework.serializers import Serializer from authentik.core.models import BackchannelProvider, Group, PropertyMapping, User, UserTypes @@ -99,6 +100,12 @@ class SCIMProvider(OutgoingSyncProvider, BackchannelProvider): def icon_url(self) -> str | None: return static("authentik/sources/scim.png") + @property + def sync_actor(self) -> Actor: + from authentik.providers.scim.tasks import scim_sync + + return scim_sync + def client_for_model( self, model: type[User | Group | SCIMProviderUser | SCIMProviderGroup] ) -> BaseOutgoingSyncClient[User | Group, Any, Any, Self]: diff --git a/authentik/providers/scim/settings.py b/authentik/providers/scim/settings.py deleted file mode 100644 index 0a0963ea90..0000000000 --- a/authentik/providers/scim/settings.py +++ /dev/null @@ -1,13 +0,0 @@ -"""SCIM task Settings""" - -from celery.schedules import crontab - -from authentik.lib.utils.time import fqdn_rand - -CELERY_BEAT_SCHEDULE = { - "providers_scim_sync": { - "task": "authentik.providers.scim.tasks.scim_sync_all", - "schedule": crontab(minute=fqdn_rand("scim_sync_all"), hour="*/4"), - "options": {"queue": "authentik_scheduled"}, - }, -} diff --git a/authentik/providers/scim/signals.py b/authentik/providers/scim/signals.py index be9855f26d..c52c2e9e31 100644 --- a/authentik/providers/scim/signals.py +++ b/authentik/providers/scim/signals.py @@ -2,11 +2,10 @@ from authentik.lib.sync.outgoing.signals import register_signals from authentik.providers.scim.models import SCIMProvider -from authentik.providers.scim.tasks import scim_sync, scim_sync_direct, scim_sync_m2m +from authentik.providers.scim.tasks import scim_sync_direct_dispatch, scim_sync_m2m_dispatch register_signals( SCIMProvider, - task_sync_single=scim_sync, - task_sync_direct=scim_sync_direct, - task_sync_m2m=scim_sync_m2m, + task_sync_direct_dispatch=scim_sync_direct_dispatch, + task_sync_m2m_dispatch=scim_sync_m2m_dispatch, ) diff --git a/authentik/providers/scim/tasks.py b/authentik/providers/scim/tasks.py index f3c2e4d493..f4e374e2f4 100644 --- a/authentik/providers/scim/tasks.py +++ b/authentik/providers/scim/tasks.py @@ -1,37 +1,40 @@ """SCIM Provider tasks""" -from authentik.events.system_tasks import SystemTask -from authentik.lib.sync.outgoing.exceptions import TransientSyncException +from django.utils.translation import gettext_lazy as _ +from dramatiq.actor import actor + from authentik.lib.sync.outgoing.tasks import SyncTasks from authentik.providers.scim.models import SCIMProvider -from authentik.root.celery import CELERY_APP sync_tasks = SyncTasks(SCIMProvider) -@CELERY_APP.task(autoretry_for=(TransientSyncException,), retry_backoff=True) +@actor(description=_("Sync SCIM provider objects.")) def scim_sync_objects(*args, **kwargs): return sync_tasks.sync_objects(*args, **kwargs) -@CELERY_APP.task( - base=SystemTask, bind=True, autoretry_for=(TransientSyncException,), retry_backoff=True -) -def scim_sync(self, provider_pk: int, *args, **kwargs): +@actor(description=_("Full sync for SCIM provider.")) +def scim_sync(provider_pk: int, *args, **kwargs): """Run full sync for SCIM provider""" - return sync_tasks.sync_single(self, provider_pk, scim_sync_objects) + return sync_tasks.sync(provider_pk, scim_sync_objects) -@CELERY_APP.task() -def scim_sync_all(): - return sync_tasks.sync_all(scim_sync) - - -@CELERY_APP.task(autoretry_for=(TransientSyncException,), retry_backoff=True) +@actor(description=_("Sync a direct object (user, group) for SCIM provider.")) def scim_sync_direct(*args, **kwargs): return sync_tasks.sync_signal_direct(*args, **kwargs) -@CELERY_APP.task(autoretry_for=(TransientSyncException,), retry_backoff=True) +@actor(description=_("Dispatch syncs for a direct object (user, group) for SCIM providers.")) +def scim_sync_direct_dispatch(*args, **kwargs): + return sync_tasks.sync_signal_direct_dispatch(scim_sync_direct, *args, **kwargs) + + +@actor(description=_("Sync a related object (memberships) for SCIM provider.")) def scim_sync_m2m(*args, **kwargs): return sync_tasks.sync_signal_m2m(*args, **kwargs) + + +@actor(description=_("Dispatch syncs for a related object (memberships) for SCIM providers.")) +def scim_sync_m2m_dispatch(*args, **kwargs): + return sync_tasks.sync_signal_m2m_dispatch(scim_sync_m2m, *args, **kwargs) diff --git a/authentik/providers/scim/tests/test_client.py b/authentik/providers/scim/tests/test_client.py index 0d3dd57e51..2735c23d3b 100644 --- a/authentik/providers/scim/tests/test_client.py +++ b/authentik/providers/scim/tests/test_client.py @@ -8,7 +8,7 @@ from authentik.core.models import Application from authentik.lib.generators import generate_id from authentik.providers.scim.clients.base import SCIMClient from authentik.providers.scim.models import SCIMMapping, SCIMProvider -from authentik.providers.scim.tasks import scim_sync_all +from authentik.providers.scim.tasks import scim_sync class SCIMClientTests(TestCase): @@ -85,6 +85,6 @@ class SCIMClientTests(TestCase): self.assertEqual(mock.call_count, 1) self.assertEqual(mock.request_history[0].method, "GET") - def test_scim_sync_all(self): - """test scim_sync_all task""" - scim_sync_all() + def test_scim_sync(self): + """test scim_sync task""" + scim_sync.send(self.provider.pk).get_result() diff --git a/authentik/providers/scim/tests/test_membership.py b/authentik/providers/scim/tests/test_membership.py index 24084622fc..6d9eb8dab2 100644 --- a/authentik/providers/scim/tests/test_membership.py +++ b/authentik/providers/scim/tests/test_membership.py @@ -8,7 +8,7 @@ from authentik.core.models import Application, Group, User from authentik.lib.generators import generate_id from authentik.providers.scim.clients.schema import ServiceProviderConfiguration from authentik.providers.scim.models import SCIMMapping, SCIMProvider -from authentik.providers.scim.tasks import scim_sync, sync_tasks +from authentik.providers.scim.tasks import scim_sync from authentik.tenants.models import Tenant @@ -79,17 +79,15 @@ class SCIMMembershipTests(TestCase): ) self.configure() - sync_tasks.trigger_single_task(self.provider, scim_sync).get() + scim_sync.send(self.provider.pk) - self.assertEqual(mocker.call_count, 6) + self.assertEqual(mocker.call_count, 4) self.assertEqual(mocker.request_history[0].method, "GET") - self.assertEqual(mocker.request_history[1].method, "GET") + self.assertEqual(mocker.request_history[1].method, "POST") self.assertEqual(mocker.request_history[2].method, "GET") self.assertEqual(mocker.request_history[3].method, "POST") - self.assertEqual(mocker.request_history[4].method, "GET") - self.assertEqual(mocker.request_history[5].method, "POST") self.assertJSONEqual( - mocker.request_history[3].body, + mocker.request_history[1].body, { "schemas": ["urn:ietf:params:scim:schemas:core:2.0:User"], "emails": [], @@ -101,7 +99,7 @@ class SCIMMembershipTests(TestCase): }, ) self.assertJSONEqual( - mocker.request_history[5].body, + mocker.request_history[3].body, { "schemas": ["urn:ietf:params:scim:schemas:core:2.0:Group"], "externalId": str(group.pk), @@ -169,17 +167,15 @@ class SCIMMembershipTests(TestCase): ) self.configure() - sync_tasks.trigger_single_task(self.provider, scim_sync).get() + scim_sync.send(self.provider.pk) - self.assertEqual(mocker.call_count, 6) + self.assertEqual(mocker.call_count, 4) self.assertEqual(mocker.request_history[0].method, "GET") - self.assertEqual(mocker.request_history[1].method, "GET") + self.assertEqual(mocker.request_history[1].method, "POST") self.assertEqual(mocker.request_history[2].method, "GET") self.assertEqual(mocker.request_history[3].method, "POST") - self.assertEqual(mocker.request_history[4].method, "GET") - self.assertEqual(mocker.request_history[5].method, "POST") self.assertJSONEqual( - mocker.request_history[3].body, + mocker.request_history[1].body, { "schemas": ["urn:ietf:params:scim:schemas:core:2.0:User"], "active": True, @@ -191,7 +187,7 @@ class SCIMMembershipTests(TestCase): }, ) self.assertJSONEqual( - mocker.request_history[5].body, + mocker.request_history[3].body, { "schemas": ["urn:ietf:params:scim:schemas:core:2.0:Group"], "externalId": str(group.pk), @@ -287,17 +283,15 @@ class SCIMMembershipTests(TestCase): ) self.configure() - sync_tasks.trigger_single_task(self.provider, scim_sync).get() + scim_sync.send(self.provider.pk) - self.assertEqual(mocker.call_count, 6) + self.assertEqual(mocker.call_count, 4) self.assertEqual(mocker.request_history[0].method, "GET") - self.assertEqual(mocker.request_history[1].method, "GET") + self.assertEqual(mocker.request_history[1].method, "POST") self.assertEqual(mocker.request_history[2].method, "GET") self.assertEqual(mocker.request_history[3].method, "POST") - self.assertEqual(mocker.request_history[4].method, "GET") - self.assertEqual(mocker.request_history[5].method, "POST") self.assertJSONEqual( - mocker.request_history[3].body, + mocker.request_history[1].body, { "schemas": ["urn:ietf:params:scim:schemas:core:2.0:User"], "emails": [], @@ -309,7 +303,7 @@ class SCIMMembershipTests(TestCase): }, ) self.assertJSONEqual( - mocker.request_history[5].body, + mocker.request_history[3].body, { "schemas": ["urn:ietf:params:scim:schemas:core:2.0:Group"], "externalId": str(group.pk), diff --git a/authentik/providers/scim/tests/test_user.py b/authentik/providers/scim/tests/test_user.py index 2f32b68857..d29d14ff67 100644 --- a/authentik/providers/scim/tests/test_user.py +++ b/authentik/providers/scim/tests/test_user.py @@ -9,11 +9,11 @@ from requests_mock import Mocker from authentik.blueprints.tests import apply_blueprint from authentik.core.models import Application, Group, User -from authentik.events.models import SystemTask from authentik.lib.generators import generate_id from authentik.lib.sync.outgoing.base import SAFE_METHODS from authentik.providers.scim.models import SCIMMapping, SCIMProvider -from authentik.providers.scim.tasks import scim_sync, sync_tasks +from authentik.providers.scim.tasks import scim_sync, scim_sync_objects +from authentik.tasks.models import Task from authentik.tenants.models import Tenant @@ -356,7 +356,7 @@ class SCIMUserTests(TestCase): email=f"{uid}@goauthentik.io", ) - sync_tasks.trigger_single_task(self.provider, scim_sync).get() + scim_sync.send(self.provider.pk) self.assertEqual(mock.call_count, 5) self.assertEqual(mock.request_history[0].method, "GET") @@ -428,14 +428,19 @@ class SCIMUserTests(TestCase): email=f"{uid}@goauthentik.io", ) - sync_tasks.trigger_single_task(self.provider, scim_sync).get() + scim_sync.send(self.provider.pk) self.assertEqual(mock.call_count, 3) for request in mock.request_history: self.assertIn(request.method, SAFE_METHODS) - task = SystemTask.objects.filter(uid=slugify(self.provider.name)).first() + task = list( + Task.objects.filter( + actor_name=scim_sync_objects.actor_name, + _uid=slugify(self.provider.name), + ).order_by("-mtime") + )[1] self.assertIsNotNone(task) - drop_msg = task.messages[3] + drop_msg = task._messages[3] self.assertEqual(drop_msg["event"], "Dropping mutating request due to dry run") self.assertIsNotNone(drop_msg["attributes"]["url"]) self.assertIsNotNone(drop_msg["attributes"]["body"]) diff --git a/authentik/root/celery.py b/authentik/root/celery.py deleted file mode 100644 index ecf0c55fba..0000000000 --- a/authentik/root/celery.py +++ /dev/null @@ -1,167 +0,0 @@ -"""authentik core celery""" - -import os -from collections.abc import Callable -from contextvars import ContextVar -from logging.config import dictConfig -from pathlib import Path -from tempfile import gettempdir - -from celery import bootsteps -from celery.apps.worker import Worker -from celery.signals import ( - after_task_publish, - setup_logging, - task_failure, - task_internal_error, - task_postrun, - task_prerun, - worker_ready, -) -from celery.worker.control import inspect_command -from django.conf import settings -from django.db import ProgrammingError -from django_tenants.utils import get_public_schema_name -from structlog.contextvars import STRUCTLOG_KEY_PREFIX -from structlog.stdlib import get_logger -from tenant_schemas_celery.app import CeleryApp as TenantAwareCeleryApp - -from authentik import get_full_version -from authentik.lib.sentry import should_ignore_exception - -# set the default Django settings module for the 'celery' program. -os.environ.setdefault("DJANGO_SETTINGS_MODULE", "authentik.root.settings") - -LOGGER = get_logger() -CELERY_APP = TenantAwareCeleryApp("authentik") -CTX_TASK_ID = ContextVar(STRUCTLOG_KEY_PREFIX + "task_id", default=Ellipsis) -HEARTBEAT_FILE = Path(gettempdir() + "/authentik-worker") - - -@setup_logging.connect -def config_loggers(*args, **kwargs): - """Apply logging settings from settings.py to celery""" - dictConfig(settings.LOGGING) - - -@after_task_publish.connect -def after_task_publish_hook(sender=None, headers=None, body=None, **kwargs): - """Log task_id after it was published""" - info = headers if "task" in headers else body - LOGGER.info( - "Task published", - task_id=info.get("id", "").replace("-", ""), - task_name=info.get("task", ""), - ) - - -@task_prerun.connect -def task_prerun_hook(task_id: str, task, *args, **kwargs): - """Log task_id on worker""" - request_id = "task-" + task_id.replace("-", "") - CTX_TASK_ID.set(request_id) - LOGGER.info("Task started", task_id=task_id, task_name=task.__name__) - - -@task_postrun.connect -def task_postrun_hook(task_id: str, task, *args, retval=None, state=None, **kwargs): - """Log task_id on worker""" - CTX_TASK_ID.set(...) - LOGGER.info( - "Task finished", task_id=task_id.replace("-", ""), task_name=task.__name__, state=state - ) - - -@task_failure.connect -@task_internal_error.connect -def task_error_hook(task_id: str, exception: Exception, traceback, *args, **kwargs): - """Create system event for failed task""" - from authentik.events.models import Event, EventAction - - LOGGER.warning("Task failure", task_id=task_id.replace("-", ""), exc=exception) - CTX_TASK_ID.set(...) - if not should_ignore_exception(exception): - Event.new( - EventAction.SYSTEM_EXCEPTION, message="Failed to execute task", task_id=task_id - ).with_exception(exception).save() - - -def _get_startup_tasks_default_tenant() -> list[Callable]: - """Get all tasks to be run on startup for the default tenant""" - from authentik.outposts.tasks import outpost_connection_discovery - - return [ - outpost_connection_discovery, - ] - - -def _get_startup_tasks_all_tenants() -> list[Callable]: - """Get all tasks to be run on startup for all tenants""" - return [] - - -@worker_ready.connect -def worker_ready_hook(*args, **kwargs): - """Run certain tasks on worker start""" - from authentik.tenants.models import Tenant - - LOGGER.info("Dispatching startup tasks...") - - def _run_task(task: Callable): - try: - task.delay() - except ProgrammingError as exc: - LOGGER.warning("Startup task failed", task=task, exc=exc) - - for task in _get_startup_tasks_default_tenant(): - with Tenant.objects.get(schema_name=get_public_schema_name()): - _run_task(task) - - for task in _get_startup_tasks_all_tenants(): - for tenant in Tenant.objects.filter(ready=True): - with tenant: - _run_task(task) - - from authentik.blueprints.v1.tasks import start_blueprint_watcher - - start_blueprint_watcher() - - -class LivenessProbe(bootsteps.StartStopStep): - """Add a timed task to touch a temporary file for healthchecking reasons""" - - requires = {"celery.worker.components:Timer"} - - def __init__(self, parent, **kwargs): - super().__init__(parent, **kwargs) - self.requests = [] - self.tref = None - - def start(self, parent: Worker): - self.tref = parent.timer.call_repeatedly( - 10.0, - self.update_heartbeat_file, - (parent,), - priority=10, - ) - self.update_heartbeat_file(parent) - - def stop(self, parent: Worker): - HEARTBEAT_FILE.unlink(missing_ok=True) - - def update_heartbeat_file(self, worker: Worker): - """Touch heartbeat file""" - HEARTBEAT_FILE.touch() - - -@inspect_command(default_timeout=0.2) -def ping(state, **kwargs): - """Ping worker(s).""" - return {"ok": "pong", "version": get_full_version()} - - -CELERY_APP.config_from_object(settings.CELERY) - -# Load task modules from all registered Django app configs. -CELERY_APP.autodiscover_tasks() -CELERY_APP.steps["worker"].add(LivenessProbe) diff --git a/authentik/root/settings.py b/authentik/root/settings.py index bb56cafa3c..68b7f629c9 100644 --- a/authentik/root/settings.py +++ b/authentik/root/settings.py @@ -4,9 +4,9 @@ import importlib from collections import OrderedDict from hashlib import sha512 from pathlib import Path +from tempfile import gettempdir import orjson -from celery.schedules import crontab from sentry_sdk import set_tag from xmlsec import enable_debug_trace @@ -65,14 +65,18 @@ SHARED_APPS = [ "pgactivity", "pglock", "channels", + "django_dramatiq_postgres", + "authentik.tasks", ] TENANT_APPS = [ "django.contrib.auth", "django.contrib.contenttypes", "django.contrib.sessions", + "pgtrigger", "authentik.admin", "authentik.api", "authentik.crypto", + "authentik.events", "authentik.flows", "authentik.outposts", "authentik.policies.dummy", @@ -120,6 +124,7 @@ TENANT_APPS = [ "authentik.stages.user_login", "authentik.stages.user_logout", "authentik.stages.user_write", + "authentik.tasks.schedules", "authentik.brands", "authentik.blueprints", "guardian", @@ -165,6 +170,7 @@ SPECTACULAR_SETTINGS = { "PolicyEngineMode": "authentik.policies.models.PolicyEngineMode", "PromptTypeEnum": "authentik.stages.prompt.models.FieldTypes", "ProxyMode": "authentik.providers.proxy.models.ProxyMode", + "TaskAggregatedStatusEnum": "authentik.tasks.models.TaskStatus", "UserTypeEnum": "authentik.core.models.UserTypes", "UserVerificationEnum": "authentik.stages.authenticator_webauthn.models.UserVerification", }, @@ -342,37 +348,86 @@ USE_TZ = True LOCALE_PATHS = ["./locale"] -CELERY = { - "task_soft_time_limit": 600, - "worker_max_tasks_per_child": 50, - "worker_concurrency": CONFIG.get_int("worker.concurrency"), - "beat_schedule": { - "clean_expired_models": { - "task": "authentik.core.tasks.clean_expired_models", - "schedule": crontab(minute="2-59/5"), - "options": {"queue": "authentik_scheduled"}, - }, - "user_cleanup": { - "task": "authentik.core.tasks.clean_temporary_users", - "schedule": crontab(minute="9-59/5"), - "options": {"queue": "authentik_scheduled"}, - }, + +# Tests + +TEST = False +TEST_RUNNER = "authentik.root.test_runner.PytestTestRunner" + + +# Dramatiq + +DRAMATIQ = { + "broker_class": "authentik.tasks.broker.Broker", + "channel_prefix": "authentik", + "task_model": "authentik.tasks.models.Task", + "task_purge_interval": timedelta_from_string( + CONFIG.get("worker.task_purge_interval") + ).total_seconds(), + "task_expiration": timedelta_from_string(CONFIG.get("worker.task_expiration")).total_seconds(), + "autodiscovery": { + "enabled": True, + "setup_module": "authentik.tasks.setup", + "apps_prefix": "authentik", }, - "beat_scheduler": "authentik.tenants.scheduler:TenantAwarePersistentScheduler", - "task_create_missing_queues": True, - "task_default_queue": "authentik", - "broker_url": CONFIG.get("broker.url") or redis_url(CONFIG.get("redis.db")), - "result_backend": CONFIG.get("result_backend.url") or redis_url(CONFIG.get("redis.db")), - "broker_transport_options": CONFIG.get_dict_from_b64_json( - "broker.transport_options", {"retry_policy": {"timeout": 5.0}} + "worker": { + "processes": CONFIG.get_int("worker.processes", 2), + "threads": CONFIG.get_int("worker.threads", 1), + "consumer_listen_timeout": timedelta_from_string( + CONFIG.get("worker.consumer_listen_timeout") + ).total_seconds(), + "watch_folder": BASE_DIR / "authentik", + }, + "scheduler_class": "authentik.tasks.schedules.scheduler.Scheduler", + "schedule_model": "authentik.tasks.schedules.models.Schedule", + "scheduler_interval": timedelta_from_string( + CONFIG.get("worker.scheduler_interval") + ).total_seconds(), + "middlewares": ( + ("django_dramatiq_postgres.middleware.FullyQualifiedActorName", {}), + # TODO: fixme + # ("dramatiq.middleware.prometheus.Prometheus", {}), + ("django_dramatiq_postgres.middleware.DbConnectionMiddleware", {}), + ("dramatiq.middleware.age_limit.AgeLimit", {}), + ( + "dramatiq.middleware.time_limit.TimeLimit", + { + "time_limit": timedelta_from_string( + CONFIG.get("worker.task_default_time_limit") + ).total_seconds() + * 1000 + }, + ), + ("dramatiq.middleware.shutdown.ShutdownNotifications", {}), + ("dramatiq.middleware.callbacks.Callbacks", {}), + ("dramatiq.middleware.pipelines.Pipelines", {}), + ( + "dramatiq.middleware.retries.Retries", + {"max_retries": CONFIG.get_int("worker.task_max_retries") if not TEST else 0}, + ), + ("dramatiq.results.middleware.Results", {"store_results": True}), + ("django_dramatiq_postgres.middleware.CurrentTask", {}), + ("authentik.tasks.middleware.TenantMiddleware", {}), + ("authentik.tasks.middleware.RelObjMiddleware", {}), + ("authentik.tasks.middleware.MessagesMiddleware", {}), + ("authentik.tasks.middleware.LoggingMiddleware", {}), + ("authentik.tasks.middleware.DescriptionMiddleware", {}), + ("authentik.tasks.middleware.WorkerHealthcheckMiddleware", {}), + ("authentik.tasks.middleware.WorkerStatusMiddleware", {}), + ( + "authentik.tasks.middleware.MetricsMiddleware", + { + "multiproc_dir": str(Path(gettempdir()) / "authentik_prometheus_tmp"), + "prefix": "authentik", + }, + ), ), - "result_backend_transport_options": CONFIG.get_dict_from_b64_json( - "result_backend.transport_options", {"retry_policy": {"timeout": 5.0}} - ), - "redis_retry_on_timeout": True, + "test": TEST, } + # Sentry integration + env = get_env() _ERROR_REPORTING = CONFIG.get_bool("error_reporting.enabled", False) if _ERROR_REPORTING: @@ -433,9 +488,6 @@ else: MEDIA_ROOT = STORAGES["default"]["OPTIONS"]["location"] MEDIA_URL = STORAGES["default"]["OPTIONS"]["base_url"] -TEST = False -TEST_RUNNER = "authentik.root.test_runner.PytestTestRunner" - structlog_configure() LOGGING = get_logger_config() @@ -446,7 +498,6 @@ _DISALLOWED_ITEMS = [ "INSTALLED_APPS", "MIDDLEWARE", "AUTHENTICATION_BACKENDS", - "CELERY", "SPECTACULAR_SETTINGS", "REST_FRAMEWORK", ] @@ -473,7 +524,6 @@ def _update_settings(app_path: str): AUTHENTICATION_BACKENDS.extend(getattr(settings_module, "AUTHENTICATION_BACKENDS", [])) SPECTACULAR_SETTINGS.update(getattr(settings_module, "SPECTACULAR_SETTINGS", {})) REST_FRAMEWORK.update(getattr(settings_module, "REST_FRAMEWORK", {})) - CELERY["beat_schedule"].update(getattr(settings_module, "CELERY_BEAT_SCHEDULE", {})) for _attr in dir(settings_module): if not _attr.startswith("__") and _attr not in _DISALLOWED_ITEMS: globals()[_attr] = getattr(settings_module, _attr) @@ -482,7 +532,6 @@ def _update_settings(app_path: str): if DEBUG: - CELERY["task_always_eager"] = True REST_FRAMEWORK["DEFAULT_RENDERER_CLASSES"].append( "rest_framework.renderers.BrowsableAPIRenderer" ) @@ -502,10 +551,6 @@ try: except ImportError: pass -# Import events after other apps since it relies on tasks and other things from all apps -# being imported for @prefill_task -TENANT_APPS.append("authentik.events") - # Load subapps's settings for _app in set(SHARED_APPS + TENANT_APPS): diff --git a/authentik/root/test_runner.py b/authentik/root/test_runner.py index 75a5924158..71ba33c7c9 100644 --- a/authentik/root/test_runner.py +++ b/authentik/root/test_runner.py @@ -16,6 +16,7 @@ from authentik.events.context_processors.geoip import GEOIP_CONTEXT_PROCESSOR from authentik.lib.config import CONFIG from authentik.lib.sentry import sentry_init from authentik.root.signals import post_startup, pre_startup, startup +from authentik.tasks.test import use_test_broker # globally set maxDiff to none to show full assert error TestCase.maxDiff = None @@ -60,7 +61,7 @@ class PytestTestRunner(DiscoverRunner): # pragma: no cover def _setup_test_environment(self): """Configure test environment settings""" settings.TEST = True - settings.CELERY["task_always_eager"] = True + settings.DRAMATIQ["test"] = True # Test-specific configuration test_config = { @@ -84,6 +85,8 @@ class PytestTestRunner(DiscoverRunner): # pragma: no cover sentry_init() self.logger.debug("Test environment configured") + use_test_broker() + # Send startup signals pre_startup.send(sender=self, mode="test") startup.send(sender=self, mode="test") diff --git a/authentik/sources/kerberos/api/source.py b/authentik/sources/kerberos/api/source.py index 9b8120e2e0..9d77a74bd2 100644 --- a/authentik/sources/kerberos/api/source.py +++ b/authentik/sources/kerberos/api/source.py @@ -2,19 +2,19 @@ from django.core.cache import cache from drf_spectacular.utils import extend_schema -from guardian.shortcuts import get_objects_for_user from rest_framework.decorators import action -from rest_framework.fields import BooleanField, SerializerMethodField +from rest_framework.fields import SerializerMethodField from rest_framework.request import Request from rest_framework.response import Response from rest_framework.viewsets import ModelViewSet from authentik.core.api.sources import SourceSerializer from authentik.core.api.used_by import UsedByMixin -from authentik.core.api.utils import PassiveSerializer -from authentik.events.api.tasks import SystemTaskSerializer +from authentik.lib.sync.api import SyncStatusSerializer +from authentik.rbac.filters import ObjectFilter from authentik.sources.kerberos.models import KerberosSource -from authentik.sources.kerberos.tasks import CACHE_KEY_STATUS +from authentik.sources.kerberos.tasks import CACHE_KEY_STATUS, kerberos_sync +from authentik.tasks.models import Task, TaskStatus class KerberosSourceSerializer(SourceSerializer): @@ -52,13 +52,6 @@ class KerberosSourceSerializer(SourceSerializer): } -class KerberosSyncStatusSerializer(PassiveSerializer): - """Kerberos Source sync status""" - - is_running = BooleanField(read_only=True) - tasks = SystemTaskSerializer(many=True, read_only=True) - - class KerberosSourceViewSet(UsedByMixin, ModelViewSet): """Kerberos Source Viewset""" @@ -88,30 +81,48 @@ class KerberosSourceViewSet(UsedByMixin, ModelViewSet): ] ordering = ["name"] - @extend_schema( - responses={ - 200: KerberosSyncStatusSerializer(), - } - ) + @extend_schema(responses={200: SyncStatusSerializer()}) @action( methods=["GET"], detail=True, pagination_class=None, url_path="sync/status", - filter_backends=[], + filter_backends=[ObjectFilter], ) def sync_status(self, request: Request, slug: str) -> Response: - """Get source's sync status""" + """Get provider's sync status""" source: KerberosSource = self.get_object() - tasks = list( - get_objects_for_user(request.user, "authentik_events.view_systemtask").filter( - name="kerberos_sync", - uid__startswith=source.slug, - ) - ) + + status = {} + with source.sync_lock as lock_acquired: - status = { - "tasks": tasks, - "is_running": not lock_acquired, - } - return Response(KerberosSyncStatusSerializer(status).data) + # If we could not acquire the lock, it means a task is using it, and thus is running + status["is_running"] = not lock_acquired + + sync_schedule = None + for schedule in source.schedules.all(): + if schedule.actor_name == kerberos_sync.actor_name: + sync_schedule = schedule + + if not sync_schedule: + return Response(SyncStatusSerializer(status).data) + + last_task: Task = ( + sync_schedule.tasks.exclude( + aggregated_status__in=(TaskStatus.CONSUMED, TaskStatus.QUEUED) + ) + .order_by("-mtime") + .first() + ) + last_successful_task: Task = ( + sync_schedule.tasks.filter(aggregated_status__in=(TaskStatus.DONE, TaskStatus.INFO)) + .order_by("-mtime") + .first() + ) + + if last_task: + status["last_sync_status"] = last_task.aggregated_status + if last_successful_task: + status["last_successful_sync"] = last_successful_task.mtime + + return Response(SyncStatusSerializer(status).data) diff --git a/authentik/sources/kerberos/models.py b/authentik/sources/kerberos/models.py index 50fc0c1ccb..d05d71b3d4 100644 --- a/authentik/sources/kerberos/models.py +++ b/authentik/sources/kerberos/models.py @@ -28,6 +28,9 @@ from authentik.core.models import ( ) from authentik.core.types import UILoginButton, UserSettingSerializer from authentik.flows.challenge import RedirectChallenge +from authentik.lib.utils.time import fqdn_rand +from authentik.tasks.schedules.common import ScheduleSpec +from authentik.tasks.schedules.models import ScheduledModel LOGGER = get_logger() @@ -43,7 +46,7 @@ class KAdminType(models.TextChoices): OTHER = "other" -class KerberosSource(Source): +class KerberosSource(ScheduledModel, Source): """Federate Kerberos realm with authentik""" realm = models.TextField(help_text=_("Kerberos realm"), unique=True) @@ -135,6 +138,27 @@ class KerberosSource(Source): return static("authentik/sources/kerberos.png") return icon + @property + def schedule_specs(self) -> list[ScheduleSpec]: + from authentik.sources.kerberos.tasks import kerberos_connectivity_check, kerberos_sync + + return [ + ScheduleSpec( + actor=kerberos_sync, + uid=self.slug, + args=(self.pk,), + crontab=f"{fqdn_rand('kerberos_sync/' + str(self.pk))} */2 * * *", + send_on_save=True, + ), + ScheduleSpec( + actor=kerberos_connectivity_check, + uid=self.slug, + args=(self.pk,), + crontab=f"{fqdn_rand('kerberos_connectivity_check/' + str(self.pk))} * * * *", + send_on_save=True, + ), + ] + def ui_login_button(self, request: HttpRequest) -> UILoginButton: return UILoginButton( challenge=RedirectChallenge( diff --git a/authentik/sources/kerberos/settings.py b/authentik/sources/kerberos/settings.py deleted file mode 100644 index 2eac46c175..0000000000 --- a/authentik/sources/kerberos/settings.py +++ /dev/null @@ -1,18 +0,0 @@ -"""LDAP Settings""" - -from celery.schedules import crontab - -from authentik.lib.utils.time import fqdn_rand - -CELERY_BEAT_SCHEDULE = { - "sources_kerberos_sync": { - "task": "authentik.sources.kerberos.tasks.kerberos_sync_all", - "schedule": crontab(minute=fqdn_rand("sources_kerberos_sync"), hour="*/2"), - "options": {"queue": "authentik_scheduled"}, - }, - "sources_kerberos_connectivity_check": { - "task": "authentik.sources.kerberos.tasks.kerberos_connectivity_check", - "schedule": crontab(minute=fqdn_rand("sources_kerberos_connectivity_check"), hour="*"), - "options": {"queue": "authentik_scheduled"}, - }, -} diff --git a/authentik/sources/kerberos/signals.py b/authentik/sources/kerberos/signals.py index a60eb09db3..3b1bc99ec6 100644 --- a/authentik/sources/kerberos/signals.py +++ b/authentik/sources/kerberos/signals.py @@ -1,6 +1,5 @@ """authentik kerberos source signals""" -from django.db.models.signals import post_save from django.dispatch import receiver from kadmin.exceptions import PyKAdminException from rest_framework.serializers import ValidationError @@ -10,24 +9,13 @@ from authentik.core.models import User from authentik.core.signals import password_changed from authentik.events.models import Event, EventAction from authentik.sources.kerberos.models import ( - KerberosSource, Krb5ConfContext, UserKerberosSourceConnection, ) -from authentik.sources.kerberos.tasks import kerberos_connectivity_check, kerberos_sync_single LOGGER = get_logger() -@receiver(post_save, sender=KerberosSource) -def sync_kerberos_source_on_save(sender, instance: KerberosSource, **_): - """Ensure that source is synced on save (if enabled)""" - if not instance.enabled or not instance.sync_users: - return - kerberos_sync_single.delay(instance.pk) - kerberos_connectivity_check.delay(instance.pk) - - @receiver(password_changed) def kerberos_sync_password(sender, user: User, password: str, **_): """Connect to kerberos and update password.""" @@ -55,8 +43,7 @@ def kerberos_sync_password(sender, user: User, password: str, **_): Event.new( EventAction.CONFIGURATION_ERROR, message=( - "Failed to change password in Kerberos source due to remote error: " - f"{exc}" + f"Failed to change password in Kerberos source due to remote error: {exc}" ), source=source, ).set_user(user).save() diff --git a/authentik/sources/kerberos/sync.py b/authentik/sources/kerberos/sync.py index 492c706a43..b6a88172a4 100644 --- a/authentik/sources/kerberos/sync.py +++ b/authentik/sources/kerberos/sync.py @@ -23,12 +23,14 @@ from authentik.sources.kerberos.models import ( Krb5ConfContext, UserKerberosSourceConnection, ) +from authentik.tasks.models import Task class KerberosSync: """Sync Kerberos users into authentik""" _source: KerberosSource + _task: Task _logger: BoundLogger _connection: KAdmin mapper: SourceMapper @@ -36,11 +38,11 @@ class KerberosSync: group_manager: PropertyMappingManager matcher: SourceMatcher - def __init__(self, source: KerberosSource): + def __init__(self, source: KerberosSource, task: Task): self._source = source + self._task = task with Krb5ConfContext(self._source): self._connection = self._source.connection() - self._messages = [] self._logger = get_logger().bind(source=self._source, syncer=self.__class__.__name__) self.mapper = SourceMapper(self._source) self.user_manager = self.mapper.get_manager(User, ["principal", "principal_obj"]) @@ -56,17 +58,6 @@ class KerberosSync: """UI name for the type of object this class synchronizes""" return "users" - @property - def messages(self) -> list[str]: - """Get all UI messages""" - return self._messages - - def message(self, *args, **kwargs): - """Add message that is later added to the System Task and shown to the user""" - formatted_message = " ".join(args) - self._messages.append(formatted_message) - self._logger.warning(*args, **kwargs) - def _handle_principal(self, principal: str) -> bool: try: # TODO: handle permission error @@ -163,7 +154,7 @@ class KerberosSync: def sync(self) -> int: """Iterate over all Kerberos users and create authentik_core.User instances""" if not self._source.enabled or not self._source.sync_users: - self.message("Source is disabled or user syncing is disabled for this Source") + self._task.info("Source is disabled or user syncing is disabled for this Source") return -1 user_count = 0 diff --git a/authentik/sources/kerberos/tasks.py b/authentik/sources/kerberos/tasks.py index 88bf4ec3e9..d56ad22bd3 100644 --- a/authentik/sources/kerberos/tasks.py +++ b/authentik/sources/kerberos/tasks.py @@ -1,67 +1,53 @@ """Kerberos Sync tasks""" from django.core.cache import cache +from django.utils.translation import gettext_lazy as _ +from django_dramatiq_postgres.middleware import CurrentTask +from dramatiq.actor import actor from structlog.stdlib import get_logger -from authentik.events.models import SystemTask as DBSystemTask -from authentik.events.models import TaskStatus -from authentik.events.system_tasks import SystemTask from authentik.lib.config import CONFIG from authentik.lib.sync.outgoing.exceptions import StopSync -from authentik.root.celery import CELERY_APP from authentik.sources.kerberos.models import KerberosSource from authentik.sources.kerberos.sync import KerberosSync +from authentik.tasks.models import Task LOGGER = get_logger() CACHE_KEY_STATUS = "goauthentik.io/sources/kerberos/status/" -@CELERY_APP.task() -def kerberos_sync_all(): - """Sync all sources""" - for source in KerberosSource.objects.filter(enabled=True, sync_users=True): - kerberos_sync_single.delay(str(source.pk)) - - -@CELERY_APP.task() -def kerberos_connectivity_check(pk: str | None = None): +@actor(description=_("Check connectivity for Kerberos sources.")) +def kerberos_connectivity_check(pk: str): """Check connectivity for Kerberos Sources""" # 2 hour timeout, this task should run every hour timeout = 60 * 60 * 2 - sources = KerberosSource.objects.filter(enabled=True) - if pk: - sources = sources.filter(pk=pk) - for source in sources: - status = source.check_connection() - cache.set(CACHE_KEY_STATUS + source.slug, status, timeout=timeout) + source = KerberosSource.objects.filter(enabled=True, pk=pk).first() + if not source: + return + status = source.check_connection() + cache.set(CACHE_KEY_STATUS + source.slug, status, timeout=timeout) -@CELERY_APP.task( - bind=True, - base=SystemTask, - # We take the configured hours timeout time by 2.5 as we run user and - # group in parallel and then membership, so 2x is to cover the serial tasks, - # and 0.5x on top of that to give some more leeway - soft_time_limit=(60 * 60 * CONFIG.get_int("sources.kerberos.task_timeout_hours")) * 2.5, - task_time_limit=(60 * 60 * CONFIG.get_int("sources.kerberos.task_timeout_hours")) * 2.5, +@actor( + time_limit=(60 * 60 * CONFIG.get_int("sources.kerberos.task_timeout_hours")) * 2.5 * 1000, + description=_("Sync Kerberos source."), ) -def kerberos_sync_single(self, source_pk: str): - """Sync a single source""" - source: KerberosSource = KerberosSource.objects.filter(pk=source_pk).first() - if not source or not source.enabled: +def kerberos_sync(pk: str): + self: Task = CurrentTask.get_task() + source: KerberosSource = KerberosSource.objects.filter(enabled=True, pk=pk).first() + if not source: return try: with source.sync_lock as lock_acquired: if not lock_acquired: + self.info("Synchronization is already running. Skipping") LOGGER.debug( "Failed to acquire lock for Kerberos sync, skipping task", source=source.slug ) return - # Delete all sync tasks from the cache - DBSystemTask.objects.filter(name="kerberos_sync", uid__startswith=source.slug).delete() - syncer = KerberosSync(source) + syncer = KerberosSync(source, self) syncer.sync() - self.set_status(TaskStatus.SUCCESSFUL, *syncer.messages) except StopSync as exc: LOGGER.warning("Error syncing kerberos", exc=exc, source=source) - self.set_error(exc) + self.error(exc) + raise exc diff --git a/authentik/sources/kerberos/tests/test_sync.py b/authentik/sources/kerberos/tests/test_sync.py index 546478acf2..ea571d01df 100644 --- a/authentik/sources/kerberos/tests/test_sync.py +++ b/authentik/sources/kerberos/tests/test_sync.py @@ -5,8 +5,9 @@ from authentik.core.models import User from authentik.lib.generators import generate_id from authentik.sources.kerberos.models import KerberosSource, KerberosSourcePropertyMapping from authentik.sources.kerberos.sync import KerberosSync -from authentik.sources.kerberos.tasks import kerberos_sync_all +from authentik.sources.kerberos.tasks import kerberos_sync from authentik.sources.kerberos.tests.utils import KerberosTestCase +from authentik.tasks.models import Task class TestKerberosSync(KerberosTestCase): @@ -31,7 +32,7 @@ class TestKerberosSync(KerberosTestCase): def test_default_mappings(self): """Test default mappings""" - KerberosSync(self.source).sync() + KerberosSync(self.source, Task()).sync() self.assertTrue( User.objects.filter(username=self.realm.user_princ.rsplit("@", 1)[0]).exists() @@ -54,7 +55,7 @@ class TestKerberosSync(KerberosTestCase): ) self.source.user_property_mappings.set([noop, email, dont_sync_service]) - KerberosSync(self.source).sync() + KerberosSync(self.source, Task()).sync() self.assertTrue( User.objects.filter(username=self.realm.user_princ.rsplit("@", 1)[0]).exists() @@ -69,7 +70,7 @@ class TestKerberosSync(KerberosTestCase): def test_tasks(self): """Test Scheduled tasks""" - kerberos_sync_all.delay().get() + kerberos_sync.send(self.source.pk) self.assertTrue( User.objects.filter(username=self.realm.user_princ.rsplit("@", 1)[0]).exists() ) diff --git a/authentik/sources/ldap/api.py b/authentik/sources/ldap/api.py index b453b80552..77be41a3cb 100644 --- a/authentik/sources/ldap/api.py +++ b/authentik/sources/ldap/api.py @@ -5,7 +5,6 @@ from typing import Any from django.core.cache import cache from django.utils.translation import gettext_lazy as _ from drf_spectacular.utils import extend_schema, inline_serializer -from guardian.shortcuts import get_objects_for_user from rest_framework.decorators import action from rest_framework.exceptions import ValidationError from rest_framework.fields import DictField, ListField, SerializerMethodField @@ -24,14 +23,16 @@ from authentik.core.api.sources import ( ) from authentik.core.api.used_by import UsedByMixin from authentik.crypto.models import CertificateKeyPair -from authentik.lib.sync.outgoing.api import SyncStatusSerializer +from authentik.lib.sync.api import SyncStatusSerializer +from authentik.rbac.filters import ObjectFilter from authentik.sources.ldap.models import ( GroupLDAPSourceConnection, LDAPSource, LDAPSourcePropertyMapping, UserLDAPSourceConnection, ) -from authentik.sources.ldap.tasks import CACHE_KEY_STATUS, SYNC_CLASSES +from authentik.sources.ldap.tasks import CACHE_KEY_STATUS, SYNC_CLASSES, ldap_sync +from authentik.tasks.models import Task, TaskStatus class LDAPSourceSerializer(SourceSerializer): @@ -155,33 +156,50 @@ class LDAPSourceViewSet(UsedByMixin, ModelViewSet): search_fields = ["name", "slug"] ordering = ["name"] - @extend_schema( - responses={ - 200: SyncStatusSerializer(), - } - ) + @extend_schema(responses={200: SyncStatusSerializer()}) @action( methods=["GET"], detail=True, pagination_class=None, url_path="sync/status", - filter_backends=[], + filter_backends=[ObjectFilter], ) def sync_status(self, request: Request, slug: str) -> Response: - """Get source's sync status""" + """Get provider's sync status""" source: LDAPSource = self.get_object() - tasks = list( - get_objects_for_user(request.user, "authentik_events.view_systemtask").filter( - name="ldap_sync", - uid__startswith=source.slug, - ) - ) + + status = {} + with source.sync_lock as lock_acquired: - status = { - "tasks": tasks, - # If we could not acquire the lock, it means a task is using it, and thus is running - "is_running": not lock_acquired, - } + # If we could not acquire the lock, it means a task is using it, and thus is running + status["is_running"] = not lock_acquired + + sync_schedule = None + for schedule in source.schedules.all(): + if schedule.actor_name == ldap_sync.actor_name: + sync_schedule = schedule + + if not sync_schedule: + return Response(SyncStatusSerializer(status).data) + + last_task: Task = ( + sync_schedule.tasks.exclude( + aggregated_status__in=(TaskStatus.CONSUMED, TaskStatus.QUEUED) + ) + .order_by("-mtime") + .first() + ) + last_successful_task: Task = ( + sync_schedule.tasks.filter(aggregated_status__in=(TaskStatus.DONE, TaskStatus.INFO)) + .order_by("-mtime") + .first() + ) + + if last_task: + status["last_sync_status"] = last_task.aggregated_status + if last_successful_task: + status["last_successful_sync"] = last_successful_task.mtime + return Response(SyncStatusSerializer(status).data) @extend_schema( diff --git a/authentik/sources/ldap/management/commands/ldap_sync.py b/authentik/sources/ldap/management/commands/ldap_sync.py index 22404ad49e..ed3952818a 100644 --- a/authentik/sources/ldap/management/commands/ldap_sync.py +++ b/authentik/sources/ldap/management/commands/ldap_sync.py @@ -3,10 +3,6 @@ from structlog.stdlib import get_logger from authentik.sources.ldap.models import LDAPSource -from authentik.sources.ldap.sync.groups import GroupLDAPSynchronizer -from authentik.sources.ldap.sync.membership import MembershipLDAPSynchronizer -from authentik.sources.ldap.sync.users import UserLDAPSynchronizer -from authentik.sources.ldap.tasks import ldap_sync_paginator from authentik.tenants.management import TenantCommand LOGGER = get_logger() @@ -24,10 +20,5 @@ class Command(TenantCommand): if not source: LOGGER.warning("Source does not exist", slug=source_slug) continue - tasks = ( - ldap_sync_paginator(source, UserLDAPSynchronizer) - + ldap_sync_paginator(source, GroupLDAPSynchronizer) - + ldap_sync_paginator(source, MembershipLDAPSynchronizer) - ) - for task in tasks: - task() + for schedule in source.schedules.all(): + schedule.send().get_result() diff --git a/authentik/sources/ldap/models.py b/authentik/sources/ldap/models.py index 975c019cbd..37787a3163 100644 --- a/authentik/sources/ldap/models.py +++ b/authentik/sources/ldap/models.py @@ -25,6 +25,9 @@ from authentik.core.models import ( from authentik.crypto.models import CertificateKeyPair from authentik.lib.config import CONFIG from authentik.lib.models import DomainlessURLValidator +from authentik.lib.utils.time import fqdn_rand +from authentik.tasks.schedules.common import ScheduleSpec +from authentik.tasks.schedules.models import ScheduledModel LDAP_TIMEOUT = 15 LDAP_UNIQUENESS = "ldap_uniq" @@ -53,7 +56,7 @@ class MultiURLValidator(DomainlessURLValidator): super().__call__(value) -class LDAPSource(Source): +class LDAPSource(ScheduledModel, Source): """Federate LDAP Directory with authentik, or create new accounts in LDAP.""" server_uri = models.TextField( @@ -159,6 +162,27 @@ class LDAPSource(Source): return LDAPSourceSerializer + @property + def schedule_specs(self) -> list[ScheduleSpec]: + from authentik.sources.ldap.tasks import ldap_connectivity_check, ldap_sync + + return [ + ScheduleSpec( + actor=ldap_sync, + uid=self.slug, + args=(self.pk,), + crontab=f"{fqdn_rand('ldap_sync/' + str(self.pk))} */2 * * *", + send_on_save=True, + ), + ScheduleSpec( + actor=ldap_connectivity_check, + uid=self.slug, + args=(self.pk,), + crontab=f"{fqdn_rand('ldap_connectivity_check/' + str(self.pk))} * * * *", + send_on_save=True, + ), + ] + @property def property_mapping_type(self) -> "type[PropertyMapping]": from authentik.sources.ldap.models import LDAPSourcePropertyMapping diff --git a/authentik/sources/ldap/settings.py b/authentik/sources/ldap/settings.py deleted file mode 100644 index c82dbeb0cb..0000000000 --- a/authentik/sources/ldap/settings.py +++ /dev/null @@ -1,18 +0,0 @@ -"""LDAP Settings""" - -from celery.schedules import crontab - -from authentik.lib.utils.time import fqdn_rand - -CELERY_BEAT_SCHEDULE = { - "sources_ldap_sync": { - "task": "authentik.sources.ldap.tasks.ldap_sync_all", - "schedule": crontab(minute=fqdn_rand("sources_ldap_sync"), hour="*/2"), - "options": {"queue": "authentik_scheduled"}, - }, - "sources_ldap_connectivity_check": { - "task": "authentik.sources.ldap.tasks.ldap_connectivity_check", - "schedule": crontab(minute=fqdn_rand("sources_ldap_connectivity_check"), hour="*"), - "options": {"queue": "authentik_scheduled"}, - }, -} diff --git a/authentik/sources/ldap/signals.py b/authentik/sources/ldap/signals.py index a2bad559bd..7d7b725bd8 100644 --- a/authentik/sources/ldap/signals.py +++ b/authentik/sources/ldap/signals.py @@ -2,7 +2,6 @@ from typing import Any -from django.db.models.signals import post_save from django.dispatch import receiver from django.utils.translation import gettext_lazy as _ from ldap3.core.exceptions import LDAPOperationResult @@ -15,29 +14,11 @@ from authentik.events.models import Event, EventAction from authentik.flows.planner import PLAN_CONTEXT_PENDING_USER from authentik.sources.ldap.models import LDAPSource from authentik.sources.ldap.password import LDAPPasswordChanger -from authentik.sources.ldap.tasks import ldap_connectivity_check, ldap_sync_single from authentik.stages.prompt.signals import password_validate LOGGER = get_logger() -@receiver(post_save, sender=LDAPSource) -def sync_ldap_source_on_save(sender, instance: LDAPSource, **_): - """Ensure that source is synced on save (if enabled)""" - if not instance.enabled: - return - ldap_connectivity_check.delay(instance.pk) - # Don't sync sources when they don't have any property mappings. This will only happen if: - # - the user forgets to set them or - # - the source is newly created, this is the first save event - # and the mappings are created with an m2m event - if instance.sync_users and not instance.user_property_mappings.exists(): - return - if instance.sync_groups and not instance.group_property_mappings.exists(): - return - ldap_sync_single.delay(instance.pk) - - @receiver(password_validate) def ldap_password_validate(sender, password: str, plan_context: dict[str, Any], **__): """if there's an LDAP Source with enabled password sync, check the password""" diff --git a/authentik/sources/ldap/sync/base.py b/authentik/sources/ldap/sync/base.py index 3d2498b41b..931313caed 100644 --- a/authentik/sources/ldap/sync/base.py +++ b/authentik/sources/ldap/sync/base.py @@ -10,22 +10,23 @@ from authentik.core.sources.mapper import SourceMapper from authentik.lib.config import CONFIG from authentik.lib.sync.mapper import PropertyMappingManager from authentik.sources.ldap.models import LDAPSource, flatten +from authentik.tasks.models import Task class BaseLDAPSynchronizer: """Sync LDAP Users and groups into authentik""" _source: LDAPSource + _task: Task _logger: BoundLogger _connection: Connection - _messages: list[str] mapper: SourceMapper manager: PropertyMappingManager - def __init__(self, source: LDAPSource): + def __init__(self, source: LDAPSource, task: Task): self._source = source + self._task = task self._connection = source.connection() - self._messages = [] self._logger = get_logger().bind(source=source, syncer=self.__class__.__name__) @staticmethod @@ -46,11 +47,6 @@ class BaseLDAPSynchronizer: """Sync function, implemented in subclass""" raise NotImplementedError() - @property - def messages(self) -> list[str]: - """Get all UI messages""" - return self._messages - @property def base_dn_users(self) -> str: """Shortcut to get full base_dn for user lookups""" @@ -65,14 +61,6 @@ class BaseLDAPSynchronizer: return f"{self._source.additional_group_dn},{self._source.base_dn}" return self._source.base_dn - def message(self, *args, **kwargs): - """Add message that is later added to the System Task and shown to the user""" - formatted_message = " ".join(args) - if "dn" in kwargs: - formatted_message += f"; DN: {kwargs['dn']}" - self._messages.append(formatted_message) - self._logger.warning(*args, **kwargs) - def get_objects(self, **kwargs) -> Generator: """Get objects from LDAP, implemented in subclass""" raise NotImplementedError() diff --git a/authentik/sources/ldap/sync/forward_delete_groups.py b/authentik/sources/ldap/sync/forward_delete_groups.py index 875601162d..515ae2171e 100644 --- a/authentik/sources/ldap/sync/forward_delete_groups.py +++ b/authentik/sources/ldap/sync/forward_delete_groups.py @@ -19,7 +19,7 @@ class GroupLDAPForwardDeletion(BaseLDAPSynchronizer): def get_objects(self, **kwargs) -> Generator: if not self._source.sync_groups or not self._source.delete_not_found_objects: - self.message("Group syncing is disabled for this Source") + self._task.info("Group syncing is disabled for this Source") return iter(()) uuid = uuid4() @@ -54,7 +54,7 @@ class GroupLDAPForwardDeletion(BaseLDAPSynchronizer): def sync(self, group_pks: tuple) -> int: """Delete authentik groups""" if not self._source.sync_groups or not self._source.delete_not_found_objects: - self.message("Group syncing is disabled for this Source") + self._task.info("Group syncing is disabled for this Source") return -1 self._logger.debug("Deleting groups", group_pks=group_pks) _, deleted_per_type = Group.objects.filter(pk__in=group_pks).delete() diff --git a/authentik/sources/ldap/sync/forward_delete_users.py b/authentik/sources/ldap/sync/forward_delete_users.py index 2ea81cc735..6e04f5b15a 100644 --- a/authentik/sources/ldap/sync/forward_delete_users.py +++ b/authentik/sources/ldap/sync/forward_delete_users.py @@ -21,7 +21,7 @@ class UserLDAPForwardDeletion(BaseLDAPSynchronizer): def get_objects(self, **kwargs) -> Generator: if not self._source.sync_users or not self._source.delete_not_found_objects: - self.message("User syncing is disabled for this Source") + self._task.info("User syncing is disabled for this Source") return iter(()) uuid = uuid4() @@ -56,7 +56,7 @@ class UserLDAPForwardDeletion(BaseLDAPSynchronizer): def sync(self, user_pks: tuple) -> int: """Delete authentik users""" if not self._source.sync_users or not self._source.delete_not_found_objects: - self.message("User syncing is disabled for this Source") + self._task.info("User syncing is disabled for this Source") return -1 self._logger.debug("Deleting users", user_pks=user_pks) _, deleted_per_type = User.objects.filter(pk__in=user_pks).delete() diff --git a/authentik/sources/ldap/sync/groups.py b/authentik/sources/ldap/sync/groups.py index 3119b7905d..dce6ffebec 100644 --- a/authentik/sources/ldap/sync/groups.py +++ b/authentik/sources/ldap/sync/groups.py @@ -21,13 +21,15 @@ from authentik.sources.ldap.models import ( flatten, ) from authentik.sources.ldap.sync.base import BaseLDAPSynchronizer +from authentik.tasks.models import Task class GroupLDAPSynchronizer(BaseLDAPSynchronizer): """Sync LDAP Users and groups into authentik""" - def __init__(self, source: LDAPSource): - super().__init__(source) + def __init__(self, source: LDAPSource, task: Task): + super().__init__(source, task) + self._source = source self.mapper = SourceMapper(source) self.manager = self.mapper.get_manager(Group, ["ldap", "dn"]) @@ -37,7 +39,7 @@ class GroupLDAPSynchronizer(BaseLDAPSynchronizer): def get_objects(self, **kwargs) -> Generator: if not self._source.sync_groups: - self.message("Group syncing is disabled for this Source") + self._task.info("Group syncing is disabled for this Source") return iter(()) return self.search_paginator( search_base=self.base_dn_groups, @@ -54,7 +56,7 @@ class GroupLDAPSynchronizer(BaseLDAPSynchronizer): def sync(self, page_data: list) -> int: """Iterate over all LDAP Groups and create authentik_core.Group instances""" if not self._source.sync_groups: - self.message("Group syncing is disabled for this Source") + self._task.info("Group syncing is disabled for this Source") return -1 group_count = 0 for group in page_data: @@ -62,9 +64,9 @@ class GroupLDAPSynchronizer(BaseLDAPSynchronizer): continue group_dn = flatten(flatten(group.get("entryDN", group.get("dn")))) if not (uniq := self.get_identifier(attributes)): - self.message( + self._task.info( f"Uniqueness field not found/not set in attributes: '{group_dn}'", - attributes=attributes.keys(), + attributes=list(attributes.keys()), dn=group_dn, ) continue diff --git a/authentik/sources/ldap/sync/membership.py b/authentik/sources/ldap/sync/membership.py index 277cd90ea9..7852deff50 100644 --- a/authentik/sources/ldap/sync/membership.py +++ b/authentik/sources/ldap/sync/membership.py @@ -9,6 +9,7 @@ from ldap3 import SUBTREE from authentik.core.models import Group, User from authentik.sources.ldap.models import LDAP_DISTINGUISHED_NAME, LDAP_UNIQUENESS, LDAPSource from authentik.sources.ldap.sync.base import BaseLDAPSynchronizer +from authentik.tasks.models import Task class MembershipLDAPSynchronizer(BaseLDAPSynchronizer): @@ -16,8 +17,8 @@ class MembershipLDAPSynchronizer(BaseLDAPSynchronizer): group_cache: dict[str, Group] - def __init__(self, source: LDAPSource): - super().__init__(source) + def __init__(self, source: LDAPSource, task: Task): + super().__init__(source, task) self.group_cache: dict[str, Group] = {} @staticmethod @@ -26,7 +27,7 @@ class MembershipLDAPSynchronizer(BaseLDAPSynchronizer): def get_objects(self, **kwargs) -> Generator: if not self._source.sync_groups: - self.message("Group syncing is disabled for this Source") + self._task.info("Group syncing is disabled for this Source") return iter(()) # If we are looking up groups from users, we don't need to fetch the group membership field @@ -45,7 +46,7 @@ class MembershipLDAPSynchronizer(BaseLDAPSynchronizer): def sync(self, page_data: list) -> int: """Iterate over all Users and assign Groups using memberOf Field""" if not self._source.sync_groups: - self.message("Group syncing is disabled for this Source") + self._task.info("Group syncing is disabled for this Source") return -1 membership_count = 0 for group in page_data: @@ -94,7 +95,7 @@ class MembershipLDAPSynchronizer(BaseLDAPSynchronizer): # group_uniq might be a single string or an array with (hopefully) a single string if isinstance(group_uniq, list): if len(group_uniq) < 1: - self.message( + self._task.info( f"Group does not have a uniqueness attribute: '{group_dn}'", group=group_dn, ) @@ -104,7 +105,7 @@ class MembershipLDAPSynchronizer(BaseLDAPSynchronizer): groups = Group.objects.filter(**{f"attributes__{LDAP_UNIQUENESS}": group_uniq}) if not groups.exists(): if self._source.sync_groups: - self.message( + self._task.info( f"Group does not exist in our DB yet, run sync_groups first: '{group_dn}'", group=group_dn, ) diff --git a/authentik/sources/ldap/sync/users.py b/authentik/sources/ldap/sync/users.py index f936b04b0b..a23e456f1a 100644 --- a/authentik/sources/ldap/sync/users.py +++ b/authentik/sources/ldap/sync/users.py @@ -23,13 +23,14 @@ from authentik.sources.ldap.models import ( from authentik.sources.ldap.sync.base import BaseLDAPSynchronizer from authentik.sources.ldap.sync.vendor.freeipa import FreeIPA from authentik.sources.ldap.sync.vendor.ms_ad import MicrosoftActiveDirectory +from authentik.tasks.models import Task class UserLDAPSynchronizer(BaseLDAPSynchronizer): """Sync LDAP Users into authentik""" - def __init__(self, source: LDAPSource): - super().__init__(source) + def __init__(self, source: LDAPSource, task: Task): + super().__init__(source, task) self.mapper = SourceMapper(source) self.manager = self.mapper.get_manager(User, ["ldap", "dn"]) @@ -39,7 +40,7 @@ class UserLDAPSynchronizer(BaseLDAPSynchronizer): def get_objects(self, **kwargs) -> Generator: if not self._source.sync_users: - self.message("User syncing is disabled for this Source") + self._task.info("User syncing is disabled for this Source") return iter(()) return self.search_paginator( search_base=self.base_dn_users, @@ -56,7 +57,7 @@ class UserLDAPSynchronizer(BaseLDAPSynchronizer): def sync(self, page_data: list) -> int: """Iterate over all LDAP Users and create authentik_core.User instances""" if not self._source.sync_users: - self.message("User syncing is disabled for this Source") + self._task.info("User syncing is disabled for this Source") return -1 user_count = 0 for user in page_data: @@ -64,9 +65,9 @@ class UserLDAPSynchronizer(BaseLDAPSynchronizer): continue user_dn = flatten(user.get("entryDN", user.get("dn"))) if not (uniq := self.get_identifier(attributes)): - self.message( + self._task.info( f"Uniqueness field not found/not set in attributes: '{user_dn}'", - attributes=attributes.keys(), + attributes=list(attributes.keys()), dn=user_dn, ) continue @@ -112,6 +113,8 @@ class UserLDAPSynchronizer(BaseLDAPSynchronizer): else: self._logger.debug("Synced User", user=ak_user.username, created=created) user_count += 1 - MicrosoftActiveDirectory(self._source).sync(attributes, ak_user, created) - FreeIPA(self._source).sync(attributes, ak_user, created) + MicrosoftActiveDirectory(self._source, self._task).sync( + attributes, ak_user, created + ) + FreeIPA(self._source, self._task).sync(attributes, ak_user, created) return user_count diff --git a/authentik/sources/ldap/sync/vendor/freeipa.py b/authentik/sources/ldap/sync/vendor/freeipa.py index 44e127e05a..3a356011ed 100644 --- a/authentik/sources/ldap/sync/vendor/freeipa.py +++ b/authentik/sources/ldap/sync/vendor/freeipa.py @@ -30,7 +30,7 @@ class FreeIPA(BaseLDAPSynchronizer): pwd_last_set: datetime = attributes.get("krbLastPwdChange", datetime.now()) pwd_last_set = pwd_last_set.replace(tzinfo=UTC) if created or pwd_last_set >= user.password_change_date: - self.message(f"'{user.username}': Reset user's password") + self._task.info(f"'{user.username}': Reset user's password") self._logger.debug( "Reset user's password", user=user.username, diff --git a/authentik/sources/ldap/sync/vendor/ms_ad.py b/authentik/sources/ldap/sync/vendor/ms_ad.py index fd02308973..1400c66006 100644 --- a/authentik/sources/ldap/sync/vendor/ms_ad.py +++ b/authentik/sources/ldap/sync/vendor/ms_ad.py @@ -60,7 +60,7 @@ class MicrosoftActiveDirectory(BaseLDAPSynchronizer): pwd_last_set: datetime = attributes.get("pwdLastSet", datetime.now()) pwd_last_set = pwd_last_set.replace(tzinfo=UTC) if created or pwd_last_set >= user.password_change_date: - self.message(f"'{user.username}': Reset user's password") + self._task.info(f"'{user.username}': Reset user's password") self._logger.debug( "Reset user's password", user=user.username, diff --git a/authentik/sources/ldap/tasks.py b/authentik/sources/ldap/tasks.py index f800bd5abe..d0c9f35432 100644 --- a/authentik/sources/ldap/tasks.py +++ b/authentik/sources/ldap/tasks.py @@ -2,18 +2,18 @@ from uuid import uuid4 -from celery import chain, group from django.core.cache import cache +from django.utils.translation import gettext_lazy as _ +from django_dramatiq_postgres.middleware import CurrentTask +from dramatiq.actor import actor +from dramatiq.composition import group +from dramatiq.message import Message from ldap3.core.exceptions import LDAPException from structlog.stdlib import get_logger -from authentik.events.models import SystemTask as DBSystemTask -from authentik.events.models import TaskStatus -from authentik.events.system_tasks import SystemTask from authentik.lib.config import CONFIG from authentik.lib.sync.outgoing.exceptions import StopSync from authentik.lib.utils.reflection import class_to_path, path_to_class -from authentik.root.celery import CELERY_APP from authentik.sources.ldap.models import LDAPSource from authentik.sources.ldap.sync.base import BaseLDAPSynchronizer from authentik.sources.ldap.sync.forward_delete_groups import GroupLDAPForwardDeletion @@ -21,6 +21,7 @@ from authentik.sources.ldap.sync.forward_delete_users import UserLDAPForwardDele from authentik.sources.ldap.sync.groups import GroupLDAPSynchronizer from authentik.sources.ldap.sync.membership import MembershipLDAPSynchronizer from authentik.sources.ldap.sync.users import UserLDAPSynchronizer +from authentik.tasks.models import Task LOGGER = get_logger() SYNC_CLASSES = [ @@ -32,92 +33,101 @@ CACHE_KEY_PREFIX = "goauthentik.io/sources/ldap/page/" CACHE_KEY_STATUS = "goauthentik.io/sources/ldap/status/" -@CELERY_APP.task() -def ldap_sync_all(): - """Sync all sources""" - for source in LDAPSource.objects.filter(enabled=True): - ldap_sync_single.apply_async(args=[str(source.pk)]) - - -@CELERY_APP.task() +@actor(description=_("Check connectivity for LDAP source.")) def ldap_connectivity_check(pk: str | None = None): """Check connectivity for LDAP Sources""" - # 2 hour timeout, this task should run every hour timeout = 60 * 60 * 2 - sources = LDAPSource.objects.filter(enabled=True) - if pk: - sources = sources.filter(pk=pk) - for source in sources: - status = source.check_connection() - cache.set(CACHE_KEY_STATUS + source.slug, status, timeout=timeout) + source = LDAPSource.objects.filter(pk=pk, enabled=True).first() + if not source: + return + status = source.check_connection() + cache.set(CACHE_KEY_STATUS + source.slug, status, timeout=timeout) -@CELERY_APP.task( +@actor( # We take the configured hours timeout time by 3.5 as we run user and # group in parallel and then membership, then deletions, so 3x is to cover the serial tasks, # and 0.5x on top of that to give some more leeway - soft_time_limit=(60 * 60 * CONFIG.get_int("ldap.task_timeout_hours")) * 3.5, - task_time_limit=(60 * 60 * CONFIG.get_int("ldap.task_timeout_hours")) * 3.5, + time_limit=(60 * 60 * CONFIG.get_int("ldap.task_timeout_hours") * 1000) * 3.5, + description=_("Sync LDAP source."), ) -def ldap_sync_single(source_pk: str): +def ldap_sync(source_pk: str): """Sync a single source""" - source: LDAPSource = LDAPSource.objects.filter(pk=source_pk).first() + task: Task = CurrentTask.get_task() + source: LDAPSource = LDAPSource.objects.filter(pk=source_pk, enabled=True).first() if not source: return + task.set_uid(f"{source.slug}") with source.sync_lock as lock_acquired: if not lock_acquired: + task.info("Synchronization is already running. Skipping") LOGGER.debug("Failed to acquire lock for LDAP sync, skipping task", source=source.slug) return - # Delete all sync tasks from the cache - DBSystemTask.objects.filter(name="ldap_sync", uid__startswith=source.slug).delete() - # The order of these operations needs to be preserved as each depends on the previous one(s) - # 1. User and group sync can happen simultaneously - # 2. Membership sync needs to run afterwards - # 3. Finally, user and group deletions can happen simultaneously - user_group_sync = ldap_sync_paginator(source, UserLDAPSynchronizer) + ldap_sync_paginator( - source, GroupLDAPSynchronizer + user_group_tasks = group( + ldap_sync_paginator(task, source, UserLDAPSynchronizer) + + ldap_sync_paginator(task, source, GroupLDAPSynchronizer) ) - membership_sync = ldap_sync_paginator(source, MembershipLDAPSynchronizer) - user_group_deletion = ldap_sync_paginator( - source, UserLDAPForwardDeletion - ) + ldap_sync_paginator(source, GroupLDAPForwardDeletion) - # Celery is buggy with empty groups, so we are careful only to add non-empty groups. - # See https://github.com/celery/celery/issues/9772 - task_groups = [] - if user_group_sync: - task_groups.append(group(user_group_sync)) - if membership_sync: - task_groups.append(group(membership_sync)) - if user_group_deletion: - task_groups.append(group(user_group_deletion)) + membership_tasks = group(ldap_sync_paginator(task, source, MembershipLDAPSynchronizer)) - all_tasks = chain(task_groups) - all_tasks() + deletion_tasks = group( + ldap_sync_paginator(task, source, UserLDAPForwardDeletion) + + ldap_sync_paginator(task, source, GroupLDAPForwardDeletion), + ) + + # User and group sync can happen at once, they have no dependencies on each other + user_group_tasks.run().wait( + timeout=60 * 60 * CONFIG.get_int("ldap.task_timeout_hours") * 1000 + ) + # Membership sync needs to run afterwards + membership_tasks.run().wait( + timeout=60 * 60 * CONFIG.get_int("ldap.task_timeout_hours") * 1000 + ) + # Finally, deletions. What we'd really like to do here is something like + # ``` + # user_identifiers = + # User.objects.exclude( + # usersourceconnection__identifier__in=user_uniqueness_identifiers, + # ).delete() + # ``` + # This runs into performance issues in large installations. So instead we spread the + # work out into three steps: + # 1. Get every object from the LDAP source. + # 2. Mark every object as "safe" in the database. This is quick, but any error could + # mean deleting users which should not be deleted, so we do it immediately, in + # large chunks, and only queue the deletion step afterwards. + # 3. Delete every unmarked item. This is slow, so we spread it over many tasks in + # small chunks. + deletion_tasks.run().wait( + timeout=60 * 60 * CONFIG.get_int("ldap.task_timeout_hours") * 1000, + ) -def ldap_sync_paginator(source: LDAPSource, sync: type[BaseLDAPSynchronizer]) -> list: +def ldap_sync_paginator( + task: Task, source: LDAPSource, sync: type[BaseLDAPSynchronizer] +) -> list[Message]: """Return a list of task signatures with LDAP pagination data""" - sync_inst: BaseLDAPSynchronizer = sync(source) - signatures = [] + sync_inst: BaseLDAPSynchronizer = sync(source, task) + messages = [] for page in sync_inst.get_objects(): page_cache_key = CACHE_KEY_PREFIX + str(uuid4()) cache.set(page_cache_key, page, 60 * 60 * CONFIG.get_int("ldap.task_timeout_hours")) - page_sync = ldap_sync.si(str(source.pk), class_to_path(sync), page_cache_key) - signatures.append(page_sync) - return signatures + page_sync = ldap_sync_page.message_with_options( + args=(source.pk, class_to_path(sync), page_cache_key), + rel_obj=task.rel_obj, + ) + messages.append(page_sync) + return messages -@CELERY_APP.task( - bind=True, - base=SystemTask, - soft_time_limit=60 * 60 * CONFIG.get_int("ldap.task_timeout_hours"), - task_time_limit=60 * 60 * CONFIG.get_int("ldap.task_timeout_hours"), +@actor( + time_limit=60 * 60 * CONFIG.get_int("ldap.task_timeout_hours") * 1000, + description=_("Sync page for LDAP source."), ) -def ldap_sync(self: SystemTask, source_pk: str, sync_class: str, page_cache_key: str): +def ldap_sync_page(source_pk: str, sync_class: str, page_cache_key: str): """Synchronization of an LDAP Source""" - self.result_timeout_hours = CONFIG.get_int("ldap.task_timeout_hours") + self: Task = CurrentTask.get_task() source: LDAPSource = LDAPSource.objects.filter(pk=source_pk).first() if not source: # Because the source couldn't be found, we don't have a UID @@ -127,7 +137,7 @@ def ldap_sync(self: SystemTask, source_pk: str, sync_class: str, page_cache_key: uid = page_cache_key.replace(CACHE_KEY_PREFIX, "") self.set_uid(f"{source.slug}:{sync.name()}:{uid}") try: - sync_inst: BaseLDAPSynchronizer = sync(source) + sync_inst: BaseLDAPSynchronizer = sync(source, self) page = cache.get(page_cache_key) if not page: error_message = ( @@ -135,18 +145,14 @@ def ldap_sync(self: SystemTask, source_pk: str, sync_class: str, page_cache_key: + "Try increasing ldap.task_timeout_hours" ) LOGGER.warning(error_message) - self.set_status(TaskStatus.ERROR, error_message) + self.error(error_message) return cache.touch(page_cache_key) count = sync_inst.sync(page) - messages = sync_inst.messages - messages.append(f"Synced {count} objects.") - self.set_status( - TaskStatus.SUCCESSFUL, - *messages, - ) + self.info(f"Synced {count} objects.") cache.delete(page_cache_key) except (LDAPException, StopSync) as exc: - # No explicit event is created here as .set_status with an error will do that + # No explicit event is created here as .error will do that LOGGER.warning("Failed to sync LDAP", exc=exc, source=source) - self.set_error(exc) + self.error(exc) + raise exc diff --git a/authentik/sources/ldap/tests/test_auth.py b/authentik/sources/ldap/tests/test_auth.py index 6dd06f8917..def896ab4d 100644 --- a/authentik/sources/ldap/tests/test_auth.py +++ b/authentik/sources/ldap/tests/test_auth.py @@ -13,6 +13,7 @@ from authentik.sources.ldap.models import LDAPSource, LDAPSourcePropertyMapping from authentik.sources.ldap.sync.users import UserLDAPSynchronizer from authentik.sources.ldap.tests.mock_ad import mock_ad_connection from authentik.sources.ldap.tests.mock_slapd import mock_slapd_connection +from authentik.tasks.models import Task LDAP_PASSWORD = generate_key() @@ -43,7 +44,7 @@ class LDAPSyncTests(TestCase): raw_conn.bind = bind_mock connection = MagicMock(return_value=raw_conn) with patch("authentik.sources.ldap.models.LDAPSource.connection", connection): - user_sync = UserLDAPSynchronizer(self.source) + user_sync = UserLDAPSynchronizer(self.source, Task()) user_sync.sync_full() user = User.objects.get(username="user0_sn") @@ -71,7 +72,7 @@ class LDAPSyncTests(TestCase): ) connection = MagicMock(return_value=mock_ad_connection(LDAP_PASSWORD)) with patch("authentik.sources.ldap.models.LDAPSource.connection", connection): - user_sync = UserLDAPSynchronizer(self.source) + user_sync = UserLDAPSynchronizer(self.source, Task()) user_sync.sync_full() user = User.objects.get(username="user0_sn") @@ -98,7 +99,7 @@ class LDAPSyncTests(TestCase): self.source.save() connection = MagicMock(return_value=mock_slapd_connection(LDAP_PASSWORD)) with patch("authentik.sources.ldap.models.LDAPSource.connection", connection): - user_sync = UserLDAPSynchronizer(self.source) + user_sync = UserLDAPSynchronizer(self.source, Task()) user_sync.sync_full() user = User.objects.get(username="user0_sn") diff --git a/authentik/sources/ldap/tests/test_sync.py b/authentik/sources/ldap/tests/test_sync.py index a6b3659360..b313efc669 100644 --- a/authentik/sources/ldap/tests/test_sync.py +++ b/authentik/sources/ldap/tests/test_sync.py @@ -8,8 +8,7 @@ from django.test import TestCase from authentik.blueprints.tests import apply_blueprint from authentik.core.models import Group, User from authentik.core.tests.utils import create_test_admin_user -from authentik.events.models import Event, EventAction, SystemTask -from authentik.events.system_tasks import TaskStatus +from authentik.events.models import Event, EventAction from authentik.lib.generators import generate_id, generate_key from authentik.lib.sync.outgoing.exceptions import StopSync from authentik.lib.utils.reflection import class_to_path @@ -23,7 +22,7 @@ from authentik.sources.ldap.sync.forward_delete_users import DELETE_CHUNK_SIZE from authentik.sources.ldap.sync.groups import GroupLDAPSynchronizer from authentik.sources.ldap.sync.membership import MembershipLDAPSynchronizer from authentik.sources.ldap.sync.users import UserLDAPSynchronizer -from authentik.sources.ldap.tasks import ldap_sync, ldap_sync_all +from authentik.sources.ldap.tasks import ldap_sync, ldap_sync_page from authentik.sources.ldap.tests.mock_ad import mock_ad_connection from authentik.sources.ldap.tests.mock_freeipa import mock_freeipa_connection from authentik.sources.ldap.tests.mock_slapd import ( @@ -33,6 +32,7 @@ from authentik.sources.ldap.tests.mock_slapd import ( user_in_slapd_cn, user_in_slapd_uid, ) +from authentik.tasks.models import Task LDAP_PASSWORD = generate_key() @@ -54,9 +54,7 @@ class LDAPSyncTests(TestCase): """Test sync with missing page""" connection = MagicMock(return_value=mock_ad_connection(LDAP_PASSWORD)) with patch("authentik.sources.ldap.models.LDAPSource.connection", connection): - ldap_sync.delay(str(self.source.pk), class_to_path(UserLDAPSynchronizer), "foo").get() - task = SystemTask.objects.filter(name="ldap_sync", uid="ldap:users:foo").first() - self.assertEqual(task.status, TaskStatus.ERROR) + ldap_sync_page.send(self.source.pk, class_to_path(UserLDAPSynchronizer), "foo") def test_sync_error(self): """Test user sync""" @@ -74,7 +72,7 @@ class LDAPSyncTests(TestCase): self.source.save() connection = MagicMock(return_value=mock_ad_connection(LDAP_PASSWORD)) with patch("authentik.sources.ldap.models.LDAPSource.connection", connection): - user_sync = UserLDAPSynchronizer(self.source) + user_sync = UserLDAPSynchronizer(self.source, Task()) with self.assertRaises(StopSync): user_sync.sync_full() self.assertFalse(User.objects.filter(username="user0_sn").exists()) @@ -105,7 +103,7 @@ class LDAPSyncTests(TestCase): # we basically just test that the mappings don't throw errors with patch("authentik.sources.ldap.models.LDAPSource.connection", connection): - user_sync = UserLDAPSynchronizer(self.source) + user_sync = UserLDAPSynchronizer(self.source, Task()) user_sync.sync_full() def test_sync_users_ad(self): @@ -133,7 +131,7 @@ class LDAPSyncTests(TestCase): ) with patch("authentik.sources.ldap.models.LDAPSource.connection", connection): - user_sync = UserLDAPSynchronizer(self.source) + user_sync = UserLDAPSynchronizer(self.source, Task()) user_sync.sync_full() user = User.objects.filter(username="user0_sn").first() self.assertEqual(user.attributes["foo"], "bar") @@ -152,7 +150,7 @@ class LDAPSyncTests(TestCase): ) connection = MagicMock(return_value=mock_slapd_connection(LDAP_PASSWORD)) with patch("authentik.sources.ldap.models.LDAPSource.connection", connection): - user_sync = UserLDAPSynchronizer(self.source) + user_sync = UserLDAPSynchronizer(self.source, Task()) user_sync.sync_full() self.assertTrue(User.objects.filter(username="user0_sn").exists()) self.assertFalse(User.objects.filter(username="user1_sn").exists()) @@ -168,7 +166,7 @@ class LDAPSyncTests(TestCase): ) connection = MagicMock(return_value=mock_freeipa_connection(LDAP_PASSWORD)) with patch("authentik.sources.ldap.models.LDAPSource.connection", connection): - user_sync = UserLDAPSynchronizer(self.source) + user_sync = UserLDAPSynchronizer(self.source, Task()) user_sync.sync_full() self.assertTrue(User.objects.filter(username="user0_sn").exists()) self.assertFalse(User.objects.filter(username="user1_sn").exists()) @@ -193,11 +191,11 @@ class LDAPSyncTests(TestCase): ) connection = MagicMock(return_value=mock_freeipa_connection(LDAP_PASSWORD)) with patch("authentik.sources.ldap.models.LDAPSource.connection", connection): - user_sync = UserLDAPSynchronizer(self.source) + user_sync = UserLDAPSynchronizer(self.source, Task()) user_sync.sync_full() - group_sync = GroupLDAPSynchronizer(self.source) + group_sync = GroupLDAPSynchronizer(self.source, Task()) group_sync.sync_full() - membership_sync = MembershipLDAPSynchronizer(self.source) + membership_sync = MembershipLDAPSynchronizer(self.source, Task()) membership_sync.sync_full() self.assertTrue( @@ -230,9 +228,9 @@ class LDAPSyncTests(TestCase): parent_group = Group.objects.get(name=_user.username) self.source.sync_parent_group = parent_group self.source.save() - group_sync = GroupLDAPSynchronizer(self.source) + group_sync = GroupLDAPSynchronizer(self.source, Task()) group_sync.sync_full() - membership_sync = MembershipLDAPSynchronizer(self.source) + membership_sync = MembershipLDAPSynchronizer(self.source, Task()) membership_sync.sync_full() group: Group = Group.objects.filter(name="test-group").first() self.assertIsNotNone(group) @@ -256,9 +254,9 @@ class LDAPSyncTests(TestCase): connection = MagicMock(return_value=mock_slapd_connection(LDAP_PASSWORD)) with patch("authentik.sources.ldap.models.LDAPSource.connection", connection): self.source.save() - group_sync = GroupLDAPSynchronizer(self.source) + group_sync = GroupLDAPSynchronizer(self.source, Task()) group_sync.sync_full() - membership_sync = MembershipLDAPSynchronizer(self.source) + membership_sync = MembershipLDAPSynchronizer(self.source, Task()) membership_sync.sync_full() group = Group.objects.filter(name="group1") self.assertTrue(group.exists()) @@ -290,11 +288,11 @@ class LDAPSyncTests(TestCase): connection = MagicMock(return_value=mock_slapd_connection(LDAP_PASSWORD)) with patch("authentik.sources.ldap.models.LDAPSource.connection", connection): self.source.save() - user_sync = UserLDAPSynchronizer(self.source) + user_sync = UserLDAPSynchronizer(self.source, Task()) user_sync.sync_full() - group_sync = GroupLDAPSynchronizer(self.source) + group_sync = GroupLDAPSynchronizer(self.source, Task()) group_sync.sync_full() - membership_sync = MembershipLDAPSynchronizer(self.source) + membership_sync = MembershipLDAPSynchronizer(self.source, Task()) membership_sync.sync_full() # Test if membership mapping based on memberUid works. posix_group = Group.objects.filter(name="group-posix").first() @@ -327,11 +325,11 @@ class LDAPSyncTests(TestCase): connection = MagicMock(return_value=mock_slapd_connection(LDAP_PASSWORD)) with patch("authentik.sources.ldap.models.LDAPSource.connection", connection): self.source.save() - user_sync = UserLDAPSynchronizer(self.source) + user_sync = UserLDAPSynchronizer(self.source, Task()) user_sync.sync_full() - group_sync = GroupLDAPSynchronizer(self.source) + group_sync = GroupLDAPSynchronizer(self.source, Task()) group_sync.sync_full() - membership_sync = MembershipLDAPSynchronizer(self.source) + membership_sync = MembershipLDAPSynchronizer(self.source, Task()) membership_sync.sync_full() # Test if membership mapping based on memberUid works. posix_group = Group.objects.filter(name="group-posix").first() @@ -348,7 +346,7 @@ class LDAPSyncTests(TestCase): self.source.save() connection = MagicMock(return_value=mock_ad_connection(LDAP_PASSWORD)) with patch("authentik.sources.ldap.models.LDAPSource.connection", connection): - ldap_sync_all.delay().get() + ldap_sync.send(self.source.pk) def test_tasks_openldap(self): """Test Scheduled tasks""" @@ -363,7 +361,7 @@ class LDAPSyncTests(TestCase): self.source.save() connection = MagicMock(return_value=mock_slapd_connection(LDAP_PASSWORD)) with patch("authentik.sources.ldap.models.LDAPSource.connection", connection): - ldap_sync_all.delay().get() + ldap_sync.send(self.source.pk) def test_user_deletion(self): """Test user deletion""" @@ -378,7 +376,7 @@ class LDAPSyncTests(TestCase): connection = MagicMock(return_value=mock_slapd_connection(LDAP_PASSWORD)) with patch("authentik.sources.ldap.models.LDAPSource.connection", connection): - ldap_sync_all.delay().get() + ldap_sync.send(self.source.pk) self.assertFalse(User.objects.filter(username="not-in-the-source").exists()) def test_user_deletion_still_in_source(self): @@ -396,7 +394,7 @@ class LDAPSyncTests(TestCase): connection = MagicMock(return_value=mock_slapd_connection(LDAP_PASSWORD)) with patch("authentik.sources.ldap.models.LDAPSource.connection", connection): - ldap_sync_all.delay().get() + ldap_sync.send(self.source.pk) self.assertTrue(User.objects.filter(username=username).exists()) def test_user_deletion_no_sync(self): @@ -413,7 +411,7 @@ class LDAPSyncTests(TestCase): connection = MagicMock(return_value=mock_slapd_connection(LDAP_PASSWORD)) with patch("authentik.sources.ldap.models.LDAPSource.connection", connection): - ldap_sync_all.delay().get() + ldap_sync.send(self.source.pk) self.assertTrue(User.objects.filter(username="not-in-the-source").exists()) def test_user_deletion_no_delete(self): @@ -428,7 +426,7 @@ class LDAPSyncTests(TestCase): connection = MagicMock(return_value=mock_slapd_connection(LDAP_PASSWORD)) with patch("authentik.sources.ldap.models.LDAPSource.connection", connection): - ldap_sync_all.delay().get() + ldap_sync.send(self.source.pk) self.assertTrue(User.objects.filter(username="not-in-the-source").exists()) def test_group_deletion(self): @@ -444,7 +442,7 @@ class LDAPSyncTests(TestCase): connection = MagicMock(return_value=mock_slapd_connection(LDAP_PASSWORD)) with patch("authentik.sources.ldap.models.LDAPSource.connection", connection): - ldap_sync_all.delay().get() + ldap_sync.send(self.source.pk) self.assertFalse(Group.objects.filter(name="not-in-the-source").exists()) def test_group_deletion_still_in_source(self): @@ -462,7 +460,7 @@ class LDAPSyncTests(TestCase): connection = MagicMock(return_value=mock_slapd_connection(LDAP_PASSWORD)) with patch("authentik.sources.ldap.models.LDAPSource.connection", connection): - ldap_sync_all.delay().get() + ldap_sync.send(self.source.pk) self.assertTrue(Group.objects.filter(name=groupname).exists()) def test_group_deletion_no_sync(self): @@ -479,7 +477,7 @@ class LDAPSyncTests(TestCase): connection = MagicMock(return_value=mock_slapd_connection(LDAP_PASSWORD)) with patch("authentik.sources.ldap.models.LDAPSource.connection", connection): - ldap_sync_all.delay().get() + ldap_sync.send(self.source.pk) self.assertTrue(Group.objects.filter(name="not-in-the-source").exists()) def test_group_deletion_no_delete(self): @@ -494,7 +492,7 @@ class LDAPSyncTests(TestCase): connection = MagicMock(return_value=mock_slapd_connection(LDAP_PASSWORD)) with patch("authentik.sources.ldap.models.LDAPSource.connection", connection): - ldap_sync_all.delay().get() + ldap_sync.send(self.source.pk) self.assertTrue(Group.objects.filter(name="not-in-the-source").exists()) def test_batch_deletion(self): @@ -517,7 +515,7 @@ class LDAPSyncTests(TestCase): connection = MagicMock(return_value=mock_slapd_connection(LDAP_PASSWORD)) with patch("authentik.sources.ldap.models.LDAPSource.connection", connection): - ldap_sync_all.delay().get() + ldap_sync.send(self.source.pk) self.assertFalse(User.objects.filter(username__startswith="not-in-the-source").exists()) self.assertFalse(Group.objects.filter(name__startswith="not-in-the-source").exists()) diff --git a/authentik/sources/oauth/apps.py b/authentik/sources/oauth/apps.py index 926736a747..919bf9ba10 100644 --- a/authentik/sources/oauth/apps.py +++ b/authentik/sources/oauth/apps.py @@ -3,6 +3,8 @@ from structlog.stdlib import get_logger from authentik.blueprints.apps import ManagedAppConfig +from authentik.lib.utils.time import fqdn_rand +from authentik.tasks.schedules.common import ScheduleSpec LOGGER = get_logger() @@ -41,3 +43,14 @@ class AuthentikSourceOAuthConfig(ManagedAppConfig): except ImportError as exc: LOGGER.warning("Failed to load OAuth Source", exc=exc) return super().import_related() + + @property + def tenant_schedule_specs(self) -> list[ScheduleSpec]: + from authentik.sources.oauth.tasks import update_well_known_jwks + + return [ + ScheduleSpec( + actor=update_well_known_jwks, + crontab=f"{fqdn_rand('update_well_known_jwks')} */3 * * *", + ), + ] diff --git a/authentik/sources/oauth/settings.py b/authentik/sources/oauth/settings.py deleted file mode 100644 index 6580a906f5..0000000000 --- a/authentik/sources/oauth/settings.py +++ /dev/null @@ -1,13 +0,0 @@ -"""OAuth source settings""" - -from celery.schedules import crontab - -from authentik.lib.utils.time import fqdn_rand - -CELERY_BEAT_SCHEDULE = { - "update_oauth_source_oidc_well_known": { - "task": "authentik.sources.oauth.tasks.update_well_known_jwks", - "schedule": crontab(minute=fqdn_rand("update_well_known_jwks"), hour="*/3"), - "options": {"queue": "authentik_scheduled"}, - }, -} diff --git a/authentik/sources/oauth/tasks.py b/authentik/sources/oauth/tasks.py index 42b110a1f5..ce36ba2dde 100644 --- a/authentik/sources/oauth/tasks.py +++ b/authentik/sources/oauth/tasks.py @@ -2,23 +2,27 @@ from json import dumps +from django.utils.translation import gettext_lazy as _ +from django_dramatiq_postgres.middleware import CurrentTask +from dramatiq.actor import actor from requests import RequestException from structlog.stdlib import get_logger -from authentik.events.models import TaskStatus -from authentik.events.system_tasks import SystemTask from authentik.lib.utils.http import get_http_session -from authentik.root.celery import CELERY_APP from authentik.sources.oauth.models import OAuthSource +from authentik.tasks.models import Task LOGGER = get_logger() -@CELERY_APP.task(bind=True, base=SystemTask) -def update_well_known_jwks(self: SystemTask): - """Update OAuth sources' config from well_known, and JWKS info from the configured URL""" +@actor( + description=_( + "Update OAuth sources' config from well_known, and JWKS info from the configured URL." + ) +) +def update_well_known_jwks(): + self: Task = CurrentTask.get_task() session = get_http_session() - messages = [] for source in OAuthSource.objects.all().exclude(oidc_well_known_url=""): try: well_known_config = session.get(source.oidc_well_known_url) @@ -26,7 +30,7 @@ def update_well_known_jwks(self: SystemTask): except RequestException as exc: text = exc.response.text if exc.response else str(exc) LOGGER.warning("Failed to update well_known", source=source, exc=exc, text=text) - messages.append(f"Failed to update OIDC configuration for {source.slug}") + self.info(f"Failed to update OIDC configuration for {source.slug}") continue config: dict = well_known_config.json() try: @@ -51,7 +55,7 @@ def update_well_known_jwks(self: SystemTask): source=source, exc=exc, ) - messages.append(f"Failed to update OIDC configuration for {source.slug}") + self.info(f"Failed to update OIDC configuration for {source.slug}") continue if dirty: LOGGER.info("Updating sources' OpenID Configuration", source=source) @@ -64,11 +68,10 @@ def update_well_known_jwks(self: SystemTask): except RequestException as exc: text = exc.response.text if exc.response else str(exc) LOGGER.warning("Failed to update JWKS", source=source, exc=exc, text=text) - messages.append(f"Failed to update JWKS for {source.slug}") + self.info(f"Failed to update JWKS for {source.slug}") continue config = jwks_config.json() if dumps(source.oidc_jwks, sort_keys=True) != dumps(config, sort_keys=True): source.oidc_jwks = config LOGGER.info("Updating sources' JWKS", source=source) source.save() - self.set_status(TaskStatus.SUCCESSFUL, *messages) diff --git a/authentik/sources/oauth/tests/test_tasks.py b/authentik/sources/oauth/tests/test_tasks.py index 9ab3aab4bf..6fa900c86e 100644 --- a/authentik/sources/oauth/tests/test_tasks.py +++ b/authentik/sources/oauth/tests/test_tasks.py @@ -35,7 +35,7 @@ class TestOAuthSourceTasks(TestCase): }, ) mock.get("http://foo/jwks", json={"foo": "bar"}) - update_well_known_jwks() + update_well_known_jwks.send() self.source.refresh_from_db() self.assertEqual(self.source.authorization_url, "foo") self.assertEqual(self.source.access_token_url, "foo") diff --git a/authentik/sources/plex/models.py b/authentik/sources/plex/models.py index 3dced6bc50..765139bdb1 100644 --- a/authentik/sources/plex/models.py +++ b/authentik/sources/plex/models.py @@ -19,7 +19,10 @@ from authentik.core.models import ( from authentik.core.types import UILoginButton, UserSettingSerializer from authentik.flows.challenge import Challenge, ChallengeResponse from authentik.lib.generators import generate_id +from authentik.lib.utils.time import fqdn_rand from authentik.stages.identification.stage import LoginChallengeMixin +from authentik.tasks.schedules.common import ScheduleSpec +from authentik.tasks.schedules.models import ScheduledModel class PlexAuthenticationChallenge(LoginChallengeMixin, Challenge): @@ -36,7 +39,7 @@ class PlexAuthenticationChallengeResponse(ChallengeResponse): component = CharField(default="ak-source-plex") -class PlexSource(Source): +class PlexSource(ScheduledModel, Source): """Authenticate against plex.tv""" client_id = models.TextField( @@ -72,6 +75,19 @@ class PlexSource(Source): def property_mapping_type(self) -> type[PropertyMapping]: return PlexSourcePropertyMapping + @property + def schedule_specs(self) -> list[ScheduleSpec]: + from authentik.sources.plex.tasks import check_plex_token + + return [ + ScheduleSpec( + actor=check_plex_token, + uid=self.slug, + args=(self.pk,), + crontab=f"{fqdn_rand(self.pk)} */3 * * *", + ), + ] + def get_base_user_properties(self, info: dict[str, Any], **kwargs): return { "username": info.get("username"), diff --git a/authentik/sources/plex/settings.py b/authentik/sources/plex/settings.py deleted file mode 100644 index 5884bbe220..0000000000 --- a/authentik/sources/plex/settings.py +++ /dev/null @@ -1,13 +0,0 @@ -"""Plex source settings""" - -from celery.schedules import crontab - -from authentik.lib.utils.time import fqdn_rand - -CELERY_BEAT_SCHEDULE = { - "check_plex_token": { - "task": "authentik.sources.plex.tasks.check_plex_token_all", - "schedule": crontab(minute=fqdn_rand("check_plex_token"), hour="*/3"), - "options": {"queue": "authentik_scheduled"}, - }, -} diff --git a/authentik/sources/plex/tasks.py b/authentik/sources/plex/tasks.py index 06b108edb6..ff2c54e81e 100644 --- a/authentik/sources/plex/tasks.py +++ b/authentik/sources/plex/tasks.py @@ -1,40 +1,34 @@ """Plex tasks""" +from django.utils.translation import gettext_lazy as _ +from django_dramatiq_postgres.middleware import CurrentTask +from dramatiq.actor import actor from requests import RequestException -from authentik.events.models import Event, EventAction, TaskStatus -from authentik.events.system_tasks import SystemTask +from authentik.events.models import Event, EventAction from authentik.lib.utils.errors import exception_to_string -from authentik.root.celery import CELERY_APP from authentik.sources.plex.models import PlexSource from authentik.sources.plex.plex import PlexAuth +from authentik.tasks.models import Task -@CELERY_APP.task() -def check_plex_token_all(): - """Check plex token for all plex sources""" - for source in PlexSource.objects.all(): - check_plex_token.delay(source.slug) - - -@CELERY_APP.task(bind=True, base=SystemTask) -def check_plex_token(self: SystemTask, source_slug: int): +@actor(description=_("Check the validity of a Plex source.")) +def check_plex_token(source_pk: str): """Check the validity of a Plex source.""" - sources = PlexSource.objects.filter(slug=source_slug) + self: Task = CurrentTask.get_task() + sources = PlexSource.objects.filter(pk=source_pk) if not sources.exists(): return source: PlexSource = sources.first() - self.set_uid(source.slug) auth = PlexAuth(source, source.plex_token) try: auth.get_user_info() - self.set_status(TaskStatus.SUCCESSFUL, "Plex token is valid.") + self.info("Plex token is valid.") except RequestException as exc: error = exception_to_string(exc) if len(source.plex_token) > 0: error = error.replace(source.plex_token, "$PLEX_TOKEN") - self.set_status( - TaskStatus.ERROR, + self.error( "Plex token is invalid/an error occurred:", error, ) diff --git a/authentik/sources/plex/tests.py b/authentik/sources/plex/tests.py index 53be3c77ce..aaab8383fc 100644 --- a/authentik/sources/plex/tests.py +++ b/authentik/sources/plex/tests.py @@ -8,7 +8,7 @@ from authentik.events.models import Event, EventAction from authentik.lib.generators import generate_key from authentik.sources.plex.models import PlexSource from authentik.sources.plex.plex import PlexAuth -from authentik.sources.plex.tasks import check_plex_token_all +from authentik.sources.plex.tasks import check_plex_token USER_INFO_RESPONSE = { "id": 1234123419, @@ -76,11 +76,11 @@ class TestPlexSource(TestCase): """Test token check task""" with Mocker() as mocker: mocker.get("https://plex.tv/api/v2/user", json=USER_INFO_RESPONSE) - check_plex_token_all() + check_plex_token.send(self.source.pk) self.assertFalse(Event.objects.filter(action=EventAction.CONFIGURATION_ERROR).exists()) with Mocker() as mocker: mocker.get("https://plex.tv/api/v2/user", exc=RequestException()) - check_plex_token_all() + check_plex_token.send(self.source.pk) self.assertTrue(Event.objects.filter(action=EventAction.CONFIGURATION_ERROR).exists()) def test_user_base_properties(self): diff --git a/authentik/stages/authenticator_duo/api.py b/authentik/stages/authenticator_duo/api.py index cccdd1dc1d..7a07b1dd70 100644 --- a/authentik/stages/authenticator_duo/api.py +++ b/authentik/stages/authenticator_duo/api.py @@ -1,5 +1,7 @@ """AuthenticatorDuoStage API Views""" +from typing import Any + from django.http import Http404 from drf_spectacular.types import OpenApiTypes from drf_spectacular.utils import OpenApiResponse, extend_schema, inline_serializer @@ -15,11 +17,11 @@ from structlog.stdlib import get_logger from authentik.core.api.groups import GroupMemberSerializer from authentik.core.api.used_by import UsedByMixin from authentik.core.api.utils import ModelSerializer +from authentik.core.models import User from authentik.flows.api.stages import StageSerializer from authentik.rbac.decorators import permission_required from authentik.stages.authenticator_duo.models import AuthenticatorDuoStage, DuoDevice from authentik.stages.authenticator_duo.stage import SESSION_KEY_DUO_ENROLL -from authentik.stages.authenticator_duo.tasks import duo_import_devices LOGGER = get_logger() @@ -159,9 +161,45 @@ class AuthenticatorDuoStageViewSet(UsedByMixin, ModelViewSet): }, status=400, ) - result = duo_import_devices.delay(str(stage.pk)).get() + result = self._duo_import_devices(stage) return Response(data=result, status=200 if result["error"] == "" else 400) + def _duo_import_devices(self, stage: AuthenticatorDuoStage) -> dict[str, Any]: + """ + Import duo devices. This used to be a blocking task. + """ + created = 0 + if stage.admin_integration_key == "": + LOGGER.info("Stage does not have admin integration configured", stage=stage) + return {"error": "Stage does not have admin integration configured", "count": created} + client = stage.admin_client() + try: + for duo_user in client.get_users_iterator(): + user_id = duo_user.get("user_id") + username = duo_user.get("username") + + user = User.objects.filter(username=username).first() + if not user: + LOGGER.debug("User not found", username=username) + continue + device = DuoDevice.objects.filter( + duo_user_id=user_id, user=user, stage=stage + ).first() + if device: + LOGGER.debug("User already has a device with ID", id=user_id) + continue + DuoDevice.objects.create( + duo_user_id=user_id, + user=user, + stage=stage, + name="Imported Duo Authenticator", + ) + created += 1 + return {"error": "", "count": created} + except RuntimeError as exc: + LOGGER.warning("failed to get users from duo", exc=exc) + return {"error": str(exc), "count": created} + class DuoDeviceSerializer(ModelSerializer): """Serializer for Duo authenticator devices""" diff --git a/authentik/stages/authenticator_duo/tasks.py b/authentik/stages/authenticator_duo/tasks.py deleted file mode 100644 index b97c1e39ee..0000000000 --- a/authentik/stages/authenticator_duo/tasks.py +++ /dev/null @@ -1,47 +0,0 @@ -"""duo tasks""" - -from structlog.stdlib import get_logger - -from authentik.core.models import User -from authentik.root.celery import CELERY_APP -from authentik.stages.authenticator_duo.models import AuthenticatorDuoStage, DuoDevice - -LOGGER = get_logger() - - -@CELERY_APP.task() -def duo_import_devices(stage_pk: str): - """Import duo devices""" - created = 0 - stage: AuthenticatorDuoStage = AuthenticatorDuoStage.objects.filter(pk=stage_pk).first() - if not stage: - LOGGER.info("No stage found", pk=stage_pk) - return {"error": "No stage found", "count": created} - if stage.admin_integration_key == "": - LOGGER.info("Stage does not have admin integration configured", stage=stage) - return {"error": "Stage does not have admin integration configured", "count": created} - client = stage.admin_client() - try: - for duo_user in client.get_users_iterator(): - user_id = duo_user.get("user_id") - username = duo_user.get("username") - - user = User.objects.filter(username=username).first() - if not user: - LOGGER.debug("User not found", username=username) - continue - device = DuoDevice.objects.filter(duo_user_id=user_id, user=user, stage=stage).first() - if device: - LOGGER.debug("User already has a device with ID", id=user_id) - continue - DuoDevice.objects.create( - duo_user_id=user_id, - user=user, - stage=stage, - name="Imported Duo Authenticator", - ) - created += 1 - return {"error": "", "count": created} - except RuntimeError as exc: - LOGGER.warning("failed to get users from duo", exc=exc) - return {"error": str(exc), "count": created} diff --git a/authentik/stages/authenticator_email/tests.py b/authentik/stages/authenticator_email/tests.py index 8a3e9b2884..14533e726f 100644 --- a/authentik/stages/authenticator_email/tests.py +++ b/authentik/stages/authenticator_email/tests.py @@ -1,10 +1,11 @@ """Test Email Authenticator API""" from datetime import timedelta -from unittest.mock import MagicMock, PropertyMock, patch +from unittest.mock import PropertyMock, patch from django.core import mail -from django.core.mail.backends.smtp import EmailBackend +from django.core.mail.backends.locmem import EmailBackend +from django.core.mail.backends.smtp import EmailBackend as SMTPEmailBackend from django.db.utils import IntegrityError from django.template.exceptions import TemplateDoesNotExist from django.urls import reverse @@ -83,24 +84,28 @@ class TestAuthenticatorEmailStage(FlowTestCase): self.assertTrue(self.device.verify_token(token)) self.assertIsNone(self.device.token) + @patch( + "authentik.stages.authenticator_email.models.AuthenticatorEmailStage.backend_class", + PropertyMock(return_value=EmailBackend), + ) def test_stage_no_prefill(self): """Test stage without prefilled email""" self.client.force_login(self.user_noemail) - with patch( - "authentik.stages.authenticator_email.models.AuthenticatorEmailStage.backend_class", - PropertyMock(return_value=EmailBackend), - ): - response = self.client.get( - reverse("authentik_api:flow-executor", kwargs={"flow_slug": self.flow.slug}), - ) - self.assertStageResponse( - response, - self.flow, - self.user_noemail, - component="ak-stage-authenticator-email", - email_required=True, - ) + response = self.client.get( + reverse("authentik_api:flow-executor", kwargs={"flow_slug": self.flow.slug}), + ) + self.assertStageResponse( + response, + self.flow, + self.user_noemail, + component="ak-stage-authenticator-email", + email_required=True, + ) + @patch( + "authentik.stages.authenticator_email.models.AuthenticatorEmailStage.backend_class", + PropertyMock(return_value=EmailBackend), + ) def test_stage_submit(self): """Test stage email submission""" # Initialize the flow @@ -115,34 +120,18 @@ class TestAuthenticatorEmailStage(FlowTestCase): email_required=False, ) - # Test email submission with locmem backend - def mock_send_mails(stage, *messages): - """Mock send_mails to send directly""" - for message in messages: - message.send() - - with ( - patch( - "authentik.stages.authenticator_email.models.AuthenticatorEmailStage.backend_class", - return_value=EmailBackend, - ), - patch( - "authentik.stages.authenticator_email.stage.send_mails", - side_effect=mock_send_mails, - ), - ): - response = self.client.post( - reverse("authentik_api:flow-executor", kwargs={"flow_slug": self.flow.slug}), - data={"component": "ak-stage-authenticator-email", "email": "test@example.com"}, - ) - self.assertEqual(response.status_code, 200) - self.assertEqual(len(mail.outbox), 1) - sent_mail = mail.outbox[0] - self.assertEqual(sent_mail.subject, self.stage.subject) - self.assertEqual(sent_mail.to, [f"{self.user} "]) - # Get from_address from global email config to test if global settings are being used - from_address_global = CONFIG.get("email.from") - self.assertEqual(sent_mail.from_email, from_address_global) + response = self.client.post( + reverse("authentik_api:flow-executor", kwargs={"flow_slug": self.flow.slug}), + data={"component": "ak-stage-authenticator-email", "email": "test@example.com"}, + ) + self.assertEqual(response.status_code, 200) + self.assertEqual(len(mail.outbox), 2) + sent_mail = mail.outbox[1] + self.assertEqual(sent_mail.subject, self.stage.subject) + self.assertEqual(sent_mail.to, [f"{self.user} "]) + # Get from_address from global email config to test if global settings are being used + from_address_global = CONFIG.get("email.from") + self.assertEqual(sent_mail.from_email, from_address_global) self.assertStageResponse( response, @@ -196,110 +185,110 @@ class TestAuthenticatorEmailStage(FlowTestCase): with self.assertRaises(TemplateDoesNotExist): self.stage.send(self.device) + @patch( + "authentik.stages.authenticator_email.models.AuthenticatorEmailStage.backend_class", + PropertyMock(return_value=EmailBackend), + ) def test_challenge_response_validation(self): """Test challenge response validation""" # Initialize the flow self.client.force_login(self.user_noemail) - with patch( - "authentik.stages.authenticator_email.models.AuthenticatorEmailStage.backend_class", - PropertyMock(return_value=EmailBackend), - ): - response = self.client.get( - reverse("authentik_api:flow-executor", kwargs={"flow_slug": self.flow.slug}), - ) + response = self.client.get( + reverse("authentik_api:flow-executor", kwargs={"flow_slug": self.flow.slug}), + ) - # Test missing code and email - response = self.client.post( - reverse("authentik_api:flow-executor", kwargs={"flow_slug": self.flow.slug}), - data={"component": "ak-stage-authenticator-email"}, - ) - self.assertIn("email required", str(response.content)) + # Test missing code and email + response = self.client.post( + reverse("authentik_api:flow-executor", kwargs={"flow_slug": self.flow.slug}), + data={"component": "ak-stage-authenticator-email"}, + ) + self.assertIn("email required", str(response.content)) - # Test invalid code - response = self.client.post( - reverse("authentik_api:flow-executor", kwargs={"flow_slug": self.flow.slug}), - data={"component": "ak-stage-authenticator-email", "code": "000000"}, - ) - self.assertIn("Code does not match", str(response.content)) + # Test invalid code + response = self.client.post( + reverse("authentik_api:flow-executor", kwargs={"flow_slug": self.flow.slug}), + data={"component": "ak-stage-authenticator-email", "code": "000000"}, + ) + self.assertIn("Code does not match", str(response.content)) - # Test valid code - self.client.force_login(self.user) - response = self.client.get( - reverse("authentik_api:flow-executor", kwargs={"flow_slug": self.flow.slug}), - ) - device = self.device - token = device.token - response = self.client.post( - reverse("authentik_api:flow-executor", kwargs={"flow_slug": self.flow.slug}), - data={"component": "ak-stage-authenticator-email", "code": token}, - ) - self.assertEqual(response.status_code, 200) - self.assertTrue(device.confirmed) + # Test valid code + self.client.force_login(self.user) + response = self.client.get( + reverse("authentik_api:flow-executor", kwargs={"flow_slug": self.flow.slug}), + ) + device = self.device + token = device.token + response = self.client.post( + reverse("authentik_api:flow-executor", kwargs={"flow_slug": self.flow.slug}), + data={"component": "ak-stage-authenticator-email", "code": token}, + ) + self.assertEqual(response.status_code, 200) + self.assertTrue(device.confirmed) + @patch( + "authentik.stages.authenticator_email.models.AuthenticatorEmailStage.backend_class", + PropertyMock(return_value=EmailBackend), + ) def test_challenge_generation(self): """Test challenge generation""" # Test with masked email - with patch( - "authentik.stages.authenticator_email.models.AuthenticatorEmailStage.backend_class", - PropertyMock(return_value=EmailBackend), - ): - response = self.client.get( - reverse("authentik_api:flow-executor", kwargs={"flow_slug": self.flow.slug}), - ) - self.assertStageResponse( - response, - self.flow, - self.user, - component="ak-stage-authenticator-email", - email_required=False, - ) - masked_email = mask_email(self.user.email) - self.assertEqual(masked_email, response.json()["email"]) - self.client.logout() + response = self.client.get( + reverse("authentik_api:flow-executor", kwargs={"flow_slug": self.flow.slug}), + ) + self.assertStageResponse( + response, + self.flow, + self.user, + component="ak-stage-authenticator-email", + email_required=False, + ) + masked_email = mask_email(self.user.email) + self.assertEqual(masked_email, response.json()["email"]) + self.client.logout() - # Test without email - self.client.force_login(self.user_noemail) - response = self.client.get( - reverse("authentik_api:flow-executor", kwargs={"flow_slug": self.flow.slug}), - ) - self.assertStageResponse( - response, - self.flow, - self.user_noemail, - component="ak-stage-authenticator-email", - email_required=True, - ) - self.assertIsNone(response.json()["email"]) + # Test without email + self.client.force_login(self.user_noemail) + response = self.client.get( + reverse("authentik_api:flow-executor", kwargs={"flow_slug": self.flow.slug}), + ) + self.assertStageResponse( + response, + self.flow, + self.user_noemail, + component="ak-stage-authenticator-email", + email_required=True, + ) + self.assertIsNone(response.json()["email"]) + @patch( + "authentik.stages.authenticator_email.models.AuthenticatorEmailStage.backend_class", + PropertyMock(return_value=EmailBackend), + ) def test_session_management(self): """Test session device management""" # Test device creation in session - with patch( - "authentik.stages.authenticator_email.models.AuthenticatorEmailStage.backend_class", - PropertyMock(return_value=EmailBackend), - ): - # Delete any existing devices for this test - EmailDevice.objects.filter(user=self.user).delete() + # Delete any existing devices for this test + EmailDevice.objects.filter(user=self.user).delete() - response = self.client.get( - reverse("authentik_api:flow-executor", kwargs={"flow_slug": self.flow.slug}), - ) - self.assertIn(SESSION_KEY_EMAIL_DEVICE, self.client.session) - device = self.client.session[SESSION_KEY_EMAIL_DEVICE] - self.assertIsInstance(device, EmailDevice) - self.assertFalse(device.confirmed) - self.assertEqual(device.user, self.user) + response = self.client.get( + reverse("authentik_api:flow-executor", kwargs={"flow_slug": self.flow.slug}), + ) + self.assertIn(SESSION_KEY_EMAIL_DEVICE, self.client.session) + device = self.client.session[SESSION_KEY_EMAIL_DEVICE] + self.assertIsInstance(device, EmailDevice) + self.assertFalse(device.confirmed) + self.assertEqual(device.user, self.user) - # Test device confirmation and cleanup - device.confirmed = True - device.email = "new_test@authentik.local" # Use a different email - self.client.session[SESSION_KEY_EMAIL_DEVICE] = device - self.client.session.save() - response = self.client.post( - reverse("authentik_api:flow-executor", kwargs={"flow_slug": self.flow.slug}), - data={"component": "ak-stage-authenticator-email", "code": device.token}, - ) - self.assertEqual(response.status_code, 200) + # Test device confirmation and cleanup + device.confirmed = True + device.email = "new_test@authentik.local" # Use a different email + self.client.session[SESSION_KEY_EMAIL_DEVICE] = device + self.client.session.save() + response = self.client.post( + reverse("authentik_api:flow-executor", kwargs={"flow_slug": self.flow.slug}), + data={"component": "ak-stage-authenticator-email", "code": device.token}, + ) + self.assertEqual(response.status_code, 200) def test_model_properties_and_methods(self): """Test model properties""" @@ -307,30 +296,23 @@ class TestAuthenticatorEmailStage(FlowTestCase): stage = self.stage self.assertEqual(stage.serializer, AuthenticatorEmailStageSerializer) - self.assertIsInstance(stage.backend, EmailBackend) + self.assertIsInstance(stage.backend, SMTPEmailBackend) self.assertEqual(device.serializer, EmailDeviceSerializer) # Test AuthenticatorEmailStage send method - with patch( - "authentik.stages.authenticator_email.models.AuthenticatorEmailStage.backend_class", - return_value=EmailBackend, - ): - self.device.generate_token() - # Test EmailDevice _compose_email method - message = self.device._compose_email() - self.assertIsInstance(message, TemplateEmailMessage) - self.assertEqual(message.subject, self.stage.subject) - self.assertEqual(message.to, [f"{self.user.name} <{self.device.email}>"]) - self.assertTrue(self.device.token in message.body) - # Test AuthenticatorEmailStage send method - self.stage.send(device) + self.device.generate_token() + # Test EmailDevice _compose_email method + message = self.device._compose_email() + self.assertIsInstance(message, TemplateEmailMessage) + self.assertEqual(message.subject, self.stage.subject) + self.assertEqual(message.to, [f"{self.user.name} <{self.device.email}>"]) + self.assertTrue(self.device.token in message.body) + @patch( + "authentik.stages.authenticator_email.models.AuthenticatorEmailStage.backend_class", + PropertyMock(return_value=EmailBackend), + ) def test_email_tasks(self): - email_send_mock = MagicMock() - with patch( - "authentik.stages.email.tasks.send_mails", - email_send_mock, - ): - # Test AuthenticatorEmailStage send method - self.stage.send(self.device) - email_send_mock.assert_called_once() + # Test AuthenticatorEmailStage send method + self.stage.send(self.device) + self.assertEqual(len(mail.outbox), 1) diff --git a/authentik/stages/authenticator_validate/tests/test_webauthn.py b/authentik/stages/authenticator_validate/tests/test_webauthn.py index 76b7053db4..093af4ae0d 100644 --- a/authentik/stages/authenticator_validate/tests/test_webauthn.py +++ b/authentik/stages/authenticator_validate/tests/test_webauthn.py @@ -134,7 +134,7 @@ class AuthenticatorValidateStageWebAuthnTests(FlowTestCase): def test_device_challenge_webauthn_restricted(self): """Test webauthn (getting device challenges with a webauthn device that is not allowed due to aaguid restrictions)""" - webauthn_mds_import.delay(force=True).get() + webauthn_mds_import.send(force=True).get_result() request = get_request("/") request.user = self.user @@ -260,7 +260,7 @@ class AuthenticatorValidateStageWebAuthnTests(FlowTestCase): def test_validate_challenge_unrestricted(self): """Test webauthn authentication (unrestricted webauthn device)""" - webauthn_mds_import.delay(force=True).get() + webauthn_mds_import.send(force=True).get_result() device = WebAuthnDevice.objects.create( user=self.user, public_key=( @@ -334,7 +334,7 @@ class AuthenticatorValidateStageWebAuthnTests(FlowTestCase): def test_validate_challenge_restricted(self): """Test webauthn authentication (restricted device type, failure)""" - webauthn_mds_import.delay(force=True).get() + webauthn_mds_import.send(force=True).get_result() device = WebAuthnDevice.objects.create( user=self.user, public_key=( diff --git a/authentik/stages/authenticator_webauthn/apps.py b/authentik/stages/authenticator_webauthn/apps.py index 059e946c13..d65f6cf3c2 100644 --- a/authentik/stages/authenticator_webauthn/apps.py +++ b/authentik/stages/authenticator_webauthn/apps.py @@ -1,6 +1,8 @@ """authentik webauthn app config""" from authentik.blueprints.apps import ManagedAppConfig +from authentik.lib.utils.time import fqdn_rand +from authentik.tasks.schedules.common import ScheduleSpec class AuthentikStageAuthenticatorWebAuthnConfig(ManagedAppConfig): @@ -10,3 +12,14 @@ class AuthentikStageAuthenticatorWebAuthnConfig(ManagedAppConfig): label = "authentik_stages_authenticator_webauthn" verbose_name = "authentik Stages.Authenticator.WebAuthn" default = True + + @property + def tenant_schedule_specs(self) -> list[ScheduleSpec]: + from authentik.stages.authenticator_webauthn.tasks import webauthn_mds_import + + return [ + ScheduleSpec( + actor=webauthn_mds_import, + crontab=f"{fqdn_rand('webauthn_mds_import')} {fqdn_rand('webauthn_mds_import', 24)} * * {fqdn_rand('webauthn_mds_import', 7)}", # noqa: E501 + ), + ] diff --git a/authentik/stages/authenticator_webauthn/settings.py b/authentik/stages/authenticator_webauthn/settings.py deleted file mode 100644 index 1ecd5af5e4..0000000000 --- a/authentik/stages/authenticator_webauthn/settings.py +++ /dev/null @@ -1,17 +0,0 @@ -"""Stage authenticator webauthn Settings""" - -from celery.schedules import crontab - -from authentik.lib.utils.time import fqdn_rand - -CELERY_BEAT_SCHEDULE = { - "stages_authenticator_webauthn_import_mds": { - "task": "authentik.stages.authenticator_webauthn.tasks.webauthn_mds_import", - "schedule": crontab( - minute=fqdn_rand("webauthn_mds_import"), - hour=fqdn_rand("webauthn_mds_import", 24), - day_of_week=fqdn_rand("webauthn_mds_import", 7), - ), - "options": {"queue": "authentik_scheduled"}, - }, -} diff --git a/authentik/stages/authenticator_webauthn/tasks.py b/authentik/stages/authenticator_webauthn/tasks.py index 15ed03be13..a5e0370bd7 100644 --- a/authentik/stages/authenticator_webauthn/tasks.py +++ b/authentik/stages/authenticator_webauthn/tasks.py @@ -6,15 +6,16 @@ from pathlib import Path from django.core.cache import cache from django.db.transaction import atomic +from django.utils.translation import gettext_lazy as _ +from django_dramatiq_postgres.middleware import CurrentTask +from dramatiq.actor import actor from fido2.mds3 import filter_revoked, parse_blob -from authentik.events.models import TaskStatus -from authentik.events.system_tasks import SystemTask, prefill_task -from authentik.root.celery import CELERY_APP from authentik.stages.authenticator_webauthn.models import ( UNKNOWN_DEVICE_TYPE_AAGUID, WebAuthnDeviceType, ) +from authentik.tasks.models import Task CACHE_KEY_MDS_NO = "goauthentik.io/stages/authenticator_webauthn/mds_no" AAGUID_BLOB_PATH = Path(__file__).parent / "mds" / "aaguid.json" @@ -29,13 +30,10 @@ def mds_ca() -> bytes: return _raw_root.read() -@CELERY_APP.task( - bind=True, - base=SystemTask, -) -@prefill_task -def webauthn_mds_import(self: SystemTask, force=False): +@actor(description=_("Background task to import FIDO Alliance MDS blob and AAGUIDs into database.")) +def webauthn_mds_import(force=False): """Background task to import FIDO Alliance MDS blob and AAGUIDs into database""" + self: Task = CurrentTask.get_task() with open(MDS_BLOB_PATH, mode="rb") as _raw_blob: blob = parse_blob(_raw_blob.read(), mds_ca()) to_create_update = [ @@ -90,7 +88,4 @@ def webauthn_mds_import(self: SystemTask, force=False): unique_fields=["aaguid"], ) - self.set_status( - TaskStatus.SUCCESSFUL, - "Successfully imported FIDO Alliance MDS blobs and AAGUIDs.", - ) + self.info("Successfully imported FIDO Alliance MDS blobs and AAGUIDs.") diff --git a/authentik/stages/authenticator_webauthn/tests.py b/authentik/stages/authenticator_webauthn/tests.py index 9e7d8b7f2e..e075876700 100644 --- a/authentik/stages/authenticator_webauthn/tests.py +++ b/authentik/stages/authenticator_webauthn/tests.py @@ -142,7 +142,7 @@ class TestAuthenticatorWebAuthnStage(FlowTestCase): def test_register_restricted_device_type_deny(self): """Test registration with restricted devices (fail)""" - webauthn_mds_import.delay(force=True).get() + webauthn_mds_import.send(force=True) self.stage.device_type_restrictions.set( WebAuthnDeviceType.objects.filter(description="YubiKey 5 Series") ) @@ -205,7 +205,7 @@ class TestAuthenticatorWebAuthnStage(FlowTestCase): def test_register_restricted_device_type_allow(self): """Test registration with restricted devices (allow)""" - webauthn_mds_import.delay(force=True).get() + webauthn_mds_import.send(force=True) self.stage.device_type_restrictions.set( WebAuthnDeviceType.objects.filter(description="iCloud Keychain") ) @@ -254,7 +254,7 @@ class TestAuthenticatorWebAuthnStage(FlowTestCase): def test_register_restricted_device_type_allow_unknown(self): """Test registration with restricted devices (allow, unknown device type)""" - webauthn_mds_import.delay(force=True).get() + webauthn_mds_import.send(force=True) WebAuthnDeviceType.objects.filter(description="iCloud Keychain").delete() self.stage.device_type_restrictions.set( WebAuthnDeviceType.objects.filter(aaguid=UNKNOWN_DEVICE_TYPE_AAGUID) diff --git a/authentik/stages/consent/tests.py b/authentik/stages/consent/tests.py index 9099fc7cc8..0ef001b225 100644 --- a/authentik/stages/consent/tests.py +++ b/authentik/stages/consent/tests.py @@ -174,7 +174,7 @@ class TestConsentStage(FlowTestCase): ) with freeze_time() as frozen_time: frozen_time.tick(timedelta(seconds=3)) - clean_expired_models.delay().get() + clean_expired_models.send() self.assertFalse( UserConsent.objects.filter(user=self.user, application=self.application).exists() ) diff --git a/authentik/stages/email/tasks.py b/authentik/stages/email/tasks.py index b4b6ba03b5..af2b14ab3c 100644 --- a/authentik/stages/email/tasks.py +++ b/authentik/stages/email/tasks.py @@ -1,22 +1,23 @@ """email stage tasks""" from email.utils import make_msgid -from smtplib import SMTPException from typing import Any -from celery import group from django.core.mail import EmailMultiAlternatives from django.core.mail.utils import DNS_NAME from django.utils.text import slugify +from django.utils.translation import gettext_lazy as _ +from django_dramatiq_postgres.middleware import CurrentTask +from dramatiq.actor import actor +from dramatiq.composition import group from structlog.stdlib import get_logger -from authentik.events.models import Event, EventAction, TaskStatus -from authentik.events.system_tasks import SystemTask +from authentik.events.models import Event, EventAction from authentik.lib.utils.reflection import class_to_path, path_to_class -from authentik.root.celery import CELERY_APP from authentik.stages.authenticator_email.models import AuthenticatorEmailStage from authentik.stages.email.models import EmailStage from authentik.stages.email.utils import logo_data +from authentik.tasks.models import Task LOGGER = get_logger() @@ -30,16 +31,14 @@ def send_mails( stage: Either an EmailStage or AuthenticatorEmailStage instance messages: List of email messages to send Returns: - Celery group promise for the email sending tasks + Dramatiq group promise for the email sending tasks """ tasks = [] # Use the class path instead of the class itself for serialization stage_class_path = class_to_path(stage.__class__) for message in messages: - tasks.append(send_mail.s(message.__dict__, stage_class_path, str(stage.pk))) - lazy_group = group(*tasks) - promise = lazy_group() - return promise + tasks.append(send_mail.message(message.__dict__, stage_class_path, str(stage.pk))) + return group(tasks).run() def get_email_body(email: EmailMultiAlternatives) -> str: @@ -50,84 +49,63 @@ def get_email_body(email: EmailMultiAlternatives) -> str: return email.body -@CELERY_APP.task( - bind=True, - autoretry_for=( - SMTPException, - ConnectionError, - OSError, - ), - retry_backoff=True, - base=SystemTask, -) +@actor(description=_("Send email.")) def send_mail( - self: SystemTask, message: dict[Any, Any], stage_class_path: str | None = None, email_stage_pk: str | None = None, ): """Send Email for Email Stage. Retries are scheduled automatically.""" - self.save_on_success = False + self: Task = CurrentTask.get_task() message_id = make_msgid(domain=DNS_NAME) self.set_uid(slugify(message_id.replace(".", "_").replace("@", "_"))) - try: - if not stage_class_path or not email_stage_pk: - stage = EmailStage(use_global_settings=True) - else: - stage_class = path_to_class(stage_class_path) - stages = stage_class.objects.filter(pk=email_stage_pk) - if not stages.exists(): - self.set_status( - TaskStatus.WARNING, - "Email stage does not exist anymore. Discarding message.", - ) - return - stage: EmailStage | AuthenticatorEmailStage = stages.first() - try: - backend = stage.backend - except ValueError as exc: - LOGGER.warning("failed to get email backend", exc=exc) - self.set_error(exc) + if not stage_class_path or not email_stage_pk: + stage = EmailStage(use_global_settings=True) + else: + stage_class = path_to_class(stage_class_path) + stages = stage_class.objects.filter(pk=email_stage_pk) + if not stages.exists(): + self.warning("Email stage does not exist anymore. Discarding message.") return - backend.open() - # Since django's EmailMessage objects are not JSON serialisable, - # we need to rebuild them from a dict - message_object = EmailMultiAlternatives() - for key, value in message.items(): - setattr(message_object, key, value) - if not stage.use_global_settings: - message_object.from_email = stage.from_address - # Because we use the Message-ID as UID for the task, manually assign it - message_object.extra_headers["Message-ID"] = message_id + stage: EmailStage | AuthenticatorEmailStage = stages.first() + try: + backend = stage.backend + except ValueError as exc: + LOGGER.warning("failed to get email backend", exc=exc) + self.error(exc) + return + backend.open() + # Since django's EmailMessage objects are not JSON serialisable, + # we need to rebuild them from a dict + message_object = EmailMultiAlternatives() + for key, value in message.items(): + setattr(message_object, key, value) + if not stage.use_global_settings: + message_object.from_email = stage.from_address + # Because we use the Message-ID as UID for the task, manually assign it + message_object.extra_headers["Message-ID"] = message_id - # Add the logo if it is used in the email body (we can't add it in the - # previous message since MIMEImage can't be converted to json) - body = get_email_body(message_object) - if "cid:logo" in body: - message_object.attach(logo_data()) + # Add the logo if it is used in the email body (we can't add it in the + # previous message since MIMEImage can't be converted to json) + body = get_email_body(message_object) + if "cid:logo" in body: + message_object.attach(logo_data()) - if ( - message_object.to - and isinstance(message_object.to[0], str) - and "=?utf-8?" in message_object.to[0] - ): - message_object.to = [message_object.to[0].split("<")[-1].replace(">", "")] + if ( + message_object.to + and isinstance(message_object.to[0], str) + and "=?utf-8?" in message_object.to[0] + ): + message_object.to = [message_object.to[0].split("<")[-1].replace(">", "")] - LOGGER.debug("Sending mail", to=message_object.to) - backend.send_messages([message_object]) - Event.new( - EventAction.EMAIL_SENT, - message=f"Email to {', '.join(message_object.to)} sent", - subject=message_object.subject, - body=get_email_body(message_object), - from_email=message_object.from_email, - to_email=message_object.to, - ).save() - self.set_status( - TaskStatus.SUCCESSFUL, - "Successfully sent Mail.", - ) - except (SMTPException, ConnectionError, OSError) as exc: - LOGGER.debug("Error sending email, retrying...", exc=exc) - self.set_error(exc) - raise exc + LOGGER.debug("Sending mail", to=message_object.to) + backend.send_messages([message_object]) + Event.new( + EventAction.EMAIL_SENT, + message=f"Email to {', '.join(message_object.to)} sent", + subject=message_object.subject, + body=get_email_body(message_object), + from_email=message_object.from_email, + to_email=message_object.to, + ).save() + self.info("Successfully sent mail.") diff --git a/authentik/stages/email/tests/test_stage.py b/authentik/stages/email/tests/test_stage.py index b43963b0a3..5d0a4d6d6c 100644 --- a/authentik/stages/email/tests/test_stage.py +++ b/authentik/stages/email/tests/test_stage.py @@ -36,6 +36,10 @@ class TestEmailStage(FlowTestCase): ) self.binding = FlowStageBinding.objects.create(target=self.flow, stage=self.stage, order=2) + @patch( + "authentik.stages.email.models.EmailStage.backend_class", + PropertyMock(return_value=EmailBackend), + ) def test_rendering(self): """Test with pending user""" plan = FlowPlan(flow_pk=self.flow.pk.hex, bindings=[self.binding], markers=[StageMarker()]) diff --git a/authentik/stages/email/tests/test_tasks.py b/authentik/stages/email/tests/test_tasks.py index 158b21636c..e990a73f20 100644 --- a/authentik/stages/email/tests/test_tasks.py +++ b/authentik/stages/email/tests/test_tasks.py @@ -1,12 +1,13 @@ """Test email stage tasks""" -from unittest.mock import patch +from unittest.mock import PropertyMock, patch +from django.core import mail from django.core.mail import EmailMultiAlternatives +from django.core.mail.backends.locmem import EmailBackend from django.test import TestCase from authentik.core.tests.utils import create_test_admin_user -from authentik.lib.utils.reflection import class_to_path from authentik.stages.authenticator_email.models import AuthenticatorEmailStage from authentik.stages.email.models import EmailStage from authentik.stages.email.tasks import get_email_body, send_mails @@ -39,20 +40,22 @@ class TestEmailTasks(TestCase): message.body = "plain text" self.assertEqual(get_email_body(message), "plain text") + @patch( + "authentik.stages.email.models.EmailStage.backend_class", + PropertyMock(return_value=EmailBackend), + ) def test_send_mails_email_stage(self): """Test send_mails with EmailStage""" message = EmailMultiAlternatives() - with patch("authentik.stages.email.tasks.send_mail") as mock_send: - send_mails(self.stage, message) - mock_send.s.assert_called_once_with( - message.__dict__, class_to_path(EmailStage), str(self.stage.pk) - ) + send_mails(self.stage, message) + self.assertEqual(len(mail.outbox), 1) + @patch( + "authentik.stages.authenticator_email.models.AuthenticatorEmailStage.backend_class", + PropertyMock(return_value=EmailBackend), + ) def test_send_mails_authenticator_stage(self): """Test send_mails with AuthenticatorEmailStage""" message = EmailMultiAlternatives() - with patch("authentik.stages.email.tasks.send_mail") as mock_send: - send_mails(self.auth_stage, message) - mock_send.s.assert_called_once_with( - message.__dict__, class_to_path(AuthenticatorEmailStage), str(self.auth_stage.pk) - ) + send_mails(self.auth_stage, message) + self.assertEqual(len(mail.outbox), 1) diff --git a/authentik/tasks/__init__.py b/authentik/tasks/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/authentik/tasks/api/__init__.py b/authentik/tasks/api/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/authentik/tasks/api/tasks.py b/authentik/tasks/api/tasks.py new file mode 100644 index 0000000000..40bea048d7 --- /dev/null +++ b/authentik/tasks/api/tasks.py @@ -0,0 +1,138 @@ +from django_dramatiq_postgres.models import TaskState +from django_filters.filters import BooleanFilter, MultipleChoiceFilter +from django_filters.filterset import FilterSet +from dramatiq.actor import Actor +from dramatiq.broker import get_broker +from dramatiq.errors import ActorNotFound +from dramatiq.message import Message +from drf_spectacular.types import OpenApiTypes +from drf_spectacular.utils import OpenApiResponse, extend_schema +from rest_framework.decorators import action +from rest_framework.fields import ReadOnlyField, SerializerMethodField +from rest_framework.mixins import ( + ListModelMixin, + RetrieveModelMixin, +) +from rest_framework.request import Request +from rest_framework.response import Response +from rest_framework.viewsets import GenericViewSet +from structlog.stdlib import get_logger + +from authentik.core.api.utils import ModelSerializer +from authentik.events.logs import LogEventSerializer +from authentik.rbac.decorators import permission_required +from authentik.tasks.models import Task, TaskStatus +from authentik.tenants.utils import get_current_tenant + +LOGGER = get_logger() + + +class TaskSerializer(ModelSerializer): + rel_obj_app_label = ReadOnlyField(source="rel_obj_content_type.app_label") + rel_obj_model = ReadOnlyField(source="rel_obj_content_type.model") + + messages = LogEventSerializer(many=True, source="_messages") + previous_messages = LogEventSerializer(many=True, source="_previous_messages") + description = SerializerMethodField() + + class Meta: + model = Task + fields = [ + "message_id", + "queue_name", + "actor_name", + "state", + "mtime", + "rel_obj_app_label", + "rel_obj_model", + "rel_obj_id", + "uid", + "messages", + "previous_messages", + "aggregated_status", + "description", + ] + + def get_description(self, instance: Task) -> str | None: + try: + actor: Actor = get_broker().get_actor(instance.actor_name) + except ActorNotFound: + LOGGER.warning("Could not find actor for schedule", schedule=instance) + return None + if "description" not in actor.options: + LOGGER.warning( + "Could not find description for actor", + task=instance, + actor=actor.actor_name, + ) + return None + return actor.options["description"] + + +class TaskFilter(FilterSet): + rel_obj_id__isnull = BooleanFilter("rel_obj_id", "isnull") + aggregated_status = MultipleChoiceFilter( + choices=TaskStatus.choices, + field_name="aggregated_status", + ) + + class Meta: + model = Task + fields = ( + "queue_name", + "actor_name", + "state", + "rel_obj_content_type__app_label", + "rel_obj_content_type__model", + "rel_obj_id", + "rel_obj_id__isnull", + "aggregated_status", + ) + + +class TaskViewSet( + RetrieveModelMixin, + ListModelMixin, + GenericViewSet, +): + queryset = Task.objects.none() + serializer_class = TaskSerializer + search_fields = ( + "message_id", + "queue_name", + "actor_name", + "state", + "rel_obj_content_type__app_label", + "rel_obj_content_type__model", + "rel_obj_id", + "_uid", + "aggregated_status", + ) + filterset_class = TaskFilter + ordering = ("-mtime",) + + def get_queryset(self): + return ( + Task.objects.select_related("rel_obj_content_type") + .defer("message", "result") + .filter(tenant=get_current_tenant()) + ) + + @permission_required(None, ["authentik_tasks.retry_task"]) + @extend_schema( + request=OpenApiTypes.NONE, + responses={ + 204: OpenApiResponse(description="Task retried successfully"), + 400: OpenApiResponse(description="Task is not in a retryable state"), + 404: OpenApiResponse(description="Task not found"), + }, + ) + @action(detail=True, methods=["POST"], permission_classes=[]) + def retry(self, request: Request, pk=None) -> Response: + """Retry task""" + task: Task = self.get_object() + if task.state not in (TaskState.REJECTED, TaskState.DONE): + return Response(status=400) + broker = get_broker() + broker.enqueue(Message.decode(task.message)) + return Response(status=204) diff --git a/authentik/tasks/api/workers.py b/authentik/tasks/api/workers.py new file mode 100644 index 0000000000..d6a855e3f5 --- /dev/null +++ b/authentik/tasks/api/workers.py @@ -0,0 +1,48 @@ +import pglock +from django.utils.timezone import now, timedelta +from drf_spectacular.utils import extend_schema, inline_serializer +from packaging.version import parse +from rest_framework.fields import BooleanField, CharField +from rest_framework.request import Request +from rest_framework.response import Response +from rest_framework.views import APIView + +from authentik import get_full_version +from authentik.rbac.permissions import HasPermission +from authentik.tasks.models import WorkerStatus + + +class WorkerView(APIView): + """Get currently connected worker count.""" + + permission_classes = [HasPermission("authentik_rbac.view_system_info")] + + @extend_schema( + responses=inline_serializer( + "Worker", + fields={ + "worker_id": CharField(), + "version": CharField(), + "version_matching": BooleanField(), + }, + many=True, + ) + ) + def get(self, request: Request) -> Response: + response = [] + our_version = parse(get_full_version()) + for status in WorkerStatus.objects.filter(last_seen__gt=now() - timedelta(minutes=2)): + lock_id = f"goauthentik.io/worker/status/{status.pk}" + with pglock.advisory(lock_id, timeout=0, side_effect=pglock.Return) as acquired: + # The worker doesn't hold the lock, it isn't running + if acquired: + continue + version_matching = parse(status.version) == our_version + response.append( + { + "worker_id": f"{status.pk}@{status.hostname}", + "version": status.version, + "version_matching": version_matching, + } + ) + return Response(response) diff --git a/authentik/tasks/apps.py b/authentik/tasks/apps.py new file mode 100644 index 0000000000..2cdb54e9f2 --- /dev/null +++ b/authentik/tasks/apps.py @@ -0,0 +1,21 @@ +from authentik.blueprints.apps import ManagedAppConfig +from authentik.lib.utils.time import fqdn_rand +from authentik.tasks.schedules.common import ScheduleSpec + + +class AuthentikTasksConfig(ManagedAppConfig): + name = "authentik.tasks" + label = "authentik_tasks" + verbose_name = "authentik Tasks" + default = True + + @property + def global_schedule_specs(self) -> list[ScheduleSpec]: + from authentik.tasks.tasks import clean_worker_statuses + + return [ + ScheduleSpec( + actor=clean_worker_statuses, + crontab=f"{fqdn_rand('clean_worker_statuses')} {fqdn_rand('clean_worker_statuses', 24)} * * *", # noqa: E501 + ), + ] diff --git a/authentik/tasks/broker.py b/authentik/tasks/broker.py new file mode 100644 index 0000000000..7c7ff24cef --- /dev/null +++ b/authentik/tasks/broker.py @@ -0,0 +1,11 @@ +from django.db.models import QuerySet +from django_dramatiq_postgres.broker import PostgresBroker +from structlog.stdlib import get_logger + +LOGGER = get_logger() + + +class Broker(PostgresBroker): + @property + def query_set(self) -> QuerySet: + return super().query_set.select_related("tenant").filter(tenant__ready=True) diff --git a/authentik/tasks/forks.py b/authentik/tasks/forks.py new file mode 100644 index 0000000000..b1689d0e87 --- /dev/null +++ b/authentik/tasks/forks.py @@ -0,0 +1,44 @@ +from signal import pause + +from structlog.stdlib import get_logger + +from authentik.lib.config import CONFIG + +LOGGER = get_logger() + + +def worker_healthcheck(): + import authentik.tasks.setup # noqa + from authentik.tasks.middleware import WorkerHealthcheckMiddleware + + host, _, port = CONFIG.get("listen.listen_http").rpartition(":") + + try: + port = int(port) + except ValueError: + LOGGER.error(f"Invalid port entered: {port}") + + WorkerHealthcheckMiddleware.run(host, port) + pause() + + +def worker_status(): + import authentik.tasks.setup # noqa + from authentik.tasks.middleware import WorkerStatusMiddleware + + WorkerStatusMiddleware.run() + + +def worker_metrics(): + import authentik.tasks.setup # noqa + from authentik.tasks.middleware import MetricsMiddleware + + addr, _, port = CONFIG.get("listen.listen_metrics").rpartition(":") + + try: + port = int(port) + except ValueError: + LOGGER.error(f"Invalid port entered: {port}") + + MetricsMiddleware.run(addr, port) + pause() diff --git a/authentik/tasks/middleware.py b/authentik/tasks/middleware.py new file mode 100644 index 0000000000..f74e075de5 --- /dev/null +++ b/authentik/tasks/middleware.py @@ -0,0 +1,210 @@ +import socket +from http.server import BaseHTTPRequestHandler +from time import sleep +from typing import Any + +import pglock +from django.db import OperationalError, connections +from django.utils.timezone import now +from django_dramatiq_postgres.middleware import HTTPServer +from django_dramatiq_postgres.middleware import MetricsMiddleware as BaseMetricsMiddleware +from django_redis import get_redis_connection +from dramatiq.broker import Broker +from dramatiq.message import Message +from dramatiq.middleware import Middleware +from redis.exceptions import RedisError +from structlog.stdlib import get_logger + +from authentik import get_full_version +from authentik.events.models import Event, EventAction +from authentik.tasks.models import Task, TaskStatus, WorkerStatus +from authentik.tenants.models import Tenant +from authentik.tenants.utils import get_current_tenant + +LOGGER = get_logger() + + +class TenantMiddleware(Middleware): + def before_enqueue(self, broker: Broker, message: Message, delay: int): + message.options["model_defaults"]["tenant"] = get_current_tenant() + + def before_process_message(self, broker: Broker, message: Message): + task: Task = message.options["task"] + task.tenant.activate() + + def after_process_message(self, *args, **kwargs): + Tenant.deactivate() + + after_skip_message = after_process_message + + +class RelObjMiddleware(Middleware): + @property + def actor_options(self): + return {"rel_obj"} + + def before_enqueue(self, broker: Broker, message: Message, delay: int): + message.options["model_defaults"]["rel_obj"] = message.options.pop("rel_obj", None) + + +class MessagesMiddleware(Middleware): + def after_enqueue(self, broker: Broker, message: Message, delay: int): + task: Task = message.options["task"] + task_created: bool = message.options["task_created"] + if task_created: + task._messages.append( + Task._make_message( + str(type(self)), + TaskStatus.INFO, + "Task has been queued", + delay=delay, + ) + ) + else: + task._previous_messages.extend(task._messages) + task._messages = [ + Task._make_message( + str(type(self)), + TaskStatus.INFO, + "Task will be retried", + delay=delay, + ) + ] + task.save(update_fields=("_messages", "_previous_messages")) + + def before_process_message(self, broker: Broker, message: Message): + task: Task = message.options["task"] + task.log(str(type(self)), TaskStatus.INFO, "Task is being processed") + + def after_process_message( + self, + broker: Broker, + message: Message, + *, + result: Any | None = None, + exception: Exception | None = None, + ): + task: Task = message.options["task"] + if exception is None: + task.log(str(type(self)), TaskStatus.INFO, "Task finished processing without errors") + else: + task.log( + str(type(self)), + TaskStatus.ERROR, + exception, + ) + Event.new( + EventAction.SYSTEM_TASK_EXCEPTION, + message=f"Task {task.actor_name} encountered an error", + actor=task.actor_name, + ).with_exception(exception).save() + + def after_skip_message(self, broker: Broker, message: Message): + task: Task = message.options["task"] + task.log(str(type(self)), TaskStatus.INFO, "Task has been skipped") + + +class LoggingMiddleware(Middleware): + def __init__(self): + self.logger = get_logger() + + def after_enqueue(self, broker: Broker, message: Message, delay: int): + self.logger.info( + "Task enqueued", + task_id=message.message_id, + task_name=message.actor_name, + ) + + def before_process_message(self, broker: Broker, message: Message): + self.logger.info("Task started", task_id=message.message_id, task_name=message.actor_name) + + def after_process_message( + self, + broker: Broker, + message: Message, + *, + result: Any | None = None, + exception: Exception | None = None, + ): + self.logger.info( + "Task finished", + task_id=message.message_id, + task_name=message.actor_name, + exc=exception, + ) + + def after_skip_message(self, broker: Broker, message: Message): + self.logger.info("Task skipped", task_id=message.message_id, task_name=message.actor_name) + + +class DescriptionMiddleware(Middleware): + @property + def actor_options(self): + return {"description"} + + +class _healthcheck_handler(BaseHTTPRequestHandler): + def do_HEAD(self): + try: + for db_conn in connections.all(): + # Force connection reload + db_conn.connect() + _ = db_conn.cursor() + redis_conn = get_redis_connection() + redis_conn.ping() + self.send_response(200) + except (OperationalError, RedisError): # pragma: no cover + self.send_response(503) + self.send_header("Content-Type", "text/plain; charset=utf-8") + self.send_header("Content-Length", "0") + self.end_headers() + + do_GET = do_HEAD + + +class WorkerHealthcheckMiddleware(Middleware): + @property + def forks(self): + from authentik.tasks.forks import worker_healthcheck + + return [worker_healthcheck] + + @staticmethod + def run(addr: str, port: int): + try: + httpd = HTTPServer((addr, port), _healthcheck_handler) + httpd.serve_forever() + except OSError as exc: + get_logger(__name__, type(WorkerHealthcheckMiddleware)).warning( + "Port is already in use, not starting healthcheck server", + exc=exc, + ) + + +class WorkerStatusMiddleware(Middleware): + @property + def forks(self): + from authentik.tasks.forks import worker_status + + return [worker_status] + + @staticmethod + def run(): + status = WorkerStatus.objects.create( + hostname=socket.gethostname(), + version=get_full_version(), + ) + lock_id = f"goauthentik.io/worker/status/{status.pk}" + with pglock.advisory(lock_id, side_effect=pglock.Raise): + while True: + status.last_seen = now() + status.save(update_fields=("last_seen",)) + sleep(30) + + +class MetricsMiddleware(BaseMetricsMiddleware): + @property + def forks(self): + from authentik.tasks.forks import worker_metrics + + return [worker_metrics] diff --git a/authentik/tasks/migrations/0001_initial.py b/authentik/tasks/migrations/0001_initial.py new file mode 100644 index 0000000000..c18effb509 --- /dev/null +++ b/authentik/tasks/migrations/0001_initial.py @@ -0,0 +1,147 @@ +# Generated by Django 5.1.11 on 2025-06-25 14:50 + +import django.db.models.deletion +import django.utils.timezone +import pgtrigger.compiler +import pgtrigger.migrations +import uuid +from django.db import migrations, models + + +class Migration(migrations.Migration): + + initial = True + + dependencies = [ + ("authentik_tenants", "0005_tenant_reputation_lower_limit_and_more"), + ("contenttypes", "0002_remove_content_type_name"), + ] + + operations = [ + migrations.CreateModel( + name="WorkerStatus", + fields=[ + ("id", models.UUIDField(default=uuid.uuid4, primary_key=True, serialize=False)), + ("hostname", models.TextField()), + ("version", models.TextField()), + ("last_seen", models.DateTimeField(auto_now_add=True)), + ], + options={ + "verbose_name": "Worker status", + "verbose_name_plural": "Worker statuses", + "default_permissions": [], + }, + ), + migrations.CreateModel( + name="Task", + fields=[ + ( + "message_id", + models.UUIDField(default=uuid.uuid4, primary_key=True, serialize=False), + ), + ("queue_name", models.TextField(default="default", help_text="Queue name")), + ("actor_name", models.TextField(help_text="Dramatiq actor name")), + ("message", models.BinaryField(help_text="Message body", null=True)), + ( + "state", + models.CharField( + choices=[ + ("queued", "Queued"), + ("consumed", "Consumed"), + ("rejected", "Rejected"), + ("done", "Done"), + ], + default="queued", + help_text="Task status", + ), + ), + ( + "mtime", + models.DateTimeField( + default=django.utils.timezone.now, help_text="Task last modified time" + ), + ), + ("result", models.BinaryField(help_text="Task result", null=True)), + ("result_expiry", models.DateTimeField(help_text="Result expiry time", null=True)), + ("rel_obj_id", models.TextField(null=True)), + ("_uid", models.TextField(blank=True, null=True)), + ("_messages", models.JSONField(default=list)), + ("_previous_messages", models.JSONField(default=list)), + ( + "aggregated_status", + models.TextField( + choices=[ + ("queued", "Queued"), + ("consumed", "Consumed"), + ("rejected", "Rejected"), + ("done", "Done"), + ("info", "Info"), + ("warning", "Warning"), + ("error", "Error"), + ] + ), + ), + ( + "rel_obj_content_type", + models.ForeignKey( + null=True, + on_delete=django.db.models.deletion.CASCADE, + to="contenttypes.contenttype", + ), + ), + ( + "tenant", + models.ForeignKey( + help_text="Tenant this task belongs to", + on_delete=django.db.models.deletion.CASCADE, + to="authentik_tenants.tenant", + ), + ), + ], + options={ + "verbose_name": "Task", + "verbose_name_plural": "Tasks", + "permissions": [("retry_task", "Retry failed task")], + "abstract": False, + "default_permissions": ("view",), + "indexes": [ + models.Index(fields=["state", "mtime"], name="authentik_t_state_bb4a31_idx"), + models.Index( + fields=["rel_obj_content_type", "rel_obj_id"], + name="authentik_t_rel_obj_3a177a_idx", + ), + ], + }, + ), + pgtrigger.migrations.AddTrigger( + model_name="task", + trigger=pgtrigger.compiler.Trigger( + name="notify_enqueueing", + sql=pgtrigger.compiler.UpsertTriggerSql( + condition="WHEN (NEW.\"state\" = 'queued')", + constraint="CONSTRAINT", + func="\n PERFORM pg_notify(\n 'authentik.tasks.' || NEW.queue_name || '.enqueue',\n NEW.message_id::text\n );\n RETURN NEW;\n ", + hash="0a9ee3db61e4d63fd72b31322fbb821706dd8a78", + operation="INSERT OR UPDATE", + pgid="pgtrigger_notify_enqueueing_0bc94", + table="authentik_tasks_task", + timing="DEFERRABLE INITIALLY DEFERRED", + when="AFTER", + ), + ), + ), + pgtrigger.migrations.AddTrigger( + model_name="task", + trigger=pgtrigger.compiler.Trigger( + name="update_aggregated_status", + sql=pgtrigger.compiler.UpsertTriggerSql( + func="\n NEW.aggregated_status := CASE\n WHEN NEW.state != 'done' THEN NEW.state\n ELSE COALESCE((\n SELECT CASE\n WHEN bool_or(msg->>'log_level' = 'error') THEN 'error'\n WHEN bool_or(msg->>'log_level' = 'warning') THEN 'warning'\n WHEN bool_or(msg->>'log_level' = 'info') THEN 'info'\n ELSE 'done'\n END\n FROM jsonb_array_elements(NEW._messages) AS msg\n ), 'done')\n END;\n\n RETURN NEW;\n ", + hash="ebc09bc08c1624966c0c58a52f243fe25a842058", + operation="INSERT OR UPDATE", + pgid="pgtrigger_update_aggregated_status_f18c4", + table="authentik_tasks_task", + when="BEFORE", + ), + ), + ), + ] diff --git a/authentik/tasks/migrations/__init__.py b/authentik/tasks/migrations/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/authentik/tasks/models.py b/authentik/tasks/models.py new file mode 100644 index 0000000000..85d4c3f99b --- /dev/null +++ b/authentik/tasks/models.py @@ -0,0 +1,169 @@ +from typing import Any +from uuid import UUID, uuid4 + +import pgtrigger +from django.contrib.contenttypes.fields import ContentType, GenericForeignKey, GenericRelation +from django.db import models +from django.utils.translation import gettext_lazy as _ +from django_dramatiq_postgres.models import TaskBase, TaskState + +from authentik.events.logs import LogEvent +from authentik.events.utils import sanitize_item +from authentik.lib.models import SerializerModel +from authentik.lib.utils.errors import exception_to_dict +from authentik.tenants.models import Tenant + + +class TaskStatus(models.TextChoices): + """Task aggregated status. Reported by the task runners""" + + QUEUED = TaskState.QUEUED + CONSUMED = TaskState.CONSUMED + REJECTED = TaskState.REJECTED + DONE = TaskState.DONE + INFO = "info" + WARNING = "warning" + ERROR = "error" + + +class Task(SerializerModel, TaskBase): + tenant = models.ForeignKey( + Tenant, + on_delete=models.CASCADE, + help_text=_("Tenant this task belongs to"), + ) + + rel_obj_content_type = models.ForeignKey(ContentType, on_delete=models.CASCADE, null=True) + rel_obj_id = models.TextField(null=True) + rel_obj = GenericForeignKey("rel_obj_content_type", "rel_obj_id") + + _uid = models.TextField(blank=True, null=True) + _messages = models.JSONField(default=list) + _previous_messages = models.JSONField(default=list) + + aggregated_status = models.TextField(choices=TaskStatus.choices) + + class Meta(TaskBase.Meta): + default_permissions = ("view",) + permissions = [ + ("retry_task", _("Retry failed task")), + ] + indexes = TaskBase.Meta.indexes + ( + models.Index(fields=("rel_obj_content_type", "rel_obj_id")), + ) + triggers = TaskBase.Meta.triggers + ( + pgtrigger.Trigger( + name="update_aggregated_status", + operation=pgtrigger.Insert | pgtrigger.Update, + when=pgtrigger.Before, + func=f""" + NEW.aggregated_status := CASE + WHEN NEW.state != '{TaskState.DONE.value}' THEN NEW.state + ELSE COALESCE(( + SELECT CASE + WHEN bool_or(msg->>'log_level' = 'error') THEN 'error' + WHEN bool_or(msg->>'log_level' = 'warning') THEN 'warning' + WHEN bool_or(msg->>'log_level' = 'info') THEN 'info' + ELSE '{TaskState.DONE.value}' + END + FROM jsonb_array_elements(NEW._messages) AS msg + ), '{TaskState.DONE.value}') + END; + + RETURN NEW; + """, # nosec + ), + ) + + @property + def uid(self) -> str: + uid = str(self.actor_name) + if self._uid: + uid += f":{self._uid}" + return uid + + @property + def serializer(self): + from authentik.tasks.api.tasks import TaskSerializer + + return TaskSerializer + + def set_uid(self, uid: str | UUID, save: bool = False): + self._uid = str(uid) + if save: + self.save() + + @classmethod + def _make_message( + cls, logger: str, log_level: TaskStatus, message: str | Exception, **attributes + ) -> dict[str, Any]: + if isinstance(message, Exception): + attributes = { + "exception": exception_to_dict(message), + **attributes, + } + message = str(message) + log = LogEvent( + message, + logger=logger, + log_level=log_level.value, + attributes=attributes, + ) + return sanitize_item(log) + + def logs(self, logs: list[LogEvent]): + for log in logs: + self._messages.append(sanitize_item(log)) + + def log( + self, + logger: str, + log_level: TaskStatus, + message: str | Exception, + save: bool = False, + **attributes, + ): + self._messages: list + self._messages.append( + self._make_message( + logger, + log_level, + message, + **attributes, + ) + ) + if save: + self.save() + + def info(self, message: str | Exception, save: bool = False, **attributes): + self.log(self.uid, TaskStatus.INFO, message, save=save, **attributes) + + def warning(self, message: str | Exception, save: bool = False, **attributes): + self.log(self.uid, TaskStatus.WARNING, message, save=save, **attributes) + + def error(self, message: str | Exception, save: bool = False, **attributes): + self.log(self.uid, TaskStatus.ERROR, message, save=save, **attributes) + + +class TasksModel(models.Model): + tasks = GenericRelation( + Task, content_type_field="rel_obj_content_type", object_id_field="rel_obj_id" + ) + + class Meta: + abstract = True + + +class WorkerStatus(models.Model): + id = models.UUIDField(primary_key=True, default=uuid4) + hostname = models.TextField() + version = models.TextField() + last_seen = models.DateTimeField(auto_now_add=True) + + class Meta: + default_permissions = [] + verbose_name = _("Worker status") + verbose_name_plural = _("Worker statuses") + + def __str__(self): + return f"{self.id} - {self.hostname} - {self.version} - {self.last_seen}" diff --git a/authentik/tasks/schedules/__init__.py b/authentik/tasks/schedules/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/authentik/tasks/schedules/api.py b/authentik/tasks/schedules/api.py new file mode 100644 index 0000000000..42eb9640bc --- /dev/null +++ b/authentik/tasks/schedules/api.py @@ -0,0 +1,133 @@ +from django_filters.filters import BooleanFilter +from django_filters.filterset import FilterSet +from dramatiq.actor import Actor +from dramatiq.broker import get_broker +from dramatiq.errors import ActorNotFound +from drf_spectacular.types import OpenApiTypes +from drf_spectacular.utils import OpenApiResponse, extend_schema +from rest_framework.decorators import action +from rest_framework.fields import ReadOnlyField +from rest_framework.mixins import ( + ListModelMixin, + RetrieveModelMixin, + UpdateModelMixin, +) +from rest_framework.request import Request +from rest_framework.response import Response +from rest_framework.serializers import SerializerMethodField +from rest_framework.viewsets import GenericViewSet +from structlog.stdlib import get_logger + +from authentik.core.api.utils import ModelSerializer +from authentik.rbac.decorators import permission_required +from authentik.tasks.models import Task, TaskStatus +from authentik.tasks.schedules.models import Schedule + +LOGGER = get_logger() + + +class ScheduleSerializer(ModelSerializer): + rel_obj_app_label = ReadOnlyField(source="rel_obj_content_type.app_label") + rel_obj_model = ReadOnlyField(source="rel_obj_content_type.model") + + description = SerializerMethodField() + last_task_status = SerializerMethodField() + + class Meta: + model = Schedule + fields = ( + "id", + "identifier", + "uid", + "actor_name", + "rel_obj_app_label", + "rel_obj_model", + "rel_obj_id", + "crontab", + "paused", + "next_run", + "description", + "last_task_status", + ) + + def get_description(self, instance: Schedule) -> str | None: + try: + actor: Actor = get_broker().get_actor(instance.actor_name) + except ActorNotFound: + LOGGER.warning("Could not find actor for schedule", schedule=instance) + return None + if "description" not in actor.options: + LOGGER.warning( + "Could not find description for actor", + schedule=instance, + actor=actor.actor_name, + ) + return None + return actor.options["description"] + + def get_last_task_status(self, instance: Schedule) -> TaskStatus | None: + last_task: Task = instance.tasks.defer("message", "result").order_by("-mtime").first() + if last_task: + return last_task.aggregated_status + return None + + +class ScheduleFilter(FilterSet): + rel_obj_id__isnull = BooleanFilter("rel_obj_id", "isnull") + + class Meta: + model = Schedule + fields = ( + "actor_name", + "rel_obj_content_type__app_label", + "rel_obj_content_type__model", + "rel_obj_id", + "rel_obj_id__isnull", + "paused", + ) + + +class ScheduleViewSet( + RetrieveModelMixin, + UpdateModelMixin, + ListModelMixin, + GenericViewSet, +): + queryset = ( + Schedule.objects.select_related("rel_obj_content_type") + .defer("args", "kwargs", "options") + .all() + ) + serializer_class = ScheduleSerializer + search_fields = ( + "id", + "identifier", + "_uid", + "actor_name", + "rel_obj_content_type__app_label", + "rel_obj_content_type__model", + "rel_obj_id", + "description", + ) + filterset_class = ScheduleFilter + ordering = ( + "next_run", + "actor_name", + "identifier", + ) + + @permission_required("authentik_tasks_schedules.send_schedule") + @extend_schema( + request=OpenApiTypes.NONE, + responses={ + 204: OpenApiResponse(description="Schedule sent successfully"), + 404: OpenApiResponse(description="Schedule not found"), + 500: OpenApiResponse(description="Failed to send schedule"), + }, + ) + @action(detail=True, pagination_class=None, filter_backends=[], methods=["POST"]) + def send(self, request: Request, pk=None) -> Response: + """Trigger this schedule now""" + schedule: Schedule = self.get_object() + schedule.send() + return Response({}) diff --git a/authentik/tasks/schedules/apps.py b/authentik/tasks/schedules/apps.py new file mode 100644 index 0000000000..4d3dddf871 --- /dev/null +++ b/authentik/tasks/schedules/apps.py @@ -0,0 +1,51 @@ +from authentik.blueprints.apps import ManagedAppConfig +from authentik.lib.utils.reflection import get_apps +from authentik.tasks.schedules.common import ScheduleSpec + + +class AuthentikTasksSchedulesConfig(ManagedAppConfig): + name = "authentik.tasks.schedules" + label = "authentik_tasks_schedules" + verbose_name = "authentik Tasks Schedules" + default = True + + @property + def tenant_schedule_specs(self) -> list[ScheduleSpec]: + from authentik.tasks.schedules.models import ScheduledModel + + schedules = [] + for Model in ScheduledModel.models(): + for obj in Model.objects.all(): + for spec in obj.schedule_specs: + spec.rel_obj = obj + spec.identifier = obj.pk + schedules.append(spec) + return schedules + + def _reconcile_schedules(self, specs: list[ScheduleSpec]): + from django.db import transaction + + from authentik.tasks.schedules.models import Schedule + + schedules_to_send = [] + with transaction.atomic(): + pks_to_keep = [] + for spec in specs: + schedule = spec.update_or_create() + pks_to_keep.append(schedule.pk) + if spec.send_on_startup: + schedules_to_send.append(schedule) + Schedule.objects.exclude(pk__in=pks_to_keep).delete() + for schedule in schedules_to_send: + schedule.send() + + @ManagedAppConfig.reconcile_tenant + def reconcile_tenant_schedules(self): + from authentik.tenants.utils import get_current_tenant, get_public_schema_name + + schedule_specs = [] + for app in get_apps(): + schedule_specs.extend(app.tenant_schedule_specs) + if get_current_tenant().schema_name == get_public_schema_name(): + schedule_specs.extend(app.global_schedule_specs) + self._reconcile_schedules(schedule_specs) diff --git a/authentik/tasks/schedules/common.py b/authentik/tasks/schedules/common.py new file mode 100644 index 0000000000..9d93c9787c --- /dev/null +++ b/authentik/tasks/schedules/common.py @@ -0,0 +1,66 @@ +import pickle # nosec +from collections.abc import Iterable +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Any + +from dramatiq.actor import Actor + +if TYPE_CHECKING: + from authentik.tasks.schedules.models import Schedule + + +@dataclass +class ScheduleSpec: + actor: Actor + crontab: str + paused: bool = False + identifier: str | None = None + uid: str | None = None + + args: Iterable[Any] = field(default_factory=tuple) + kwargs: dict[str, Any] = field(default_factory=dict) + options: dict[str, Any] = field(default_factory=dict) + + rel_obj: Any | None = None + + send_on_save: bool = False + + send_on_startup: bool = False + + def get_args(self) -> bytes: + return pickle.dumps(self.args) + + def get_kwargs(self) -> bytes: + return pickle.dumps(self.kwargs) + + def get_options(self) -> bytes: + return pickle.dumps(self.options) + + def update_or_create(self) -> "Schedule": + from authentik.tasks.schedules.models import Schedule + + query = { + "actor_name": self.actor.actor_name, + "identifier": self.identifier, + } + defaults = { + **query, + "_uid": self.uid, + "paused": self.paused, + "args": self.get_args(), + "kwargs": self.get_kwargs(), + "options": self.get_options(), + } + create_defaults = { + **defaults, + "crontab": self.crontab, + "rel_obj": self.rel_obj, + } + + schedule, _ = Schedule.objects.update_or_create( + **query, + defaults=defaults, + create_defaults=create_defaults, + ) + + return schedule diff --git a/authentik/tasks/schedules/migrations/0001_initial.py b/authentik/tasks/schedules/migrations/0001_initial.py new file mode 100644 index 0000000000..ff75e50dd3 --- /dev/null +++ b/authentik/tasks/schedules/migrations/0001_initial.py @@ -0,0 +1,97 @@ +# Generated by Django 5.1.11 on 2025-07-07 16:01 + +import django.db.models.deletion +import django_dramatiq_postgres.models +import pgtrigger.compiler +import pgtrigger.migrations +import uuid +from django.db import migrations, models + + +class Migration(migrations.Migration): + + initial = True + + dependencies = [ + ("contenttypes", "0002_remove_content_type_name"), + ] + + operations = [ + migrations.CreateModel( + name="Schedule", + fields=[ + ( + "id", + models.UUIDField( + default=uuid.uuid4, editable=False, primary_key=True, serialize=False + ), + ), + ( + "actor_name", + models.TextField(editable=False, help_text="Dramatiq actor to call"), + ), + ("args", models.BinaryField(help_text="Args to send to the actor")), + ("kwargs", models.BinaryField(help_text="Kwargs to send to the actor")), + ("options", models.BinaryField(help_text="Options to send to the actor")), + ( + "crontab", + models.TextField( + help_text="When to schedule tasks", + validators=[django_dramatiq_postgres.models.validate_crontab], + ), + ), + ("paused", models.BooleanField(default=False, help_text="Pause this schedule")), + ("next_run", models.DateTimeField(auto_now_add=True)), + ( + "identifier", + models.TextField( + editable=False, + help_text="Unique schedule identifier", + null=True, + unique=True, + ), + ), + ( + "_uid", + models.TextField(blank=True, help_text="User schedule identifier", null=True), + ), + ("rel_obj_id", models.TextField(null=True)), + ( + "rel_obj_content_type", + models.ForeignKey( + null=True, + on_delete=django.db.models.deletion.CASCADE, + to="contenttypes.contenttype", + ), + ), + ], + options={ + "verbose_name": "Schedule", + "verbose_name_plural": "Schedules", + "permissions": [("send_schedule", "Manually trigger a schedule")], + "abstract": False, + "default_permissions": ("change", "view"), + "indexes": [ + models.Index( + fields=["rel_obj_content_type", "rel_obj_id"], + name="authentik_t_rel_obj_575af2_idx", + ) + ], + }, + ), + pgtrigger.migrations.AddTrigger( + model_name="schedule", + trigger=pgtrigger.compiler.Trigger( + name="set_next_run_on_paused", + sql=pgtrigger.compiler.UpsertTriggerSql( + condition='WHEN (NEW."paused" AND NOT OLD."paused")', + func="\n NEW.next_run = to_timestamp(0);\n RETURN NEW;\n ", + hash="7fe580a86de70723522cfcbac712785984000f92", + operation="UPDATE", + pgid="pgtrigger_set_next_run_on_paused_95c6d", + table="authentik_tasks_schedules_schedule", + when="BEFORE", + ), + ), + ), + ] diff --git a/authentik/tasks/schedules/migrations/__init__.py b/authentik/tasks/schedules/migrations/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/authentik/tasks/schedules/models.py b/authentik/tasks/schedules/models.py new file mode 100644 index 0000000000..b380d74e35 --- /dev/null +++ b/authentik/tasks/schedules/models.py @@ -0,0 +1,73 @@ +from django.apps import apps +from django.contrib.contenttypes.fields import GenericForeignKey, GenericRelation +from django.contrib.contenttypes.models import ContentType +from django.db import models +from django.utils.translation import gettext_lazy as _ +from django_dramatiq_postgres.models import ScheduleBase + +from authentik.lib.models import SerializerModel +from authentik.tasks.models import TasksModel +from authentik.tasks.schedules.common import ScheduleSpec + + +class Schedule(TasksModel, SerializerModel, ScheduleBase): + identifier = models.TextField( + unique=True, editable=False, null=True, help_text=_("Unique schedule identifier") + ) + _uid = models.TextField(blank=True, null=True, help_text=_("User schedule identifier")) + + rel_obj_content_type = models.ForeignKey(ContentType, on_delete=models.CASCADE, null=True) + rel_obj_id = models.TextField(null=True) + rel_obj = GenericForeignKey("rel_obj_content_type", "rel_obj_id") + + class Meta(ScheduleBase.Meta): + default_permissions = ( + "change", + "view", + ) + permissions = [ + ("send_schedule", _("Manually trigger a schedule")), + ] + indexes = (models.Index(fields=("rel_obj_content_type", "rel_obj_id")),) + + def __str__(self): + return f"Schedule {self.actor_name}:{self.uid}" + + @property + def uid(self) -> str: + uid = str(self.actor_name) + if self._uid: + uid += f":{self._uid}" + return uid + + @property + def serializer(self): + from authentik.tasks.schedules.api import ScheduleSerializer + + return ScheduleSerializer + + +class ScheduledModel(TasksModel, models.Model): + schedules = GenericRelation( + Schedule, content_type_field="rel_obj_content_type", object_id_field="rel_obj_id" + ) + + class Meta: + abstract = True + + @classmethod + def models(cls) -> list[models.Model]: + def is_scheduled_model(klass) -> bool: + if ScheduledModel in klass.__bases__: + return True + return any(is_scheduled_model(klass) for klass in klass.__bases__) + + return [ + model + for model in apps.get_models() + if is_scheduled_model(model) and not model.__subclasses__() + ] + + @property + def schedule_specs(self) -> list[ScheduleSpec]: + raise NotImplementedError diff --git a/authentik/tasks/schedules/scheduler.py b/authentik/tasks/schedules/scheduler.py new file mode 100644 index 0000000000..ec4b486d12 --- /dev/null +++ b/authentik/tasks/schedules/scheduler.py @@ -0,0 +1,26 @@ +import pglock +from django_dramatiq_postgres.scheduler import Scheduler as SchedulerBase +from structlog.stdlib import get_logger + +from authentik.tenants.models import Tenant + +LOGGER = get_logger() + + +class Scheduler(SchedulerBase): + def _lock(self, tenant: Tenant) -> pglock.advisory: + return pglock.advisory( + lock_id=f"authentik.scheduler/{tenant.schema_name}", + side_effect=pglock.Return, + timeout=0, + ) + + def run(self): + for tenant in Tenant.objects.filter(ready=True): + with tenant: + with self._lock(tenant) as lock_acquired: + if not lock_acquired: + self.logger.debug("Could not acquire lock, skipping scheduling") + return + count = self._run() + self.logger.info(f"Sent {count} scheduled tasks") diff --git a/authentik/tasks/schedules/signals.py b/authentik/tasks/schedules/signals.py new file mode 100644 index 0000000000..87b3810306 --- /dev/null +++ b/authentik/tasks/schedules/signals.py @@ -0,0 +1,18 @@ +from django.conf import settings +from django.db.models.signals import post_save +from django.dispatch import receiver + +from authentik.tasks.schedules.models import ScheduledModel + + +@receiver(post_save) +def post_save_scheduled_model(sender, instance, **_): + if not isinstance(instance, ScheduledModel): + return + if settings.TEST: + return + for spec in instance.schedule_specs: + spec.rel_obj = instance + schedule = spec.update_or_create() + if spec.send_on_save: + schedule.send() diff --git a/authentik/tasks/schedules/urls.py b/authentik/tasks/schedules/urls.py new file mode 100644 index 0000000000..5772b32d9b --- /dev/null +++ b/authentik/tasks/schedules/urls.py @@ -0,0 +1,5 @@ +from authentik.tasks.schedules.api import ScheduleViewSet + +api_urlpatterns = [ + ("tasks/schedules", ScheduleViewSet), +] diff --git a/authentik/tasks/setup.py b/authentik/tasks/setup.py new file mode 100644 index 0000000000..919575726e --- /dev/null +++ b/authentik/tasks/setup.py @@ -0,0 +1,14 @@ +from authentik.root.setup import setup + +setup() + +import django # noqa: E402 + +django.setup() + +from authentik.root.signals import post_startup, pre_startup, startup # noqa: E402 + +_startup_sender = type("WorkerStartup", (object,), {}) +pre_startup.send(sender=_startup_sender) +startup.send(sender=_startup_sender) +post_startup.send(sender=_startup_sender) diff --git a/authentik/tasks/signals.py b/authentik/tasks/signals.py new file mode 100644 index 0000000000..226a6e447b --- /dev/null +++ b/authentik/tasks/signals.py @@ -0,0 +1,45 @@ +"""admin signals""" + +import pglock +from django.dispatch import receiver +from django.utils.timezone import now, timedelta +from packaging.version import parse +from prometheus_client import Gauge + +from authentik import get_full_version +from authentik.root.monitoring import monitoring_set +from authentik.tasks.models import WorkerStatus + +OLD_GAUGE_WORKERS = Gauge( + "authentik_admin_workers", + "Currently connected workers, their versions and if they are the same version as authentik", + ["version", "version_matched"], +) +GAUGE_WORKERS = Gauge( + "authentik_tasks_workers", + "Currently connected workers, their versions and if they are the same version as authentik", + ["version", "version_matched"], +) + + +_version = parse(get_full_version()) + + +@receiver(monitoring_set) +def monitoring_set_workers(sender, **kwargs): + """Set worker gauge""" + worker_version_count = {} + for status in WorkerStatus.objects.filter(last_seen__gt=now() - timedelta(minutes=2)): + lock_id = f"goauthentik.io/worker/status/{status.pk}" + with pglock.advisory(lock_id, timeout=0, side_effect=pglock.Return) as acquired: + # The worker doesn't hold the lock, it isn't running + if acquired: + continue + version_matching = parse(status.version) == _version + worker_version_count.setdefault( + status.version, {"count": 0, "matching": version_matching} + ) + worker_version_count[status.version]["count"] += 1 + for version, stats in worker_version_count.items(): + OLD_GAUGE_WORKERS.labels(version, stats["matching"]).set(stats["count"]) + GAUGE_WORKERS.labels(version, stats["matching"]).set(stats["count"]) diff --git a/authentik/tasks/tasks.py b/authentik/tasks/tasks.py new file mode 100644 index 0000000000..59ef5d5ce3 --- /dev/null +++ b/authentik/tasks/tasks.py @@ -0,0 +1,10 @@ +from django.utils.timezone import now, timedelta +from django.utils.translation import gettext_lazy as _ +from dramatiq import actor + +from authentik.tasks.models import WorkerStatus + + +@actor(description=_("Remove old worker statuses.")) +def clean_worker_statuses(): + WorkerStatus.objects.filter(last_seen__lt=now() - timedelta(days=1)).delete() diff --git a/authentik/tasks/test.py b/authentik/tasks/test.py new file mode 100644 index 0000000000..168b44ad75 --- /dev/null +++ b/authentik/tasks/test.py @@ -0,0 +1,82 @@ +from queue import PriorityQueue + +import dramatiq +from django.utils.module_loading import import_string +from django_dramatiq_postgres.conf import Conf +from dramatiq.broker import Broker, MessageProxy, get_broker +from dramatiq.middleware.retries import Retries +from dramatiq.results.middleware import Results +from dramatiq.worker import Worker, _ConsumerThread, _WorkerThread + +from authentik.tasks.broker import PostgresBroker +from authentik.tasks.middleware import MetricsMiddleware + + +class TestWorker(Worker): + def __init__(self, queue_name: str, broker: Broker): + super().__init__(broker=broker) + self.work_queue = PriorityQueue() + self.consumers = { + queue_name: _ConsumerThread( + broker=self.broker, + queue_name=queue_name, + prefetch=2, + work_queue=self.work_queue, + worker_timeout=1, + ), + } + self.consumers[queue_name].consumer = self.broker.consume( + queue_name=queue_name, + prefetch=2, + timeout=1, + ) + self._worker = _WorkerThread( + broker=self.broker, + consumers=self.consumers, + work_queue=self.work_queue, + worker_timeout=1, + ) + + self.broker.emit_before("worker_boot", self) + self.broker.emit_after("worker_boot", self) + + def process_message(self, message: MessageProxy): + self.work_queue.put(message) + self.consumers[message.queue_name].consumer.in_processing.add(message.message_id) + self._worker.process_message(message) + + +class TestBroker(PostgresBroker): + def enqueue(self, *args, **kwargs): + message = super().enqueue(*args, **kwargs) + worker = TestWorker(message.queue_name, broker=self) + worker.process_message(MessageProxy(message)) + return message + + +def use_test_broker(): + old_broker = get_broker() + + broker = TestBroker() + + for actor_name in old_broker.get_declared_actors(): + actor = old_broker.get_actor(actor_name) + actor.broker = broker + actor.broker.declare_actor(actor) + + for middleware_class, middleware_kwargs in Conf().middlewares: + middleware: dramatiq.middleware.middleware.Middleware = import_string(middleware_class)( + **middleware_kwargs, + ) + if isinstance(middleware, MetricsMiddleware): + continue + if isinstance(middleware, Retries): + middleware.max_retries = 0 + if isinstance(middleware, Results): + middleware.backend = import_string(Conf().result_backend)( + *Conf().result_backend_args, + **Conf().result_backend_kwargs, + ) + broker.add_middleware(middleware) + + dramatiq.set_broker(broker) diff --git a/authentik/tasks/tests/__init__.py b/authentik/tasks/tests/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/authentik/tasks/tests/test_actors.py b/authentik/tasks/tests/test_actors.py new file mode 100644 index 0000000000..8e2f6cb2c3 --- /dev/null +++ b/authentik/tasks/tests/test_actors.py @@ -0,0 +1,10 @@ +from django.test import TestCase +from dramatiq.broker import get_broker + + +class TestActors(TestCase): + def test_all_actors_have_description(self): + broker = get_broker() + for actor_name in broker.get_declared_actors(): + actor = broker.get_actor(actor_name) + self.assertIn("description", actor.options) diff --git a/authentik/tasks/tests/test_api.py b/authentik/tasks/tests/test_api.py new file mode 100644 index 0000000000..2f2b89c0ef --- /dev/null +++ b/authentik/tasks/tests/test_api.py @@ -0,0 +1,26 @@ +from json import loads + +from django.test import TestCase +from django.urls import reverse + +from authentik.core.models import Group, User +from authentik.lib.generators import generate_id + + +class TestAdminAPI(TestCase): + """test admin api""" + + def setUp(self) -> None: + super().setUp() + self.user = User.objects.create(username=generate_id()) + self.group = Group.objects.create(name=generate_id(), is_superuser=True) + self.group.users.add(self.user) + self.group.save() + self.client.force_login(self.user) + + def test_workers(self): + """Test Workers API""" + response = self.client.get(reverse("authentik_api:tasks_workers")) + self.assertEqual(response.status_code, 200) + body = loads(response.content) + self.assertEqual(len(body), 0) diff --git a/authentik/tasks/urls.py b/authentik/tasks/urls.py new file mode 100644 index 0000000000..d48ef1597d --- /dev/null +++ b/authentik/tasks/urls.py @@ -0,0 +1,9 @@ +from django.urls import path + +from authentik.tasks.api.tasks import TaskViewSet +from authentik.tasks.api.workers import WorkerView + +api_urlpatterns = [ + ("tasks/tasks", TaskViewSet), + path("tasks/workers", WorkerView.as_view(), name="tasks_workers"), +] diff --git a/authentik/tasks/worker.py b/authentik/tasks/worker.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/authentik/tenants/migrations/0001_initial.py b/authentik/tenants/migrations/0001_initial.py index eda3e7c917..e4eac2f28b 100644 --- a/authentik/tenants/migrations/0001_initial.py +++ b/authentik/tenants/migrations/0001_initial.py @@ -5,6 +5,7 @@ import uuid import django.db.models.deletion import django_tenants.postgresql_backend.base from django.db import migrations, models +from django_tenants.utils import get_tenant_base_schema import authentik.lib.utils.time import authentik.tenants.models @@ -144,7 +145,7 @@ class Migration(migrations.Migration): ), migrations.RunPython(code=create_default_tenant, reverse_code=migrations.RunPython.noop), migrations.RunSQL( - sql="CREATE SCHEMA IF NOT EXISTS template;", - reverse_sql="DROP SCHEMA IF EXISTS template;", + sql=f"CREATE SCHEMA IF NOT EXISTS {get_tenant_base_schema()};", + reverse_sql=f"DROP SCHEMA IF EXISTS {get_tenant_base_schema()};", ), ] diff --git a/authentik/tenants/models.py b/authentik/tenants/models.py index bc3f9a464c..629c917d25 100644 --- a/authentik/tenants/models.py +++ b/authentik/tenants/models.py @@ -4,6 +4,7 @@ import re from uuid import uuid4 from django.apps import apps +from django.conf import settings from django.core.exceptions import ValidationError from django.core.validators import MaxValueValidator, MinValueValidator from django.db import models @@ -11,6 +12,7 @@ from django.db.utils import IntegrityError from django.dispatch import receiver from django.utils.translation import gettext_lazy as _ from django_tenants.models import DomainMixin, TenantMixin, post_schema_sync +from django_tenants.utils import get_tenant_base_schema from rest_framework.serializers import Serializer from structlog.stdlib import get_logger @@ -113,8 +115,8 @@ class Tenant(TenantMixin, SerializerModel): ) def save(self, *args, **kwargs): - if self.schema_name == "template": - raise IntegrityError("Cannot create schema named template") + if self.schema_name == get_tenant_base_schema() and not settings.TEST: + raise IntegrityError(f"Cannot create schema named {self.schema_name}") super().save(*args, **kwargs) @property diff --git a/authentik/tenants/scheduler.py b/authentik/tenants/scheduler.py deleted file mode 100644 index 753831ae0b..0000000000 --- a/authentik/tenants/scheduler.py +++ /dev/null @@ -1,22 +0,0 @@ -"""Tenant-aware Celery beat scheduler""" - -from tenant_schemas_celery.scheduler import ( - TenantAwarePersistentScheduler as BaseTenantAwarePersistentScheduler, -) -from tenant_schemas_celery.scheduler import TenantAwareScheduleEntry - - -class TenantAwarePersistentScheduler(BaseTenantAwarePersistentScheduler): - """Tenant-aware Celery beat scheduler""" - - @classmethod - def get_queryset(cls): - return super().get_queryset().filter(ready=True) - - def apply_entry(self, entry: TenantAwareScheduleEntry, producer=None): - # https://github.com/maciej-gol/tenant-schemas-celery/blob/master/tenant_schemas_celery/scheduler.py#L85 - # When (as by default) no tenant schemas are set, the public schema is excluded - # so we need to explicitly include it here, otherwise the task is not executed - if entry.tenant_schemas is None: - entry.tenant_schemas = self.get_queryset().values_list("schema_name", flat=True) - return super().apply_entry(entry, producer) diff --git a/authentik/tenants/tests/test_settings.py b/authentik/tenants/tests/test_settings.py deleted file mode 100644 index 63eb09cbff..0000000000 --- a/authentik/tenants/tests/test_settings.py +++ /dev/null @@ -1,67 +0,0 @@ -"""Test Settings API""" - -from django.urls import reverse - -from authentik.core.tests.utils import create_test_admin_user -from authentik.lib.generators import generate_id -from authentik.tenants.models import Domain, Tenant -from authentik.tenants.tests.utils import TenantAPITestCase - -TENANTS_API_KEY = generate_id() -HEADERS = {"Authorization": f"Bearer {TENANTS_API_KEY}"} - - -class TestSettingsAPI(TenantAPITestCase): - """Test settings API""" - - def setUp(self): - super().setUp() - self.tenant_1 = Tenant.objects.create( - name=generate_id(), schema_name="t_" + generate_id().lower() - ) - Domain.objects.create(tenant=self.tenant_1, domain="tenant1.testserver") - with self.tenant_1: - self.admin_1 = create_test_admin_user() - self.tenant_2 = Tenant.objects.create( - name=generate_id(), schema_name="t_" + generate_id().lower() - ) - Domain.objects.create(tenant=self.tenant_2, domain="tenant2.testserver") - with self.tenant_2: - self.admin_2 = create_test_admin_user() - - def test_settings(self): - """Test settings API""" - # First edit settings to different values in two different tenants - # We need those context managers here because the test client doesn't put itself - # in the tenant context as a real request would. - with self.tenant_1: - self.client.force_login(self.admin_1) - response = self.client.patch( - reverse("authentik_api:tenant_settings"), - data={ - "avatars": "tenant_1_mode", - }, - HTTP_HOST="tenant1.testserver", - ) - self.assertEqual(response.status_code, 200) - with self.tenant_1: - self.client.logout() - - with self.tenant_2: - self.client.force_login(self.admin_2) - response = self.client.patch( - reverse("authentik_api:tenant_settings"), - data={ - "avatars": "tenant_2_mode", - }, - HTTP_HOST="tenant2.testserver", - ) - self.assertEqual(response.status_code, 200) - with self.tenant_2: - self.client.logout() - - # Assert that the settings have changed and are different - self.tenant_1.refresh_from_db() - self.tenant_2.refresh_from_db() - self.assertEqual(self.tenant_1.avatars, "tenant_1_mode") - self.assertEqual(self.tenant_2.avatars, "tenant_2_mode") diff --git a/authentik/tenants/tests/utils.py b/authentik/tenants/tests/utils.py index 5f6a568ed0..074eef4fd8 100644 --- a/authentik/tenants/tests/utils.py +++ b/authentik/tenants/tests/utils.py @@ -1,7 +1,10 @@ from django.core.management import call_command from django.db import connection, connections +from django_tenants.utils import get_public_schema_name, get_tenant_base_schema, schema_context from rest_framework.test import APITransactionTestCase +from authentik.tenants.models import Tenant + class TenantAPITestCase(APITransactionTestCase): # Overridden to also remove additional schemas we may have created @@ -17,7 +20,12 @@ class TenantAPITestCase(APITransactionTestCase): super()._fixture_teardown() def setUp(self): - call_command("migrate_schemas", schema="template", tenant=True) + with schema_context(get_public_schema_name()): + Tenant.objects.update_or_create( + defaults={"name": "Template", "ready": False}, + schema_name=get_tenant_base_schema(), + ) + call_command("migrate_schemas", schema=get_tenant_base_schema(), tenant=True) def assertSchemaExists(self, schema_name): with connection.cursor() as cursor: @@ -28,7 +36,8 @@ class TenantAPITestCase(APITransactionTestCase): self.assertEqual(cursor.rowcount, 1) cursor.execute( - "SELECT * FROM information_schema.tables WHERE table_schema = 'template'" + "SELECT * FROM information_schema.tables WHERE table_schema = %(schema_name)s", + {"schema_name": get_tenant_base_schema()}, ) expected_tables = cursor.rowcount cursor.execute( diff --git a/blueprints/schema.json b/blueprints/schema.json index 1a6b8e10b8..1c364171b0 100644 --- a/blueprints/schema.json +++ b/blueprints/schema.json @@ -4216,6 +4216,46 @@ } } }, + { + "type": "object", + "required": [ + "model", + "identifiers" + ], + "properties": { + "model": { + "const": "authentik_tasks_schedules.schedule" + }, + "id": { + "type": "string" + }, + "state": { + "type": "string", + "enum": [ + "absent", + "created", + "must_created", + "present" + ], + "default": "present" + }, + "conditions": { + "type": "array", + "items": { + "type": "boolean" + } + }, + "permissions": { + "$ref": "#/$defs/model_authentik_tasks_schedules.schedule_permissions" + }, + "attrs": { + "$ref": "#/$defs/model_authentik_tasks_schedules.schedule" + }, + "identifiers": { + "$ref": "#/$defs/model_authentik_tasks_schedules.schedule" + } + } + }, { "type": "object", "required": [ @@ -4933,13 +4973,11 @@ "authentik_events.delete_notificationrule", "authentik_events.delete_notificationtransport", "authentik_events.delete_notificationwebhookmapping", - "authentik_events.run_task", "authentik_events.view_event", "authentik_events.view_notification", "authentik_events.view_notificationrule", "authentik_events.view_notificationtransport", "authentik_events.view_notificationwebhookmapping", - "authentik_events.view_systemtask", "authentik_flows.add_flow", "authentik_flows.add_flowstagebinding", "authentik_flows.add_flowtoken", @@ -5414,6 +5452,11 @@ "authentik_stages_user_write.change_userwritestage", "authentik_stages_user_write.delete_userwritestage", "authentik_stages_user_write.view_userwritestage", + "authentik_tasks.retry_task", + "authentik_tasks.view_task", + "authentik_tasks_schedules.change_schedule", + "authentik_tasks_schedules.send_schedule", + "authentik_tasks_schedules.view_schedule", "authentik_tenants.add_domain", "authentik_tenants.add_tenant", "authentik_tenants.change_domain", @@ -7288,9 +7331,11 @@ "enum": [ null, "authentik.tenants", + "authentik.tasks", "authentik.admin", "authentik.api", "authentik.crypto", + "authentik.events", "authentik.flows", "authentik.outposts", "authentik.policies.dummy", @@ -7338,6 +7383,7 @@ "authentik.stages.user_login", "authentik.stages.user_logout", "authentik.stages.user_write", + "authentik.tasks.schedules", "authentik.brands", "authentik.blueprints", "authentik.core", @@ -7350,8 +7396,7 @@ "authentik.enterprise.search", "authentik.enterprise.stages.authenticator_endpoint_gdtc", "authentik.enterprise.stages.mtls", - "authentik.enterprise.stages.source", - "authentik.events" + "authentik.enterprise.stages.source" ], "title": "App", "description": "Match events created by selected application. When left empty, all applications are matched." @@ -7365,6 +7410,11 @@ null, "authentik_tenants.domain", "authentik_crypto.certificatekeypair", + "authentik_events.event", + "authentik_events.notificationtransport", + "authentik_events.notification", + "authentik_events.notificationrule", + "authentik_events.notificationwebhookmapping", "authentik_flows.flow", "authentik_flows.flowstagebinding", "authentik_outposts.dockerserviceconnection", @@ -7445,6 +7495,7 @@ "authentik_stages_user_login.userloginstage", "authentik_stages_user_logout.userlogoutstage", "authentik_stages_user_write.userwritestage", + "authentik_tasks_schedules.schedule", "authentik_brands.brand", "authentik_blueprints.blueprintinstance", "authentik_core.group", @@ -7461,12 +7512,7 @@ "authentik_providers_ssf.ssfprovider", "authentik_stages_authenticator_endpoint_gdtc.authenticatorendpointgdtcstage", "authentik_stages_mtls.mutualtlsstage", - "authentik_stages_source.sourcestage", - "authentik_events.event", - "authentik_events.notificationtransport", - "authentik_events.notification", - "authentik_events.notificationrule", - "authentik_events.notificationwebhookmapping" + "authentik_stages_source.sourcestage" ], "title": "Model", "description": "Match events created by selected model. When left empty, all models are matched. When an app is selected, all the application's models are matched." @@ -9580,13 +9626,11 @@ "authentik_events.delete_notificationrule", "authentik_events.delete_notificationtransport", "authentik_events.delete_notificationwebhookmapping", - "authentik_events.run_task", "authentik_events.view_event", "authentik_events.view_notification", "authentik_events.view_notificationrule", "authentik_events.view_notificationtransport", "authentik_events.view_notificationwebhookmapping", - "authentik_events.view_systemtask", "authentik_flows.add_flow", "authentik_flows.add_flowstagebinding", "authentik_flows.add_flowtoken", @@ -10061,6 +10105,11 @@ "authentik_stages_user_write.change_userwritestage", "authentik_stages_user_write.delete_userwritestage", "authentik_stages_user_write.view_userwritestage", + "authentik_tasks.retry_task", + "authentik_tasks.view_task", + "authentik_tasks_schedules.change_schedule", + "authentik_tasks_schedules.send_schedule", + "authentik_tasks_schedules.view_schedule", "authentik_tenants.add_domain", "authentik_tenants.add_tenant", "authentik_tenants.change_domain", @@ -15858,6 +15907,56 @@ } } }, + "model_authentik_tasks_schedules.schedule": { + "type": "object", + "properties": { + "rel_obj_id": { + "type": [ + "string", + "null" + ], + "minLength": 1, + "title": "Rel obj id" + }, + "crontab": { + "type": "string", + "minLength": 1, + "title": "Crontab", + "description": "When to schedule tasks" + }, + "paused": { + "type": "boolean", + "title": "Paused", + "description": "Pause this schedule" + } + }, + "required": [] + }, + "model_authentik_tasks_schedules.schedule_permissions": { + "type": "array", + "items": { + "type": "object", + "required": [ + "permission" + ], + "properties": { + "permission": { + "type": "string", + "enum": [ + "change_schedule", + "send_schedule", + "view_schedule" + ] + }, + "user": { + "type": "integer" + }, + "role": { + "type": "string" + } + } + } + }, "model_authentik_tenants.domain": { "type": "object", "properties": { diff --git a/cmd/server/healthcheck.go b/cmd/server/healthcheck.go index 0f0ccec029..ef02c46685 100644 --- a/cmd/server/healthcheck.go +++ b/cmd/server/healthcheck.go @@ -5,8 +5,9 @@ import ( "net/http" "os" "path" + "strconv" "strings" - "time" + "syscall" log "github.com/sirupsen/logrus" "github.com/spf13/cobra" @@ -14,9 +15,7 @@ import ( "goauthentik.io/internal/utils/web" ) -var workerHeartbeat = path.Join(os.TempDir(), "authentik-worker") - -const workerThreshold = 30 +var workerPidFile = path.Join(os.TempDir(), "authentik-worker.pid") var healthcheckCmd = &cobra.Command{ Use: "healthcheck", @@ -61,16 +60,79 @@ func checkServer() int { return 0 } +func splitHostPort(address string) (host, port string) { + lastColon := strings.LastIndex(address, ":") + if lastColon == -1 { + return address, "" + } + + host = address[:lastColon] + port = address[lastColon+1:] + + if strings.HasPrefix(host, "[") && strings.HasSuffix(host, "]") { + host = host[1 : len(host)-1] + } + + return host, port +} + func checkWorker() int { - stat, err := os.Stat(workerHeartbeat) + pidB, err := os.ReadFile(workerPidFile) if err != nil { - log.WithError(err).Warning("failed to check worker heartbeat file") + log.WithError(err).Warning("failed to check worker PID file") return 1 } - delta := time.Since(stat.ModTime()).Seconds() - if delta > workerThreshold { - log.WithField("threshold", workerThreshold).WithField("delta", delta).Warning("Worker hasn't updated heartbeat in threshold") + pidS := strings.TrimSpace(string(pidB[:])) + pid, err := strconv.Atoi(pidS) + if err != nil { + log.WithError(err).Warning("failed to find worker process PID") return 1 } + process, err := os.FindProcess(pid) + if err != nil { + log.WithError(err).Warning("failed to find worker process") + return 1 + } + err = process.Signal(syscall.Signal(0)) + if err != nil { + log.WithError(err).Warning("failed to signal worker process") + return 1 + } + h := &http.Client{ + Transport: web.NewUserAgentTransport("goauthentik.io/healthcheck", http.DefaultTransport), + } + + host, port := splitHostPort(config.Get().Listen.HTTP) + + if host == "0.0.0.0" || host == "::" { + url := fmt.Sprintf("http://%s:%s/-/health/ready/", "::1", port) + _, err := h.Head(url) + if err != nil { + log.WithError(err).WithField("url", url).Warning("failed to send healthcheck request") + url := fmt.Sprintf("http://%s:%s/-/health/ready/", "127.0.0.1", port) + res, err := h.Head(url) + if err != nil { + log.WithError(err).WithField("url", url).Warning("failed to send healthcheck request") + return 1 + } + if res.StatusCode >= 400 { + log.WithField("status", res.StatusCode).Warning("unhealthy status code") + return 1 + } + } + } else { + url := fmt.Sprintf("http://%s:%s/-/health/ready/", host, port) + res, err := h.Head(url) + if err != nil { + log.WithError(err).Warning("failed to send healthcheck request") + return 1 + } + if res.StatusCode >= 400 { + log.WithField("status", res.StatusCode).Warning("unhealthy status code") + return 1 + } + } + + log.Debug("successfully checked health") return 0 } diff --git a/lifecycle/ak b/lifecycle/ak index e4f214f3a8..a25a8d3930 100755 --- a/lifecycle/ak +++ b/lifecycle/ak @@ -74,21 +74,17 @@ fi if [[ "$1" == "server" ]]; then set_mode "server" - # If we have bootstrap credentials set, run bootstrap tasks outside of main server - # sync, so that we can sure the first start actually has working bootstrap - # credentials - if [[ ! -z "${AUTHENTIK_BOOTSTRAP_PASSWORD}" || ! -z "${AUTHENTIK_BOOTSTRAP_TOKEN}" ]]; then - python -m manage bootstrap_tasks - fi run_authentik elif [[ "$1" == "worker" ]]; then set_mode "worker" shift - check_if_root "python -m manage worker $@" -elif [[ "$1" == "worker-status" ]]; then - wait_for_db - celery -A authentik.root.celery flower \ - --port=9000 + # If we have bootstrap credentials set, run bootstrap tasks outside of main server + # sync, so that we can sure the first start actually has working bootstrap + # credentials + if [[ -n "${AUTHENTIK_BOOTSTRAP_PASSWORD}" || -n "${AUTHENTIK_BOOTSTRAP_TOKEN}" ]]; then + python -m manage apply_blueprint system/bootstrap.yaml || true + fi + check_if_root "python -m manage worker --pid-file ${TMPDIR}/authentik-worker.pid $@" elif [[ "$1" == "bash" ]]; then /bin/bash elif [[ "$1" == "test-all" ]]; then diff --git a/lifecycle/gunicorn.conf.py b/lifecycle/gunicorn.conf.py index 58b6643807..4381684d61 100644 --- a/lifecycle/gunicorn.conf.py +++ b/lifecycle/gunicorn.conf.py @@ -2,7 +2,6 @@ import os from hashlib import sha512 -from os import makedirs from pathlib import Path from tempfile import gettempdir from typing import TYPE_CHECKING @@ -33,11 +32,11 @@ wait_for_db() _tmp = Path(gettempdir()) worker_class = "lifecycle.worker.DjangoUvicornWorker" -worker_tmp_dir = str(_tmp.joinpath("authentik_worker_tmp")) +worker_tmp_dir = str(_tmp.joinpath("authentik_gunicorn_tmp")) prometheus_tmp_dir = str(_tmp.joinpath("authentik_prometheus_tmp")) -makedirs(worker_tmp_dir, exist_ok=True) -makedirs(prometheus_tmp_dir, exist_ok=True) +os.makedirs(worker_tmp_dir, exist_ok=True) +os.makedirs(prometheus_tmp_dir, exist_ok=True) bind = f"unix://{str(_tmp.joinpath('authentik-core.sock'))}" diff --git a/manage.py b/manage.py index 1cfca76534..060df5e2c7 100755 --- a/manage.py +++ b/manage.py @@ -17,9 +17,7 @@ if __name__ == "__main__": if ( len(sys.argv) > 1 # Explicitly only run migrate for server and worker - # `bootstrap_tasks` is a special case as that command might be triggered by the `ak` - # script to pre-run certain tasks for an automated install - and sys.argv[1] in ["dev_server", "worker", "bootstrap_tasks"] + and sys.argv[1] in ["dev_server", "worker"] # and don't run if this is the child process of a dev_server and os.environ.get(DJANGO_AUTORELOAD_ENV, None) is None ): diff --git a/packages/django-dramatiq-postgres/README.md b/packages/django-dramatiq-postgres/README.md new file mode 100644 index 0000000000..e69de29bb2 diff --git a/packages/django-dramatiq-postgres/django_dramatiq_postgres/__init__.py b/packages/django-dramatiq-postgres/django_dramatiq_postgres/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/packages/django-dramatiq-postgres/django_dramatiq_postgres/apps.py b/packages/django-dramatiq-postgres/django_dramatiq_postgres/apps.py new file mode 100644 index 0000000000..89873dfc14 --- /dev/null +++ b/packages/django-dramatiq-postgres/django_dramatiq_postgres/apps.py @@ -0,0 +1,45 @@ +import dramatiq +from django.apps import AppConfig +from django.core.exceptions import ImproperlyConfigured +from django.utils.module_loading import import_string +from dramatiq.results.middleware import Results + +from django_dramatiq_postgres.conf import Conf + + +class DjangoDramatiqPostgres(AppConfig): + name = "django_dramatiq_postgres" + verbose_name = "Django Dramatiq postgres" + + def ready(self): + old_broker = dramatiq.get_broker() + + if len(old_broker.actors) != 0: + raise ImproperlyConfigured( + "Actors were previously registered. " + "Make sure your actors are not imported too early." + ) + + encoder: dramatiq.encoder.Encoder = import_string(Conf().encoder_class)() + dramatiq.set_encoder(encoder) + + broker: dramatiq.broker.Broker = import_string(Conf().broker_class)( + *Conf().broker_args, + **Conf().broker_kwargs, + middleware=[], + ) + + for middleware_class, middleware_kwargs in Conf().middlewares: + middleware: dramatiq.middleware.middleware.Middleware = import_string(middleware_class)( + **middleware_kwargs, + ) + if isinstance(middleware, Results): + middleware.backend = import_string(Conf().result_backend)( + *Conf().result_backend_args, + **Conf().result_backend_kwargs, + ) + broker.add_middleware(middleware) + + dramatiq.set_broker(broker) + + return super().ready() diff --git a/packages/django-dramatiq-postgres/django_dramatiq_postgres/broker.py b/packages/django-dramatiq-postgres/django_dramatiq_postgres/broker.py new file mode 100644 index 0000000000..50cd8576e9 --- /dev/null +++ b/packages/django-dramatiq-postgres/django_dramatiq_postgres/broker.py @@ -0,0 +1,454 @@ +import functools +import logging +import time +from collections.abc import Iterable +from queue import Empty, Queue +from typing import Any + +import tenacity +from django.core.exceptions import ImproperlyConfigured +from django.db import ( + DEFAULT_DB_ALIAS, + DatabaseError, + InterfaceError, + OperationalError, + connections, + transaction, +) +from django.db.backends.postgresql.base import DatabaseWrapper +from django.db.models import QuerySet +from django.utils import timezone +from django.utils.functional import cached_property +from django.utils.module_loading import import_string +from dramatiq.broker import Broker, Consumer, MessageProxy +from dramatiq.common import compute_backoff, current_millis, dq_name, xq_name +from dramatiq.errors import ConnectionError, QueueJoinTimeout +from dramatiq.logging import get_logger +from dramatiq.message import Message +from dramatiq.middleware import ( + Middleware, +) +from pglock.core import _cast_lock_id +from psycopg import Notify, sql +from psycopg.errors import AdminShutdown + +from django_dramatiq_postgres.conf import Conf +from django_dramatiq_postgres.models import CHANNEL_PREFIX, ChannelIdentifier, TaskBase, TaskState + +logger = get_logger(__name__) + + +def channel_name(queue_name: str, identifier: ChannelIdentifier) -> str: + return f"{CHANNEL_PREFIX}.{queue_name}.{identifier.value}" + + +def raise_connection_error(func): + @functools.wraps(func) + def wrapper(*args, **kwargs): + try: + return func(*args, **kwargs) + except OperationalError as exc: + raise ConnectionError(str(exc)) from exc + + return wrapper + + +class PostgresBroker(Broker): + def __init__( + self, + *args, + middleware: list[Middleware] | None = None, + db_alias: str = DEFAULT_DB_ALIAS, + **kwargs, + ): + super().__init__(*args, middleware=[], **kwargs) + self.logger = get_logger(__name__, type(self)) + + self.queues = set() + + self.db_alias = db_alias + self.middleware = [] + if middleware: + raise ImproperlyConfigured( + "Middlewares should be set in django settings, not passed directly to the broker." + ) + + @property + def connection(self) -> DatabaseWrapper: + return connections[self.db_alias] + + @property + def consumer_class(self) -> "type[_PostgresConsumer]": + return _PostgresConsumer + + @cached_property + def model(self) -> type[TaskBase]: + return import_string(Conf().task_model) + + @property + def query_set(self) -> QuerySet: + return self.model.objects.using(self.db_alias).defer("message", "result") + + def consume(self, queue_name: str, prefetch: int = 1, timeout: int = 30000) -> Consumer: + self.declare_queue(queue_name) + return self.consumer_class( + broker=self, + db_alias=self.db_alias, + queue_name=queue_name, + prefetch=prefetch, + timeout=timeout, + ) + + def declare_queue(self, queue_name: str): + if queue_name not in self.queues: + self.emit_before("declare_queue", queue_name) + self.queues.add(queue_name) + # Nothing to do, all queues are in the same table + self.emit_after("declare_queue", queue_name) + + delayed_name = dq_name(queue_name) + self.delay_queues.add(delayed_name) + self.emit_after("declare_delay_queue", delayed_name) + + def model_defaults(self, message: Message) -> dict[str, Any]: + return { + "queue_name": message.queue_name, + "actor_name": message.actor_name, + "state": TaskState.QUEUED, + } + + @tenacity.retry( + retry=tenacity.retry_if_exception_type( + ( + AdminShutdown, + InterfaceError, + DatabaseError, + ConnectionError, + OperationalError, + ) + ), + reraise=True, + wait=tenacity.wait_random_exponential(multiplier=1, max=5), + stop=tenacity.stop_after_attempt(3), + before_sleep=tenacity.before_sleep_log(logger, logging.INFO, exc_info=True), + ) + def enqueue(self, message: Message, *, delay: int | None = None) -> Message: + canonical_queue_name = message.queue_name + queue_name = canonical_queue_name + if delay: + queue_name = dq_name(queue_name) + message_eta = current_millis() + delay + message = message.copy( + queue_name=queue_name, + options={ + "eta": message_eta, + }, + ) + + self.declare_queue(canonical_queue_name) + self.logger.debug(f"Enqueueing message {message.message_id} on queue {queue_name}") + + message.options["model_defaults"] = self.model_defaults(message) + self.emit_before("enqueue", message, delay) + + with transaction.atomic(using=self.db_alias): + query = { + "message_id": message.message_id, + } + defaults = message.options["model_defaults"] + del message.options["model_defaults"] + defaults["message"] = message.encode() + create_defaults = { + **query, + **defaults, + } + + task, created = self.query_set.update_or_create( + **query, + defaults=defaults, + create_defaults=create_defaults, + ) + message.options["task"] = task + message.options["task_created"] = created + + self.emit_after("enqueue", message, delay) + return message + + def get_declared_queues(self) -> set[str]: + return self.queues.copy() + + def flush(self, queue_name: str): + self.query_set.filter( + queue_name__in=(queue_name, dq_name(queue_name), xq_name(queue_name)) + ).delete() + + def flush_all(self): + for queue_name in self.queues: + self.flush(queue_name) + + def join( + self, + queue_name: str, + interval: int = 100, + *, + timeout: int | None = None, + ): + deadline = timeout and time.monotonic() + timeout / 1000 + while True: + if deadline and time.monotonic() >= deadline: + raise QueueJoinTimeout(queue_name) + + if self.query_set.filter( + queue_name=queue_name, + state__in=(TaskState.QUEUED, TaskState.CONSUMED), + ).exists(): + return + + time.sleep(interval / 1000) + + +class _PostgresConsumer(Consumer): + def __init__( + self, + *args, + broker: PostgresBroker, + db_alias: str, + queue_name: str, + prefetch: int, + timeout: int, + **kwargs, + ): + self.logger = get_logger(__name__, type(self)) + + self.notifies: list[Notify] = [] + self.broker = broker + self.db_alias = db_alias + self.queue_name = queue_name + self.timeout = timeout // 1000 + self.unlock_queue = Queue() + self.in_processing = set() + self.prefetch = prefetch + self.misses = 0 + self._listen_connection: DatabaseWrapper | None = None + self.postgres_channel = channel_name(self.queue_name, ChannelIdentifier.ENQUEUE) + + # Override because dramatiq doesn't allow us setting this manually + self.timeout = Conf().worker["consumer_listen_timeout"] + + self.task_purge_interval = timezone.timedelta(seconds=Conf().task_purge_interval) + self.task_purge_last_run = timezone.now() - self.task_purge_interval + + self.scheduler = None + if Conf().schedule_model: + self.scheduler = import_string(Conf().scheduler_class)() + self.scheduler.broker = self.broker + self.scheduler_interval = timezone.timedelta(seconds=Conf().scheduler_interval) + self.scheduler_last_run = timezone.now() - self.scheduler_interval + + @property + def connection(self) -> DatabaseWrapper: + return connections[self.db_alias] + + @property + def query_set(self) -> QuerySet: + return self.broker.query_set + + @property + def listen_connection(self) -> DatabaseWrapper: + if self._listen_connection is not None and self._listen_connection.is_usable(): + return self._listen_connection + self._listen_connection = connections.create_connection(self.db_alias) + # Required for notifications + # See https://www.psycopg.org/psycopg3/docs/advanced/async.html#asynchronous-notifications + # Should be set to True by Django by default + self._listen_connection.set_autocommit(True) + with self._listen_connection.cursor() as cursor: + cursor.execute(sql.SQL("LISTEN {}").format(sql.Identifier(self.postgres_channel))) + return self._listen_connection + + @raise_connection_error + def ack(self, message: Message): + task = message.options.pop("task", None) + self.query_set.filter( + message_id=message.message_id, + queue_name=message.queue_name, + state=TaskState.CONSUMED, + ).update( + state=TaskState.DONE, + message=message.encode(), + ) + message.options["task"] = task + self.unlock_queue.put_nowait(message.message_id) + self.in_processing.remove(message.message_id) + + @raise_connection_error + def nack(self, message: Message): + task = message.options.pop("task", None) + self.query_set.filter( + message_id=message.message_id, + queue_name=message.queue_name, + ).exclude( + state=TaskState.REJECTED, + ).update( + state=TaskState.REJECTED, + message=message.encode(), + ) + message.options["task"] = task + self.unlock_queue.put_nowait(message.message_id) + self.in_processing.remove(message.message_id) + + @raise_connection_error + def requeue(self, messages: Iterable[Message]): + self.query_set.filter( + message_id__in=[message.message_id for message in messages], + ).update( + state=TaskState.QUEUED, + ) + for message in messages: + self.unlock_queue.put_nowait(message.message_id) + self.in_processing.remove(message.message_id) + self._purge_locks() + + def _fetch_pending_notifies(self) -> list[Notify]: + self.logger.debug(f"Polling for lost messages in {self.queue_name}") + notifies = ( + self.query_set.filter( + state__in=(TaskState.QUEUED, TaskState.CONSUMED), + queue_name=self.queue_name, + ) + .exclude( + message_id__in=self.in_processing, + ) + .values_list("message_id", flat=True) + ) + return [Notify(pid=0, channel=self.postgres_channel, payload=item) for item in notifies] + + def _poll_for_notify(self): + with self.listen_connection.cursor() as cursor: + notifies = list(cursor.connection.notifies(timeout=self.timeout, stop_after=1)) + self.logger.debug( + f"Received {len(notifies)} postgres notifies on channel {self.postgres_channel}" + ) + self.notifies += notifies + + def _get_message_lock_id(self, message_id: str) -> int: + return _cast_lock_id( + f"{channel_name(self.queue_name, ChannelIdentifier.LOCK)}.{message_id}" + ) + + def _consume_one(self, message: Message) -> bool: + if message.message_id in self.in_processing: + self.logger.debug(f"Message {message.message_id} already consumed by self") + return False + + result = ( + self.query_set.filter( + message_id=message.message_id, + state__in=(TaskState.QUEUED, TaskState.CONSUMED), + ) + .extra( + where=["pg_try_advisory_lock(%s)"], + params=[self._get_message_lock_id(message.message_id)], + ) + .update( + state=TaskState.CONSUMED, + mtime=timezone.now(), + ) + ) + return result == 1 + + @raise_connection_error + def __next__(self) -> MessageProxy | None: + # This method is called every second + + # If we don't have a connection yet, fetch missed notifications from the table directly + if self._listen_connection is None: + # We might miss a notification between the initial query and the first time we wait for + # notifications, it doesn't matter because we re-fetch for missed messages later on. + self.notifies = self._fetch_pending_notifies() + self.logger.debug( + f"Found {len(self.notifies)} pending messages in queue {self.queue_name}" + ) + + processing = len(self.in_processing) + if processing >= self.prefetch: + # Wait and don't consume the message, other worker will be faster + self.misses, backoff_ms = compute_backoff(self.misses, max_backoff=1000) + self.logger.debug( + f"Too many messages in processing: {processing}. Sleeping {backoff_ms} ms" + ) + time.sleep(backoff_ms / 1000) + return None + + if not self.notifies: + self._poll_for_notify() + + if not self.notifies: + self.notifies[:] = self._fetch_pending_notifies() + + # If we have some notifies, loop to find one to do + while self.notifies: + notify = self.notifies.pop(0) + task: TaskBase | None = ( + self.query_set.defer(None).defer("result").filter(message_id=notify.payload).first() + ) + if task is None: + continue + message = Message.decode(task.message) + message.options["task"] = task + if self._consume_one(message): + self.in_processing.add(message.message_id) + return MessageProxy(message) + else: + self.logger.debug(f"Message {message.message_id} already consumed. Skipping.") + + # No message to process + self._purge_locks() + self._auto_purge() + self._scheduler() + + return None + + def _purge_locks(self): + while True: + try: + message_id = self.unlock_queue.get(block=False) + except Empty: + return + self.logger.debug(f"Unlocking message {message_id}") + with self.connection.cursor() as cursor: + cursor.execute( + "SELECT pg_advisory_unlock(%s)", (self._get_message_lock_id(message_id),) + ) + self.unlock_queue.task_done() + + def _auto_purge(self): + if timezone.now() - self.task_purge_last_run < self.task_purge_interval: + return + self.logger.debug("Running garbage collector") + count = self.query_set.filter( + state__in=(TaskState.DONE, TaskState.REJECTED), + mtime__lte=timezone.now() - timezone.timedelta(seconds=Conf().task_purge_interval), + result_expiry__lte=timezone.now(), + ).delete() + self.logger.info(f"Purged {count} messages in all queues") + + def _scheduler(self): + if not self.scheduler: + return + if timezone.now() - self.scheduler_last_run < self.scheduler_interval: + return + self.scheduler.run() + + @raise_connection_error + def close(self): + try: + self._purge_locks() + finally: + try: + self.connection.close() + finally: + if self._listen_connection is not None: + conn = self._listen_connection + self._listen_connection = None + conn.close() diff --git a/packages/django-dramatiq-postgres/django_dramatiq_postgres/conf.py b/packages/django-dramatiq-postgres/django_dramatiq_postgres/conf.py new file mode 100644 index 0000000000..fab7b098bb --- /dev/null +++ b/packages/django-dramatiq-postgres/django_dramatiq_postgres/conf.py @@ -0,0 +1,124 @@ +from typing import Any + +from django.conf import settings +from django.core.exceptions import ImproperlyConfigured + + +class Conf: + def __init__(self): + try: + _ = settings.DRAMATIQ + except AttributeError as exc: + raise ImproperlyConfigured("Setting DRAMATIQ not set.") from exc + if "task_model" not in self.conf: + raise ImproperlyConfigured("DRAMATIQ.task_model not defined") + + @property + def conf(self) -> dict[str, Any]: + return settings.DRAMATIQ + + @property + def encoder_class(self) -> str: + return self.conf.get("encoder_class", "dramatiq.encoder.PickleEncoder") + + @property + def broker_class(self) -> str: + return self.conf.get("broker_class", "django_dramatiq_postgres.broker.PostgresBroker") + + @property + def broker_args(self) -> tuple[Any]: + return self.conf.get("broker_args", ()) + + @property + def broker_kwargs(self) -> dict[str, Any]: + return self.conf.get("broker_kwargs", {}) + + @property + def middlewares(self) -> tuple[tuple[str, dict[str, Any]]]: + return self.conf.get( + "middlewares", + ( + ("django_dramatiq_postgres.middleware.DbConnectionMiddleware", {}), + ("dramatiq.middleware.age_limit.AgeLimit", {}), + ("dramatiq.middleware.time_limit.TimeLimit", {}), + ("dramatiq.middleware.shutdown.ShutdownNotifications", {}), + ("dramatiq.middleware.callbacks.Callbacks", {}), + ("dramatiq.middleware.pipelines.Pipelines", {}), + ("dramatiq.middleware.retries.Retries", {}), + ), + ) + + @property + def channel_prefix(self) -> str: + return self.conf.get("channel_prefix", "dramatiq") + + @property + def task_model(self) -> str: + return self.conf["task_model"] + + @property + def task_purge_interval(self) -> int: + # 24 hours + return self.conf.get("task_purge_interval", 24 * 60 * 60) + + @property + def task_expiration(self) -> int: + # 30 days + return self.conf.get("task_expiration", 60 * 60 * 24 * 30) + + @property + def result_backend(self) -> str: + return self.conf.get("result_backend", "django_dramatiq_postgres.results.PostgresBackend") + + @property + def result_backend_args(self) -> tuple[Any]: + return self.conf.get("result_backend_args", ()) + + @property + def result_backend_kwargs(self) -> dict[str, Any]: + return self.conf.get("result_backend_kwargs", {}) + + @property + def autodiscovery(self) -> dict[str, Any]: + autodiscovery = { + "enabled": False, + "setup_module": "django_dramatiq_postgres.setup", + "apps_prefix": None, + "actors_module_name": "tasks", + "modules_callback": None, + **self.conf.get("autodiscovery", {}), + } + if not autodiscovery["enabled"] and not autodiscovery["modules_callback"]: + raise ImproperlyConfigured( + "One of DRAMATIQ.autodiscovery.enabled or " + "DRAMATIQ.autodiscovery.modules_callback must be configured." + ) + return autodiscovery + + @property + def worker(self) -> dict[str, Any]: + return { + "use_gevent": False, + "watch_folder": ".", + "watch_use_polling": False, + "processes": None, + "threads": None, + "consumer_listen_timeout": 30, + **self.conf.get("worker", {}), + } + + @property + def scheduler_class(self) -> str: + return self.conf.get("scheduler_class", "django_dramatiq_postgres.scheduler.Scheduler") + + @property + def schedule_model(self) -> str | None: + return self.conf.get("schedule_model") + + @property + def scheduler_interval(self) -> int: + return self.conf.get("scheduler_interval", 60) + + @property + def test(self) -> bool: + return self.conf.get("test", False) diff --git a/packages/django-dramatiq-postgres/django_dramatiq_postgres/forks.py b/packages/django-dramatiq-postgres/django_dramatiq_postgres/forks.py new file mode 100644 index 0000000000..f35dd637ae --- /dev/null +++ b/packages/django-dramatiq-postgres/django_dramatiq_postgres/forks.py @@ -0,0 +1,18 @@ +import os +from signal import pause + +from django.utils.module_loading import import_module + +from django_dramatiq_postgres.conf import Conf + + +def worker_metrics(): + import_module(Conf().autodiscovery["setup_module"]) + + from django_dramatiq_postgres.middleware import MetricsMiddleware + + MetricsMiddleware.run( + os.getenv("dramatiq_prom_host", "0.0.0.0"), # nosec + int(os.getenv("dramatiq_prom_port", "9191")), + ) + pause() diff --git a/packages/django-dramatiq-postgres/django_dramatiq_postgres/management/__init__.py b/packages/django-dramatiq-postgres/django_dramatiq_postgres/management/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/packages/django-dramatiq-postgres/django_dramatiq_postgres/management/commands/__init__.py b/packages/django-dramatiq-postgres/django_dramatiq_postgres/management/commands/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/packages/django-dramatiq-postgres/django_dramatiq_postgres/management/commands/worker.py b/packages/django-dramatiq-postgres/django_dramatiq_postgres/management/commands/worker.py new file mode 100644 index 0000000000..f0b88bf1b1 --- /dev/null +++ b/packages/django-dramatiq-postgres/django_dramatiq_postgres/management/commands/worker.py @@ -0,0 +1,100 @@ +import os +import sys + +from django.apps.registry import apps +from django.core.management.base import BaseCommand +from django.utils.module_loading import import_string, module_has_submodule + +from django_dramatiq_postgres.conf import Conf + + +class Command(BaseCommand): + """Run worker""" + + def add_arguments(self, parser): + parser.add_argument( + "--pid-file", + action="store", + default=None, + dest="pid_file", + help="PID file", + ) + parser.add_argument( + "--watch", + action="store_true", + default=False, + dest="watch", + help="Watch for file changes", + ) + + def handle( + self, + pid_file, + watch, + verbosity, + **options, + ): + worker = Conf().worker + executable_name = "dramatiq-gevent" if worker["use_gevent"] else "dramatiq" + executable_path = self._resolve_executable(executable_name) + watch_args = ["--watch", worker["watch_folder"]] if watch else [] + if watch_args and worker["watch_use_polling"]: + watch_args.append("--watch-use-polling") + + parallel_args = [] + if processes := worker["processes"]: + parallel_args.extend(["--processes", str(processes)]) + if threads := worker["threads"]: + parallel_args.extend(["--threads", str(threads)]) + + pid_file_args = [] + if pid_file is not None: + pid_file_args = ["--pid-file", pid_file] + + verbosity_args = ["-v"] * (verbosity - 1) + + tasks_modules = self._discover_tasks_modules() + process_args = [ + executable_name, + "--path", + ".", + *parallel_args, + *watch_args, + *pid_file_args, + *verbosity_args, + *tasks_modules, + ] + + os.execvp(executable_path, process_args) # nosec + + def _resolve_executable(self, exec_name: str): + bin_dir = os.path.dirname(sys.executable) + if bin_dir: + for d in [bin_dir, os.path.join(bin_dir, "Scripts")]: + exec_path = os.path.join(d, exec_name) + if os.path.isfile(exec_path): + return exec_path + return exec_name + + def _discover_tasks_modules(self) -> list[str]: + # Does not support a tasks directory + autodiscovery = Conf().autodiscovery + modules = [autodiscovery["setup_module"]] + + if autodiscovery["enabled"]: + for app in apps.get_app_configs(): + if autodiscovery["apps_prefix"] and not app.name.startswith( + autodiscovery["apps_prefix"] + ): + continue + if module_has_submodule(app.module, autodiscovery["actors_module_name"]): + modules.append(f"{app.name}.{autodiscovery['actors_module_name']}") + else: + modules_callback = autodiscovery["modules_callback"] + callback = ( + modules_callback + if not isinstance(modules_callback, str) + else import_string(modules_callback) + ) + modules.extend(callback()) + return modules diff --git a/packages/django-dramatiq-postgres/django_dramatiq_postgres/middleware.py b/packages/django-dramatiq-postgres/django_dramatiq_postgres/middleware.py new file mode 100644 index 0000000000..e4dcba7b92 --- /dev/null +++ b/packages/django-dramatiq-postgres/django_dramatiq_postgres/middleware.py @@ -0,0 +1,309 @@ +import contextvars +import os +import socket +from http.server import BaseHTTPRequestHandler +from http.server import HTTPServer as BaseHTTPServer +from ipaddress import IPv6Address, ip_address +from typing import Any + +from django.db import ( + close_old_connections, + connections, +) +from dramatiq.actor import Actor +from dramatiq.broker import Broker +from dramatiq.common import current_millis +from dramatiq.logging import get_logger +from dramatiq.message import Message +from dramatiq.middleware.middleware import Middleware + +from django_dramatiq_postgres.conf import Conf +from django_dramatiq_postgres.models import TaskBase + + +class HTTPServer(BaseHTTPServer): + def server_bind(self): + self.socket.close() + + host, port = self.server_address[:2] + if host == "0.0.0.0": # nosec + host = "::" # nosec + + # Strip IPv6 brackets + if host.startswith("[") and host.endswith("]"): + host = host[1:-1] + + self.server_address = (host, port) + + self.address_family = ( + socket.AF_INET6 if isinstance(ip_address(host), IPv6Address) else socket.AF_INET + ) + + self.socket = socket.create_server( + self.server_address, + family=self.address_family, + dualstack_ipv6=self.address_family == socket.AF_INET6, + ) + + self.server_name = socket.getfqdn(host) + self.server_port = port + + +class DbConnectionMiddleware(Middleware): + def _close_old_connections(self, *args, **kwargs): + if Conf().test: + return + close_old_connections() + + before_process_message = _close_old_connections + after_process_message = _close_old_connections + + def _close_connections(self, *args, **kwargs): + connections.close_all() + + before_consumer_thread_shutdown = _close_connections + before_worker_thread_shutdown = _close_connections + before_worker_shutdown = _close_connections + + +class FullyQualifiedActorName(Middleware): + def before_declare_actor(self, broker: Broker, actor: Actor): + actor.actor_name = f"{actor.fn.__module__}.{actor.fn.__name__}" + + +class CurrentTaskNotFound(Exception): + """ + Not current task found. Did you call get_task outside a running task? + """ + + +class CurrentTask(Middleware): + def __init__(self): + self.logger = get_logger(__name__, type(self)) + + # This is a list of tasks, so that in tests, when a task calls another task, this acts as a pile + _TASKS: contextvars.ContextVar[list[TaskBase] | None] = contextvars.ContextVar( + "_TASKS", + default=None, + ) + + @classmethod + def get_task(cls) -> TaskBase: + task = cls._TASKS.get() + if not task: + raise CurrentTaskNotFound() + return task[-1] + + def before_process_message(self, broker: Broker, message: Message): + tasks = self._TASKS.get() + if tasks is None: + tasks = [] + tasks.append(message.options["task"]) + self._TASKS.set(tasks) + + def after_process_message( + self, + broker: Broker, + message: Message, + *, + result: Any | None = None, + exception: Exception | None = None, + ): + tasks: list[TaskBase] | None = self._TASKS.get() + if tasks is None or len(tasks) == 0: + return + + task = tasks[-1] + fields_to_exclude = { + "message_id", + "queue_name", + "actor_name", + "message", + "state", + "mtime", + "result", + "result_expiry", + } + fields_to_update = [ + f.name + for f in task._meta.get_fields() + if f.name not in fields_to_exclude and not f.auto_created and f.column + ] + if fields_to_update: + task.save(update_fields=fields_to_update) + self._TASKS.set(tasks[:-1]) + + def after_skip_message(self, broker: Broker, message: Message): + self.after_process_message(broker, message) + + +class MetricsMiddleware(Middleware): + def __init__( + self, + prefix: str, + multiproc_dir: str, + labels: list[str] | None = None, + ): + super().__init__() + self.prefix = prefix + self.labels: list[str] = labels if labels is not None else ["queue_name", "actor_name"] + + self.delayed_messages = set() + self.message_start_times = {} + + os.makedirs(multiproc_dir, exist_ok=True) + os.environ.setdefault("PROMETHEUS_MULTIPROC_DIR", multiproc_dir) + + @property + def forks(self): + from django_dramatiq_postgres.forks import worker_metrics + + return [worker_metrics] + + def before_worker_boot(self, broker: Broker, worker): + if Conf().test: + return + + from prometheus_client import Counter, Gauge, Histogram + + self.total_messages = Counter( + f"{self.prefix}_tasks_total", + "The total number of tasks processed.", + self.labels, + ) + self.total_errored_messages = Counter( + f"{self.prefix}_tasks_errors_total", + "The total number of errored tasks.", + self.labels, + ) + self.total_retried_messages = Counter( + f"{self.prefix}_tasks_retries_total", + "The total number of retried tasks.", + self.labels, + ) + self.total_rejected_messages = Counter( + f"{self.prefix}_tasks_rejected_total", + "The total number of dead-lettered tasks.", + self.labels, + ) + self.inprogress_messages = Gauge( + f"{self.prefix}_tasks_inprogress", + "The number of tasks in progress.", + self.labels, + multiprocess_mode="livesum", + ) + self.inprogress_delayed_messages = Gauge( + f"{self.prefix}_tasks_delayed_inprogress", + "The number of delayed tasks in memory.", + self.labels, + ) + self.messages_durations = Histogram( + f"{self.prefix}_tasks_duration_miliseconds", + "The time spent processing tasks.", + self.labels, + buckets=( + 5, + 10, + 25, + 50, + 75, + 100, + 250, + 500, + 750, + 1_000, + 2_500, + 5_000, + 7_500, + 10_000, + 30_000, + 60_000, + 600_000, + 900_000, + 1_800_000, + 3_600_000, + float("inf"), + ), + ) + + def after_worker_shutdown(self, broker: Broker, worker): + from prometheus_client import multiprocess + + # TODO: worker_id + multiprocess.mark_process_dead(os.getpid()) + + def _make_labels(self, message: Message) -> list[str]: + return [message.queue_name, message.actor_name] + + def after_nack(self, broker: Broker, message: Message): + self.total_rejected_messages.labels(*self._make_labels(message)).inc() + + def after_enqueue(self, broker: Broker, message: Message, delay: int): + if "retries" in message.options: + self.total_retried_messages.labels(*self._make_labels(message)).inc() + + def before_delay_message(self, broker: Broker, message: Message): + self.delayed_messages.add(message.message_id) + self.inprogress_delayed_messages.labels(*self._make_labels(message)).inc() + + def before_process_message(self, broker: Broker, message: Message): + labels = self._make_labels(message) + if message.message_id in self.delayed_messages: + self.delayed_messages.remove(message.message_id) + self.inprogress_delayed_messages.labels(*labels).dec() + + self.inprogress_messages.labels(*labels).inc() + self.message_start_times[message.message_id] = current_millis() + + def after_process_message( + self, + broker: Broker, + message: Message, + *, + result: Any | None = None, + exception: Exception | None = None, + ): + labels = self._make_labels(message) + + message_start_time = self.message_start_times.pop(message.message_id, current_millis()) + message_duration = current_millis() - message_start_time + self.messages_durations.labels(*labels).observe(message_duration) + + self.inprogress_messages.labels(*labels).dec() + self.total_messages.labels(*labels).inc() + if exception is not None: + self.total_errored_messages.labels(*labels).inc() + + after_skip_message = after_process_message + + @staticmethod + def run(addr: str, port: int): + try: + server = HTTPServer((addr, port), _MetricsHandler) + server.serve_forever() + except OSError: + get_logger(__name__, type(MetricsMiddleware)).warning( + "Port is already in use, not starting metrics server" + ) + + +class _MetricsHandler(BaseHTTPRequestHandler): + def do_GET(self): + from prometheus_client import ( + CONTENT_TYPE_LATEST, + CollectorRegistry, + generate_latest, + multiprocess, + ) + + registry = CollectorRegistry() + multiprocess.MultiProcessCollector(registry) + output = generate_latest(registry) + self.send_response(200) + self.send_header("Content-Type", CONTENT_TYPE_LATEST) + self.end_headers() + self.wfile.write(output) + + def log_message(self, format, *args): + logger = get_logger(__name__, type(self)) + logger.debug(format, *args) diff --git a/packages/django-dramatiq-postgres/django_dramatiq_postgres/models.py b/packages/django-dramatiq-postgres/django_dramatiq_postgres/models.py new file mode 100644 index 0000000000..a25f3d1f41 --- /dev/null +++ b/packages/django-dramatiq-postgres/django_dramatiq_postgres/models.py @@ -0,0 +1,162 @@ +import pickle # nosec +from enum import StrEnum, auto +from uuid import uuid4 + +import pgtrigger +from cron_converter import Cron +from django.contrib.contenttypes.fields import GenericForeignKey +from django.contrib.contenttypes.models import ContentType +from django.core.exceptions import ValidationError +from django.db import models +from django.utils.timezone import datetime, now, timedelta +from django.utils.translation import gettext_lazy as _ +from dramatiq.actor import Actor +from dramatiq.broker import Broker, get_broker +from dramatiq.message import Message + +from django_dramatiq_postgres.conf import Conf + +CHANNEL_PREFIX = f"{Conf().channel_prefix}.tasks" + + +class ChannelIdentifier(StrEnum): + ENQUEUE = auto() + LOCK = auto() + + +class TaskState(models.TextChoices): + """Task system-state. Reported by the task runners""" + + QUEUED = "queued" + CONSUMED = "consumed" + REJECTED = "rejected" + DONE = "done" + + +class TaskBase(models.Model): + message_id = models.UUIDField(primary_key=True, default=uuid4) + queue_name = models.TextField(default="default", help_text=_("Queue name")) + + actor_name = models.TextField(help_text=_("Dramatiq actor name")) + message = models.BinaryField(null=True, help_text=_("Message body")) + state = models.CharField( + default=TaskState.QUEUED, + choices=TaskState.choices, + help_text=_("Task status"), + ) + mtime = models.DateTimeField(default=now, help_text=_("Task last modified time")) + + result = models.BinaryField(null=True, help_text=_("Task result")) + result_expiry = models.DateTimeField(null=True, help_text=_("Result expiry time")) + + class Meta: + abstract = True + verbose_name = _("Task") + verbose_name_plural = _("Tasks") + indexes = (models.Index(fields=("state", "mtime")),) + triggers = ( + pgtrigger.Trigger( + name="notify_enqueueing", + operation=pgtrigger.Insert | pgtrigger.Update, + when=pgtrigger.After, + condition=pgtrigger.Q(new__state=TaskState.QUEUED), + timing=pgtrigger.Deferred, + func=f""" + PERFORM pg_notify( + '{CHANNEL_PREFIX}.' || NEW.queue_name || '.{ChannelIdentifier.ENQUEUE.value}', + NEW.message_id::text + ); + RETURN NEW; + """, # noqa: E501 + ), + ) + + def __str__(self): + return str(self.message_id) + + +def validate_crontab(value): + try: + Cron(value) + except ValueError as exc: + raise ValidationError( + _("%(value)s is not a valid crontab"), + params={"value": value}, + ) from exc + + +class ScheduleBase(models.Model): + id = models.UUIDField(primary_key=True, default=uuid4, editable=False) + + actor_name = models.TextField(editable=False, help_text=_("Dramatiq actor to call")) + args = models.BinaryField(editable=False, help_text=_("Args to send to the actor")) + kwargs = models.BinaryField(editable=False, help_text=_("Kwargs to send to the actor")) + options = models.BinaryField(editable=False, help_text=_("Options to send to the actor")) + + rel_obj_content_type = models.ForeignKey(ContentType, on_delete=models.CASCADE, null=True) + rel_obj_id = models.TextField(null=True) + rel_obj = GenericForeignKey("rel_obj_content_type", "rel_obj_id") + + crontab = models.TextField(validators=[validate_crontab], help_text=_("When to schedule tasks")) + paused = models.BooleanField(default=False, help_text=_("Pause this schedule")) + + next_run = models.DateTimeField(auto_now_add=True, editable=False) + + class Meta: + abstract = True + verbose_name = _("Schedule") + verbose_name_plural = _("Schedules") + triggers = ( + pgtrigger.Trigger( + name="set_next_run_on_paused", + operation=pgtrigger.Update, + when=pgtrigger.Before, + condition=pgtrigger.Q(new__paused=True) & pgtrigger.Q(old__paused=False), + func=""" + NEW.next_run = to_timestamp(0); + RETURN NEW; + """, + ), + ) + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._original_crontab = self.crontab + + def __str__(self): + return f"Schedule {self.actor_name} ({self.id})" + + def save(self, *args, **kwargs): + if self.crontab != self._original_crontab: + self.next_run = self.compute_next_run(now()) + + super().save(*args, **kwargs) + + self._original_crontab = self.crontab + + @classmethod + def dispatch_by_actor(cls, actor: Actor): + """Dispatch a schedule by looking up its actor. + Only available for schedules without custom arguments.""" + schedule = cls.objects.filter(actor_name=actor.actor_name, paused=False).first() + if schedule: + schedule.send() + + def send(self, broker: Broker | None = None) -> Message: + broker = broker or get_broker() + actor: Actor = broker.get_actor(self.actor_name) + return actor.send_with_options( + args=pickle.loads(self.args), # nosec + kwargs=pickle.loads(self.kwargs), # nosec + rel_obj=self, + **pickle.loads(self.options), # nosec + ) + + def compute_next_run(self, next_run: datetime | None = None) -> datetime: + next_run: datetime = self.next_run if not next_run else next_run + while True: + next_run = Cron(self.crontab).schedule(next_run).next() + if next_run > now(): + return next_run + # Force to calculate the one after + next_run += timedelta(minutes=1) diff --git a/packages/django-dramatiq-postgres/django_dramatiq_postgres/results.py b/packages/django-dramatiq-postgres/django_dramatiq_postgres/results.py new file mode 100644 index 0000000000..fdfdcd04a0 --- /dev/null +++ b/packages/django-dramatiq-postgres/django_dramatiq_postgres/results.py @@ -0,0 +1,43 @@ +from django.db import DEFAULT_DB_ALIAS +from django.db.models import QuerySet +from django.utils import timezone +from django.utils.functional import cached_property +from django.utils.module_loading import import_string +from dramatiq.message import Message +from dramatiq.results.backend import Missing, MResult, Result, ResultBackend + +from django_dramatiq_postgres.conf import Conf +from django_dramatiq_postgres.models import TaskBase + + +class PostgresBackend(ResultBackend): + def __init__(self, *args, db_alias: str = DEFAULT_DB_ALIAS, **kwargs): + super().__init__(*args, **kwargs) + self.db_alias = db_alias + + @cached_property + def model(self) -> type[TaskBase]: + return import_string(Conf().task_model) + + @property + def query_set(self) -> QuerySet: + return self.model.objects.using(self.db_alias).defer("message") + + def build_message_key(self, message: Message) -> str: + return str(message.message_id) + + def _get(self, message_key: str) -> MResult: + message = self.query_set.filter(message_id=message_key).first() + if message is None: + return Missing + data = message.result + if data is None: + return Missing + return self.encoder.decode(data) + + def _store(self, message_key: str, result: Result, ttl: int) -> None: + self.query_set.filter(message_id=message_key).update( + mtime=timezone.now(), + result=self.encoder.encode(result), + result_expiry=timezone.now() + timezone.timedelta(milliseconds=ttl), + ) diff --git a/packages/django-dramatiq-postgres/django_dramatiq_postgres/scheduler.py b/packages/django-dramatiq-postgres/django_dramatiq_postgres/scheduler.py new file mode 100644 index 0000000000..3383d62f53 --- /dev/null +++ b/packages/django-dramatiq-postgres/django_dramatiq_postgres/scheduler.py @@ -0,0 +1,57 @@ +import pglock +from django.db import router, transaction +from django.db.models import QuerySet +from django.utils.functional import cached_property +from django.utils.module_loading import import_string +from django.utils.timezone import now +from dramatiq.broker import Broker +from dramatiq.logging import get_logger + +from django_dramatiq_postgres.conf import Conf +from django_dramatiq_postgres.models import ScheduleBase + + +class Scheduler: + broker: Broker + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.logger = get_logger(__name__, type(self)) + + @cached_property + def model(self) -> type[ScheduleBase]: + return import_string(Conf().schedule_model) + + @property + def query_set(self) -> QuerySet: + return self.model.objects.filter(paused=False) + + def process_schedule(self, schedule: ScheduleBase): + schedule.next_run = schedule.compute_next_run() + schedule.send(self.broker) + schedule.save() + + def _lock(self) -> pglock.advisory: + return pglock.advisory( + lock_id=f"{Conf().channel_prefix}.scheduler", + side_effect=pglock.Return, + timeout=0, + ) + + def _run(self) -> int: + count = 0 + with transaction.atomic(using=router.db_for_write(self.model)): + for schedule in self.query_set.select_for_update().filter( + next_run__lt=now(), + ): + self.process_schedule(schedule) + count += 1 + return count + + def run(self): + with self._lock() as lock_acquired: + if not lock_acquired: + self.logger.debug("Could not acquire lock, skipping scheduling") + return + count = self._run() + self.logger.info(f"Sent {count} scheduled tasks") diff --git a/packages/django-dramatiq-postgres/django_dramatiq_postgres/setup.py b/packages/django-dramatiq-postgres/django_dramatiq_postgres/setup.py new file mode 100644 index 0000000000..8fb58f6e12 --- /dev/null +++ b/packages/django-dramatiq-postgres/django_dramatiq_postgres/setup.py @@ -0,0 +1,3 @@ +import django + +django.setup() diff --git a/packages/django-dramatiq-postgres/pyproject.toml b/packages/django-dramatiq-postgres/pyproject.toml new file mode 100644 index 0000000000..f2294dc1ba --- /dev/null +++ b/packages/django-dramatiq-postgres/pyproject.toml @@ -0,0 +1,50 @@ +[project] +name = "django-dramatiq-postgres" +version = "0.1.0" +description = "Django and Dramatiq integration with postgres-specific features" +requires-python = ">=3.9,<3.14" +readme = "README.md" +license = "MIT" +authors = [{ name = "Authentik Security Inc.", email = "hello@goauthentik.io" }] +keywords = ["django", "dramatiq", "postgres"] + +classifiers = [ + "Development Status :: 3 - Alpha", + "Environment :: Web Environment", + "Framework :: Django", + "Framework :: Django :: 4.2", + "Framework :: Django :: 5.0", + "Framework :: Django :: 5.1", + "Framework :: Django :: 5.2", + "Intended Audience :: Developers", + "Operating System :: MacOS", + "Operating System :: POSIX", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", + "Programming Language :: Python", + "Topic :: Software Development :: Libraries :: Python Modules", +] + +dependencies = [ + "cron-converter >=1,<2", + "django >=4.2,<6.0", + "django-pgtrigger >=4,<5", + "dramatiq[watch] >=1.17,<1.18", + "tenacity >=9,<10", +] + +[project.urls] +Homepage = "https://github.com/goauthentik/authentik/tree/main/packages/django-dramatiq-postgres" +Documentation = "https://github.com/goauthentik/authentik/tree/main/packages/django-dramatiq-postgres" +Repository = "https://github.com/goauthentik/authentik/tree/main/packages/django-dramatiq-postgres" + +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[tool.setuptools.packages] +find = {} diff --git a/pyproject.toml b/pyproject.toml index ed5a37d470..b3da07f5b2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,7 +6,6 @@ authors = [{ name = "authentik Team", email = "hello@goauthentik.io" }] requires-python = "==3.13.*" dependencies = [ "argon2-cffi==25.1.0", - "celery==5.5.3", "channels==4.2.2", "channels-redis==4.3.0", "cryptography==45.0.5", @@ -16,10 +15,12 @@ dependencies = [ "django==5.1.11", "django-countries==7.6.1", "django-cte==2.0.0", + "django-dramatiq-postgres", "django-filter==25.1", "django-guardian==3.0.3", "django-model-utils==5.0.0", "django-pglock==1.7.2", + "django-pgtrigger==4.15.2", "django-prometheus==2.4.1", "django-redis==6.0.0", "django-storages[s3]==1.14.6", @@ -33,7 +34,6 @@ dependencies = [ "dumb-init==1.2.5.post1", "duo-client==5.5.0", "fido2==2.0.0", - "flower==2.0.1", "geoip2==5.1.0", "geopy==2.4.1", "google-api-python-client==2.177.0", @@ -62,7 +62,6 @@ dependencies = [ "setproctitle==1.3.6", "structlog==25.4.0", "swagger-spec-validator==3.0.4", - "tenant-schemas-celery==3.0.0", "twilio==9.7.0", "ua-parser==1.0.1", "unidecode==1.4.0", @@ -115,8 +114,12 @@ no-binary-package = [ ] [tool.uv.sources] -opencontainers = { git = "https://github.com/vsoch/oci-python", rev = "ceb4fcc090851717a3069d78e85ceb1e86c2740c" } djangorestframework = { git = "https://github.com/goauthentik/django-rest-framework", rev = "896722bab969fabc74a08b827da59409cf9f1a4e" } +django-dramatiq-postgres = { workspace = true } +opencontainers = { git = "https://github.com/vsoch/oci-python", rev = "ceb4fcc090851717a3069d78e85ceb1e86c2740c" } + +[tool.uv.workspace] +members = ["packages/django-dramatiq-postgres"] [project.scripts] ak = "lifecycle.ak:main" @@ -136,12 +139,12 @@ skip = [ "unittest.xml", "./blueprints/schema.json", "go.sum", - "locale", - "**/web/src/locales", + "locale", + "**/web/src/locales", "**/dist", # Distributed build output "**/storybook-static", "**/web/xliff", - "**/out", # TypeScript type-checking output + "**/out", # TypeScript type-checking output "./web/custom-elements.json", # TypeScript custom element definitions "./website/build", # TODO: Remove this after moving website to docs "./website/**/build", # TODO: Remove this after moving website to docs @@ -232,7 +235,7 @@ show_missing = true DJANGO_SETTINGS_MODULE = "authentik.root.settings" python_files = ["tests.py", "test_*.py", "*_tests.py"] junit_family = "xunit2" -addopts = "-p no:celery -p authentik.root.test_plugin --junitxml=unittest.xml -vv --full-trace --doctest-modules --import-mode=importlib" +addopts = "-p authentik.root.test_plugin --junitxml=unittest.xml -vv --full-trace --doctest-modules --import-mode=importlib --ignore=authentik/tasks/setup.py" filterwarnings = [ "ignore:defusedxml.lxml is no longer supported and will be removed in a future release.:DeprecationWarning", "ignore:SelectableGroups dict interface is deprecated. Use select.:DeprecationWarning", diff --git a/schema.yml b/schema.yml index e9778f5dd3..83b3221559 100644 --- a/schema.yml +++ b/schema.yml @@ -320,35 +320,6 @@ paths: schema: $ref: '#/components/schemas/GenericError' description: '' - /admin/workers/: - get: - operationId: admin_workers_list - description: Get currently connected worker count. - tags: - - admin - security: - - authentik: [] - responses: - '200': - content: - application/json: - schema: - type: array - items: - $ref: '#/components/schemas/Worker' - description: '' - '400': - content: - application/json: - schema: - $ref: '#/components/schemas/ValidationError' - description: '' - '403': - content: - application/json: - schema: - $ref: '#/components/schemas/GenericError' - description: '' /authenticators/admin/all/: get: operationId: authenticators_admin_all_list @@ -8065,145 +8036,6 @@ paths: schema: $ref: '#/components/schemas/GenericError' description: '' - /events/system_tasks/: - get: - operationId: events_system_tasks_list - description: Read-only view set that returns all background tasks - parameters: - - in: query - name: name - schema: - type: string - - name: ordering - required: false - in: query - description: Which field to use when ordering the results. - schema: - type: string - - name: page - required: false - in: query - description: A page number within the paginated result set. - schema: - type: integer - - name: page_size - required: false - in: query - description: Number of results to return per page. - schema: - type: integer - - name: search - required: false - in: query - description: A search term. - schema: - type: string - - in: query - name: status - schema: - type: string - enum: - - error - - successful - - unknown - - warning - - in: query - name: uid - schema: - type: string - tags: - - events - security: - - authentik: [] - responses: - '200': - content: - application/json: - schema: - $ref: '#/components/schemas/PaginatedSystemTaskList' - description: '' - '400': - content: - application/json: - schema: - $ref: '#/components/schemas/ValidationError' - description: '' - '403': - content: - application/json: - schema: - $ref: '#/components/schemas/GenericError' - description: '' - /events/system_tasks/{uuid}/: - get: - operationId: events_system_tasks_retrieve - description: Read-only view set that returns all background tasks - parameters: - - in: path - name: uuid - schema: - type: string - format: uuid - description: A UUID string identifying this System Task. - required: true - tags: - - events - security: - - authentik: [] - responses: - '200': - content: - application/json: - schema: - $ref: '#/components/schemas/SystemTask' - description: '' - '400': - content: - application/json: - schema: - $ref: '#/components/schemas/ValidationError' - description: '' - '403': - content: - application/json: - schema: - $ref: '#/components/schemas/GenericError' - description: '' - /events/system_tasks/{uuid}/run/: - post: - operationId: events_system_tasks_run_create - description: Run task - parameters: - - in: path - name: uuid - schema: - type: string - format: uuid - description: A UUID string identifying this System Task. - required: true - tags: - - events - security: - - authentik: [] - responses: - '204': - description: Task retried successfully - '404': - description: Task not found - '500': - description: Failed to retry task - '400': - content: - application/json: - schema: - $ref: '#/components/schemas/ValidationError' - description: '' - '403': - content: - application/json: - schema: - $ref: '#/components/schemas/GenericError' - description: '' /events/transports/: get: operationId: events_transports_list @@ -19641,8 +19473,6 @@ paths: schema: $ref: '#/components/schemas/SyncStatus' description: '' - '404': - description: Task not found '400': content: application/json: @@ -20684,8 +20514,6 @@ paths: schema: $ref: '#/components/schemas/SyncStatus' description: '' - '404': - description: Task not found '400': content: application/json: @@ -23219,8 +23047,6 @@ paths: schema: $ref: '#/components/schemas/SyncStatus' description: '' - '404': - description: Task not found '400': content: application/json: @@ -24939,6 +24765,7 @@ paths: - authentik_stages_user_login.userloginstage - authentik_stages_user_logout.userlogoutstage - authentik_stages_user_write.userwritestage + - authentik_tasks_schedules.schedule - authentik_tenants.domain required: true - in: query @@ -25188,6 +25015,7 @@ paths: - authentik_stages_user_login.userloginstage - authentik_stages_user_logout.userlogoutstage - authentik_stages_user_write.userwritestage + - authentik_tasks_schedules.schedule - authentik_tenants.domain required: true - in: query @@ -28322,7 +28150,7 @@ paths: /sources/kerberos/{slug}/sync/status/: get: operationId: sources_kerberos_sync_status_retrieve - description: Get source's sync status + description: Get provider's sync status parameters: - in: path name: slug @@ -28339,7 +28167,7 @@ paths: content: application/json: schema: - $ref: '#/components/schemas/KerberosSyncStatus' + $ref: '#/components/schemas/SyncStatus' description: '' '400': content: @@ -28772,7 +28600,7 @@ paths: /sources/ldap/{slug}/sync/status/: get: operationId: sources_ldap_sync_status_retrieve - description: Get source's sync status + description: Get provider's sync status parameters: - in: path name: slug @@ -40617,6 +40445,422 @@ paths: schema: $ref: '#/components/schemas/GenericError' description: '' + /tasks/schedules/: + get: + operationId: tasks_schedules_list + parameters: + - in: query + name: actor_name + schema: + type: string + - name: ordering + required: false + in: query + description: Which field to use when ordering the results. + schema: + type: string + - name: page + required: false + in: query + description: A page number within the paginated result set. + schema: + type: integer + - name: page_size + required: false + in: query + description: Number of results to return per page. + schema: + type: integer + - in: query + name: paused + schema: + type: boolean + - in: query + name: rel_obj_content_type__app_label + schema: + type: string + - in: query + name: rel_obj_content_type__model + schema: + type: string + - in: query + name: rel_obj_id + schema: + type: string + - in: query + name: rel_obj_id__isnull + schema: + type: boolean + - name: search + required: false + in: query + description: A search term. + schema: + type: string + tags: + - tasks + security: + - authentik: [] + responses: + '200': + content: + application/json: + schema: + $ref: '#/components/schemas/PaginatedScheduleList' + description: '' + '400': + content: + application/json: + schema: + $ref: '#/components/schemas/ValidationError' + description: '' + '403': + content: + application/json: + schema: + $ref: '#/components/schemas/GenericError' + description: '' + /tasks/schedules/{id}/: + get: + operationId: tasks_schedules_retrieve + parameters: + - in: path + name: id + schema: + type: string + format: uuid + description: A UUID string identifying this Schedule. + required: true + tags: + - tasks + security: + - authentik: [] + responses: + '200': + content: + application/json: + schema: + $ref: '#/components/schemas/Schedule' + description: '' + '400': + content: + application/json: + schema: + $ref: '#/components/schemas/ValidationError' + description: '' + '403': + content: + application/json: + schema: + $ref: '#/components/schemas/GenericError' + description: '' + put: + operationId: tasks_schedules_update + parameters: + - in: path + name: id + schema: + type: string + format: uuid + description: A UUID string identifying this Schedule. + required: true + tags: + - tasks + requestBody: + content: + application/json: + schema: + $ref: '#/components/schemas/ScheduleRequest' + required: true + security: + - authentik: [] + responses: + '200': + content: + application/json: + schema: + $ref: '#/components/schemas/Schedule' + description: '' + '400': + content: + application/json: + schema: + $ref: '#/components/schemas/ValidationError' + description: '' + '403': + content: + application/json: + schema: + $ref: '#/components/schemas/GenericError' + description: '' + patch: + operationId: tasks_schedules_partial_update + parameters: + - in: path + name: id + schema: + type: string + format: uuid + description: A UUID string identifying this Schedule. + required: true + tags: + - tasks + requestBody: + content: + application/json: + schema: + $ref: '#/components/schemas/PatchedScheduleRequest' + security: + - authentik: [] + responses: + '200': + content: + application/json: + schema: + $ref: '#/components/schemas/Schedule' + description: '' + '400': + content: + application/json: + schema: + $ref: '#/components/schemas/ValidationError' + description: '' + '403': + content: + application/json: + schema: + $ref: '#/components/schemas/GenericError' + description: '' + /tasks/schedules/{id}/send/: + post: + operationId: tasks_schedules_send_create + description: Trigger this schedule now + parameters: + - in: path + name: id + schema: + type: string + format: uuid + description: A UUID string identifying this Schedule. + required: true + tags: + - tasks + security: + - authentik: [] + responses: + '204': + description: Schedule sent successfully + '404': + description: Schedule not found + '500': + description: Failed to send schedule + '400': + content: + application/json: + schema: + $ref: '#/components/schemas/ValidationError' + description: '' + '403': + content: + application/json: + schema: + $ref: '#/components/schemas/GenericError' + description: '' + /tasks/tasks/: + get: + operationId: tasks_tasks_list + parameters: + - in: query + name: actor_name + schema: + type: string + - in: query + name: aggregated_status + schema: + type: array + items: + type: string + enum: + - consumed + - done + - error + - info + - queued + - rejected + - warning + explode: true + style: form + - name: ordering + required: false + in: query + description: Which field to use when ordering the results. + schema: + type: string + - name: page + required: false + in: query + description: A page number within the paginated result set. + schema: + type: integer + - name: page_size + required: false + in: query + description: Number of results to return per page. + schema: + type: integer + - in: query + name: queue_name + schema: + type: string + - in: query + name: rel_obj_content_type__app_label + schema: + type: string + - in: query + name: rel_obj_content_type__model + schema: + type: string + - in: query + name: rel_obj_id + schema: + type: string + - in: query + name: rel_obj_id__isnull + schema: + type: boolean + - name: search + required: false + in: query + description: A search term. + schema: + type: string + - in: query + name: state + schema: + type: string + enum: + - consumed + - done + - queued + - rejected + description: |+ + Task status + + tags: + - tasks + security: + - authentik: [] + responses: + '200': + content: + application/json: + schema: + $ref: '#/components/schemas/PaginatedTaskList' + description: '' + '400': + content: + application/json: + schema: + $ref: '#/components/schemas/ValidationError' + description: '' + '403': + content: + application/json: + schema: + $ref: '#/components/schemas/GenericError' + description: '' + /tasks/tasks/{message_id}/: + get: + operationId: tasks_tasks_retrieve + parameters: + - in: path + name: message_id + schema: + type: string + format: uuid + description: A UUID string identifying this Task. + required: true + tags: + - tasks + security: + - authentik: [] + responses: + '200': + content: + application/json: + schema: + $ref: '#/components/schemas/Task' + description: '' + '400': + content: + application/json: + schema: + $ref: '#/components/schemas/ValidationError' + description: '' + '403': + content: + application/json: + schema: + $ref: '#/components/schemas/GenericError' + description: '' + /tasks/tasks/{message_id}/retry/: + post: + operationId: tasks_tasks_retry_create + description: Retry task + parameters: + - in: path + name: message_id + schema: + type: string + format: uuid + description: A UUID string identifying this Task. + required: true + tags: + - tasks + security: + - authentik: [] + responses: + '204': + description: Task retried successfully + '400': + description: Task is not in a retryable state + '404': + description: Task not found + '403': + content: + application/json: + schema: + $ref: '#/components/schemas/GenericError' + description: '' + /tasks/workers: + get: + operationId: tasks_workers_list + description: Get currently connected worker count. + tags: + - tasks + security: + - authentik: [] + responses: + '200': + content: + application/json: + schema: + type: array + items: + $ref: '#/components/schemas/Worker' + description: '' + '400': + content: + application/json: + schema: + $ref: '#/components/schemas/ValidationError' + description: '' + '403': + content: + application/json: + schema: + $ref: '#/components/schemas/GenericError' + description: '' /tenants/domains/: get: operationId: tenants_domains_list @@ -41160,9 +41404,11 @@ components: AppEnum: enum: - authentik.tenants + - authentik.tasks - authentik.admin - authentik.api - authentik.crypto + - authentik.events - authentik.flows - authentik.outposts - authentik.policies.dummy @@ -41210,6 +41456,7 @@ components: - authentik.stages.user_login - authentik.stages.user_logout - authentik.stages.user_write + - authentik.tasks.schedules - authentik.brands - authentik.blueprints - authentik.core @@ -41223,7 +41470,6 @@ components: - authentik.enterprise.stages.authenticator_endpoint_gdtc - authentik.enterprise.stages.mtls - authentik.enterprise.stages.source - - authentik.events type: string AppleChallengeResponseRequest: type: object @@ -47431,21 +47677,6 @@ components: - name - realm - slug - KerberosSyncStatus: - type: object - description: Kerberos Source sync status - properties: - is_running: - type: boolean - readOnly: true - tasks: - type: array - items: - $ref: '#/components/schemas/SystemTask' - readOnly: true - required: - - is_running - - tasks KubernetesServiceConnection: type: object description: KubernetesServiceConnection Serializer @@ -48167,6 +48398,16 @@ components: - name - server_uri - slug + LastTaskStatusEnum: + enum: + - queued + - consumed + - rejected + - done + - info + - warning + - error + type: string License: type: object description: License Serializer @@ -48628,6 +48869,11 @@ components: enum: - authentik_tenants.domain - authentik_crypto.certificatekeypair + - authentik_events.event + - authentik_events.notificationtransport + - authentik_events.notification + - authentik_events.notificationrule + - authentik_events.notificationwebhookmapping - authentik_flows.flow - authentik_flows.flowstagebinding - authentik_outposts.dockerserviceconnection @@ -48708,6 +48954,7 @@ components: - authentik_stages_user_login.userloginstage - authentik_stages_user_logout.userlogoutstage - authentik_stages_user_write.userwritestage + - authentik_tasks_schedules.schedule - authentik_brands.brand - authentik_blueprints.blueprintinstance - authentik_core.group @@ -48725,11 +48972,6 @@ components: - authentik_stages_authenticator_endpoint_gdtc.authenticatorendpointgdtcstage - authentik_stages_mtls.mutualtlsstage - authentik_stages_source.sourcestage - - authentik_events.event - - authentik_events.notificationtransport - - authentik_events.notification - - authentik_events.notificationrule - - authentik_events.notificationwebhookmapping type: string MutualTLSStage: type: object @@ -51583,6 +51825,21 @@ components: - pagination - results - autocomplete + PaginatedScheduleList: + type: object + properties: + pagination: + $ref: '#/components/schemas/Pagination' + results: + type: array + items: + $ref: '#/components/schemas/Schedule' + autocomplete: + $ref: '#/components/schemas/Autocomplete' + required: + - pagination + - results + - autocomplete PaginatedScopeMappingList: type: object properties: @@ -51673,21 +51930,6 @@ components: - pagination - results - autocomplete - PaginatedSystemTaskList: - type: object - properties: - pagination: - $ref: '#/components/schemas/Pagination' - results: - type: array - items: - $ref: '#/components/schemas/SystemTask' - autocomplete: - $ref: '#/components/schemas/Autocomplete' - required: - - pagination - - results - - autocomplete PaginatedTOTPDeviceList: type: object properties: @@ -51703,6 +51945,21 @@ components: - pagination - results - autocomplete + PaginatedTaskList: + type: object + properties: + pagination: + $ref: '#/components/schemas/Pagination' + results: + type: array + items: + $ref: '#/components/schemas/Task' + autocomplete: + $ref: '#/components/schemas/Autocomplete' + required: + - pagination + - results + - autocomplete PaginatedTenantList: type: object properties: @@ -55337,6 +55594,20 @@ components: event_retention: type: string minLength: 1 + PatchedScheduleRequest: + type: object + properties: + rel_obj_id: + type: string + nullable: true + minLength: 1 + crontab: + type: string + minLength: 1 + description: When to schedule tasks + paused: + type: boolean + description: Pause this schedule PatchedScopeMappingRequest: type: object description: ScopeMapping Serializer @@ -59291,6 +59562,81 @@ components: - pk - provider - provider_obj + Schedule: + type: object + properties: + id: + type: string + format: uuid + readOnly: true + identifier: + type: string + readOnly: true + nullable: true + description: Unique schedule identifier + uid: + type: string + readOnly: true + actor_name: + type: string + readOnly: true + description: Dramatiq actor to call + rel_obj_app_label: + type: string + readOnly: true + rel_obj_model: + type: string + title: Python model class name + readOnly: true + rel_obj_id: + type: string + nullable: true + crontab: + type: string + description: When to schedule tasks + paused: + type: boolean + description: Pause this schedule + next_run: + type: string + format: date-time + readOnly: true + description: + type: string + nullable: true + readOnly: true + last_task_status: + allOf: + - $ref: '#/components/schemas/LastTaskStatusEnum' + nullable: true + readOnly: true + required: + - actor_name + - crontab + - description + - id + - identifier + - last_task_status + - next_run + - rel_obj_app_label + - rel_obj_model + - uid + ScheduleRequest: + type: object + properties: + rel_obj_id: + type: string + nullable: true + minLength: 1 + crontab: + type: string + minLength: 1 + description: When to schedule tasks + paused: + type: boolean + description: Pause this schedule + required: + - crontab ScopeMapping: type: object description: ScopeMapping Serializer @@ -59982,6 +60328,13 @@ components: $ref: '#/components/schemas/FlowSetRequest' required: - name + StateEnum: + enum: + - queued + - consumed + - rejected + - done + type: string StaticDevice: type: object description: Serializer for static authenticator devices @@ -60080,19 +60433,17 @@ components: - messages SyncStatus: type: object - description: Provider sync status + description: Provider/source sync status properties: is_running: type: boolean - readOnly: true - tasks: - type: array - items: - $ref: '#/components/schemas/SystemTask' - readOnly: true + last_successful_sync: + type: string + format: date-time + last_sync_status: + $ref: '#/components/schemas/TaskAggregatedStatusEnum' required: - is_running - - tasks SystemInfo: type: object description: Get system information. @@ -60168,65 +60519,6 @@ components: - http_is_secure - runtime - server_time - SystemTask: - type: object - description: Serialize TaskInfo and TaskResult - properties: - uuid: - type: string - format: uuid - readOnly: true - name: - type: string - full_name: - type: string - description: Get full name with UID - readOnly: true - uid: - type: string - description: - type: string - start_timestamp: - type: string - format: date-time - readOnly: true - finish_timestamp: - type: string - format: date-time - readOnly: true - duration: - type: number - format: double - readOnly: true - status: - $ref: '#/components/schemas/SystemTaskStatusEnum' - messages: - type: array - items: - $ref: '#/components/schemas/LogEvent' - expires: - type: string - format: date-time - nullable: true - expiring: - type: boolean - required: - - description - - duration - - finish_timestamp - - full_name - - messages - - name - - start_timestamp - - status - - uuid - SystemTaskStatusEnum: - enum: - - unknown - - successful - - warning - - error - type: string TOTPDevice: type: object description: Serializer for totp authenticator devices @@ -60258,6 +60550,72 @@ components: maxLength: 64 required: - name + Task: + type: object + properties: + message_id: + type: string + format: uuid + queue_name: + type: string + description: Queue name + actor_name: + type: string + description: Dramatiq actor name + state: + allOf: + - $ref: '#/components/schemas/StateEnum' + description: Task status + mtime: + type: string + format: date-time + description: Task last modified time + rel_obj_app_label: + type: string + readOnly: true + rel_obj_model: + type: string + title: Python model class name + readOnly: true + rel_obj_id: + type: string + nullable: true + uid: + type: string + readOnly: true + messages: + type: array + items: + $ref: '#/components/schemas/LogEvent' + previous_messages: + type: array + items: + $ref: '#/components/schemas/LogEvent' + aggregated_status: + $ref: '#/components/schemas/TaskAggregatedStatusEnum' + description: + type: string + nullable: true + readOnly: true + required: + - actor_name + - aggregated_status + - description + - messages + - previous_messages + - rel_obj_app_label + - rel_obj_model + - uid + TaskAggregatedStatusEnum: + enum: + - queued + - consumed + - rejected + - done + - info + - warning + - error + type: string Tenant: type: object description: Tenant Serializer diff --git a/scripts/generate_config.py b/scripts/generate_config.py index df87abdba4..6172dec9c6 100755 --- a/scripts/generate_config.py +++ b/scripts/generate_config.py @@ -46,6 +46,12 @@ def generate_local_config(): "enabled": False, "api_key": generate_id(), }, + "worker": { + "processes": 1, + "threads": 1, + "consumer_listen_timeout": "seconds=10", + "scheduler_interval": "seconds=30", + }, } diff --git a/tests/e2e/test_source_ldap_samba.py b/tests/e2e/test_source_ldap_samba.py index 77f1444dae..0ec40a1f86 100644 --- a/tests/e2e/test_source_ldap_samba.py +++ b/tests/e2e/test_source_ldap_samba.py @@ -11,6 +11,7 @@ from authentik.sources.ldap.models import LDAPSource, LDAPSourcePropertyMapping from authentik.sources.ldap.sync.groups import GroupLDAPSynchronizer from authentik.sources.ldap.sync.membership import MembershipLDAPSynchronizer from authentik.sources.ldap.sync.users import UserLDAPSynchronizer +from authentik.tasks.models import Task from tests.e2e.utils import SeleniumTestCase, retry @@ -60,7 +61,7 @@ class TestSourceLDAPSamba(SeleniumTestCase): name="goauthentik.io/sources/ldap/default-name" ) ) - UserLDAPSynchronizer(source).sync_full() + UserLDAPSynchronizer(source, Task()).sync_full() self.assertTrue(User.objects.filter(username="bob").exists()) self.assertTrue(User.objects.filter(username="james").exists()) self.assertTrue(User.objects.filter(username="john").exists()) @@ -93,9 +94,9 @@ class TestSourceLDAPSamba(SeleniumTestCase): managed="goauthentik.io/sources/ldap/default-name" ) ) - GroupLDAPSynchronizer(source).sync_full() - UserLDAPSynchronizer(source).sync_full() - MembershipLDAPSynchronizer(source).sync_full() + GroupLDAPSynchronizer(source, Task()).sync_full() + UserLDAPSynchronizer(source, Task()).sync_full() + MembershipLDAPSynchronizer(source, Task()).sync_full() self.assertIsNotNone(User.objects.get(username="bob")) self.assertIsNotNone(User.objects.get(username="james")) self.assertIsNotNone(User.objects.get(username="john")) @@ -139,7 +140,7 @@ class TestSourceLDAPSamba(SeleniumTestCase): name="goauthentik.io/sources/ldap/default-name" ) ) - UserLDAPSynchronizer(source).sync_full() + UserLDAPSynchronizer(source, Task()).sync_full() username = "bob" password = generate_id() result = self.samba.exec_run( @@ -162,7 +163,7 @@ class TestSourceLDAPSamba(SeleniumTestCase): ) self.assertEqual(result.exit_code, 0) # Sync again - UserLDAPSynchronizer(source).sync_full() + UserLDAPSynchronizer(source, Task()).sync_full() user.refresh_from_db() # Since password in samba was checked, it should be invalidated here too self.assertFalse(user.has_usable_password()) diff --git a/tests/integration/test_outpost_docker.py b/tests/integration/test_outpost_docker.py index 26d4c0993c..5ec3a93b64 100644 --- a/tests/integration/test_outpost_docker.py +++ b/tests/integration/test_outpost_docker.py @@ -45,7 +45,7 @@ class OutpostDockerTests(DockerTestCase, ChannelsLiveServerTestCase): }, ) # Ensure that local connection have been created - outpost_connection_discovery() + outpost_connection_discovery.send() self.provider: ProxyProvider = ProxyProvider.objects.create( name="test", internal_host="http://localhost", diff --git a/tests/integration/test_outpost_kubernetes.py b/tests/integration/test_outpost_kubernetes.py index 099eddc87a..01566c4f06 100644 --- a/tests/integration/test_outpost_kubernetes.py +++ b/tests/integration/test_outpost_kubernetes.py @@ -24,7 +24,7 @@ class OutpostKubernetesTests(TestCase): def setUp(self): super().setUp() # Ensure that local connection have been created - outpost_connection_discovery() + outpost_connection_discovery.send() self.provider: ProxyProvider = ProxyProvider.objects.create( name="test", internal_host="http://localhost", diff --git a/tests/integration/test_proxy_docker.py b/tests/integration/test_proxy_docker.py index a11116da27..f0268cfb45 100644 --- a/tests/integration/test_proxy_docker.py +++ b/tests/integration/test_proxy_docker.py @@ -45,7 +45,7 @@ class TestProxyDocker(DockerTestCase, ChannelsLiveServerTestCase): }, ) # Ensure that local connection have been created - outpost_connection_discovery() + outpost_connection_discovery.send() self.provider: ProxyProvider = ProxyProvider.objects.create( name="test", internal_host="http://localhost", diff --git a/tests/integration/test_proxy_kubernetes.py b/tests/integration/test_proxy_kubernetes.py index 477beb3851..262e57ad45 100644 --- a/tests/integration/test_proxy_kubernetes.py +++ b/tests/integration/test_proxy_kubernetes.py @@ -23,7 +23,7 @@ class TestProxyKubernetes(TestCase): def setUp(self): # Ensure that local connection have been created - outpost_connection_discovery() + outpost_connection_discovery.send() self.controller = None @pytest.mark.timeout(120) diff --git a/uv.lock b/uv.lock index 2563859080..4b92ffd2cf 100644 --- a/uv.lock +++ b/uv.lock @@ -2,6 +2,12 @@ version = 1 revision = 2 requires-python = "==3.13.*" +[manifest] +members = [ + "authentik", + "django-dramatiq-postgres", +] + [[package]] name = "aiohappyeyeballs" version = "2.6.1" @@ -69,18 +75,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/fb/76/641ae371508676492379f16e2fa48f4e2c11741bd63c48be4b12a6b09cba/aiosignal-1.4.0-py3-none-any.whl", hash = "sha256:053243f8b92b990551949e63930a839ff0cf0b0ebbe0597b0f3fb19e1a0fe82e", size = 7490, upload-time = "2025-07-03T22:54:42.156Z" }, ] -[[package]] -name = "amqp" -version = "5.3.1" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "vine" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/79/fc/ec94a357dfc6683d8c86f8b4cfa5416a4c36b28052ec8260c77aca96a443/amqp-5.3.1.tar.gz", hash = "sha256:cddc00c725449522023bad949f70fff7b48f0b1ade74d170a6f10ab044739432", size = 129013, upload-time = "2024-11-12T19:55:44.051Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/26/99/fc813cd978842c26c82534010ea849eee9ab3a13ea2b74e95cb9c99e747b/amqp-5.3.1-py3-none-any.whl", hash = "sha256:43b3319e1b4e7d1251833a93d672b4af1e40f3d632d479b98661a95f117880a2", size = 50944, upload-time = "2024-11-12T19:55:41.782Z" }, -] - [[package]] name = "annotated-types" version = "0.7.0" @@ -169,7 +163,6 @@ version = "2025.6.4" source = { editable = "." } dependencies = [ { name = "argon2-cffi" }, - { name = "celery" }, { name = "channels" }, { name = "channels-redis" }, { name = "cryptography" }, @@ -179,10 +172,12 @@ dependencies = [ { name = "django" }, { name = "django-countries" }, { name = "django-cte" }, + { name = "django-dramatiq-postgres" }, { name = "django-filter" }, { name = "django-guardian" }, { name = "django-model-utils" }, { name = "django-pglock" }, + { name = "django-pgtrigger" }, { name = "django-prometheus" }, { name = "django-redis" }, { name = "django-storages", extra = ["s3"] }, @@ -196,7 +191,6 @@ dependencies = [ { name = "dumb-init" }, { name = "duo-client" }, { name = "fido2" }, - { name = "flower" }, { name = "geoip2" }, { name = "geopy" }, { name = "google-api-python-client" }, @@ -225,7 +219,6 @@ dependencies = [ { name = "setproctitle" }, { name = "structlog" }, { name = "swagger-spec-validator" }, - { name = "tenant-schemas-celery" }, { name = "twilio" }, { name = "ua-parser" }, { name = "unidecode" }, @@ -268,7 +261,6 @@ dev = [ [package.metadata] requires-dist = [ { name = "argon2-cffi", specifier = "==25.1.0" }, - { name = "celery", specifier = "==5.5.3" }, { name = "channels", specifier = "==4.2.2" }, { name = "channels-redis", specifier = "==4.3.0" }, { name = "cryptography", specifier = "==45.0.5" }, @@ -278,10 +270,12 @@ requires-dist = [ { name = "django", specifier = "==5.1.11" }, { name = "django-countries", specifier = "==7.6.1" }, { name = "django-cte", specifier = "==2.0.0" }, + { name = "django-dramatiq-postgres", editable = "packages/django-dramatiq-postgres" }, { name = "django-filter", specifier = "==25.1" }, { name = "django-guardian", specifier = "==3.0.3" }, { name = "django-model-utils", specifier = "==5.0.0" }, { name = "django-pglock", specifier = "==1.7.2" }, + { name = "django-pgtrigger", specifier = "==4.15.2" }, { name = "django-prometheus", specifier = "==2.4.1" }, { name = "django-redis", specifier = "==6.0.0" }, { name = "django-storages", extras = ["s3"], specifier = "==1.14.6" }, @@ -295,7 +289,6 @@ requires-dist = [ { name = "dumb-init", specifier = "==1.2.5.post1" }, { name = "duo-client", specifier = "==5.5.0" }, { name = "fido2", specifier = "==2.0.0" }, - { name = "flower", specifier = "==2.0.1" }, { name = "geoip2", specifier = "==5.1.0" }, { name = "geopy", specifier = "==2.4.1" }, { name = "google-api-python-client", specifier = "==2.177.0" }, @@ -324,7 +317,6 @@ requires-dist = [ { name = "setproctitle", specifier = "==1.3.6" }, { name = "structlog", specifier = "==25.4.0" }, { name = "swagger-spec-validator", specifier = "==3.0.4" }, - { name = "tenant-schemas-celery", specifier = "==3.0.0" }, { name = "twilio", specifier = "==9.7.0" }, { name = "ua-parser", specifier = "==1.0.1" }, { name = "unidecode", specifier = "==1.4.0" }, @@ -543,15 +535,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/a9/cf/45fb5261ece3e6b9817d3d82b2f343a505fd58674a92577923bc500bd1aa/bcrypt-4.3.0-cp39-abi3-win_amd64.whl", hash = "sha256:e53e074b120f2877a35cc6c736b8eb161377caae8925c17688bd46ba56daaa5b", size = 152799, upload-time = "2025-02-28T01:23:53.139Z" }, ] -[[package]] -name = "billiard" -version = "4.2.1" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/7c/58/1546c970afcd2a2428b1bfafecf2371d8951cc34b46701bea73f4280989e/billiard-4.2.1.tar.gz", hash = "sha256:12b641b0c539073fc8d3f5b8b7be998956665c4233c7c1fcd66a7e677c4fb36f", size = 155031, upload-time = "2024-09-21T13:40:22.491Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/30/da/43b15f28fe5f9e027b41c539abc5469052e9d48fd75f8ff094ba2a0ae767/billiard-4.2.1-py3-none-any.whl", hash = "sha256:40b59a4ac8806ba2c2369ea98d876bc6108b051c227baffd928c644d15d8f3cb", size = 86766, upload-time = "2024-09-21T13:40:20.188Z" }, -] - [[package]] name = "black" version = "25.1.0" @@ -646,25 +629,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/9b/ef/1c4698cac96d792005ef0611832f38eaee477c275ab4b02cbfc4daba7ad3/cbor2-5.6.5-py3-none-any.whl", hash = "sha256:3038523b8fc7de312bb9cdcbbbd599987e64307c4db357cd2030c472a6c7d468", size = 23752, upload-time = "2024-10-09T12:26:23.167Z" }, ] -[[package]] -name = "celery" -version = "5.5.3" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "billiard" }, - { name = "click" }, - { name = "click-didyoumean" }, - { name = "click-plugins" }, - { name = "click-repl" }, - { name = "kombu" }, - { name = "python-dateutil" }, - { name = "vine" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/bb/7d/6c289f407d219ba36d8b384b42489ebdd0c84ce9c413875a8aae0c85f35b/celery-5.5.3.tar.gz", hash = "sha256:6c972ae7968c2b5281227f01c3a3f984037d21c5129d07bf3550cc2afc6b10a5", size = 1667144, upload-time = "2025-06-01T11:08:12.563Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/c9/af/0dcccc7fdcdf170f9a1585e5e96b6fb0ba1749ef6be8c89a6202284759bd/celery-5.5.3-py3-none-any.whl", hash = "sha256:0b5761a07057acee94694464ca482416b959568904c9dfa41ce8413a7d65d525", size = 438775, upload-time = "2025-06-01T11:08:09.94Z" }, -] - [[package]] name = "certifi" version = "2025.7.14" @@ -763,43 +727,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/85/32/10bb5764d90a8eee674e9dc6f4db6a0ab47c8c4d0d83c27f7c39ac415a4d/click-8.2.1-py3-none-any.whl", hash = "sha256:61a3265b914e850b85317d0b3109c7f8cd35a670f963866005d6ef1d5175a12b", size = 102215, upload-time = "2025-05-20T23:19:47.796Z" }, ] -[[package]] -name = "click-didyoumean" -version = "0.3.1" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "click" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/30/ce/217289b77c590ea1e7c24242d9ddd6e249e52c795ff10fac2c50062c48cb/click_didyoumean-0.3.1.tar.gz", hash = "sha256:4f82fdff0dbe64ef8ab2279bd6aa3f6a99c3b28c05aa09cbfc07c9d7fbb5a463", size = 3089, upload-time = "2024-03-24T08:22:07.499Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/1b/5b/974430b5ffdb7a4f1941d13d83c64a0395114503cc357c6b9ae4ce5047ed/click_didyoumean-0.3.1-py3-none-any.whl", hash = "sha256:5c4bb6007cfea5f2fd6583a2fb6701a22a41eb98957e63d0fac41c10e7c3117c", size = 3631, upload-time = "2024-03-24T08:22:06.356Z" }, -] - -[[package]] -name = "click-plugins" -version = "1.1.1.2" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "click" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/c3/a4/34847b59150da33690a36da3681d6bbc2ec14ee9a846bc30a6746e5984e4/click_plugins-1.1.1.2.tar.gz", hash = "sha256:d7af3984a99d243c131aa1a828331e7630f4a88a9741fd05c927b204bcf92261", size = 8343, upload-time = "2025-06-25T00:47:37.555Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/3d/9a/2abecb28ae875e39c8cad711eb1186d8d14eab564705325e77e4e6ab9ae5/click_plugins-1.1.1.2-py2.py3-none-any.whl", hash = "sha256:008d65743833ffc1f5417bf0e78e8d2c23aab04d9745ba817bd3e71b0feb6aa6", size = 11051, upload-time = "2025-06-25T00:47:36.731Z" }, -] - -[[package]] -name = "click-repl" -version = "0.3.0" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "click" }, - { name = "prompt-toolkit" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/cb/a2/57f4ac79838cfae6912f997b4d1a64a858fb0c86d7fcaae6f7b58d267fca/click-repl-0.3.0.tar.gz", hash = "sha256:17849c23dba3d667247dc4defe1757fff98694e90fe37474f3feebb69ced26a9", size = 10449, upload-time = "2023-06-15T12:43:51.141Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/52/40/9d857001228658f0d59e97ebd4c346fe73e138c6de1bce61dc568a57c7f8/click_repl-0.3.0-py3-none-any.whl", hash = "sha256:fb7e06deb8da8de86180a33a9da97ac316751c094c6899382da7feeeeb51b812", size = 10289, upload-time = "2023-06-15T12:43:48.626Z" }, -] - [[package]] name = "codespell" version = "2.4.1" @@ -870,6 +797,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/59/f1/4da7717f0063a222db253e7121bd6a56f6fb1ba439dcc36659088793347c/coverage-7.8.0-py3-none-any.whl", hash = "sha256:dbf364b4c5e7bae9250528167dfe40219b62e2d573c854d74be213e1e52069f7", size = 203435, upload-time = "2025-03-30T20:36:43.61Z" }, ] +[[package]] +name = "cron-converter" +version = "1.2.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "python-dateutil" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/c1/45/549d071e7bde4d3bb6a566b1a116e3b79803df916c3499d27509b214a965/cron_converter-1.2.1.tar.gz", hash = "sha256:6766c6ba44b8236201ac03030f314fd655343c1c4848ce216458e8d340066c59", size = 14313, upload-time = "2024-05-25T17:56:51.757Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/2e/76/2a477e17b7c5c49e81bdc711aab7ba9a2a661c54b7c5021e0c1c01abb0e0/cron_converter-1.2.1-py3-none-any.whl", hash = "sha256:4604e356c15a8fbe76a86bb42508f611ad3cade7dd65e2d6f601c2e0d5226ffc", size = 13338, upload-time = "2024-05-25T17:56:49.51Z" }, +] + [[package]] name = "cryptography" version = "45.0.5" @@ -1007,6 +946,27 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/0c/88/2f3510961fbfabe1441e69783ccca197a129a9705c539aaf019bf6d9fba5/django_cte-2.0.0-py3-none-any.whl", hash = "sha256:6de54525be8d9c8bd3d0a2d0422a733295f501bca0ae1c9351cf314d376e0111", size = 12975, upload-time = "2025-06-16T15:29:27.465Z" }, ] +[[package]] +name = "django-dramatiq-postgres" +version = "0.1.0" +source = { editable = "packages/django-dramatiq-postgres" } +dependencies = [ + { name = "cron-converter" }, + { name = "django" }, + { name = "django-pgtrigger" }, + { name = "dramatiq", extra = ["watch"] }, + { name = "tenacity" }, +] + +[package.metadata] +requires-dist = [ + { name = "cron-converter", specifier = ">=1,<2" }, + { name = "django", specifier = ">=4.2,<6.0" }, + { name = "django-pgtrigger", specifier = ">=4,<5" }, + { name = "dramatiq", extras = ["watch"], specifier = ">=1.17,<1.18" }, + { name = "tenacity", specifier = ">=9,<10" }, +] + [[package]] name = "django-filter" version = "25.1" @@ -1068,6 +1028,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/fb/d2/21f19531945f03021460d40654bc2fc3b0c474b57b279d5f5a1c34be7f1b/django_pglock-1.7.2-py3-none-any.whl", hash = "sha256:2f9335527779445fe86507b37e26cfde485a32b91d982a8f80039d3bcd25d596", size = 17674, upload-time = "2025-05-15T22:07:22.618Z" }, ] +[[package]] +name = "django-pgtrigger" +version = "4.15.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "django" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/3e/5f/dfbe4ac57214549d49e1748fcced105a1ae8718f79bc26c7288835fb6de9/django_pgtrigger-4.15.2.tar.gz", hash = "sha256:a7286db986baa3aa759526c440b8b9baf894b56a4ee237333a44c06f56ab6b9f", size = 33120, upload-time = "2025-04-29T20:38:22.033Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/40/97/81a4779e73f09d347d006b866b85b4be7774543d9f6eb6fe49a1303802af/django_pgtrigger-4.15.2-py3-none-any.whl", hash = "sha256:77ac6bd44fe5df0307e2630e3294cc68dcb25c9ec1a6bb523667546d28c80bea", size = 36434, upload-time = "2025-04-29T20:38:20.7Z" }, +] + [[package]] name = "django-prometheus" version = "2.4.1" @@ -1176,6 +1148,24 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/e3/26/57c6fb270950d476074c087527a558ccb6f4436657314bfb6cdf484114c4/docker-7.1.0-py3-none-any.whl", hash = "sha256:c96b93b7f0a746f9e77d325bcfb87422a3d8bd4f03136ae8a85b37f1898d5fc0", size = 147774, upload-time = "2024-05-23T11:13:55.01Z" }, ] +[[package]] +name = "dramatiq" +version = "1.17.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "prometheus-client" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/c6/7a/6792ddc64a77d22bfd97261b751a7a76cf2f9d62edc59aafb679ac48b77d/dramatiq-1.17.1.tar.gz", hash = "sha256:2675d2f57e0d82db3a7d2a60f1f9c536365349db78c7f8d80a63e4c54697647a", size = 99071, upload-time = "2024-10-26T05:09:28.283Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/ee/36/925c7afd5db4f1a3f00676b9c3c58f31ff7ae29a347282d86c8d429280a5/dramatiq-1.17.1-py3-none-any.whl", hash = "sha256:951cdc334478dff8e5150bb02a6f7a947d215ee24b5aedaf738eff20e17913df", size = 120382, upload-time = "2024-10-26T05:09:26.436Z" }, +] + +[package.optional-dependencies] +watch = [ + { name = "watchdog" }, + { name = "watchdog-gevent" }, +] + [[package]] name = "drf-jsonschema-serializer" version = "3.0.0" @@ -1279,22 +1269,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/4c/7d/a1dba174d7ec4b6b8d6360eed0ac3a4a4e2aa45f234e903592d3184c6c3f/fido2-2.0.0-py3-none-any.whl", hash = "sha256:685f54a50a57e019c6156e2dd699802a603e3abf70bab334f26affdd4fb8d4f7", size = 224761, upload-time = "2025-05-20T09:44:59.029Z" }, ] -[[package]] -name = "flower" -version = "2.0.1" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "celery" }, - { name = "humanize" }, - { name = "prometheus-client" }, - { name = "pytz" }, - { name = "tornado" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/09/a1/357f1b5d8946deafdcfdd604f51baae9de10aafa2908d0b7322597155f92/flower-2.0.1.tar.gz", hash = "sha256:5ab717b979530770c16afb48b50d2a98d23c3e9fe39851dcf6bc4d01845a02a0", size = 3220408, upload-time = "2023-08-13T14:37:46.073Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/a6/ff/ee2f67c0ff146ec98b5df1df637b2bc2d17beeb05df9f427a67bd7a7d79c/flower-2.0.1-py2.py3-none-any.whl", hash = "sha256:9db2c621eeefbc844c8dd88be64aef61e84e2deb29b271e02ab2b5b9f01068e2", size = 383553, upload-time = "2023-08-13T14:37:41.552Z" }, -] - [[package]] name = "freezegun" version = "1.5.1" @@ -1385,6 +1359,28 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/e5/15/cf2a69ade4b194aa524ac75112d5caac37414b20a3a03e6865dfe0bd1539/geopy-2.4.1-py3-none-any.whl", hash = "sha256:ae8b4bc5c1131820f4d75fce9d4aaaca0c85189b3aa5d64c3dcaf5e3b7b882a7", size = 125437, upload-time = "2023-11-23T21:49:30.421Z" }, ] +[[package]] +name = "gevent" +version = "25.5.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "cffi", marker = "platform_python_implementation == 'CPython' and sys_platform == 'win32'" }, + { name = "greenlet", marker = "platform_python_implementation == 'CPython'" }, + { name = "zope-event" }, + { name = "zope-interface" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/f1/58/267e8160aea00ab00acd2de97197eecfe307064a376fb5c892870a8a6159/gevent-25.5.1.tar.gz", hash = "sha256:582c948fa9a23188b890d0bc130734a506d039a2e5ad87dae276a456cc683e61", size = 6388207, upload-time = "2025-05-12T12:57:59.833Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/10/25/2162b38d7b48e08865db6772d632bd1648136ce2bb50e340565e45607cad/gevent-25.5.1-cp313-cp313-macosx_11_0_universal2.whl", hash = "sha256:a022a9de9275ce0b390b7315595454258c525dc8287a03f1a6cacc5878ab7cbc", size = 2928044, upload-time = "2025-05-12T11:11:36.33Z" }, + { url = "https://files.pythonhosted.org/packages/1b/e0/dbd597a964ed00176da122ea759bf2a6c1504f1e9f08e185379f92dc355f/gevent-25.5.1-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3fae8533f9d0ef3348a1f503edcfb531ef7a0236b57da1e24339aceb0ce52922", size = 1788751, upload-time = "2025-05-12T11:52:32.643Z" }, + { url = "https://files.pythonhosted.org/packages/f1/74/960cc4cf4c9c90eafbe0efc238cdf588862e8e278d0b8c0d15a0da4ed480/gevent-25.5.1-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:c7b32d9c3b5294b39ea9060e20c582e49e1ec81edbfeae6cf05f8ad0829cb13d", size = 1869766, upload-time = "2025-05-12T11:54:23.903Z" }, + { url = "https://files.pythonhosted.org/packages/56/78/fa84b1c7db79b156929685db09a7c18c3127361dca18a09e998e98118506/gevent-25.5.1-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:7b95815fe44f318ebbfd733b6428b4cb18cc5e68f1c40e8501dd69cc1f42a83d", size = 1835358, upload-time = "2025-05-12T12:00:06.794Z" }, + { url = "https://files.pythonhosted.org/packages/00/5c/bfefe3822bbca5b83bfad256c82251b3f5be13d52d14e17a786847b9b625/gevent-25.5.1-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2d316529b70d325b183b2f3f5cde958911ff7be12eb2b532b5c301f915dbbf1e", size = 2073071, upload-time = "2025-05-12T11:33:04.2Z" }, + { url = "https://files.pythonhosted.org/packages/20/e4/08a77a3839a37db96393dea952e992d5846a881b887986dde62ead6b48a1/gevent-25.5.1-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:f6ba33c13db91ffdbb489a4f3d177a261ea1843923e1d68a5636c53fe98fa5ce", size = 1809805, upload-time = "2025-05-12T12:00:00.537Z" }, + { url = "https://files.pythonhosted.org/packages/2b/ac/28848348f790c1283df74b0fc0a554271d0606676470f848eccf84eae42a/gevent-25.5.1-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:37ee34b77c7553777c0b8379915f75934c3f9c8cd32f7cd098ea43c9323c2276", size = 2138305, upload-time = "2025-05-12T11:40:56.566Z" }, + { url = "https://files.pythonhosted.org/packages/52/9e/0e9e40facd2d714bfb00f71fc6dacaacc82c24c1c2e097bf6461e00dec9f/gevent-25.5.1-cp313-cp313-win_amd64.whl", hash = "sha256:9fa6aa0da224ed807d3b76cdb4ee8b54d4d4d5e018aed2478098e685baae7896", size = 1637444, upload-time = "2025-05-12T12:17:45.995Z" }, +] + [[package]] name = "google-api-core" version = "2.25.1" @@ -1456,6 +1452,30 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/86/f1/62a193f0227cf15a920390abe675f386dec35f7ae3ffe6da582d3ade42c7/googleapis_common_protos-1.70.0-py3-none-any.whl", hash = "sha256:b8bfcca8c25a2bb253e0e0b0adaf8c00773e5e6af6fd92397576680b807e0fd8", size = 294530, upload-time = "2025-04-14T10:17:01.271Z" }, ] +[[package]] +name = "greenlet" +version = "3.2.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/34/c1/a82edae11d46c0d83481aacaa1e578fea21d94a1ef400afd734d47ad95ad/greenlet-3.2.2.tar.gz", hash = "sha256:ad053d34421a2debba45aa3cc39acf454acbcd025b3fc1a9f8a0dee237abd485", size = 185797, upload-time = "2025-05-09T19:47:35.066Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/89/30/97b49779fff8601af20972a62cc4af0c497c1504dfbb3e93be218e093f21/greenlet-3.2.2-cp313-cp313-macosx_11_0_universal2.whl", hash = "sha256:3ab7194ee290302ca15449f601036007873028712e92ca15fc76597a0aeb4c59", size = 269150, upload-time = "2025-05-09T14:50:30.784Z" }, + { url = "https://files.pythonhosted.org/packages/21/30/877245def4220f684bc2e01df1c2e782c164e84b32e07373992f14a2d107/greenlet-3.2.2-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2dc5c43bb65ec3669452af0ab10729e8fdc17f87a1f2ad7ec65d4aaaefabf6bf", size = 637381, upload-time = "2025-05-09T15:24:12.893Z" }, + { url = "https://files.pythonhosted.org/packages/8e/16/adf937908e1f913856b5371c1d8bdaef5f58f251d714085abeea73ecc471/greenlet-3.2.2-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:decb0658ec19e5c1f519faa9a160c0fc85a41a7e6654b3ce1b44b939f8bf1325", size = 651427, upload-time = "2025-05-09T15:24:51.074Z" }, + { url = "https://files.pythonhosted.org/packages/ad/49/6d79f58fa695b618654adac64e56aff2eeb13344dc28259af8f505662bb1/greenlet-3.2.2-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:6fadd183186db360b61cb34e81117a096bff91c072929cd1b529eb20dd46e6c5", size = 645795, upload-time = "2025-05-09T15:29:26.673Z" }, + { url = "https://files.pythonhosted.org/packages/5a/e6/28ed5cb929c6b2f001e96b1d0698c622976cd8f1e41fe7ebc047fa7c6dd4/greenlet-3.2.2-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1919cbdc1c53ef739c94cf2985056bcc0838c1f217b57647cbf4578576c63825", size = 648398, upload-time = "2025-05-09T14:53:36.61Z" }, + { url = "https://files.pythonhosted.org/packages/9d/70/b200194e25ae86bc57077f695b6cc47ee3118becf54130c5514456cf8dac/greenlet-3.2.2-cp313-cp313-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:3885f85b61798f4192d544aac7b25a04ece5fe2704670b4ab73c2d2c14ab740d", size = 606795, upload-time = "2025-05-09T14:53:47.039Z" }, + { url = "https://files.pythonhosted.org/packages/f8/c8/ba1def67513a941154ed8f9477ae6e5a03f645be6b507d3930f72ed508d3/greenlet-3.2.2-cp313-cp313-musllinux_1_1_aarch64.whl", hash = "sha256:85f3e248507125bf4af607a26fd6cb8578776197bd4b66e35229cdf5acf1dfbf", size = 1117976, upload-time = "2025-05-09T15:27:06.542Z" }, + { url = "https://files.pythonhosted.org/packages/c3/30/d0e88c1cfcc1b3331d63c2b54a0a3a4a950ef202fb8b92e772ca714a9221/greenlet-3.2.2-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:1e76106b6fc55fa3d6fe1c527f95ee65e324a13b62e243f77b48317346559708", size = 1145509, upload-time = "2025-05-09T14:54:02.223Z" }, + { url = "https://files.pythonhosted.org/packages/90/2e/59d6491834b6e289051b252cf4776d16da51c7c6ca6a87ff97e3a50aa0cd/greenlet-3.2.2-cp313-cp313-win_amd64.whl", hash = "sha256:fe46d4f8e94e637634d54477b0cfabcf93c53f29eedcbdeecaf2af32029b4421", size = 296023, upload-time = "2025-05-09T14:53:24.157Z" }, + { url = "https://files.pythonhosted.org/packages/65/66/8a73aace5a5335a1cba56d0da71b7bd93e450f17d372c5b7c5fa547557e9/greenlet-3.2.2-cp313-cp313t-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ba30e88607fb6990544d84caf3c706c4b48f629e18853fc6a646f82db9629418", size = 629911, upload-time = "2025-05-09T15:24:22.376Z" }, + { url = "https://files.pythonhosted.org/packages/48/08/c8b8ebac4e0c95dcc68ec99198842e7db53eda4ab3fb0a4e785690883991/greenlet-3.2.2-cp313-cp313t-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:055916fafad3e3388d27dd68517478933a97edc2fc54ae79d3bec827de2c64c4", size = 635251, upload-time = "2025-05-09T15:24:52.205Z" }, + { url = "https://files.pythonhosted.org/packages/37/26/7db30868f73e86b9125264d2959acabea132b444b88185ba5c462cb8e571/greenlet-3.2.2-cp313-cp313t-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:2593283bf81ca37d27d110956b79e8723f9aa50c4bcdc29d3c0543d4743d2763", size = 632620, upload-time = "2025-05-09T15:29:28.051Z" }, + { url = "https://files.pythonhosted.org/packages/10/ec/718a3bd56249e729016b0b69bee4adea0dfccf6ca43d147ef3b21edbca16/greenlet-3.2.2-cp313-cp313t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:89c69e9a10670eb7a66b8cef6354c24671ba241f46152dd3eed447f79c29fb5b", size = 628851, upload-time = "2025-05-09T14:53:38.472Z" }, + { url = "https://files.pythonhosted.org/packages/9b/9d/d1c79286a76bc62ccdc1387291464af16a4204ea717f24e77b0acd623b99/greenlet-3.2.2-cp313-cp313t-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:02a98600899ca1ca5d3a2590974c9e3ec259503b2d6ba6527605fcd74e08e207", size = 593718, upload-time = "2025-05-09T14:53:48.313Z" }, + { url = "https://files.pythonhosted.org/packages/cd/41/96ba2bf948f67b245784cd294b84e3d17933597dffd3acdb367a210d1949/greenlet-3.2.2-cp313-cp313t-musllinux_1_1_aarch64.whl", hash = "sha256:b50a8c5c162469c3209e5ec92ee4f95c8231b11db6a04db09bbe338176723bb8", size = 1105752, upload-time = "2025-05-09T15:27:08.217Z" }, + { url = "https://files.pythonhosted.org/packages/68/3b/3b97f9d33c1f2eb081759da62bd6162159db260f602f048bc2f36b4c453e/greenlet-3.2.2-cp313-cp313t-musllinux_1_1_x86_64.whl", hash = "sha256:45f9f4853fb4cc46783085261c9ec4706628f3b57de3e68bae03e8f8b3c0de51", size = 1125170, upload-time = "2025-05-09T14:54:04.082Z" }, +] + [[package]] name = "gssapi" version = "1.9.0" @@ -1574,15 +1594,6 @@ http2 = [ { name = "h2" }, ] -[[package]] -name = "humanize" -version = "4.12.3" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/22/d1/bbc4d251187a43f69844f7fd8941426549bbe4723e8ff0a7441796b0789f/humanize-4.12.3.tar.gz", hash = "sha256:8430be3a615106fdfceb0b2c1b41c4c98c6b0fc5cc59663a5539b111dd325fb0", size = 80514, upload-time = "2025-04-30T11:51:07.98Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/a0/1e/62a2ec3104394a2975a2629eec89276ede9dbe717092f6966fcf963e1bf0/humanize-4.12.3-py3-none-any.whl", hash = "sha256:2cbf6370af06568fa6d2da77c86edb7886f3160ecd19ee1ffef07979efc597f6", size = 128487, upload-time = "2025-04-30T11:51:06.468Z" }, -] - [[package]] name = "hyperframe" version = "6.1.0" @@ -1773,21 +1784,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/b5/ba/d86d465c589cbed8f872e187e90abbac73eb1453483477771e87e7ee8376/k5test-0.10.4-py2.py3-none-any.whl", hash = "sha256:33de7ff10bf99155fe8ee5d5976798ad1db6237214306dadf5a0ae9d6bb0ad03", size = 11954, upload-time = "2024-03-20T02:48:24.502Z" }, ] -[[package]] -name = "kombu" -version = "5.5.4" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "amqp" }, - { name = "packaging" }, - { name = "tzdata" }, - { name = "vine" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/0f/d3/5ff936d8319ac86b9c409f1501b07c426e6ad41966fedace9ef1b966e23f/kombu-5.5.4.tar.gz", hash = "sha256:886600168275ebeada93b888e831352fe578168342f0d1d5833d88ba0d847363", size = 461992, upload-time = "2025-06-01T10:19:22.281Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/ef/70/a07dcf4f62598c8ad579df241af55ced65bed76e42e45d3c368a6d82dbc1/kombu-5.5.4-py3-none-any.whl", hash = "sha256:a12ed0557c238897d8e518f1d1fdf84bd1516c5e305af2dacd85c2015115feb8", size = 210034, upload-time = "2025-06-01T10:19:20.436Z" }, -] - [[package]] name = "kubernetes" version = "33.1.0" @@ -2332,18 +2328,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/32/ae/ec06af4fe3ee72d16973474f122541746196aaa16cea6f66d18b963c6177/prometheus_client-0.22.1-py3-none-any.whl", hash = "sha256:cca895342e308174341b2cbf99a56bef291fbc0ef7b9e5412a0f26d653ba7094", size = 58694, upload-time = "2025-06-02T14:29:00.068Z" }, ] -[[package]] -name = "prompt-toolkit" -version = "3.0.51" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "wcwidth" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/bb/6e/9d084c929dfe9e3bfe0c6a47e31f78a25c54627d64a66e884a8bf5474f1c/prompt_toolkit-3.0.51.tar.gz", hash = "sha256:931a162e3b27fc90c86f1b48bb1fb2c528c2761475e57c9c06de13311c7b54ed", size = 428940, upload-time = "2025-04-15T09:18:47.731Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/ce/4f/5249960887b1fbe561d9ff265496d170b55a735b76724f10ef19f9e40716/prompt_toolkit-3.0.51-py3-none-any.whl", hash = "sha256:52742911fde84e2d423e2f9a4cf1de7d7ac4e51958f648d9540e0fb8db077b07", size = 387810, upload-time = "2025-04-15T09:18:44.753Z" }, -] - [[package]] name = "propcache" version = "0.3.2" @@ -2738,15 +2722,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/d7/96/e396b927ce6ccaa4b385b9b54440c133406a3ac070ce6a8a0786d8ba0308/python_kadmin_rs-0.6.1-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:b271f5849548a7c248309fa30e4ee90465d6838aad11557fd208ba54da10a072", size = 1612724, upload-time = "2025-06-18T13:24:26.359Z" }, ] -[[package]] -name = "pytz" -version = "2025.2" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/f8/bf/abbd3cdfb8fbc7fb3d4d38d320f2441b1e7cbe29be4f23797b4a2b5d8aac/pytz-2025.2.tar.gz", hash = "sha256:360b9e3dbb49a209c21ad61809c7fb453643e048b38924c765813546746e81c3", size = 320884, upload-time = "2025-03-25T02:25:00.538Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/81/c4/34e93fe5f5429d7570ec1fa436f1986fb1f00c3e0f43a589fe2bbcd22c3f/pytz-2025.2-py2.py3-none-any.whl", hash = "sha256:5ddf76296dd8c44c26eb8f4b6f35488f3ccbf6fbbd7adee0b7262d43f0ec2f00", size = 509225, upload-time = "2025-03-25T02:24:58.468Z" }, -] - [[package]] name = "pywin32" version = "310" @@ -3122,34 +3097,12 @@ wheels = [ ] [[package]] -name = "tenant-schemas-celery" -version = "3.0.0" +name = "tenacity" +version = "9.1.2" source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "celery" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/d0/fe/cfe19eb7cc3ad8e39d7df7b7c44414bf665b6ac6660c998eb498f89d16c6/tenant_schemas_celery-3.0.0.tar.gz", hash = "sha256:6be3ae1a5826f262f0f3dd343c6a85a34a1c59b89e04ae37de018f36562fed55", size = 15954, upload-time = "2024-05-19T11:16:41.837Z" } +sdist = { url = "https://files.pythonhosted.org/packages/0a/d4/2b0cd0fe285e14b36db076e78c93766ff1d529d70408bd1d2a5a84f1d929/tenacity-9.1.2.tar.gz", hash = "sha256:1169d376c297e7de388d18b4481760d478b0e99a777cad3a9c86e556f4b697cb", size = 48036, upload-time = "2025-04-02T08:25:09.966Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/db/2c/376e1e641ad08b374c75d896468a7be2e6906ce3621fd0c9f9dc09ff1963/tenant_schemas_celery-3.0.0-py3-none-any.whl", hash = "sha256:ca0f69e78ef698eb4813468231df5a0ab6a660c08e657b65f5ac92e16887eec8", size = 18108, upload-time = "2024-05-19T11:16:39.92Z" }, -] - -[[package]] -name = "tornado" -version = "6.5.1" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/51/89/c72771c81d25d53fe33e3dca61c233b665b2780f21820ba6fd2c6793c12b/tornado-6.5.1.tar.gz", hash = "sha256:84ceece391e8eb9b2b95578db65e920d2a61070260594819589609ba9bc6308c", size = 509934, upload-time = "2025-05-22T18:15:38.788Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/77/89/f4532dee6843c9e0ebc4e28d4be04c67f54f60813e4bf73d595fe7567452/tornado-6.5.1-cp39-abi3-macosx_10_9_universal2.whl", hash = "sha256:d50065ba7fd11d3bd41bcad0825227cc9a95154bad83239357094c36708001f7", size = 441948, upload-time = "2025-05-22T18:15:20.862Z" }, - { url = "https://files.pythonhosted.org/packages/15/9a/557406b62cffa395d18772e0cdcf03bed2fff03b374677348eef9f6a3792/tornado-6.5.1-cp39-abi3-macosx_10_9_x86_64.whl", hash = "sha256:9e9ca370f717997cb85606d074b0e5b247282cf5e2e1611568b8821afe0342d6", size = 440112, upload-time = "2025-05-22T18:15:22.591Z" }, - { url = "https://files.pythonhosted.org/packages/55/82/7721b7319013a3cf881f4dffa4f60ceff07b31b394e459984e7a36dc99ec/tornado-6.5.1-cp39-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b77e9dfa7ed69754a54c89d82ef746398be82f749df69c4d3abe75c4d1ff4888", size = 443672, upload-time = "2025-05-22T18:15:24.027Z" }, - { url = "https://files.pythonhosted.org/packages/7d/42/d11c4376e7d101171b94e03cef0cbce43e823ed6567ceda571f54cf6e3ce/tornado-6.5.1-cp39-abi3-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:253b76040ee3bab8bcf7ba9feb136436a3787208717a1fb9f2c16b744fba7331", size = 443019, upload-time = "2025-05-22T18:15:25.735Z" }, - { url = "https://files.pythonhosted.org/packages/7d/f7/0c48ba992d875521ac761e6e04b0a1750f8150ae42ea26df1852d6a98942/tornado-6.5.1-cp39-abi3-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:308473f4cc5a76227157cdf904de33ac268af770b2c5f05ca6c1161d82fdd95e", size = 443252, upload-time = "2025-05-22T18:15:27.499Z" }, - { url = "https://files.pythonhosted.org/packages/89/46/d8d7413d11987e316df4ad42e16023cd62666a3c0dfa1518ffa30b8df06c/tornado-6.5.1-cp39-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:caec6314ce8a81cf69bd89909f4b633b9f523834dc1a352021775d45e51d9401", size = 443930, upload-time = "2025-05-22T18:15:29.299Z" }, - { url = "https://files.pythonhosted.org/packages/78/b2/f8049221c96a06df89bed68260e8ca94beca5ea532ffc63b1175ad31f9cc/tornado-6.5.1-cp39-abi3-musllinux_1_2_i686.whl", hash = "sha256:13ce6e3396c24e2808774741331638ee6c2f50b114b97a55c5b442df65fd9692", size = 443351, upload-time = "2025-05-22T18:15:31.038Z" }, - { url = "https://files.pythonhosted.org/packages/76/ff/6a0079e65b326cc222a54720a748e04a4db246870c4da54ece4577bfa702/tornado-6.5.1-cp39-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:5cae6145f4cdf5ab24744526cc0f55a17d76f02c98f4cff9daa08ae9a217448a", size = 443328, upload-time = "2025-05-22T18:15:32.426Z" }, - { url = "https://files.pythonhosted.org/packages/49/18/e3f902a1d21f14035b5bc6246a8c0f51e0eef562ace3a2cea403c1fb7021/tornado-6.5.1-cp39-abi3-win32.whl", hash = "sha256:e0a36e1bc684dca10b1aa75a31df8bdfed656831489bc1e6a6ebed05dc1ec365", size = 444396, upload-time = "2025-05-22T18:15:34.205Z" }, - { url = "https://files.pythonhosted.org/packages/7b/09/6526e32bf1049ee7de3bebba81572673b19a2a8541f795d887e92af1a8bc/tornado-6.5.1-cp39-abi3-win_amd64.whl", hash = "sha256:908e7d64567cecd4c2b458075589a775063453aeb1d2a1853eedb806922f568b", size = 444840, upload-time = "2025-05-22T18:15:36.1Z" }, - { url = "https://files.pythonhosted.org/packages/55/a7/535c44c7bea4578e48281d83c615219f3ab19e6abc67625ef637c73987be/tornado-6.5.1-cp39-abi3-win_arm64.whl", hash = "sha256:02420a0eb7bf617257b9935e2b754d1b63897525d8a289c9d65690d580b4dcf7", size = 443596, upload-time = "2025-05-22T18:15:37.433Z" }, + { url = "https://files.pythonhosted.org/packages/e5/30/643397144bfbfec6f6ef821f36f33e57d35946c44a2352d3c9f0ae847619/tenacity-9.1.2-py3-none-any.whl", hash = "sha256:f77bf36710d8b73a50b2dd155c97b870017ad21afe6ab300326b0371b3b05138", size = 28248, upload-time = "2025-04-02T08:25:07.678Z" }, ] [[package]] @@ -3361,15 +3314,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/63/9a/0962b05b308494e3202d3f794a6e85abe471fe3cafdbcf95c2e8c713aabd/uvloop-0.21.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:a5c39f217ab3c663dc699c04cbd50c13813e31d917642d459fdcec07555cc553", size = 4660018, upload-time = "2024-10-14T23:38:10.888Z" }, ] -[[package]] -name = "vine" -version = "5.1.0" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/bd/e4/d07b5f29d283596b9727dd5275ccbceb63c44a1a82aa9e4bfd20426762ac/vine-5.1.0.tar.gz", hash = "sha256:8b62e981d35c41049211cf62a0a1242d8c1ee9bd15bb196ce38aefd6799e61e0", size = 48980, upload-time = "2023-11-05T08:46:53.857Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/03/ff/7c0c86c43b3cbb927e0ccc0255cb4057ceba4799cd44ae95174ce8e8b5b2/vine-5.1.0-py3-none-any.whl", hash = "sha256:40fdf3c48b2cfe1c38a49e9ae2da6fda88e4794c810050a728bd7413811fb1dc", size = 9636, upload-time = "2023-11-05T08:46:51.205Z" }, -] - [[package]] name = "watchdog" version = "6.0.0" @@ -3391,6 +3335,19 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/33/e8/e40370e6d74ddba47f002a32919d91310d6074130fe4e17dabcafc15cbf1/watchdog-6.0.0-py3-none-win_ia64.whl", hash = "sha256:a1914259fa9e1454315171103c6a30961236f508b9b623eae470268bbcc6a22f", size = 79067, upload-time = "2024-11-01T14:07:11.845Z" }, ] +[[package]] +name = "watchdog-gevent" +version = "0.2.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "gevent" }, + { name = "watchdog" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/97/69/91cfca7c21c382e3a8aca4251dcd7d4315228d9346381feb2dde36d14061/watchdog_gevent-0.2.1.tar.gz", hash = "sha256:ae6b94d0f8c8ce1c5956cd865f612b61f456cf19801744bba25a349fe8e8c337", size = 4296, upload-time = "2024-10-19T05:29:12.987Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/00/a9/54b88e150b77791958957e2188312477d09fc84820fc03f8b3a7569d10b0/watchdog_gevent-0.2.1-py3-none-any.whl", hash = "sha256:e8114658104a018f626ee54052335407c1438369febc776c4b4c4308ed002350", size = 3462, upload-time = "2024-10-19T05:29:11.421Z" }, +] + [[package]] name = "watchfiles" version = "1.1.0" @@ -3425,15 +3382,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/da/f5/cf6aa047d4d9e128f4b7cde615236a915673775ef171ff85971d698f3c2c/watchfiles-1.1.0-cp313-cp313t-musllinux_1_1_x86_64.whl", hash = "sha256:cc08ef8b90d78bfac66f0def80240b0197008e4852c9f285907377b2947ffdcb", size = 622744, upload-time = "2025-06-15T19:06:05.066Z" }, ] -[[package]] -name = "wcwidth" -version = "0.2.13" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/6c/63/53559446a878410fc5a5974feb13d31d78d752eb18aeba59c7fef1af7598/wcwidth-0.2.13.tar.gz", hash = "sha256:72ea0c06399eb286d978fdedb6923a9eb47e1c486ce63e9b4e64fc18303972b5", size = 101301, upload-time = "2024-01-06T02:10:57.829Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/fd/84/fd2ba7aafacbad3c4201d395674fc6348826569da3c0937e75505ead3528/wcwidth-0.2.13-py2.py3-none-any.whl", hash = "sha256:3da69048e4540d84af32131829ff948f1e022c1c6bdb8d6102117aac784f6859", size = 34166, upload-time = "2024-01-06T02:10:55.763Z" }, -] - [[package]] name = "webauthn" version = "2.6.0" @@ -3567,6 +3515,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/2e/54/647ade08bf0db230bfea292f893923872fd20be6ac6f53b2b936ba839d75/zipp-3.23.0-py3-none-any.whl", hash = "sha256:071652d6115ed432f5ce1d34c336c0adfd6a884660d1e9712a256d3d3bd4b14e", size = 10276, upload-time = "2025-06-08T17:06:38.034Z" }, ] +[[package]] +name = "zope-event" +version = "5.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "setuptools" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/46/c2/427f1867bb96555d1d34342f1dd97f8c420966ab564d58d18469a1db8736/zope.event-5.0.tar.gz", hash = "sha256:bac440d8d9891b4068e2b5a2c5e2c9765a9df762944bda6955f96bb9b91e67cd", size = 17350, upload-time = "2023-06-23T06:28:35.709Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/fe/42/f8dbc2b9ad59e927940325a22d6d3931d630c3644dae7e2369ef5d9ba230/zope.event-5.0-py3-none-any.whl", hash = "sha256:2832e95014f4db26c47a13fdaef84cef2f4df37e66b59d8f1f4a8f319a632c26", size = 6824, upload-time = "2023-06-23T06:28:32.652Z" }, +] + [[package]] name = "zope-interface" version = "7.2" diff --git a/web/src/admin/Routes.ts b/web/src/admin/Routes.ts index 76c251d5c6..116a1ba564 100644 --- a/web/src/admin/Routes.ts +++ b/web/src/admin/Routes.ts @@ -18,8 +18,8 @@ export const ROUTES: Route[] = [ return html``; }), new Route(new RegExp("^/administration/system-tasks$"), async () => { - await import("#admin/system-tasks/SystemTaskListPage"); - return html``; + await import("#admin/admin-overview/SystemTasksPage"); + return html``; }), new Route(new RegExp("^/core/providers$"), async () => { await import("#admin/providers/ProviderListPage"); diff --git a/web/src/admin/admin-overview/SystemTasksPage.ts b/web/src/admin/admin-overview/SystemTasksPage.ts new file mode 100644 index 0000000000..f96716c45c --- /dev/null +++ b/web/src/admin/admin-overview/SystemTasksPage.ts @@ -0,0 +1,85 @@ +import "#components/ak-page-header"; +import "#elements/Tabs"; +import "#elements/buttons/ActionButton/index"; +import "#elements/buttons/SpinnerButton/index"; +import "#elements/events/LogViewer"; +import "#elements/tasks/ScheduleList"; +import "#elements/tasks/TaskList"; +import "@patternfly/elements/pf-tooltip/pf-tooltip.js"; + +import { AKElement } from "#elements/Base"; + +import { msg } from "@lit/localize"; +import { CSSResult, html, TemplateResult } from "lit"; +import { customElement } from "lit/decorators.js"; + +import PFBanner from "@patternfly/patternfly/components/Banner/banner.css"; +import PFButton from "@patternfly/patternfly/components/Button/button.css"; +import PFCard from "@patternfly/patternfly/components/Card/card.css"; +import PFContent from "@patternfly/patternfly/components/Content/content.css"; +import PFDescriptionList from "@patternfly/patternfly/components/DescriptionList/description-list.css"; +import PFList from "@patternfly/patternfly/components/List/list.css"; +import PFPage from "@patternfly/patternfly/components/Page/page.css"; +import PFGrid from "@patternfly/patternfly/layouts/Grid/grid.css"; +import PFBase from "@patternfly/patternfly/patternfly-base.css"; + +@customElement("ak-system-tasks") +export class SystemTasksPage extends AKElement { + static get styles(): CSSResult[] { + return [ + PFBase, + PFList, + PFBanner, + PFPage, + PFContent, + PFButton, + PFDescriptionList, + PFGrid, + PFCard, + ]; + } + + render(): TemplateResult { + return html` + +
+
+
+ +
+
+
+
+
+
+ +
+
+
+
`; + } +} + +declare global { + interface HTMLElementTagNameMap { + "ak-system-tasks": SystemTasksPage; + } +} diff --git a/web/src/admin/admin-overview/cards/WorkerStatusCard.ts b/web/src/admin/admin-overview/cards/WorkerStatusCard.ts index 6bcbdb9b25..3730755459 100644 --- a/web/src/admin/admin-overview/cards/WorkerStatusCard.ts +++ b/web/src/admin/admin-overview/cards/WorkerStatusCard.ts @@ -2,7 +2,7 @@ import { DEFAULT_CONFIG } from "#common/api/config"; import { AdminStatus, AdminStatusCard } from "#admin/admin-overview/cards/AdminStatusCard"; -import { AdminApi, Worker } from "@goauthentik/api"; +import { TasksApi, Worker } from "@goauthentik/api"; import { msg } from "@lit/localize"; import { html, TemplateResult } from "lit"; @@ -13,7 +13,7 @@ export class WorkersStatusCard extends AdminStatusCard { icon = "pf-icon pf-icon-server"; getPrimaryValue(): Promise { - return new AdminApi(DEFAULT_CONFIG).adminWorkersList(); + return new TasksApi(DEFAULT_CONFIG).tasksWorkersList(); } renderHeader(): TemplateResult { diff --git a/web/src/admin/admin-overview/charts/SyncStatusChart.ts b/web/src/admin/admin-overview/charts/SyncStatusChart.ts index 19887358d0..c797ea8673 100644 --- a/web/src/admin/admin-overview/charts/SyncStatusChart.ts +++ b/web/src/admin/admin-overview/charts/SyncStatusChart.ts @@ -11,7 +11,7 @@ import { ProvidersApi, SourcesApi, SyncStatus, - SystemTaskStatusEnum, + TaskAggregatedStatusEnum, } from "@goauthentik/api"; import { ChartData, ChartOptions } from "chart.js"; @@ -61,16 +61,22 @@ export class SyncStatusChart extends AKChart { let objectKey = "healthy"; try { const status = await fetchSyncStatus(element); - status.tasks.forEach((task) => { - if (task.status !== SystemTaskStatusEnum.Successful) { - objectKey = "failed"; - } - const now = new Date().getTime(); - const maxDelta = 3600000; // 1 hour - if (!status || now - task.finishTimestamp.getTime() > maxDelta) { - objectKey = "unsynced"; - } - }); + + const now = new Date().getTime(); + const maxDelta = 3600000; // 1 hour + + if ( + status.lastSyncStatus === TaskAggregatedStatusEnum.Error || + status.lastSyncStatus === TaskAggregatedStatusEnum.Rejected || + status.lastSyncStatus === TaskAggregatedStatusEnum.Warning + ) { + objectKey = "failed"; + } else if ( + !status.lastSuccessfulSync || + now - status.lastSuccessfulSync.getTime() > maxDelta + ) { + objectKey = "unsynced"; + } } catch { objectKey = "unsynced"; } @@ -136,6 +142,17 @@ export class SyncStatusChart extends AKChart { }, msg("LDAP Source"), ), + await this.fetchStatus( + () => { + return new SourcesApi(DEFAULT_CONFIG).sourcesKerberosList(); + }, + (element) => { + return new SourcesApi(DEFAULT_CONFIG).sourcesKerberosSyncStatusRetrieve({ + slug: element.slug, + }); + }, + msg("Kerberos Source"), + ), ]; this.centerText = statuses.reduce((total, el) => (total += el.total), 0).toString(); return statuses; diff --git a/web/src/admin/blueprints/BlueprintListPage.ts b/web/src/admin/blueprints/BlueprintListPage.ts index 2fb70d6ca8..317df7cac9 100644 --- a/web/src/admin/blueprints/BlueprintListPage.ts +++ b/web/src/admin/blueprints/BlueprintListPage.ts @@ -5,6 +5,7 @@ import "#elements/buttons/ActionButton/index"; import "#elements/buttons/SpinnerButton/index"; import "#elements/forms/DeleteBulkForm"; import "#elements/forms/ModalForm"; +import "#elements/tasks/TaskList"; import "@patternfly/elements/pf-tooltip/pf-tooltip.js"; import { DEFAULT_CONFIG } from "#common/api/config"; @@ -18,6 +19,7 @@ import { BlueprintInstance, BlueprintInstanceStatusEnum, ManagedApi, + ModelEnum, RbacPermissionsAssignedByUsersListModelEnum, } from "@goauthentik/api"; @@ -108,7 +110,8 @@ export class BlueprintListPage extends TablePage { } renderExpanded(item: BlueprintInstance): TemplateResult { - return html` + const [appLabel, modelName] = ModelEnum.AuthentikBlueprintsBlueprintinstance.split("."); + return html`
@@ -122,6 +125,22 @@ export class BlueprintListPage extends TablePage {
+
+
+
+ ${msg("Tasks")} +
+
+
+ +
+
+
+
`; } diff --git a/web/src/admin/events/RuleListPage.ts b/web/src/admin/events/RuleListPage.ts index 14616d138a..f04e177ccc 100644 --- a/web/src/admin/events/RuleListPage.ts +++ b/web/src/admin/events/RuleListPage.ts @@ -5,6 +5,7 @@ import "#components/ak-status-label"; import "#elements/buttons/SpinnerButton/index"; import "#elements/forms/DeleteBulkForm"; import "#elements/forms/ModalForm"; +import "#elements/tasks/TaskList"; import "@patternfly/elements/pf-tooltip/pf-tooltip.js"; import { DEFAULT_CONFIG } from "#common/api/config"; @@ -15,6 +16,7 @@ import { TablePage } from "#elements/table/TablePage"; import { EventsApi, + ModelEnum, NotificationRule, RbacPermissionsAssignedByUsersListModelEnum, } from "@goauthentik/api"; @@ -125,6 +127,7 @@ export class RuleListPage extends TablePage { } renderExpanded(item: NotificationRule): TemplateResult { + const [appLabel, modelName] = ModelEnum.AuthentikEventsNotificationrule.split("."); return html`

@@ -134,6 +137,22 @@ Bindings to groups/users are checked against the user of the event.`, )}

+
+
+
+ ${msg("Tasks")} +
+
+
+ +
+
+
+
`; } diff --git a/web/src/admin/events/TransportListPage.ts b/web/src/admin/events/TransportListPage.ts index 4cb0b5b40a..e76d6f18e9 100644 --- a/web/src/admin/events/TransportListPage.ts +++ b/web/src/admin/events/TransportListPage.ts @@ -4,6 +4,7 @@ import "#elements/buttons/ActionButton/index"; import "#elements/buttons/SpinnerButton/index"; import "#elements/forms/DeleteBulkForm"; import "#elements/forms/ModalForm"; +import "#elements/tasks/TaskList"; import "@patternfly/elements/pf-tooltip/pf-tooltip.js"; import { DEFAULT_CONFIG } from "#common/api/config"; @@ -13,6 +14,7 @@ import { TablePage } from "#elements/table/TablePage"; import { EventsApi, + ModelEnum, NotificationTransport, RbacPermissionsAssignedByUsersListModelEnum, } from "@goauthentik/api"; @@ -38,6 +40,7 @@ export class TransportListPage extends TablePage { checkbox = true; clearOnRefresh = true; + expandable = true; @property() order = "name"; @@ -114,6 +117,30 @@ export class TransportListPage extends TablePage { ]; } + renderExpanded(item: NotificationTransport): TemplateResult { + const [appLabel, modelName] = ModelEnum.AuthentikEventsNotificationtransport.split("."); + return html` +
+
+
+
+ ${msg("Tasks")} +
+
+
+ +
+
+
+
+
+ `; + } + renderObjectCreate(): TemplateResult { return html` diff --git a/web/src/admin/outposts/OutpostListPage.ts b/web/src/admin/outposts/OutpostListPage.ts index 321acfeae2..87d4b10825 100644 --- a/web/src/admin/outposts/OutpostListPage.ts +++ b/web/src/admin/outposts/OutpostListPage.ts @@ -6,6 +6,8 @@ import "#admin/rbac/ObjectPermissionModal"; import "#elements/buttons/SpinnerButton/index"; import "#elements/forms/DeleteBulkForm"; import "#elements/forms/ModalForm"; +import "#elements/tasks/ScheduleList"; +import "#elements/tasks/TaskList"; import "@patternfly/elements/pf-tooltip/pf-tooltip.js"; import { DEFAULT_CONFIG } from "#common/api/config"; @@ -16,6 +18,7 @@ import { PaginatedResponse, TableColumn } from "#elements/table/Table"; import { TablePage } from "#elements/table/TablePage"; import { + ModelEnum, Outpost, OutpostHealth, OutpostsApi, @@ -160,7 +163,8 @@ export class OutpostListPage extends TablePage { } renderExpanded(item: Outpost): TemplateResult { - return html` + const [appLabel, modelName] = ModelEnum.AuthentikOutpostsOutpost.split("."); + return html`

${msg( @@ -178,6 +182,38 @@ export class OutpostListPage extends TablePage {

`; })} +
+
+
+ ${msg("Schedules")} +
+
+
+ +
+
+
+
+
+
+
+ ${msg("Tasks")} +
+
+
+ +
+
+
+
`; } diff --git a/web/src/admin/outposts/ServiceConnectionListPage.ts b/web/src/admin/outposts/ServiceConnectionListPage.ts index 41bbb9096b..31e5d88b09 100644 --- a/web/src/admin/outposts/ServiceConnectionListPage.ts +++ b/web/src/admin/outposts/ServiceConnectionListPage.ts @@ -8,6 +8,8 @@ import "#elements/buttons/SpinnerButton/index"; import "#elements/forms/DeleteBulkForm"; import "#elements/forms/ModalForm"; import "#elements/forms/ProxyForm"; +import "#elements/tasks/ScheduleList"; +import "#elements/tasks/TaskList"; import "@patternfly/elements/pf-tooltip/pf-tooltip.js"; import { DEFAULT_CONFIG } from "#common/api/config"; @@ -41,6 +43,7 @@ export class OutpostServiceConnectionListPage extends TablePage> { @@ -110,6 +113,46 @@ export class OutpostServiceConnectionListPage extends TablePage +
+
+
+
+ ${msg("Schedules")} +
+
+
+ +
+
+
+
+
+
+
+ ${msg("Tasks")} +
+
+
+ +
+
+
+
+
+ `; + } + renderToolbarSelected(): TemplateResult { const disabled = this.selectedElements.length < 1; return html` -
{ - new ProvidersApi(DEFAULT_CONFIG) - .providersGoogleWorkspaceSyncStatusRetrieve({ - id: this.provider?.pk || 0, - }) - .then((state) => { - this.syncState = state; - }) - .catch(() => { - this.syncState = undefined; - }); - }} - > +
${this.renderTabOverview()}
${msg( @@ -163,7 +149,9 @@ export class GoogleWorkspaceProviderViewPage extends AKElement { ` : html``}
-
+
@@ -210,7 +198,9 @@ export class GoogleWorkspaceProviderViewPage extends AKElement {
-
+
{ return new ProvidersApi( @@ -219,16 +209,36 @@ export class GoogleWorkspaceProviderViewPage extends AKElement { id: this.provider?.pk || 0, }); }} - .triggerSync=${() => { - return new ProvidersApi( - DEFAULT_CONFIG, - ).providersGoogleWorkspacePartialUpdate({ - id: this.provider?.pk || 0, - patchedGoogleWorkspaceProviderRequest: {}, - }); - }} >
+
+
+
+
${msg("Schedules")}
+
+
+ +
+
+
+
+
+
+
${msg("Tasks")}
+
+
+ +
+
+
`; } } diff --git a/web/src/admin/providers/microsoft_entra/MicrosoftEntraProviderViewPage.ts b/web/src/admin/providers/microsoft_entra/MicrosoftEntraProviderViewPage.ts index 6b996a7e81..80ef71fef6 100644 --- a/web/src/admin/providers/microsoft_entra/MicrosoftEntraProviderViewPage.ts +++ b/web/src/admin/providers/microsoft_entra/MicrosoftEntraProviderViewPage.ts @@ -7,6 +7,9 @@ import "#elements/Tabs"; import "#elements/buttons/ActionButton/index"; import "#elements/buttons/ModalButton"; import "#elements/events/LogViewer"; +import "#elements/sync/SyncStatusCard"; +import "#elements/tasks/ScheduleList"; +import "#elements/tasks/TaskList"; import { DEFAULT_CONFIG } from "#common/api/config"; import { EVENT_REFRESH } from "#common/constants"; @@ -15,9 +18,9 @@ import { AKElement } from "#elements/Base"; import { MicrosoftEntraProvider, + ModelEnum, ProvidersApi, RbacPermissionsAssignedByUsersListModelEnum, - SyncStatus, } from "@goauthentik/api"; import { msg } from "@lit/localize"; @@ -44,9 +47,6 @@ export class MicrosoftEntraProviderViewPage extends AKElement { @state() provider?: MicrosoftEntraProvider; - @state() - syncState?: SyncStatus; - static styles: CSSResult[] = [ PFBase, PFButton, @@ -86,22 +86,7 @@ export class MicrosoftEntraProviderViewPage extends AKElement { return html``; } return html` -
{ - new ProvidersApi(DEFAULT_CONFIG) - .providersMicrosoftEntraSyncStatusRetrieve({ - id: this.provider?.pk || 0, - }) - .then((state) => { - this.syncState = state; - }) - .catch(() => { - this.syncState = undefined; - }); - }} - > +
${this.renderTabOverview()}
${msg( @@ -162,7 +149,9 @@ export class MicrosoftEntraProviderViewPage extends AKElement {
` : html``}
-
+
@@ -209,8 +198,9 @@ export class MicrosoftEntraProviderViewPage extends AKElement {
- -
+
{ return new ProvidersApi( @@ -219,16 +209,37 @@ export class MicrosoftEntraProviderViewPage extends AKElement { id: this.provider?.pk || 0, }); }} - .triggerSync=${() => { - return new ProvidersApi( - DEFAULT_CONFIG, - ).providersMicrosoftEntraPartialUpdate({ - id: this.provider?.pk || 0, - patchedMicrosoftEntraProviderRequest: {}, - }); - }} >
+ +
+
+
+
${msg("Schedules")}
+
+
+ +
+
+
+
+
+
+
${msg("Tasks")}
+
+
+ +
+
+
`; } } diff --git a/web/src/admin/providers/scim/SCIMProviderViewPage.ts b/web/src/admin/providers/scim/SCIMProviderViewPage.ts index f764b508f4..bef089a13a 100644 --- a/web/src/admin/providers/scim/SCIMProviderViewPage.ts +++ b/web/src/admin/providers/scim/SCIMProviderViewPage.ts @@ -10,6 +10,8 @@ import "#elements/ak-mdx/index"; import "#elements/buttons/ActionButton/index"; import "#elements/buttons/ModalButton"; import "#elements/sync/SyncStatusCard"; +import "#elements/tasks/ScheduleList"; +import "#elements/tasks/TaskList"; import { DEFAULT_CONFIG } from "#common/api/config"; import { EVENT_REFRESH } from "#common/constants"; @@ -17,6 +19,7 @@ import { EVENT_REFRESH } from "#common/constants"; import { AKElement } from "#elements/Base"; import { + ModelEnum, ProvidersApi, RbacPermissionsAssignedByUsersListModelEnum, SCIMProvider, @@ -142,6 +145,7 @@ export class SCIMProviderViewPage extends AKElement { if (!this.provider) { return html``; } + const [appLabel, modelName] = ModelEnum.AuthentikProvidersScimScimprovider.split("."); return html` ${!this.provider?.assignedBackchannelApplicationName ? html`
${msg( @@ -150,99 +154,121 @@ export class SCIMProviderViewPage extends AKElement {
` : html``}
-
-
-
-
-
-
- ${msg("Name")} -
-
-
- ${this.provider.name} -
-
-
-
-
- ${msg("Assigned to application")} -
-
-
- -
-
-
-
-
- ${msg("Dry-run")} -
-
-
- -
-
-
-
-
- ${msg("URL")} -
-
-
- ${this.provider.url} -
-
-
-
-
- +
+
+
+
+
+ ${msg("Name")} +
+
+
+ ${this.provider.name} +
+
+
+
+
+ ${msg("Assigned to application")} +
+
+
+ +
+
+
+
+
+ ${msg("Dry-run")} +
+
+
+ +
+
+
+
+
+ ${msg("URL")} +
+
+
+ ${this.provider.url} +
+
+
+
-
- { - return new ProvidersApi( - DEFAULT_CONFIG, - ).providersScimSyncStatusRetrieve({ - id: this.provider?.pk || 0, - }); - }} - .triggerSync=${() => { - return new ProvidersApi(DEFAULT_CONFIG).providersScimPartialUpdate({ - id: this.provider?.pk || 0, - patchedSCIMProviderRequest: {}, - }); - }} - > +
-
+
+ { + return new ProvidersApi(DEFAULT_CONFIG).providersScimSyncStatusRetrieve( + { + id: this.provider?.pk || 0, + }, + ); + }} + > +
+
+
+
+
${msg("Schedules")}
+
+
+ +
+
+
+
+
+
+
${msg("Tasks")}
+
+
+ +
+
+
+
diff --git a/web/src/admin/providers/ssf/SSFProviderViewPage.ts b/web/src/admin/providers/ssf/SSFProviderViewPage.ts index c5e2d5405c..1f61bb9173 100644 --- a/web/src/admin/providers/ssf/SSFProviderViewPage.ts +++ b/web/src/admin/providers/ssf/SSFProviderViewPage.ts @@ -7,6 +7,7 @@ import "#elements/EmptyState"; import "#elements/Tabs"; import "#elements/buttons/ModalButton"; import "#elements/buttons/SpinnerButton/index"; +import "#elements/tasks/TaskList"; import { DEFAULT_CONFIG } from "#common/api/config"; import { EVENT_REFRESH } from "#common/constants"; @@ -14,6 +15,7 @@ import { EVENT_REFRESH } from "#common/constants"; import { AKElement } from "#elements/Base"; import { + ModelEnum, ProvidersApi, RbacPermissionsAssignedByUsersListModelEnum, SSFProvider, @@ -109,6 +111,7 @@ export class SSFProviderViewPage extends AKElement { if (!this.provider) { return html``; } + const [appLabel, modelName] = ModelEnum.AuthentikProvidersSsfSsfprovider.split("."); return html`
${msg("SSF Provider is in preview.")} ${msg("Send us feedback!")} @@ -166,6 +169,14 @@ export class SSFProviderViewPage extends AKElement {
+
+
${msg("Tasks")}
+ +
`; } } diff --git a/web/src/admin/sources/kerberos/KerberosSourceViewPage.ts b/web/src/admin/sources/kerberos/KerberosSourceViewPage.ts index 31598331fb..f6722a29f3 100644 --- a/web/src/admin/sources/kerberos/KerberosSourceViewPage.ts +++ b/web/src/admin/sources/kerberos/KerberosSourceViewPage.ts @@ -17,16 +17,16 @@ import { AKElement } from "#elements/Base"; import { KerberosSource, + ModelEnum, RbacPermissionsAssignedByUsersListModelEnum, SourcesApi, - SyncStatus, } from "@goauthentik/api"; import MDSourceKerberosBrowser from "~docs/users-sources/sources/protocols/kerberos/browser.md"; import { msg } from "@lit/localize"; import { CSSResult, html, TemplateResult } from "lit"; -import { customElement, property, state } from "lit/decorators.js"; +import { customElement, property } from "lit/decorators.js"; import PFBanner from "@patternfly/patternfly/components/Banner/banner.css"; import PFButton from "@patternfly/patternfly/components/Button/button.css"; @@ -54,9 +54,6 @@ export class KerberosSourceViewPage extends AKElement { @property({ attribute: false }) source!: KerberosSource; - @state() - syncState?: SyncStatus; - static styles: CSSResult[] = [ PFBase, PFPage, @@ -77,61 +74,16 @@ export class KerberosSourceViewPage extends AKElement { }); } - load(): void { - new SourcesApi(DEFAULT_CONFIG) - .sourcesKerberosSyncStatusRetrieve({ - slug: this.source.slug, - }) - .then((state) => { - this.syncState = state; - }); - } - - renderSyncCards(): TemplateResult { - if (!this.source.syncUsers) { - return html``; - } - return html` -
-
-

${msg("Connectivity")}

-
-
- -
-
-
- { - return new SourcesApi(DEFAULT_CONFIG).sourcesKerberosSyncStatusRetrieve({ - slug: this.source?.slug, - }); - }} - .triggerSync=${() => { - return new SourcesApi(DEFAULT_CONFIG).sourcesKerberosPartialUpdate({ - slug: this.source?.slug || "", - patchedKerberosSourceRequest: {}, - }); - }} - > -
- `; - } - render(): TemplateResult { if (!this.source) { return html``; } + const [appLabel, modelName] = ModelEnum.AuthentikSourcesKerberosKerberossource.split("."); return html`
{ - this.load(); - }} >
${msg("Kerberos Source is in preview.")} @@ -140,7 +92,9 @@ export class KerberosSourceViewPage extends AKElement { >
-
+
@@ -184,7 +138,41 @@ export class KerberosSourceViewPage extends AKElement {
- ${this.renderSyncCards()} +
+ { + return new SourcesApi( + DEFAULT_CONFIG, + ).sourcesKerberosSyncStatusRetrieve({ + slug: this.source?.slug, + }); + }} + > +
+
+
+

${msg("Connectivity")}

+
+
+ +
+
+
+
+

${msg("Schedules")}

+
+
+ +
+
diff --git a/web/src/admin/sources/ldap/LDAPSourceViewPage.ts b/web/src/admin/sources/ldap/LDAPSourceViewPage.ts index 378d5eb0d0..b6d8825d29 100644 --- a/web/src/admin/sources/ldap/LDAPSourceViewPage.ts +++ b/web/src/admin/sources/ldap/LDAPSourceViewPage.ts @@ -8,6 +8,7 @@ import "#elements/buttons/ActionButton/index"; import "#elements/buttons/SpinnerButton/index"; import "#elements/forms/ModalForm"; import "#elements/sync/SyncStatusCard"; +import "#elements/tasks/ScheduleList"; import { DEFAULT_CONFIG } from "#common/api/config"; import { EVENT_REFRESH } from "#common/constants"; @@ -16,14 +17,14 @@ import { AKElement } from "#elements/Base"; import { LDAPSource, + ModelEnum, RbacPermissionsAssignedByUsersListModelEnum, SourcesApi, - SyncStatus, } from "@goauthentik/api"; import { msg } from "@lit/localize"; import { CSSResult, html, TemplateResult } from "lit"; -import { customElement, property, state } from "lit/decorators.js"; +import { customElement, property } from "lit/decorators.js"; import PFButton from "@patternfly/patternfly/components/Button/button.css"; import PFCard from "@patternfly/patternfly/components/Card/card.css"; @@ -50,9 +51,6 @@ export class LDAPSourceViewPage extends AKElement { @property({ attribute: false }) source!: LDAPSource; - @state() - syncState?: SyncStatus; - static styles: CSSResult[] = [ PFBase, PFPage, @@ -72,31 +70,21 @@ export class LDAPSourceViewPage extends AKElement { }); } - load(): void { - new SourcesApi(DEFAULT_CONFIG) - .sourcesLdapSyncStatusRetrieve({ - slug: this.source.slug, - }) - .then((state) => { - this.syncState = state; - }); - } - render(): TemplateResult { if (!this.source) { return html``; } + const [appLabel, modelName] = ModelEnum.AuthentikSourcesLdapLdapsource.split("."); return html`
{ - this.load(); - }} >
-
+
@@ -151,7 +139,20 @@ export class LDAPSourceViewPage extends AKElement {
-
+
+ { + return new SourcesApi(DEFAULT_CONFIG).sourcesLdapSyncStatusRetrieve( + { + slug: this.source?.slug, + }, + ); + }} + > +
+

${msg("Connectivity")}

@@ -161,22 +162,17 @@ export class LDAPSourceViewPage extends AKElement { >
-
- { - return new SourcesApi(DEFAULT_CONFIG).sourcesLdapSyncStatusRetrieve( - { - slug: this.source?.slug, - }, - ); - }} - .triggerSync=${() => { - return new SourcesApi(DEFAULT_CONFIG).sourcesLdapPartialUpdate({ - slug: this.source?.slug || "", - patchedLDAPSourceRequest: {}, - }); - }} - > +
+
+

${msg("Schedules")}

+
+
+ +
diff --git a/web/src/admin/system-tasks/SystemTaskListPage.ts b/web/src/admin/system-tasks/SystemTaskListPage.ts deleted file mode 100644 index 6bcdd71783..0000000000 --- a/web/src/admin/system-tasks/SystemTaskListPage.ts +++ /dev/null @@ -1,162 +0,0 @@ -import "#elements/buttons/ActionButton/index"; -import "#elements/buttons/SpinnerButton/index"; -import "#elements/events/LogViewer"; -import "@patternfly/elements/pf-tooltip/pf-tooltip.js"; - -import { DEFAULT_CONFIG } from "#common/api/config"; -import { EVENT_REFRESH } from "#common/constants"; -import { formatElapsedTime } from "#common/temporal"; - -import { PFColor } from "#elements/Label"; -import { PaginatedResponse, TableColumn } from "#elements/table/Table"; -import { TablePage } from "#elements/table/TablePage"; - -import { EventsApi, SystemTask, SystemTaskStatusEnum } from "@goauthentik/api"; - -import { msg, str } from "@lit/localize"; -import { CSSResult, html, TemplateResult } from "lit"; -import { customElement, property } from "lit/decorators.js"; - -import PFDescriptionList from "@patternfly/patternfly/components/DescriptionList/description-list.css"; - -@customElement("ak-system-task-list") -export class SystemTaskListPage extends TablePage { - pageTitle(): string { - return msg("System Tasks"); - } - pageDescription(): string { - return msg("Long-running operations which authentik executes in the background."); - } - pageIcon(): string { - return "pf-icon pf-icon-automation"; - } - - expandable = true; - - searchEnabled(): boolean { - return true; - } - - @property() - order = "name"; - - static styles: CSSResult[] = [...super.styles, PFDescriptionList]; - - async apiEndpoint(): Promise> { - return new EventsApi(DEFAULT_CONFIG).eventsSystemTasksList( - await this.defaultEndpointConfig(), - ); - } - - columns(): TableColumn[] { - return [ - new TableColumn(msg("Identifier"), "name"), - new TableColumn(msg("Description")), - new TableColumn(msg("Last run")), - new TableColumn(msg("Status"), "status"), - new TableColumn(msg("Actions")), - ]; - } - - taskStatus(task: SystemTask): TemplateResult { - switch (task.status) { - case SystemTaskStatusEnum.Successful: - return html`${msg("Successful")}`; - case SystemTaskStatusEnum.Warning: - return html`${msg("Warning")}`; - case SystemTaskStatusEnum.Error: - return html`${msg("Error")}`; - default: - return html`${msg("Unknown")}`; - } - } - - renderExpanded(item: SystemTask): TemplateResult { - return html` -
-
-
-
- ${msg("Duration")} -
-
-
- ${msg(str`${item.duration.toFixed(2)} seconds`)} -
-
-
-
-
- ${msg("Expiry")} -
-
-
- ${item.expiring - ? html` - - ${formatElapsedTime(item.expires || new Date())} - - ` - : msg("-")} -
-
-
-
-
- ${msg("Messages")} -
-
-
- -
-
-
-
-
- - - `; - } - - row(item: SystemTask): TemplateResult[] { - return [ - html`
${item.name}${item.uid ? `:${item.uid}` : ""}
`, - html`${item.description}`, - html`
${formatElapsedTime(item.finishTimestamp)}
- ${item.finishTimestamp.toLocaleString()}`, - this.taskStatus(item), - html` { - return new EventsApi(DEFAULT_CONFIG) - .eventsSystemTasksRunCreate({ - uuid: item.uuid, - }) - .then(() => { - this.dispatchEvent( - new CustomEvent(EVENT_REFRESH, { - bubbles: true, - composed: true, - }), - ); - }); - }} - > - - - - `, - ]; - } -} - -declare global { - interface HTMLElementTagNameMap { - "ak-system-task-list": SystemTaskListPage; - } -} diff --git a/web/src/elements/Label.ts b/web/src/elements/Label.ts index 02972ab98f..0569f8b848 100644 --- a/web/src/elements/Label.ts +++ b/web/src/elements/Label.ts @@ -14,6 +14,7 @@ export enum PFColor { Green = "pf-m-green", Orange = "pf-m-orange", Red = "pf-m-red", + Blue = "pf-m-blue", Grey = "", } @@ -25,6 +26,7 @@ const chromeList: Chrome[] = [ ["danger", PFColor.Red, "pf-m-red", "fa-times"], ["warning", PFColor.Orange, "pf-m-orange", "fa-exclamation-triangle"], ["success", PFColor.Green, "pf-m-green", "fa-check"], + ["running", PFColor.Blue, "pf-m-blue", "fa-clock"], ["info", PFColor.Grey, "pf-m-grey", "fa-info-circle"], ]; diff --git a/web/src/elements/sync/SyncStatusCard.stories.ts b/web/src/elements/sync/SyncStatusCard.stories.ts index d5881cbab7..2d2c5472f6 100644 --- a/web/src/elements/sync/SyncStatusCard.stories.ts +++ b/web/src/elements/sync/SyncStatusCard.stories.ts @@ -1,6 +1,6 @@ import "./SyncStatusCard.js"; -import { LogLevelEnum, SyncStatus, SystemTaskStatusEnum } from "@goauthentik/api"; +import { SyncStatus, TaskAggregatedStatusEnum } from "@goauthentik/api"; import type { Meta, StoryObj } from "@storybook/web-components"; @@ -17,7 +17,6 @@ export const Running: StoryObj = { args: { status: { isRunning: true, - tasks: [], } as SyncStatus, }, // @ts-ignore @@ -32,33 +31,11 @@ export const Running: StoryObj = { }, }; -export const SingleTask: StoryObj = { +export const LastSyncDone: StoryObj = { args: { status: { isRunning: false, - tasks: [ - { - uuid: "9ff42169-8249-4b67-ae3d-e455d822de2b", - name: "Single task", - fullName: "foo:bar:baz", - status: SystemTaskStatusEnum.Successful, - messages: [ - { - logger: "foo", - event: "bar", - attributes: { - foo: "bar", - }, - timestamp: new Date(), - logLevel: LogLevelEnum.Info, - }, - ], - description: "foo", - startTimestamp: new Date(), - finishTimestamp: new Date(), - duration: 0, - }, - ], + lastSyncStatus: TaskAggregatedStatusEnum.Done, } as SyncStatus, }, // @ts-ignore @@ -73,75 +50,30 @@ export const SingleTask: StoryObj = { }, }; -export const MultipleTasks: StoryObj = { +export const LastSyncError: StoryObj = { args: { status: { isRunning: false, - tasks: [ - { - uuid: "9ff42169-8249-4b67-ae3d-e455d822de2b", - name: "Single task", - fullName: "foo:bar:baz", - status: SystemTaskStatusEnum.Successful, - messages: [ - { - logger: "foo", - event: "bar", - attributes: { - foo: "bar", - }, - timestamp: new Date(), - logLevel: LogLevelEnum.Info, - }, - ], - description: "foo", - startTimestamp: new Date(), - finishTimestamp: new Date(), - duration: 0, - }, - { - uuid: "9ff42169-8249-4b67-ae3d-e455d822de2b", - name: "Single task", - fullName: "foo:bar:baz", - status: SystemTaskStatusEnum.Successful, - messages: [ - { - logger: "foo", - event: "bar", - attributes: { - foo: "bar", - }, - timestamp: new Date(), - logLevel: LogLevelEnum.Info, - }, - ], - description: "foo", - startTimestamp: new Date(), - finishTimestamp: new Date(), - duration: 0, - }, - { - uuid: "9ff42169-8249-4b67-ae3d-e455d822de2b", - name: "Single task", - fullName: "foo:bar:baz", - status: SystemTaskStatusEnum.Successful, - messages: [ - { - logger: "foo", - event: "bar", - attributes: { - foo: "bar", - }, - timestamp: new Date(), - logLevel: LogLevelEnum.Info, - }, - ], - description: "foo", - startTimestamp: new Date(), - finishTimestamp: new Date(), - duration: 0, - }, - ], + lastSyncStatus: TaskAggregatedStatusEnum.Error, + } as SyncStatus, + }, + // @ts-ignore + render: ({ status }: SyncStatus) => { + return html`
+ { + return status; + }} + > +
`; + }, +}; + +export const LastSuccessfulSync: StoryObj = { + args: { + status: { + isRunning: false, + lastSuccessfulSync: new Date(), } as SyncStatus, }, // @ts-ignore diff --git a/web/src/elements/sync/SyncStatusCard.ts b/web/src/elements/sync/SyncStatusCard.ts index ec4083b001..059d59e596 100644 --- a/web/src/elements/sync/SyncStatusCard.ts +++ b/web/src/elements/sync/SyncStatusCard.ts @@ -2,101 +2,24 @@ import "#components/ak-status-label"; import "#elements/EmptyState"; import "#elements/buttons/ActionButton/index"; import "#elements/events/LogViewer"; +import "#elements/tasks/TaskStatus"; -import { EVENT_REFRESH } from "#common/constants"; import { formatElapsedTime } from "#common/temporal"; import { AKElement } from "#elements/Base"; -import { PaginatedResponse, Table, TableColumn } from "#elements/table/Table"; -import { SlottedTemplateResult } from "#elements/types"; -import { SyncStatus, SystemTask, SystemTaskStatusEnum } from "@goauthentik/api"; +import { SyncStatus } from "@goauthentik/api"; import { msg } from "@lit/localize"; -import { css, CSSResult, html, nothing, TemplateResult } from "lit"; +import { CSSResult, html, TemplateResult } from "lit"; import { customElement, property, state } from "lit/decorators.js"; import PFButton from "@patternfly/patternfly/components/Button/button.css"; import PFCard from "@patternfly/patternfly/components/Card/card.css"; -import PFTable from "@patternfly/patternfly/components/Table/table.css"; +import PFDescriptionList from "@patternfly/patternfly/components/DescriptionList/description-list.css"; +import PFStack from "@patternfly/patternfly/layouts/Stack/stack.css"; import PFBase from "@patternfly/patternfly/patternfly-base.css"; -@customElement("ak-sync-status-table") -export class SyncStatusTable extends Table { - @property({ attribute: false }) - tasks: SystemTask[] = []; - - expandable = true; - - static styles = [ - ...super.styles, - css` - code:not(:last-of-type)::after { - content: "-"; - margin: 0 0.25rem; - } - `, - ]; - - async apiEndpoint(): Promise> { - if (this.tasks.length === 1) { - this.expandedElements = this.tasks; - } - return { - pagination: { - next: 0, - previous: 0, - count: this.tasks.length, - current: 1, - totalPages: 1, - startIndex: 0, - endIndex: this.tasks.length, - }, - results: this.tasks, - }; - } - - columns(): TableColumn[] { - return [ - new TableColumn(msg("Task")), - new TableColumn(msg("Status")), - new TableColumn(msg("Finished")), - ]; - } - - row(item: SystemTask): TemplateResult[] { - const nameParts = item.fullName.split(":"); - nameParts.shift(); - return [ - html`
${item.name}
- ${nameParts.map((part) => html`${part}`)}`, - html``, - html`
${formatElapsedTime(item.finishTimestamp)}
- ${item.finishTimestamp.toLocaleString()}`, - ]; - } - - renderExpanded(item: SystemTask): TemplateResult { - return html` -
- -
- `; - } - - protected override renderToolbarContainer(): SlottedTemplateResult { - return nothing; - } - - protected override renderTablePagination(): SlottedTemplateResult { - return nothing; - } -} - @customElement("ak-sync-status-card") export class SyncStatusCard extends AKElement { @state() @@ -108,10 +31,9 @@ export class SyncStatusCard extends AKElement { @property({ attribute: false }) fetch!: () => Promise; - @property({ attribute: false }) - triggerSync!: () => Promise; - - static styles: CSSResult[] = [PFBase, PFButton, PFCard, PFTable]; + static get styles(): CSSResult[] { + return [PFBase, PFButton, PFCard, PFDescriptionList, PFStack]; + } firstUpdated() { this.loading = true; @@ -125,16 +47,48 @@ export class SyncStatusCard extends AKElement { if (this.loading) { return html``; } - if (!this.syncState) { - return html`${msg("No sync status.")}`; - } - if (this.syncState.isRunning) { - return html`${msg("Sync currently running.")}`; - } - if (this.syncState.tasks.length < 1) { - return html`${msg("Not synced yet.")}`; - } - return html``; + return html` +
+
+
+ ${msg("Current status")} +
+
+
+ ${this.syncState?.isRunning + ? html`${msg("Sync is currently running.")}` + : html`${msg("Sync is not currently running.")}`} +
+
+
+
+
+ ${msg("Last successful sync")} +
+
+
+ ${this.syncState?.lastSuccessfulSync + ? html`${formatElapsedTime(this.syncState?.lastSuccessfulSync)}` + : html`${msg("No successful sync found.")}`} +
+
+
+
+
+ ${msg("Last sync status")} +
+
+
+ +
+
+
+
+ `; } render(): TemplateResult { @@ -145,7 +99,9 @@ export class SyncStatusCard extends AKElement { class="pf-c-button pf-m-plain" type="button" @click=${() => { - this.fetch(); + this.fetch().then((status) => { + this.syncState = status; + }); }} > @@ -154,42 +110,12 @@ export class SyncStatusCard extends AKElement {
${msg("Sync status")}
${this.renderSyncStatus()}
-
`; } } declare global { interface HTMLElementTagNameMap { - "ak-sync-status-table": SyncStatusTable; "ak-sync-status-card": SyncStatusCard; } } diff --git a/web/src/elements/tasks/ScheduleForm.ts b/web/src/elements/tasks/ScheduleForm.ts new file mode 100644 index 0000000000..954e16fbe2 --- /dev/null +++ b/web/src/elements/tasks/ScheduleForm.ts @@ -0,0 +1,68 @@ +import "#components/ak-switch-input"; +import "#components/ak-text-input"; +import "#elements/forms/FormGroup"; +import "#elements/forms/HorizontalFormElement"; +import "#elements/forms/ModalForm"; +import "#elements/forms/ProxyForm"; + +import { DEFAULT_CONFIG } from "#common/api/config"; + +import { ModelForm } from "#elements/forms/ModelForm"; + +import { Schedule, TasksApi } from "@goauthentik/api"; + +import { msg } from "@lit/localize"; +import { html, TemplateResult } from "lit"; +import { customElement } from "lit/decorators.js"; +import { ifDefined } from "lit/directives/if-defined.js"; + +@customElement("ak-schedule-form") +export class ScheduleForm extends ModelForm { + async loadInstance(pk: string): Promise { + return await new TasksApi(DEFAULT_CONFIG).tasksSchedulesRetrieve({ + id: pk, + }); + } + + getSuccessMessage(): string { + if (!this.instance) { + return ""; + } + return msg("Successfully updated schedule."); + } + + async send(data: Schedule): Promise { + if (!this.instance) { + return; + } + return await new TasksApi(DEFAULT_CONFIG).tasksSchedulesUpdate({ + id: this.instance.id, + scheduleRequest: data, + }); + } + + renderForm(): TemplateResult { + return html`
+ + + +
`; + } +} + +declare global { + interface HTMLElementTagNameMap { + "ak-schedule-form": ScheduleForm; + } +} diff --git a/web/src/elements/tasks/ScheduleList.ts b/web/src/elements/tasks/ScheduleList.ts new file mode 100644 index 0000000000..fd89bb9bcf --- /dev/null +++ b/web/src/elements/tasks/ScheduleList.ts @@ -0,0 +1,179 @@ +import "#elements/buttons/ActionButton/index"; +import "#elements/buttons/SpinnerButton/index"; +import "#elements/forms/DeleteBulkForm"; +import "#elements/forms/ModalForm"; +import "#elements/tasks/ScheduleForm"; +import "#elements/tasks/TaskList"; +import "#elements/tasks/TaskStatus"; +import "@patternfly/elements/pf-tooltip/pf-tooltip.js"; + +import { DEFAULT_CONFIG } from "#common/api/config"; +import { EVENT_REFRESH } from "#common/constants"; +import { formatElapsedTime } from "#common/temporal"; + +import { PaginatedResponse, Table, TableColumn } from "#elements/table/Table"; + +import { ModelEnum, Schedule, TasksApi } from "@goauthentik/api"; + +import { msg } from "@lit/localize"; +import { CSSResult, html, TemplateResult } from "lit"; +import { customElement, property } from "lit/decorators.js"; + +import PFDescriptionList from "@patternfly/patternfly/components/DescriptionList/description-list.css"; + +@customElement("ak-schedule-list") +export class ScheduleList extends Table { + expandable = true; + clearOnRefresh = true; + + searchEnabled(): boolean { + return true; + } + + @property() + order = "next_run"; + + @property() + relObjAppLabel?: string; + @property() + relObjModel?: string; + @property() + relObjId?: string; + + @property({ type: Boolean }) + showOnlyStandalone: boolean = true; + + static get styles(): CSSResult[] { + return super.styles.concat(PFDescriptionList); + } + + async apiEndpoint(): Promise> { + const relObjIdIsnull = + typeof this.relObjId !== "undefined" + ? undefined + : this.showOnlyStandalone + ? true + : undefined; + return new TasksApi(DEFAULT_CONFIG).tasksSchedulesList({ + ...(await this.defaultEndpointConfig()), + relObjContentTypeAppLabel: this.relObjAppLabel, + relObjContentTypeModel: this.relObjModel, + relObjId: this.relObjId, + relObjIdIsnull, + }); + } + + #toggleShowOnlyStandalone = () => { + this.showOnlyStandalone = !this.showOnlyStandalone; + this.page = 1; + return this.fetch(); + }; + + columns(): TableColumn[] { + return [ + new TableColumn(msg("Schedule"), "actor_name"), + new TableColumn(msg("Crontab"), "crontab"), + new TableColumn(msg("Next run"), "next_run"), + new TableColumn(msg("Last status")), + new TableColumn(msg("Actions")), + ]; + } + + renderToolbarAfter(): TemplateResult { + if (this.relObjId !== undefined) { + return html``; + } + return html`  +
+
+
+ +
+
+
`; + } + + row(item: Schedule): TemplateResult[] { + return [ + html`
${item.description}
+ ${item.uid}`, + html`${item.crontab}`, + html` + ${item.paused + ? html`Paused` + : html` +
${formatElapsedTime(item.nextRun)}
+ ${item.nextRun.toLocaleString()} + `} + `, + html``, + html` { + return new TasksApi(DEFAULT_CONFIG) + .tasksSchedulesSendCreate({ + id: item.id, + }) + .then(() => { + this.dispatchEvent( + new CustomEvent(EVENT_REFRESH, { + bubbles: true, + composed: true, + }), + ); + }); + }} + > + + + + + + ${msg("Update")} + ${msg("Update Schedule")} + + + `, + ]; + } + + renderExpanded(item: Schedule): TemplateResult { + const [appLabel, modelName] = ModelEnum.AuthentikTasksSchedulesSchedule.split("."); + return html` +
+
+ +
+
+ `; + } +} + +declare global { + interface HTMLElementTagNameMap { + "ak-schedule-list": ScheduleList; + } +} diff --git a/web/src/elements/tasks/TaskList.ts b/web/src/elements/tasks/TaskList.ts new file mode 100644 index 0000000000..059867eead --- /dev/null +++ b/web/src/elements/tasks/TaskList.ts @@ -0,0 +1,203 @@ +import "#admin/rbac/ObjectPermissionModal"; +import "#elements/buttons/ActionButton/index"; +import "#elements/buttons/SpinnerButton/index"; +import "#elements/events/LogViewer"; +import "#elements/forms/DeleteBulkForm"; +import "#elements/forms/ModalForm"; +import "#elements/tasks/TaskStatus"; +import "@patternfly/elements/pf-tooltip/pf-tooltip.js"; + +import { DEFAULT_CONFIG } from "#common/api/config"; +import { EVENT_REFRESH } from "#common/constants"; +import { formatElapsedTime } from "#common/temporal"; + +import { PaginatedResponse, Table, TableColumn } from "#elements/table/Table"; + +import { + Task, + TasksApi, + TasksTasksListAggregatedStatusEnum, + TasksTasksListStateEnum, +} from "@goauthentik/api"; + +import { msg } from "@lit/localize"; +import { CSSResult, html, TemplateResult } from "lit"; +import { customElement, property } from "lit/decorators.js"; + +import PFDescriptionList from "@patternfly/patternfly/components/DescriptionList/description-list.css"; + +@customElement("ak-task-list") +export class TaskList extends Table { + expandable = true; + clearOnRefresh = true; + + @property() + relObjAppLabel?: string; + @property() + relObjModel?: string; + @property() + relObjId?: string; + + @property({ type: Boolean }) + showOnlyStandalone: boolean = true; + + @property({ type: Boolean }) + excludeSuccessful: boolean = true; + + searchEnabled(): boolean { + return true; + } + + @property() + order = "-mtime"; + + static get styles(): CSSResult[] { + return super.styles.concat(PFDescriptionList); + } + + async apiEndpoint(): Promise> { + const relObjIdIsnull = + typeof this.relObjId !== "undefined" + ? undefined + : this.showOnlyStandalone + ? true + : undefined; + const aggregatedStatus = this.excludeSuccessful + ? [ + TasksTasksListAggregatedStatusEnum.Queued, + TasksTasksListAggregatedStatusEnum.Consumed, + TasksTasksListAggregatedStatusEnum.Rejected, + TasksTasksListAggregatedStatusEnum.Warning, + TasksTasksListAggregatedStatusEnum.Error, + ] + : undefined; + return new TasksApi(DEFAULT_CONFIG).tasksTasksList({ + ...(await this.defaultEndpointConfig()), + relObjContentTypeAppLabel: this.relObjAppLabel, + relObjContentTypeModel: this.relObjModel, + relObjId: this.relObjId, + relObjIdIsnull, + aggregatedStatus, + }); + } + + #toggleShowOnlyStandalone = () => { + this.showOnlyStandalone = !this.showOnlyStandalone; + this.page = 1; + return this.fetch(); + }; + + #toggleExcludeSuccessful = () => { + this.excludeSuccessful = !this.excludeSuccessful; + this.page = 1; + return this.fetch(); + }; + + columns(): TableColumn[] { + return [ + new TableColumn(msg("Task"), "actor_name"), + new TableColumn(msg("Queue"), "queue_name"), + new TableColumn(msg("Last updated"), "mtime"), + new TableColumn(msg("Status"), "aggregated_status"), + new TableColumn(msg("Actions")), + ]; + } + + renderToolbarAfter(): TemplateResult { + return html`  +
+
+
+ ${this.relObjId === undefined + ? html` ` + : html``} + +
+
+
`; + } + + row(item: Task): TemplateResult[] { + return [ + html`
${item.description}
+ ${item.uid}`, + html`${item.queueName}`, + html`
${formatElapsedTime(item.mtime || new Date())}
+ ${item.mtime?.toLocaleString()}`, + html``, + item.state === TasksTasksListStateEnum.Rejected || + item.state === TasksTasksListStateEnum.Done + ? html` { + return new TasksApi(DEFAULT_CONFIG) + .tasksTasksRetryCreate({ + messageId: item.messageId ?? "", + }) + .then(() => { + this.dispatchEvent( + new CustomEvent(EVENT_REFRESH, { + bubbles: true, + composed: true, + }), + ); + }); + }} + > + + + + ` + : html``, + ]; + } + + renderExpanded(item: Task): TemplateResult { + return html` +
+
+

Current execution logs

+ +

Previous executions logs

+ +
+
+ `; + } +} + +declare global { + interface HTMLElementTagNameMap { + "ak-task-list": TaskList; + } +} diff --git a/web/src/elements/tasks/TaskStatus.ts b/web/src/elements/tasks/TaskStatus.ts new file mode 100644 index 0000000000..6bbadc4bb2 --- /dev/null +++ b/web/src/elements/tasks/TaskStatus.ts @@ -0,0 +1,66 @@ +import "@patternfly/elements/pf-tooltip/pf-tooltip.js"; + +import { AKElement } from "#elements/Base"; +import { PFColor } from "#elements/Label"; + +import { + LastTaskStatusEnum, + TaskAggregatedStatusEnum, + TasksTasksListAggregatedStatusEnum, +} from "@goauthentik/api"; + +import { msg } from "@lit/localize"; +import { CSSResult, html, TemplateResult } from "lit"; +import { customElement, property } from "lit/decorators.js"; + +import PFButton from "@patternfly/patternfly/components/Button/button.css"; +import PFBase from "@patternfly/patternfly/patternfly-base.css"; + +@customElement("ak-task-status") +export class TaskStatus extends AKElement { + @property() + status?: TaskAggregatedStatusEnum | TasksTasksListAggregatedStatusEnum | LastTaskStatusEnum; + + static get styles(): CSSResult[] { + return [PFBase, PFButton]; + } + + render(): TemplateResult { + switch (this.status) { + case TasksTasksListAggregatedStatusEnum.Queued: + case TaskAggregatedStatusEnum.Queued: + case LastTaskStatusEnum.Queued: + return html`${msg("Waiting to run")}`; + case TasksTasksListAggregatedStatusEnum.Consumed: + case TaskAggregatedStatusEnum.Consumed: + case LastTaskStatusEnum.Consumed: + return html`${msg("Running")}`; + case TasksTasksListAggregatedStatusEnum.Done: + case TaskAggregatedStatusEnum.Done: + case LastTaskStatusEnum.Done: + case TasksTasksListAggregatedStatusEnum.Info: + case TaskAggregatedStatusEnum.Info: + case LastTaskStatusEnum.Info: + return html`${msg("Successful")}`; + case TasksTasksListAggregatedStatusEnum.Warning: + case TaskAggregatedStatusEnum.Warning: + case LastTaskStatusEnum.Warning: + return html`${msg("Warning")}`; + case TasksTasksListAggregatedStatusEnum.Rejected: + case TaskAggregatedStatusEnum.Rejected: + case LastTaskStatusEnum.Rejected: + case TasksTasksListAggregatedStatusEnum.Error: + case TaskAggregatedStatusEnum.Error: + case LastTaskStatusEnum.Error: + return html`${msg("Error")}`; + default: + return html`${msg("Unknown")}`; + } + } +} + +declare global { + interface HTMLElementTagNameMap { + "ak-task-status": TaskStatus; + } +} diff --git a/website/docs/core/architecture.md b/website/docs/core/architecture.md index cbc4dd0a5a..d2e65ff3ec 100644 --- a/website/docs/core/architecture.md +++ b/website/docs/core/architecture.md @@ -31,7 +31,7 @@ Similar to [other outposts](../add-secure-apps/outposts/index.mdx), this outpost - `/media` is used to store icons and such, but not required, and if not mounted, authentik will allow you to set a URL to icons in place of a file upload -### Background Worker +### Worker This container executes background tasks, such as sending emails, the event notification system, and everything you can see on the _System Tasks_ page in the frontend. diff --git a/website/docs/developer-docs/index.md b/website/docs/developer-docs/index.md index 6cef87a6d7..ba25234a93 100644 --- a/website/docs/developer-docs/index.md +++ b/website/docs/developer-docs/index.md @@ -64,6 +64,7 @@ authentik ├── recovery - Generate keys to use in case you lock yourself out ├── root - Root django application, contains global settings and routes ├── sources +│   ├── kerberos - Sync Kerberos users into authentik │   ├── ldap - Sync LDAP users from OpenLDAP or Active Directory into authentik │   ├── oauth - OAuth1 and OAuth2 Source │   ├── plex - Plex source @@ -87,12 +88,13 @@ authentik │   ├── user_login - Login the currently pending user │   ├── user_logout - Logout the currently pending user │   └── user_write - Write any currently pending data to the user. +├── tasks - Background tasks └── tenants - Soft tennancy, configure defaults and branding per domain ``` This Django project is running in gunicorn, which spawns multiple workers and threads. Gunicorn is run from a lightweight Go application which reverse-proxies it, handles static files and will eventually gain more functionality as more code is migrated to go. -There are also several background tasks which run in Celery, the root celery application is defined in `authentik.root.celery`. +There are also several background tasks that run in Dramatiq, via the `django-dramatiq-postgres` package, with some additional helpers in `authentik.tasks`. ## How can I contribute? diff --git a/website/docs/developer-docs/setup/debugging.md b/website/docs/developer-docs/setup/debugging.md index b5846560ac..272d4fd353 100644 --- a/website/docs/developer-docs/setup/debugging.md +++ b/website/docs/developer-docs/setup/debugging.md @@ -6,7 +6,7 @@ This page describes how to debug different components of an authentik instance, ## authentik Server & Worker (Python) -The majority of the authentik codebase is in Python, running in Gunicorn for the server and Celery for the worker. These instructions show how this code can be debugged/inspected. The local debugging setup requires a setup as described in [Full development environment](./full-dev-environment.mdx) +The majority of the authentik codebase is in Python, running in Gunicorn for the server and Dramatiq for the worker. These instructions show how this code can be debugged/inspected. The local debugging setup requires a setup as described in [Full development environment](./full-dev-environment.mdx) Note that authentik uses [debugpy](https://github.com/microsoft/debugpy), which relies on the "Debug Adapter Protocol" (DAP). These instructions demonstrate debugging using [Visual Studio Code](https://code.visualstudio.com/), however they should be adaptable to other editors that support DAP. @@ -20,6 +20,10 @@ With this setup in place, you can set Breakpoints in VS Code. To connect to the Note that due to the Python debugger for VS Code, when a Python file in authentik is saved and the Django process restarts, you must manually reconnect the Debug session. Automatic re-connection is not supported for the Python debugger (see [here](https://github.com/microsoft/vscode-python/issues/19998) and [here](https://github.com/microsoft/vscode-python/issues/1182)). ::: +#### Debug the server or the worker + +Whichever process is first started listens on port `9901`. Additional processes started after that will then try to listen on the same port, which will fail, and will simply not start the debugger in that case. + #### Debugging in containers When debugging an authentik instance running in containers, there are some additional steps that need to be taken in addition to the steps above. diff --git a/website/docs/developer-docs/setup/full-dev-environment.mdx b/website/docs/developer-docs/setup/full-dev-environment.mdx index 74b6bdc5d2..9880c16233 100644 --- a/website/docs/developer-docs/setup/full-dev-environment.mdx +++ b/website/docs/developer-docs/setup/full-dev-environment.mdx @@ -177,12 +177,22 @@ make web ## 5. Running authentik -With both backend and frontend set up, start the application: +Now that the backend and frontend have been set up and built, you can start authentik. + +Start the server by running the following command in the same directory as your local authentik git repository: ```shell -make run +make run-server # Starts authentik server ``` +Start the worker by running the following command in the same directory as your local authentik git repository: + +```shell +make run-worker # Starts authentik worker +``` + +Both processes need to run to get a fully functioning authentik development environment. + authentik will be accessible at http://localhost:9000. ### Initial setup @@ -198,6 +208,10 @@ To define a password for the default admin (called **akadmin**), you can manuall In case of issues in this process, feel free to use `make dev-reset` which drops and restores the authentik PostgreSQL instance to a "fresh install" state. ::: +### Hot-reloading + +When `AUTHENTIK_DEBUG` is set to `true` (the default for the development environment), the authentik server automatically reloads whenever changes are made to the code. However, due to instabilities in the reloading process of the worker, that behavior is turned off for the worker. You can enable code reloading in the worker by manually running `uv run ak worker --watch`. + ## End-to-End (E2E) Setup To run E2E tests, navigate to the `/tests/e2e` directory in your local copy of the authentik git repo, and start the services by running `docker compose up -d`. diff --git a/website/docs/install-config/configuration/configuration.mdx b/website/docs/install-config/configuration/configuration.mdx index a5a7be57a6..ab179fe196 100644 --- a/website/docs/install-config/configuration/configuration.mdx +++ b/website/docs/install-config/configuration/configuration.mdx @@ -208,29 +208,123 @@ When your PostgreSQL databases are running behind a connection pooler (like PgBo - `AUTHENTIK_CHANNEL__URL`: Channel layers configuration URL, uses [the Redis Settings](#redis-settings) by default -## Broker Settings +## Worker settings -- `AUTHENTIK_BROKER__URL`: Broker configuration URL, defaults to Redis using [the respective settings](#redis-settings) -- `AUTHENTIK_BROKER__TRANSPORT_OPTIONS`: Base64-encoded broker transport options +##### `AUTHENTIK_WORKER__PROCESSES` - :::info - `AUTHENTIK_REDIS__CACHE_TIMEOUT_REPUTATION` only applies to the cache expiry, see [`AUTHENTIK_REPUTATION__EXPIRY`](#authentik_reputation__expiry) to control how long reputation is persisted for. - ::: +Configure how many worker processes should be started for Dramatiq to use. In environments where scaling with multiple replicas of the authentik worker is not possible, this number can be increased to handle higher loads. + +Defaults to 1. In environments where scaling with multiple replicas of the authentik worker is not possible, this number can be increased to handle higher loads. + +##### `AUTHENTIK_WORKER__THREADS` + +Configure how many Dramatiq threads are started per worker. In environments where scaling with multiple replicas of the authentik worker is not possible, this number can be increased to handle higher loads. + +Defaults to 2. A value below 2 threads is not recommended, unless you have multiple worker replicas. + +##### `AUTHENTIK_WORKER__CONSUMER_LISTEN_TIMEOUT` + +Configure how long a worker waits for a PostgreSQL `LISTEN` notification. + +Defaults to `seconds=30`. + +##### `AUTHENTIK_WORKER__TASK_MAX_RETRIES` + +Configure how many times a failing task will be retried before abandoning. + +Defaults to 20. + +##### `AUTHENTIK_WORKER__TASK_DEFAULT_TIME_LIMIT` + +Configure the default duration a task can run for before it is aborted. Some tasks will override this setting based on other settings, such as LDAP source synchronization tasks. + +Defaults to `minutes=10`. + +##### `AUTHENTIK_WORKER__TASK_PURGE_INTERVAL` + +Configure the interval at which old tasks are cleaned up. + +Defaults to `days=1`. + +##### `AUTHENTIK_WORKER__TASK_EXPIRATION` + +Configure how long tasks are kept in the database before they are deleted. + +Defaults to `days=30`. + +##### `AUTHENTIK_WORKER__SCHEDULER_INTERVAL` + +Configure how often the task scheduler runs. + +Defaults to `seconds=60`. ## Listen Settings -- `AUTHENTIK_LISTEN__HTTP`: Listening address:port (e.g. `0.0.0.0:9000`) for HTTP (Applies to Server and Proxy outpost) -- `AUTHENTIK_LISTEN__HTTPS`: Listening address:port (e.g. `0.0.0.0:9443`) for HTTPS (Applies to Server and Proxy outpost) -- `AUTHENTIK_LISTEN__LDAP`: Listening address:port (e.g. `0.0.0.0:3389`) for LDAP (Applies to LDAP outpost) -- `AUTHENTIK_LISTEN__LDAPS`: Listening address:port (e.g. `0.0.0.0:6636`) for LDAPS (Applies to LDAP outpost) -- `AUTHENTIK_LISTEN__METRICS`: Listening address:port (e.g. `0.0.0.0:9300`) for Prometheus metrics (Applies to All) -- `AUTHENTIK_LISTEN__DEBUG`: Listening address:port (e.g. `0.0.0.0:9900`) for Go Debugging metrics (Applies to All) -- `AUTHENTIK_LISTEN__DEBUG_PY`: Listening address:port (e.g. `0.0.0.0:9901`) for Python debugging server (Applies to Server, see [Debugging](../../developer-docs/setup/debugging.md)) -- `AUTHENTIK_LISTEN__TRUSTED_PROXY_CIDRS`: List of comma-separated CIDRs that proxy headers should be accepted from (Applies to Server) +##### `AUTHENTIK_LISTEN__HTTP` - Defaults to `127.0.0.0/8`, `10.0.0.0/8`, `172.16.0.0/12`, `192.168.0.0/16`, `fe80::/10`, `::1/128`. +Listening address:port for HTTP. - Requests directly coming from one an address within a CIDR specified here are able to set proxy headers, such as `X-Forwarded-For`. Requests coming from other addresses will not be able to set these headers. +Applies to the Server, the Worker, and Proxy outposts. + +Defaults to `0.0.0.0:9000`. + +##### `AUTHENTIK_LISTEN__HTTPS` + +Listening address:port for HTTPS. + +Applies to the Server and Proxy outposts. + +Defaults to `0.0.0.0:9443`. + +##### `AUTHENTIK_LISTEN__LDAP` + +Listening address:port for LDAP. + +Applies to LDAP outposts. + +Defaults to `0.0.0.0:3389`. + +##### `AUTHENTIK_LISTEN__LDAPS` + +Listening address:port for LDAPS. + +Applies to LDAP outposts. + +Defaults to `0.0.0.0:6636`. + +##### `AUTHENTIK_LISTEN__METRICS` + +Listening address:port for Prometheus metrics. + +Applies to all. + +Defaults to `0.0.0.0:9300`. + +##### `AUTHENTIK_LISTEN__DEBUG` + +Listening address:port for Go Debugging metrics. + +Applies to all, except the worker. + +Defaults to `0.0.0.0:9900`. + +##### `AUTHENTIK_LISTEN__DEBUG_PY` + +Listening address:port for Python debugging server, see [Debugging](../../developer-docs/setup/debugging.md). + +Applies to the Server and the Worker. + +Defaults to `0.0.0.0:9901`. + +##### `AUTHENTIK_LISTEN__TRUSTED_PROXY_CIDRS` + +List of comma-separated CIDRs that proxy headers should be accepted from. + +Applies to the Server. + +Requests directly coming from one an address within a CIDR specified here are able to set proxy headers, such as `X-Forwarded-For`. Requests coming from other addresses will not be able to set these headers. + +Defaults to `127.0.0.0/8`, `10.0.0.0/8`, `172.16.0.0/12`, `192.168.0.0/16`, `fe80::/10`, `::1/128`. ## Media Storage Settings @@ -426,12 +520,6 @@ Configure how many gunicorn threads a worker processes should have (see https:// Defaults to 4. -### `AUTHENTIK_WORKER__CONCURRENCY` - -Configure Celery worker concurrency for authentik worker (see https://docs.celeryq.dev/en/latest/userguide/configuration.html#worker-concurrency). This essentially defines the number of worker processes spawned for a single worker. - -Defaults to 2. - ### `AUTHENTIK_WEB__PATH` :::info diff --git a/website/docs/releases/2025/v2025.8.md b/website/docs/releases/2025/v2025.8.md new file mode 100644 index 0000000000..1007f134ec --- /dev/null +++ b/website/docs/releases/2025/v2025.8.md @@ -0,0 +1,93 @@ +--- +title: Release 2025.8 +slug: "/releases/2025.8" +--- + +:::::note +2025.8 has not been released yet! We're publishing these release notes as a preview of what's to come, and for our awesome beta testers trying out release candidates. + +To try out the release candidate, replace your Docker image tag with the latest release candidate number, such as 2025.8.0-rc1. You can find the latest one in [the latest releases on GitHub](https://github.com/goauthentik/authentik/releases). If you don't find any, it means we haven't released one yet. +::::: + +## Breaking changes + +### Worker rework + +Upgrade instructions: + +- upgrade the authentik server first +- inspect the celery task queue to check that all of them are done: + - `bash -c 'DJANGO_SETTINGS_MODULE=authentik.root.settings celery -A authentik.root.celery inspect active'` + - `bash -c 'DJANGO_SETTINGS_MODULE=authentik.root.settings celery -A authentik.root.celery inspect scheduled'` + - `bash -c 'DJANGO_SETTINGS_MODULE=authentik.root.settings celery -A authentik.root.celery inspect reserved'` +- once no more tasks, upgrade the worker +- ??? +- profit + +### Renamed/removed settings + +The `AUTHENTIK_WORKER__CONCURRENCY` setting has been renamed `AUTHENTIK_WORKER__PROCESSES`. The old setting is still available as an alias and will be removed in a future release. + +The following settings have been removed and no longer have an effect: + +- `AUTHENTIK_BROKER__URL` +- `AUTHENTIK_BROKER__TRANSPORT_OPTIONS` +- `AUTHENTIK_RESULT_BACKEND__URL` + +### Renamed/removed metrics + +The `authentik_admin_workers` metric has been renamed `authentik_tasks_workers`. + +The following metrics have been removed: + +- `authentik_system_tasks` +- `authentik_system_tasks_time_seconds` +- `authentik_system_tasks_status` + +Instead, the following metrics are now available: + +- `authentik_tasks_total` +- `authentik_tasks_errors_total` +- `authentik_tasks_retries_total` +- `authentik_tasks_rejected_total` +- `authentik_tasks_inprogress` +- `authentik_tasks_delayed_inprogress` +- `authentik_tasks_duration_miliseconds` + +## New features + +## Upgrading + +This release does not introduce any new requirements. You can follow the upgrade instructions below; for more detailed information about upgrading authentik, refer to our [Upgrade documentation](../../install-config/upgrade.mdx). + +:::warning +When you upgrade, be aware that the version of the authentik instance and of any outposts must be the same. We recommended that you always upgrade any outposts at the same time you upgrade your authentik instance. +::: + +### Docker Compose + +To upgrade, download the new Docker Compose file and update the Docker stack with the new version, using these commands: + +```shell +wget -O docker-compose.yml https://goauthentik.io/version/2025.8/docker-compose.yml +docker compose up -d +``` + +The `-O` flag retains the downloaded file's name, overwriting any existing local file with the same name. + +### Kubernetes + +Upgrade the Helm Chart to the new version, using the following commands: + +```shell +helm repo update +helm upgrade authentik authentik/authentik -f values.yaml --version ^2025.8 +``` + +## Minor changes/fixes + + + +## API Changes + + diff --git a/website/docs/sidebar.mjs b/website/docs/sidebar.mjs index 89f0385721..29c57bdfc4 100644 --- a/website/docs/sidebar.mjs +++ b/website/docs/sidebar.mjs @@ -585,6 +585,7 @@ const items = [ label: "System Management", collapsed: true, items: [ + "sys-mgmt/background-tasks", "sys-mgmt/brands", { //#endregion @@ -595,6 +596,7 @@ const items = [ collapsed: true, items: [ "sys-mgmt/ops/monitoring", + "sys-mgmt/ops/worker", "sys-mgmt/ops/storage-s3", "sys-mgmt/ops/geoip", "sys-mgmt/ops/backup-restore", diff --git a/website/docs/sys-mgmt/background-tasks.md b/website/docs/sys-mgmt/background-tasks.md new file mode 100644 index 0000000000..ca423fbe30 --- /dev/null +++ b/website/docs/sys-mgmt/background-tasks.md @@ -0,0 +1,78 @@ +--- +title: Background tasks +slug: /background-tasks +--- + +authentik uses background tasks to run various operations independently and asynchronously, separated from the continuous web requests processed for general user interaction. These background tasks are run by the [worker](./ops/worker.md). + +## What are background tasks used for? + +Here is a non-exhaustive list of what background tasks are used for: + +- Outposts: manage [outpost](../add-secure-apps/outposts/index.mdx) deployments, send notifications to outpost when a refresh is needed +- Housekeeping: clean up old objects, check for updates, etc. +- Blueprints: import and apply [Blueprints](../customize/blueprints/index.mdx) +- Synchronization: sync users to and from authentik, from sources and to providers. This is used by: + - [SCIM Provider](../add-secure-apps/providers/scim/index.md) + - [Google Workspace Provider](../add-secure-apps/providers/gws/index.md) + - [Microsoft Entra Provider](../add-secure-apps/providers/entra/index.md) + - [SSF Provider](../add-secure-apps/providers/ssf/index.md) + - [Kerberos Source](../users-sources/sources/protocols/kerberos/index.md) + - [LDAP Source](../users-sources/sources/protocols/ldap/index.md) +- Enterprise [license management](../enterprise/manage-enterprise.mdx#license-management) +- Event Notifications: send [Notifications](./events/notifications.md) when events are created +- Emails: send emails when triggered by one of the email stages or otherwise + +## Schedules + +authentik runs some tasks on a schedule. Schedules can be [configured](#schedule-configuration) or manually triggered by clicking the play arrow. + +## Tasks statuses + +A task can have the following statuses: + +- **Successful**: the task executed successfully. No extra action is required. +- **Warning**: the task emitted a warning. Look at the task logs for further information. See [Failed tasks](#failed-tasks) for more details. +- **Error**: the task failed to process. Either the task threw an exception, or reported an other error. Look at the task logs for further information. See [Failed tasks](#failed-tasks) for more details. +- **Waiting to run**: the task has been queued for running, but no worker has picked it up yet, either because none are available, they are already busy, or because it's just been queued. +- **Running**: the task is currently running. + +## Manage background tasks + +### View system tasks + +You can view and manage all background tasks and schedules from the Admin interface. + +However, by default, tasks are shown _as close as possible_ to their relevant objects. For instance, LDAP source synchronization tasks and schedules are shown on the LDAP source detail page. + +When a task or a schedule cannot be associated to an object (for example, housekeeping tasks), it is referred to as "standalone" and is displayed under **Dashboards** > **System Tasks**. Note that tasks created from a schedule are associated to that schedule and thus are not considered standalone. Both schedule and task items can be expanded to view additional details about them. + +If you cannot find the object to which a task or schedule is attached, deselect the "Show only standalone tasks/schedules" toggle on the **System Tasks** page to show all tasks and schedules, including the ones that are attached to objects. + +By default, successful tasks are hidden to minimize the number of shown items. Deselect "Exclude successful tasks" to display them. + +### Schedule configuration + +When the authentik system creates a schedule it is assigned a default interval. The schedule uses a format based on [unix-cron](https://man7.org/linux/man-pages/man5/crontab.5.html). + +To change that interval, click the Edit icon for the specific schedule and update it. + +:::warning +Some tasks are required to run at regular intervals. For tuning reasons we recommend editing the intervals only for synchronization schedules, not for other types of schedules. +::: + +Schedules can also be _paused_ to prevent new tasks to be created from them. They can still be triggered manually while paused. When you un-pause a schedule, it will be triggered immediately. + +### Failed tasks + +When a task fails, i.e. when the code throws an exception, the task will be retried as many times as the value configured in [`AUTHENTIK_WORKER__TASK_MAX_RETRIES`](../install-config/configuration/configuration.mdx#authentik_worker__task_max_retries). Tasks that self-reported an error or a warning will not be retried. + +Failed tasks will be displayed like any other tasks. Each task can be expanded to show its logs. The logs are split into two parts: "Current execution logs" for the current execution, and "Previous execution logs" for logs from previous executions that happened before a retry was initiated. The information contained in the logs indicate either a transient error (a network connection failed for example), a mis-configuration (wrong password set in the LDAP source for example), or a bug in authentik. + +#### Restarting tasks + +To restart a task, click the retry arrow next to the task. It will be queued again and picked up by a worker. + +:::info +To retry tasks created from a schedule, we recommend manually triggering the schedule (click the Run arrow beside the schedule) instead of restarting one of its tasks. +::: diff --git a/website/docs/sys-mgmt/ops/monitoring.md b/website/docs/sys-mgmt/ops/monitoring.md index ba72008d86..b0351d7fc4 100644 --- a/website/docs/sys-mgmt/ops/monitoring.md +++ b/website/docs/sys-mgmt/ops/monitoring.md @@ -6,11 +6,13 @@ authentik can be easily monitored in multiple ways. ## Server monitoring -Configure your monitoring software to send requests to `/-/health/live/`, which will return a `HTTP 200` response as long as authentik is running. You can also send HTTP requests to `/-/health/ready/`, which will return `HTTP 200` if both PostgreSQL and Redis connections can be/have been established correctly. +Configure your monitoring software to send requests to `/-/health/live/`, which will return a `HTTP 200` response as long as authentik is running. You can also send HTTP requests to `/-/health/ready/`, which will return `HTTP 200` if both PostgreSQL and Redis connections can be established correctly. ## Worker monitoring -The worker container can be monitored by running `ak healthcheck` in the worker container. This will ping the worker and ensure it can communicate with redis as required. +The worker container can be monitored by running `ak healthcheck` in the worker container. This will check that the worker is running and ensure that both PostgreSQL and Redis connections can be established correctly. + +You can also send HTTP requests to `/-/health/ready/`, which will return `HTTP 200` if both PostgreSQL and Redis connections can be established correctly. ## Outpost monitoring @@ -22,7 +24,7 @@ Both Docker Compose and Kubernetes deployments use these methods by default to d ## Metrics -Both the core authentik server and any outposts expose Prometheus metrics on a separate port (9300), which can be scraped to gather further insight into authentik's state. The metrics require no authentication, as they are hosted on a separate, non-exposed port by default. +Both the core authentik server, worker and any outposts expose Prometheus metrics on a separate port (9300), which can be scraped to gather further insight into authentik's state. The metrics require no authentication, as they are hosted on a separate, non-exposed port by default. You can find an example dashboard here: [grafana.com](https://grafana.com/grafana/dashboards/14837-authentik/) diff --git a/website/docs/sys-mgmt/ops/worker.md b/website/docs/sys-mgmt/ops/worker.md new file mode 100644 index 0000000000..05560dd101 --- /dev/null +++ b/website/docs/sys-mgmt/ops/worker.md @@ -0,0 +1,50 @@ +--- +title: Worker +slug: /worker +--- + +The authentik worker runs [background tasks](../background-tasks.md). The worker also watches for [blueprints](../../customize/blueprints/index.mdx#storage---file) and [certificates](../certificates.md#external-certificates) that are added to the file system. It runs in a separate container from the server to handle these tasks. + +## How it works + +authentik tasks are stored and managed using its PostgreSQL database (installed with authentik). When authentik needs to run a background task, the following happens, inside a PostgreSQL transaction: + +- a row is inserted in a dedicated PostgreSQL table, containing all the relevant information needed to run the task. +- at the end of the transaction, a PostgreSQL trigger executes a `NOTIFY` command to send a notification to workers that a new task is available. + +The worker runs a loop to find tasks that need to run: + +- it tries to `LISTEN` on the tasks channel to pick up new tasks that were just queued. It only does so for a certain period of time (configurable with [`AUTHENTIK_WORKER__CONSUMER_LISTEN_TIMEOUT`](../../install-config/configuration/configuration.mdx#authentik_worker__consumer_listen_timeout)). +- if no task was received, it tries to find tasks that were not registered via a `NOTIFY` command. This happens when no worker is running when a task is created, or if the worker was busy running a different task. This is done by looking into the tasks table for tasks that aren't marked as finished (either successfully or unsuccessfully). On worker start, this is done before `LISTEN`, to process older tasks first. +- if a task is found or received, the worker grabs an advisory lock stating that it is responsible for the task. If several workers try to pick up the same task at the same time, only one of them will grab the task, and the others will continue without a task. +- if a task was found or received and the lock was properly acquired, the task is executed. +- if no task was found or the lock couldn't be acquired: + - locks are cleaned up and deleted for tasks that are finished. + - old tasks are purged at a regular interval, configurable with [`AUTHENTIK_WORKER__TASK_PURGE_INTERVAL`](../../install-config/configuration/configuration.mdx#authentik_worker__task_purge_interval). How long tasks are kept for is configurable with [`AUTHENTIK_WORKER__TASK_EXPIRATION`](../../install-config/configuration/configuration.mdx#authentik_worker__task_expiration). + - the scheduler is run at a regular interval, configurable with [`AUTHENTIK_WORKER__SCHEDULER_INTERVAL`](../../install-config/configuration/configuration.mdx#authentik_worker__scheduler_interval). + +### Task retries + +When a task throws an exception, the worker will automatically try to re-run the task up to the value configured by [`AUTHENTIK_WORKER__TASK_MAX_RETRIES`](../../install-config/configuration/configuration.mdx#authentik_worker__task_max_retries). Those retries are done with an exponential backoff strategy; only after all retries are exhausted is the task marked as failed. Otherwise, the task stays in the "Running" status while the worker retries it. However, logs shown in the Admin interface are updated after each try. + +### Time limits + +All tasks have a time limit. If running a task takes longer than than limit, the task is cancelled and marked as failed. The default time limit is configurable with [`AUTHENTIK_WORKER__TASK_DEFAULT_TIME_LIMIT`](../../install-config/configuration/configuration.mdx#authentik_worker__task_default_time_limit). Some tasks override that time limit for specific purposes, like synchronization. + +## Manage the worker + +### Scaling + +How many workers are needed will depend on what tasks are expected to run. The number of tasks that can concurrently run is calculated as follows: + +- workers replicas (1 for docker-compose, defaults to 1 for the Helm chart but can be configured) _multiplied_ by [`AUTHENTIK_WORKER__PROCESSES`](../../install-config/configuration/configuration.mdx#authentik_worker__processes) _multiplied_ by [`AUTHENTIK_WORKER__THREADS`](../../install-config/configuration/configuration.mdx#authentik_worker__threads) + +For example, let's say an LDAP source is configured with 1000 users and 200 groups. The LDAP source syncs the users first, then the groups, and finally memberships. All those steps are done by splitting the objects to synchronize into pages, of size [`AUTHENTIK_LDAP__PAGE_SIZE`](../../install-config/configuration/configuration.mdx#authentik_ldap__page_size). Let's say that setting is 50. That means there are `1000 / 50 = 20` pages of users, `200 / 50 = 4` pages of groups. We won't worry about the number of membership pages, because those are usually smaller than the previous ones. + +This means that in this scenario, the maximum number of concurrent tasks will be 20, plus 1 as there is a "meta" task watching over the synchronization and managing the steps so they are executed in order. Thus, for the synchronization to run as fast as possible, there needs to be 21 available workers when it starts. However, other tasks might also be running at the same time, or might get created while the synchronization is running. Thus, we recommend having more workers than necessary to keep a buffer for those tasks. + +### Monitor worker and tasks status + +The workers expose metrics about their operation on [`AUTHENTIK_LISTEN__METRICS`](../../install-config/configuration/configuration.mdx#authentik_listen__metrics). Those metrics allow monitoring of the number of pending, failed and successful tasks. They also provide insights about tasks durations. + +The worker also has an available healthcheck endpoint. See [Monitoring](./monitoring.md#worker-monitoring) for details.