Compare commits

..

4 Commits

Author SHA1 Message Date
Marcelo Elizeche Landó
b8f952558f align blueprint import schema with 200 result response 2026-05-06 13:29:35 -03:00
Marc 'risson' Schmitt
ebd18b466d root: ensure uv sync does not update uv.lock (#22084) 2026-05-06 14:48:59 +00:00
dependabot[bot]
b32df17513 core: bump dramatiq from 1.17.1 to 2.1.0 (#22076)
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>
2026-05-06 14:42:29 +00:00
Teffen Ellis
1db6c3af8b web: Fix Vendored Lex package. Add Unit Tests (#22083)
* Fix API reference paths.

* Clean up vendored code.

* Flesh out test.

* Fix edgecase.

* Clean up return value.

* Fix linter.
2026-05-06 14:31:17 +00:00
40 changed files with 657 additions and 1407 deletions

View File

@@ -49,7 +49,7 @@ runs:
if: ${{ contains(inputs.dependencies, 'python') }}
shell: bash
working-directory: ${{ inputs.working-directory }}
run: uv sync --all-extras --dev --frozen
run: uv sync --all-extras --dev --locked
- name: Setup rust (stable)
if: ${{ contains(inputs.dependencies, 'rust') && !contains(inputs.dependencies, 'rust-nightly') }}
uses: actions-rust-lang/setup-rust-toolchain@2b1f5e9b395427c92ee4e3331786ca3c37afe2d7 # v1

184
Cargo.lock generated
View File

@@ -143,45 +143,6 @@ version = "0.5.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7d902e3d592a523def97af8f317b08ce16b7ab854c1985a0c671e6f15cebc236"
[[package]]
name = "asn1-rs"
version = "0.7.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "56624a96882bb8c26d61312ae18cb45868e5a9992ea73c58e45c3101e56a1e60"
dependencies = [
"asn1-rs-derive",
"asn1-rs-impl",
"displaydoc",
"nom",
"num-traits",
"rusticata-macros",
"thiserror 2.0.18",
"time",
]
[[package]]
name = "asn1-rs-derive"
version = "0.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3109e49b1e4909e9db6515a30c633684d68cdeaa252f215214cb4fa1a5bfee2c"
dependencies = [
"proc-macro2",
"quote",
"syn",
"synstructure",
]
[[package]]
name = "asn1-rs-impl"
version = "0.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7b18050c2cd6fe86c3a76584ef5e0baf286d038cda203eb6223df2cc413565f7"
dependencies = [
"proc-macro2",
"quote",
"syn",
]
[[package]]
name = "async-trait"
version = "0.1.89"
@@ -215,13 +176,10 @@ dependencies = [
"arc-swap",
"argh",
"authentik-axum",
"authentik-client",
"authentik-common",
"axum",
"axum-server",
"color-eyre",
"eyre",
"futures",
"hyper-unix-socket",
"hyper-util",
"metrics",
@@ -229,19 +187,9 @@ dependencies = [
"nix 0.31.2",
"pyo3",
"pyo3-build-config",
"rand 0.10.1",
"rustls",
"serde",
"serde_json",
"serde_repr",
"sqlx",
"time",
"tokio",
"tokio-retry2",
"tokio-tungstenite",
"tower",
"tracing",
"url",
"uuid",
"which",
]
@@ -299,7 +247,6 @@ dependencies = [
"nix 0.31.2",
"notify",
"pin-project-lite",
"rcgen",
"reqwest",
"reqwest-middleware",
"rustls",
@@ -595,17 +542,6 @@ version = "0.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "613afe47fcd5fac7ccf1db93babcb082c5994d996f20b8b159f2ad1658eb5724"
[[package]]
name = "chacha20"
version = "0.10.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6f8d983286843e49675a4b7a2d174efe136dc93a18d69130dd18198a6c167601"
dependencies = [
"cfg-if",
"cpufeatures 0.3.0",
"rand_core 0.10.1",
]
[[package]]
name = "chrono"
version = "0.4.44"
@@ -843,15 +779,6 @@ dependencies = [
"libc",
]
[[package]]
name = "cpufeatures"
version = "0.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8b2a41393f66f16b0823bb79094d54ac5fbd34ab292ddafb9a0456ac9f87d201"
dependencies = [
"libc",
]
[[package]]
name = "crc"
version = "3.4.0"
@@ -946,20 +873,6 @@ dependencies = [
"zeroize",
]
[[package]]
name = "der-parser"
version = "10.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "07da5016415d5a3c4dd39b11ed26f915f52fc4e0dc197d87908bc916e51bc1a6"
dependencies = [
"asn1-rs",
"displaydoc",
"nom",
"num-bigint",
"num-traits",
"rusticata-macros",
]
[[package]]
name = "deranged"
version = "0.5.8"
@@ -1378,7 +1291,6 @@ dependencies = [
"cfg-if",
"libc",
"r-efi 6.0.0",
"rand_core 0.10.1",
"wasip2",
"wasip3",
]
@@ -2270,16 +2182,6 @@ dependencies = [
"windows-sys 0.61.2",
]
[[package]]
name = "num-bigint"
version = "0.4.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a5e44f723f1133c9deac646763579fdb3ac745e418f2a7af9cd0c431da1f20b9"
dependencies = [
"num-integer",
"num-traits",
]
[[package]]
name = "num-bigint-dig"
version = "0.8.6"
@@ -2509,15 +2411,6 @@ dependencies = [
"memchr",
]
[[package]]
name = "oid-registry"
version = "0.8.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "12f40cff3dde1b6087cc5d5f5d4d65712f34016a03ed60e9c08dcc392736b5b7"
dependencies = [
"asn1-rs",
]
[[package]]
name = "once_cell"
version = "1.21.4"
@@ -2919,17 +2812,6 @@ dependencies = [
"rand_core 0.9.5",
]
[[package]]
name = "rand"
version = "0.10.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d2e8e8bcc7961af1fdac401278c6a831614941f6164ee3bf4ce61b7edb162207"
dependencies = [
"chacha20",
"getrandom 0.4.2",
"rand_core 0.10.1",
]
[[package]]
name = "rand_chacha"
version = "0.3.1"
@@ -2968,12 +2850,6 @@ dependencies = [
"getrandom 0.3.4",
]
[[package]]
name = "rand_core"
version = "0.10.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "63b8176103e19a2643978565ca18b50549f6101881c443590420e4dc998a3c69"
[[package]]
name = "rand_xoshiro"
version = "0.7.0"
@@ -3001,19 +2877,6 @@ dependencies = [
"bitflags 2.11.0",
]
[[package]]
name = "rcgen"
version = "0.14.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "10b99e0098aa4082912d4c649628623db6aba77335e4f4569ff5083a6448b32e"
dependencies = [
"aws-lc-rs",
"rustls-pki-types",
"time",
"x509-parser",
"yasna",
]
[[package]]
name = "redox_syscall"
version = "0.5.18"
@@ -3177,15 +3040,6 @@ dependencies = [
"semver",
]
[[package]]
name = "rusticata-macros"
version = "4.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "faf0c4a6ece9950b9abdb62b1cfcf2a68b3b67a10ba445b3bb85be2a293d0632"
dependencies = [
"nom",
]
[[package]]
name = "rustix"
version = "1.1.4"
@@ -3565,7 +3419,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e3bf829a2d51ab4a5ddf1352d8470c140cadc8301b2ae1789db023f01cedd6ba"
dependencies = [
"cfg-if",
"cpufeatures 0.2.17",
"cpufeatures",
"digest",
]
@@ -3576,7 +3430,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a7507d819769d01a365ab707794a4084392c824f54a7a6a7862f8c3d0892b283"
dependencies = [
"cfg-if",
"cpufeatures 0.2.17",
"cpufeatures",
"digest",
]
@@ -4146,12 +4000,8 @@ checksum = "8f72a05e828585856dacd553fba484c242c46e391fb0e58917c942ee9202915c"
dependencies = [
"futures-util",
"log",
"rustls",
"rustls-pki-types",
"tokio",
"tokio-rustls",
"tungstenite",
"webpki-roots 0.26.11",
]
[[package]]
@@ -4365,11 +4215,8 @@ dependencies = [
"httparse",
"log",
"rand 0.9.4",
"rustls",
"rustls-pki-types",
"sha1",
"thiserror 2.0.18",
"url",
]
[[package]]
@@ -5240,24 +5087,6 @@ version = "0.6.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9edde0db4769d2dc68579893f2306b26c6ecfbe0ef499b013d731b7b9247e0b9"
[[package]]
name = "x509-parser"
version = "0.18.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d43b0f71ce057da06bc0851b23ee24f3f86190b07203dd8f567d0b706a185202"
dependencies = [
"asn1-rs",
"aws-lc-rs",
"data-encoding",
"der-parser",
"lazy_static",
"nom",
"oid-registry",
"rusticata-macros",
"thiserror 2.0.18",
"time",
]
[[package]]
name = "yaml-rust2"
version = "0.10.4"
@@ -5269,15 +5098,6 @@ dependencies = [
"hashlink",
]
[[package]]
name = "yasna"
version = "0.5.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e17bb3549cc1321ae1296b9cdc2698e2b6cb1992adfa19a8c72e5b7a738f44cd"
dependencies = [
"time",
]
[[package]]
name = "yoke"
version = "0.8.1"

View File

@@ -50,11 +50,6 @@ notify = "= 8.2.0"
pin-project-lite = "= 0.2.17"
pyo3 = "= 0.28.3"
pyo3-build-config = "= 0.28.3"
rand = "= 0.10.1"
rcgen = { version = "= 0.14.7", default-features = false, features = [
"aws_lc_rs",
"fips",
] }
regex = "= 1.12.3"
reqwest = { version = "= 0.13.3", features = [
"form",
@@ -105,10 +100,6 @@ time = { version = "= 0.3.47", features = ["macros"] }
tokio = { version = "= 1.52.1", features = ["full", "tracing"] }
tokio-retry2 = "= 0.9.1"
tokio-rustls = "= 0.26.4"
tokio-tungstenite = { version = "= 0.29.0", features = [
"rustls-tls-webpki-roots",
"url",
] }
tokio-util = { version = "= 0.7.18", features = ["full"] }
tower = "= 0.5.3"
tower-http = { version = "= 0.6.8", features = ["timeout"] }
@@ -269,41 +260,28 @@ publish.workspace = true
[features]
default = ["core", "proxy"]
core = ["ak-common/core", "dep:pyo3", "dep:sqlx"]
proxy = ["ak-common/proxy", "dep:ak-client"]
proxy = ["ak-common/proxy"]
[build-dependencies]
pyo3-build-config.workspace = true
[dependencies]
ak-axum.workspace = true
ak-client = { workspace = true, optional = true }
ak-common.workspace = true
arc-swap.workspace = true
argh.workspace = true
axum-server.workspace = true
axum.workspace = true
color-eyre.workspace = true
eyre.workspace = true
futures.workspace = true
hyper-unix-socket.workspace = true
hyper-util.workspace = true
metrics-exporter-prometheus.workspace = true
metrics.workspace = true
metrics-exporter-prometheus.workspace = true
nix.workspace = true
pyo3 = { workspace = true, optional = true }
rand.workspace = true
rustls.workspace = true
serde.workspace = true
serde_json.workspace = true
serde_repr.workspace = true
sqlx = { workspace = true, optional = true }
time.workspace = true
tokio-retry2.workspace = true
tokio-tungstenite.workspace = true
tokio.workspace = true
tower.workspace = true
tracing.workspace = true
url.workspace = true
uuid.workspace = true
which.workspace = true

View File

@@ -217,10 +217,7 @@ class BlueprintInstanceViewSet(UsedByMixin, ModelViewSet):
@extend_schema(
request={"multipart/form-data": BlueprintUploadSerializer},
responses={
204: BlueprintImportResultSerializer,
400: BlueprintImportResultSerializer,
},
responses={200: BlueprintImportResultSerializer},
)
@action(url_path="import", detail=False, methods=["POST"], parser_classes=(MultiPartParser,))
@validate(
@@ -247,21 +244,13 @@ class BlueprintInstanceViewSet(UsedByMixin, ModelViewSet):
import_response = self.BlueprintImportResultSerializer(
data={
"logs": [],
"success": False,
"logs": [LogEventSerializer(log).data for log in logs],
"success": valid,
}
)
import_response.is_valid(raise_exception=True)
import_response.initial_data["logs"] = [LogEventSerializer(log).data for log in logs]
import_response.initial_data["success"] = valid
import_response.is_valid()
if not valid:
return Response(data=import_response.initial_data, status=200)
successful = importer.apply()
import_response.initial_data["success"] = successful
import_response.is_valid()
if not successful:
return Response(data=import_response.initial_data, status=200)
if valid:
import_response.initial_data["success"] = importer.apply()
import_response.is_valid()
return Response(data=import_response.initial_data, status=200)

View File

@@ -3,6 +3,7 @@
from json import dumps, loads
from tempfile import NamedTemporaryFile, mkdtemp
from django.core.files.uploadedfile import SimpleUploadedFile
from django.urls import reverse
from rest_framework.test import APITestCase
from yaml import dump
@@ -141,6 +142,20 @@ class TestBlueprintsV1API(APITestCase):
)
self.assertEqual(res.status_code, 200)
def test_api_import_invalid_blueprint_returns_result_payload(self):
"""Invalid blueprint content returns a result payload instead of a 400 response."""
file = SimpleUploadedFile("invalid-blueprint.yaml", b'{"version": 3}')
res = self.client.post(
reverse("authentik_api:blueprintinstance-import-"),
data={"file": file},
format="multipart",
)
self.assertEqual(res.status_code, 200)
self.assertFalse(res.json()["success"])
self.assertGreater(len(res.json()["logs"]), 0)
def test_api_import_unknown_path(self):
"""Path not in available blueprints is rejected (covers api.py:56)."""
res = self.client.post(

View File

@@ -7,7 +7,7 @@ from dramatiq.broker import Broker, MessageProxy, get_broker
from dramatiq.middleware.middleware import Middleware
from dramatiq.middleware.retries import Retries
from dramatiq.results.middleware import Results
from dramatiq.worker import Worker, _ConsumerThread, _WorkerThread
from dramatiq.worker import ConsumerThread, Worker, WorkerThread
from authentik.tasks.broker import PostgresBroker
@@ -20,7 +20,7 @@ class TestWorker(Worker):
self.worker_id = 1000
self.work_queue = PriorityQueue()
self.consumers = {
TESTING_QUEUE: _ConsumerThread(
TESTING_QUEUE: ConsumerThread(
broker=self.broker,
queue_name=TESTING_QUEUE,
prefetch=2,
@@ -33,7 +33,7 @@ class TestWorker(Worker):
prefetch=2,
timeout=1,
)
self._worker = _WorkerThread(
self._worker = WorkerThread(
broker=self.broker,
consumers=self.consumers,
work_queue=self.work_queue,
@@ -78,17 +78,18 @@ def use_test_broker():
actor.broker = broker
actor.broker.declare_actor(actor)
for middleware_class, middleware_kwargs in Conf().middlewares:
middleware: Middleware = import_string(middleware_class)(
for middleware_class_path, middleware_kwargs in Conf().middlewares:
middleware_class = import_string(middleware_class_path)
if issubclass(middleware_class, Results):
middleware_kwargs["backend"] = import_string(Conf().result_backend)(
*Conf().result_backend_args,
**Conf().result_backend_kwargs,
)
middleware: Middleware = middleware_class(
**middleware_kwargs,
)
if isinstance(middleware, Retries):
middleware.max_retries = 0
if isinstance(middleware, Results):
middleware.backend = import_string(Conf().result_backend)(
*Conf().result_backend_args,
**Conf().result_backend_kwargs,
)
broker.add_middleware(middleware)
broker.start()

View File

@@ -79,7 +79,7 @@ function prepare_debug {
apt-get update
apt-get install -y --no-install-recommends krb5-kdc krb5-user krb5-admin-server libkrb5-dev gcc
source "${VENV_PATH}/bin/activate"
uv sync --active --frozen
uv sync --active --locked
touch /unittest.xml
chown authentik:authentik /unittest.xml
}

View File

@@ -101,6 +101,8 @@ RUN --mount=type=bind,target=rust-toolchain.toml,src=rust-toolchain.toml \
rustc --version && \
cargo --version
RUN cat /root/.rustup/settings.toml
# Stage: Download uv
FROM ghcr.io/astral-sh/uv:0.11.5@sha256:555ac94f9a22e656fc5f2ce5dfee13b04e94d099e46bb8dd3a73ec7263f2e484 AS uv
# Stage: Base python image
@@ -198,7 +200,7 @@ RUN --mount=type=bind,target=pyproject.toml,src=pyproject.toml \
--mount=type=bind,target=packages/django-postgres-cache,src=packages/django-postgres-cache \
--mount=type=bind,target=rust-toolchain.toml,src=rust-toolchain.toml \
--mount=type=cache,id=uv-python-deps-$TARGETARCH$TARGETVARIANT,target=/root/.cache/uv \
uv sync --frozen --no-install-project --no-dev
uv sync --locked --no-install-project --no-dev
# Stage: Run
FROM python-base AS final-image

View File

@@ -21,45 +21,33 @@ COPY web .
RUN npm run build-proxy
# Stage 2: Build
FROM ghcr.io/goauthentik/fips-debian:trixie-slim-fips@sha256:7726387c78b5787d2146868c2ccc8948a3591d0a5a6436f7780c8c28acc76341 AS builder
FROM --platform=${BUILDPLATFORM} docker.io/library/golang:1.26.2-trixie@sha256:4a7137ea573f79c86ae451ff05817ed762ef5597fcf732259e97abeb3108d873 AS builder
ARG TARGETOS
ARG TARGETARCH
ARG TARGETVARIANT
ENV PATH="/root/.cargo/bin:$PATH"
SHELL ["/bin/sh", "-o", "pipefail", "-c"]
RUN rm -f /etc/apt/apt.conf.d/docker-clean; echo 'Binary::apt::APT::Keep-Downloaded-Packages "true";' > /etc/apt/apt.conf.d/keep-cache
RUN --mount=type=cache,id=apt-$TARGETARCH$TARGETVARIANT,sharing=locked,target=/var/cache/apt \
--mount=type=bind,target=rust-toolchain.toml,src=rust-toolchain.toml \
apt-get update && \
# Required for installing pip packages
apt-get install -y --no-install-recommends \
# Build essentials
build-essential \
# aws-lc deps
cmake clang golang && \
curl https://sh.rustup.rs -sSf | sh -s -- -y --profile minimal --default-toolchain none && \
rustup install && \
rustup default "$(sed -n 's/channel = "\(.*\)"/\1/p' rust-toolchain.toml)" && \
rustc --version && \
cargo --version
# See https://github.com/aws/aws-lc-rs/issues/569
ENV AWS_LC_FIPS_SYS_CC=clang
ARG GOOS=$TARGETOS
ARG GOARCH=$TARGETARCH
RUN --mount=type=bind,target=rust-toolchain.toml,src=rust-toolchain.toml \
--mount=type=bind,target=Cargo.toml,src=Cargo.toml \
--mount=type=bind,target=Cargo.lock,src=Cargo.lock \
--mount=type=bind,target=.cargo/,src=.cargo/ \
--mount=type=bind,target=src/,src=src/ \
--mount=type=bind,target=packages/,src=packages/ \
--mount=type=bind,target=authentik/lib/default.yml,src=authentik/lib/default.yml \
# Required otherwise workspace discovery fails
--mount=type=bind,target=website/scripts/docsmg/,src=website/scripts/docsmg/ \
--mount=type=cache,id=cargo-git-db-$TARGETARCH$TARGETVARIANT,target=/root/.cargo/git/db/ \
--mount=type=cache,id=cargo-registry-$TARGETARCH$TARGETVARIANT,target=/root/.cargo/registry/ \
--mount=type=cache,id=rust-target-$TARGETARCH$TARGETVARIANT,target=/build/target/ \
cargo build --package authentik --no-default-features --features proxy --locked --release && \
cp ./target/release/authentik /bin/authentik
WORKDIR /go/src/goauthentik.io
RUN --mount=type=cache,id=apt-$TARGETARCH$TARGETVARIANT,sharing=locked,target=/var/cache/apt \
dpkg --add-architecture arm64 && \
apt-get update && \
apt-get install -y --no-install-recommends crossbuild-essential-arm64 gcc-aarch64-linux-gnu
RUN --mount=type=bind,target=/go/src/goauthentik.io/go.mod,src=./go.mod \
--mount=type=bind,target=/go/src/goauthentik.io/go.sum,src=./go.sum \
--mount=type=cache,target=/go/pkg/mod \
go mod download
COPY . .
RUN --mount=type=cache,sharing=locked,target=/go/pkg/mod \
--mount=type=cache,id=go-build-$TARGETARCH$TARGETVARIANT,sharing=locked,target=/root/.cache/go-build \
if [ "$TARGETARCH" = "arm64" ]; then export CC=aarch64-linux-gnu-gcc && export CC_FOR_TARGET=gcc-aarch64-linux-gnu; fi && \
CGO_ENABLED=1 GOFIPS140=latest GOARM="${TARGETVARIANT#v}" \
go build -o /go/proxy ./cmd/proxy
# Stage 3: Run
FROM ghcr.io/goauthentik/fips-debian:trixie-slim-fips@sha256:7726387c78b5787d2146868c2ccc8948a3591d0a5a6436f7780c8c28acc76341
@@ -84,13 +72,13 @@ RUN apt-get update && \
apt-get clean && \
rm -rf /tmp/* /var/lib/apt/lists/*
COPY --from=builder /bin/authentik /
COPY --from=builder /go/proxy /
COPY --from=web-builder /static/robots.txt /web/robots.txt
COPY --from=web-builder /static/security.txt /web/security.txt
COPY --from=web-builder /static/dist/ /web/dist/
COPY --from=web-builder /static/authentik/ /web/authentik/
HEALTHCHECK --interval=5s --retries=20 --start-period=3s CMD [ "/authentik", "healthcheck" ]
HEALTHCHECK --interval=5s --retries=20 --start-period=3s CMD [ "/proxy", "healthcheck" ]
EXPOSE 9000 9300 9443
@@ -99,4 +87,4 @@ USER 1000
ENV TMPDIR=/dev/shm/ \
GOFIPS=1
ENTRYPOINT ["/authentik", "proxy"]
ENTRYPOINT ["/proxy"]

View File

@@ -28,12 +28,7 @@ class HttpHandler(BaseHTTPRequestHandler):
_ = db_conn.cursor()
def do_GET(self):
from django.db import (
DatabaseError,
InterfaceError,
OperationalError,
connections,
)
from django.db import DatabaseError, InterfaceError, OperationalError, connections
from psycopg.errors import AdminShutdown
from authentik.root.monitoring import monitoring_set
@@ -42,7 +37,6 @@ class HttpHandler(BaseHTTPRequestHandler):
AdminShutdown,
InterfaceError,
DatabaseError,
ConnectionError,
OperationalError,
)

View File

@@ -27,7 +27,6 @@ ipnet.workspace = true
json-subscriber.workspace = true
notify.workspace = true
pin-project-lite.workspace = true
rcgen.workspace = true
reqwest.workspace = true
reqwest-middleware.workspace = true
rustls.workspace = true

View File

@@ -1,6 +1,6 @@
//! Utilities for working with the authentik API client.
use ak_client::{apis::configuration::Configuration, models::Pagination};
use ak_client::apis::configuration::Configuration;
use eyre::{Result, eyre};
use url::Url;
@@ -60,42 +60,6 @@ pub fn make_config() -> Result<Configuration> {
})
}
/// Fetch all pages from a paginated API endpoint, returning all results combined.
///
/// - `fetch`: a function that takes a page number and returns a future resolving to a paginated
/// response.
/// - `get_pagination`: a function that extracts the [`Pagination`] metadata from a response.
/// - `get_results`: a function that extracts the result items from a response.
pub async fn fetch_all<T, R, E, F, Fut>(
fetch: F,
get_pagination: impl Fn(&R) -> &Pagination,
get_results: impl Fn(R) -> Vec<T>,
) -> std::result::Result<Vec<T>, E>
where
F: Fn(i32) -> Fut,
Fut: Future<Output = std::result::Result<R, E>>,
{
let mut page = 1;
let mut results = Vec::with_capacity(0);
loop {
let response = fetch(page).await?;
let next = get_pagination(&response).next;
if page == 1 {
let count = get_pagination(&response).count as usize;
results.reserve(count);
}
results.extend(get_results(response));
if next > 0.0 {
page += 1;
} else {
break;
}
}
Ok(results)
}
#[cfg(test)]
mod tests {
use serde_json::json;

View File

@@ -3,9 +3,8 @@ use std::{collections::HashMap, net::SocketAddr, num::NonZeroUsize};
use ipnet::IpNet;
use serde::{Deserialize, Serialize};
pub(super) const KEYS_TO_PARSE_AS_LIST: [&str; 5] = [
pub(super) const KEYS_TO_PARSE_AS_LIST: [&str; 4] = [
"listen.http",
"listen.https",
"listen.metrics",
"listen.trusted_proxy_cidrs",
"log.http_headers",
@@ -60,7 +59,6 @@ pub struct PostgreSQLConfig {
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ListenConfig {
pub http: Vec<SocketAddr>,
pub https: Vec<SocketAddr>,
pub metrics: Vec<SocketAddr>,
pub debug_tokio: SocketAddr,
pub trusted_proxy_cidrs: Vec<IpNet>,

View File

@@ -7,8 +7,6 @@ use tracing::trace;
use crate::config;
pub mod self_signed;
/// Dummy resolver for FIPS compliance check.
#[derive(Debug)]
struct EmptyCertResolver;

View File

@@ -1,52 +0,0 @@
use eyre::Result;
use rcgen::{
Certificate, CertificateParams, DistinguishedName, DnType, ExtendedKeyUsagePurpose, KeyPair,
KeyUsagePurpose, PKCS_RSA_SHA256, SanType,
};
use rustls::{
crypto::aws_lc_rs::sign::any_supported_type,
pki_types::{CertificateDer, PrivateKeyDer},
sign::CertifiedKey,
};
use time::{Duration, OffsetDateTime};
pub fn generate() -> Result<(Certificate, KeyPair)> {
let signing_key = KeyPair::generate_for(&PKCS_RSA_SHA256)?;
let mut params = CertificateParams::default();
params.not_before = OffsetDateTime::now_utc();
params.not_after = OffsetDateTime::now_utc() + Duration::days(365);
params.distinguished_name = {
let mut dn = DistinguishedName::new();
dn.push(DnType::OrganizationName, "authentik");
dn.push(DnType::CommonName, "authentik default certificate");
dn
};
params.subject_alt_names = vec![SanType::DnsName("*".try_into()?)];
params.key_usages = vec![
KeyUsagePurpose::DigitalSignature,
KeyUsagePurpose::KeyEncipherment,
];
params.extended_key_usages = vec![ExtendedKeyUsagePurpose::ServerAuth];
let cert = params.self_signed(&signing_key)?;
Ok((cert, signing_key))
}
pub fn generate_certifiedkey() -> Result<CertifiedKey> {
let (cert, keypair) = generate()?;
let cert_der = cert.der().to_vec();
let key_der = keypair.serialize_der();
let private_key =
PrivateKeyDer::try_from(key_der).map_err(|_| rcgen::Error::CouldNotParseKeyPair)?;
let signing_key =
any_supported_type(&private_key).map_err(|_| rcgen::Error::CouldNotParseKeyPair)?;
Ok(CertifiedKey::new(
vec![CertificateDer::from(cert_der)],
signing_key,
))
}

View File

@@ -30,12 +30,12 @@ pub fn install() -> Result<()> {
}
if config.debug {
// let console_layer = console_subscriber::ConsoleLayer::builder()
// .server_addr(config.listen.debug_tokio)
// .spawn();
let console_layer = console_subscriber::ConsoleLayer::builder()
.server_addr(config.listen.debug_tokio)
.spawn();
tracing_subscriber::registry()
.with(ErrorLayer::default())
// .with(console_layer)
.with(console_layer)
.with(
fmt::layer()
.compact()
@@ -187,10 +187,13 @@ pub mod sentry {
environment: config.environment,
send_pii: config.send_pii,
#[expect(
clippy::as_conversions,
clippy::cast_possible_truncation,
reason = "This is fine, we'll never get big values here."
)]
#[expect(
clippy::as_conversions,
reason = "This is fine, we'll never get big values here."
)]
sample_rate: config.traces_sample_rate as f32,
})
}

View File

@@ -8,8 +8,8 @@
"url": "https://github.com/goauthentik/authentik.git"
},
"scripts": {
"clean": "tsc -b --clean tsconfig.json tsconfig.esm.json",
"build": "npm run clean && tsc -b tsconfig.json tsconfig.esm.json",
"clean": "tsc -b --clean tsconfig.json tsconfig.esm.json",
"prepare": "npm run build"
},
"main": "./dist/index.js",

View File

@@ -32,16 +32,17 @@ class DjangoDramatiqPostgres(AppConfig):
middleware=[],
)
for middleware_class, middleware_kwargs in Conf().middlewares:
middleware: dramatiq.middleware.middleware.Middleware = import_string(middleware_class)(
**middleware_kwargs,
)
if isinstance(middleware, Results):
middleware.backend = import_string(Conf().result_backend)(
for middleware_class_path, middleware_kwargs in Conf().middlewares:
middleware_class = import_string(middleware_class_path)
if issubclass(middleware_class, Results):
middleware_kwargs["backend"] = import_string(Conf().result_backend)(
*Conf().result_backend_args,
**Conf().result_backend_kwargs,
)
broker.add_middleware(middleware) # type: ignore[no-untyped-call]
middleware: dramatiq.middleware.middleware.Middleware = middleware_class(
**middleware_kwargs,
)
broker.add_middleware(middleware)
dramatiq.set_broker(broker)

View File

@@ -23,11 +23,9 @@ from django.utils.functional import cached_property
from django.utils.module_loading import import_string
from dramatiq.broker import Broker, Consumer, MessageProxy
from dramatiq.common import compute_backoff, current_millis, dq_name, q_name, xq_name
from dramatiq.errors import ConnectionError, QueueJoinTimeout
from dramatiq.errors import BrokerConnectionError, QueueJoinTimeout
from dramatiq.message import Message
from dramatiq.middleware import (
Middleware,
)
from dramatiq.middleware import Middleware
from pglock.core import _cast_lock_id
from psycopg import sql
from psycopg.errors import AdminShutdown
@@ -46,7 +44,6 @@ DATABASE_ERRORS = (
AdminShutdown,
InterfaceError,
DatabaseError,
ConnectionError,
OperationalError,
)
@@ -55,7 +52,7 @@ def channel_name(queue_name: str, identifier: ChannelIdentifier) -> str:
return f"{CHANNEL_PREFIX}.{queue_name}.{identifier.value}"
def raise_connection_error(func: Callable[P, R]) -> Callable[P, R]: # noqa: UP047
def raise_broker_connection_error(func: Callable[P, R]) -> Callable[P, R]: # noqa: UP047
@functools.wraps(func)
def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
try:
@@ -66,13 +63,13 @@ def raise_connection_error(func: Callable[P, R]) -> Callable[P, R]: # noqa: UP0
connections.close_all()
except DATABASE_ERRORS:
pass
raise ConnectionError(str(exc)) from exc # type: ignore[no-untyped-call]
raise BrokerConnectionError(str(exc)) from exc # type: ignore[no-untyped-call]
return wrapper
class PostgresBroker(Broker):
queues: set[str] # type: ignore[assignment]
queues: set[str]
def __init__(
self,
@@ -81,7 +78,7 @@ class PostgresBroker(Broker):
db_alias: str = DEFAULT_DB_ALIAS,
**kwargs: Any,
) -> None:
super().__init__(*args, middleware=[], **kwargs) # type: ignore[no-untyped-call,misc]
super().__init__(*args, middleware=[], **kwargs) # type: ignore[misc]
self.logger = get_logger(__name__, type(self))
self.queues = set()
@@ -122,10 +119,10 @@ class PostgresBroker(Broker):
def declare_queue(self, queue_name: str) -> None:
if queue_name not in self.queues:
self.emit_before("declare_queue", queue_name) # type: ignore[no-untyped-call]
self.emit_before("declare_queue", queue_name)
self.queues.add(queue_name)
# Nothing more to do, all queues are in the same table
self.emit_after("declare_queue", queue_name) # type: ignore[no-untyped-call]
self.emit_after("declare_queue", queue_name)
def model_defaults(self, message: Message[Any]) -> dict[str, Any]:
eta = None
@@ -141,7 +138,7 @@ class PostgresBroker(Broker):
}
@tenacity.retry(
retry=tenacity.retry_if_exception_type(ConnectionError),
retry=tenacity.retry_if_exception_type(BrokerConnectionError),
reraise=True,
wait=tenacity.wait_random_exponential(multiplier=1, max=5),
stop=tenacity.stop_after_attempt(3),
@@ -149,11 +146,11 @@ class PostgresBroker(Broker):
cast(logging.Logger, logger), logging.INFO, exc_info=True
),
)
@raise_connection_error
@raise_broker_connection_error
def enqueue(self, message: Message[Any], *, delay: int | None = None) -> Message[Any]:
queue_name = q_name(message.queue_name) # type: ignore[no-untyped-call]
queue_name = q_name(message.queue_name)
if delay:
message_eta = current_millis() + delay # type: ignore[no-untyped-call]
message_eta = current_millis() + delay
message.options["eta"] = message_eta
self.declare_queue(queue_name)
@@ -163,7 +160,7 @@ class PostgresBroker(Broker):
message.options["model_defaults"] = self.model_defaults(message)
message.options["model_create_defaults"] = {}
self.emit_before("enqueue", message, delay) # type: ignore[no-untyped-call]
self.emit_before("enqueue", message, delay)
with transaction.atomic(using=self.db_alias):
query = {
@@ -185,7 +182,7 @@ class PostgresBroker(Broker):
message.options["task"] = task
message.options["task_created"] = created
self.emit_after("enqueue", message, delay) # type: ignore[no-untyped-call]
self.emit_after("enqueue", message, delay)
return message
def get_declared_queues(self) -> set[str]:
@@ -193,7 +190,7 @@ class PostgresBroker(Broker):
def flush(self, queue_name: str) -> None:
self.query_set.filter(
queue_name__in=(queue_name, dq_name(queue_name), xq_name(queue_name)) # type: ignore[no-untyped-call]
queue_name__in=(queue_name, dq_name(queue_name), xq_name(queue_name))
).delete()
def flush_all(self) -> None:
@@ -375,7 +372,7 @@ class _PostgresConsumer(Consumer):
self.in_processing.add(str(message_id))
return message
@raise_connection_error
@raise_broker_connection_error
def __next__(self) -> MessageProxy | None:
# This method is called every second
@@ -395,7 +392,7 @@ class _PostgresConsumer(Consumer):
if processing >= self.prefetch:
# If we have too many messages already processing, wait and don't consume a message
# straight away, other workers will be faster.
self.misses, backoff_ms = compute_backoff(self.misses, max_backoff=1000) # type: ignore[no-untyped-call]
self.misses, backoff_ms = compute_backoff(self.misses, max_backoff=1000)
self.logger.debug(
"Too many messages in processing, Sleeping",
processing=processing,
@@ -420,7 +417,7 @@ class _PostgresConsumer(Consumer):
break
message = self._consume_one(str(message_id))
if message is not None:
return MessageProxy(message) # type: ignore[no-untyped-call]
return MessageProxy(message)
else:
self.logger.debug("Message already consumed. Skipping.", message_id=message_id)
continue
@@ -444,7 +441,7 @@ class _PostgresConsumer(Consumer):
self.to_unlock.add(str(message_id))
return False
def _post_process_message(self, message: Message[Any], state: TaskState) -> None:
def _post_process_message(self, message: MessageProxy, state: TaskState) -> None:
self.logger.debug("Post-processing message", message=message.message_id, state=state)
try:
self.in_processing.remove(str(message.message_id))
@@ -466,16 +463,16 @@ class _PostgresConsumer(Consumer):
)
message.options["task"] = task
@raise_connection_error
def ack(self, message: Message[Any]) -> None:
@raise_broker_connection_error
def ack(self, message: MessageProxy) -> None:
self._post_process_message(message, TaskState.DONE)
@raise_connection_error
def nack(self, message: Message[Any]) -> None:
@raise_broker_connection_error
def nack(self, message: MessageProxy) -> None:
self._post_process_message(message, TaskState.REJECTED)
@raise_connection_error
def requeue(self, messages: Iterable[Message[Any]]) -> None:
@raise_broker_connection_error
def requeue(self, messages: Iterable[MessageProxy]) -> None:
self.query_set.filter(
message_id__in=[message.message_id for message in messages],
).update(
@@ -514,7 +511,7 @@ class _PostgresConsumer(Consumer):
self.logger.info("Purged messages in all queues", count=count)
self.task_purge_last_run = timezone.now()
@raise_connection_error
@raise_broker_connection_error
def close(self) -> None:
try:
self._purge_locks()

View File

@@ -5,7 +5,7 @@ from signal import pause
from django_dramatiq_postgres.conf import Conf
def worker_metrics() -> None:
def worker_metrics() -> int:
import_module(Conf().autodiscovery["setup_module"])
from django_dramatiq_postgres.middleware import MetricsMiddleware
@@ -15,3 +15,4 @@ def worker_metrics() -> None:
int(os.getenv("dramatiq_prom_port", "9191")),
)
pause()
return 0

View File

@@ -10,7 +10,7 @@ from typing import TYPE_CHECKING, Any, cast
from django.db import DatabaseError, close_old_connections, connections
from dramatiq.actor import Actor
from dramatiq.broker import Broker
from dramatiq.broker import Broker, MessageProxy
from dramatiq.common import current_millis
from dramatiq.message import Message
from dramatiq.middleware.middleware import Middleware
@@ -79,7 +79,7 @@ class DbConnectionMiddleware(Middleware):
class TaskStateBeforeMiddleware(Middleware):
def before_process_message(self, broker: PostgresBroker, message: Message[Any]) -> None:
def before_process_message(self, broker: PostgresBroker, message: Message[Any]) -> None: # type: ignore[override]
broker.query_set.filter(
message_id=message.message_id,
queue_name=message.queue_name,
@@ -90,7 +90,7 @@ class TaskStateBeforeMiddleware(Middleware):
class TaskStateAfterMiddleware(Middleware):
def before_process_message(self, broker: PostgresBroker, message: Message[Any]) -> None:
def before_process_message(self, broker: PostgresBroker, message: MessageProxy) -> None: # type: ignore[override]
broker.query_set.filter(
message_id=message.message_id,
queue_name=message.queue_name,
@@ -99,7 +99,7 @@ class TaskStateAfterMiddleware(Middleware):
state=TaskState.RUNNING,
)
def after_skip_message(self, broker: PostgresBroker, message: Message[Any]) -> None:
def after_skip_message(self, broker: PostgresBroker, message: MessageProxy) -> None: # type: ignore[override]
broker.query_set.filter(
message_id=message.message_id,
queue_name=message.queue_name,
@@ -110,11 +110,11 @@ class TaskStateAfterMiddleware(Middleware):
def after_process_message(
self,
broker: PostgresBroker,
message: Message[Any],
broker: PostgresBroker, # type: ignore[override]
message: MessageProxy,
*,
result: Any | None = None,
exception: Exception | None = None,
exception: BaseException | None = None,
) -> None:
self.after_skip_message(broker, message)
@@ -147,7 +147,7 @@ class CurrentTask(Middleware):
raise CurrentTaskNotFound()
return task[-1]
def before_process_message(self, broker: Broker, message: Message[Any]) -> None:
def before_process_message(self, broker: Broker, message: MessageProxy) -> None:
tasks = self._TASKS.get()
if tasks is None:
tasks = []
@@ -157,10 +157,10 @@ class CurrentTask(Middleware):
def after_process_message(
self,
broker: Broker,
message: Message[Any],
message: MessageProxy,
*,
result: Any | None = None,
exception: Exception | None = None,
exception: BaseException | None = None,
) -> None:
tasks: list[TaskBase] | None = self._TASKS.get()
if tasks is None or len(tasks) == 0:
@@ -194,7 +194,7 @@ class CurrentTask(Middleware):
pass
self._TASKS.set(tasks[:-1])
def after_skip_message(self, broker: Broker, message: Message[Any]) -> None:
def after_skip_message(self, broker: Broker, message: MessageProxy) -> None:
self.after_process_message(broker, message)
@@ -236,7 +236,7 @@ class MetricsMiddleware(Middleware):
self.message_start_times: dict[str, int] = {}
@property
def forks(self) -> list[Callable[[], None]]:
def forks(self) -> list[Callable[[], int]]:
from django_dramatiq_postgres.forks import worker_metrics
return [worker_metrics]
@@ -310,41 +310,41 @@ class MetricsMiddleware(Middleware):
# TODO: worker_id
multiprocess.mark_process_dead(os.getpid()) # type: ignore[no-untyped-call]
def _make_labels(self, message: Message[Any]) -> list[str]:
def _make_labels(self, message: MessageProxy | Message[Any]) -> list[str]:
return [message.queue_name, message.actor_name]
def after_nack(self, broker: Broker, message: Message[Any]) -> None:
def after_nack(self, broker: Broker, message: MessageProxy) -> None:
self.total_rejected_messages.labels(*self._make_labels(message)).inc()
def after_enqueue(self, broker: Broker, message: Message[Any], delay: int) -> None:
if "retries" in message.options:
self.total_retried_messages.labels(*self._make_labels(message)).inc()
def before_delay_message(self, broker: Broker, message: Message[Any]) -> None:
def before_delay_message(self, broker: Broker, message: MessageProxy) -> None:
self.delayed_messages.add(message.message_id)
self.in_progress_delayed_messages.labels(*self._make_labels(message)).inc()
def before_process_message(self, broker: Broker, message: Message[Any]) -> None:
def before_process_message(self, broker: Broker, message: MessageProxy) -> None:
labels = self._make_labels(message)
if message.message_id in self.delayed_messages:
self.delayed_messages.remove(message.message_id)
self.in_progress_delayed_messages.labels(*labels).dec()
self.in_progress_messages.labels(*labels).inc()
self.message_start_times[message.message_id] = current_millis() # type: ignore[no-untyped-call]
self.message_start_times[message.message_id] = current_millis()
def after_process_message(
self,
broker: Broker,
message: Message[Any],
message: MessageProxy,
*,
result: Any | None = None,
exception: Exception | None = None,
exception: BaseException | None = None,
) -> None:
labels = self._make_labels(message)
message_start_time = self.message_start_times.pop(message.message_id, current_millis()) # type: ignore[no-untyped-call]
message_duration = current_millis() - message_start_time # type: ignore[no-untyped-call]
message_start_time = self.message_start_times.pop(message.message_id, current_millis())
message_duration = current_millis() - message_start_time
self.messages_durations.labels(*labels).observe(message_duration)
self.in_progress_messages.labels(*labels).dec()

View File

@@ -159,7 +159,7 @@ class ScheduleBase(models.Model):
def send(self, broker: Broker | None = None) -> Message[Any]:
broker = broker or get_broker()
actor: Actor[Any, Any] = broker.get_actor(self.actor_name) # type: ignore[no-untyped-call]
actor: Actor[Any, Any] = broker.get_actor(self.actor_name)
return actor.send_with_options(
args=pickle.loads(self.args), # nosec
kwargs=pickle.loads(self.kwargs), # nosec

View File

@@ -36,7 +36,7 @@ dependencies = [
"django >=4.2,<6.0",
"django-pglock >=1.7,<2",
"django-pgtrigger >=4,<5",
"dramatiq >=1.17,<1.18",
"dramatiq >=2,<3",
"tenacity >=9,<10",
"structlog >=25,<26",
]

View File

@@ -9678,18 +9678,14 @@ paths:
security:
- authentik: []
responses:
'204':
'200':
content:
application/json:
schema:
$ref: '#/components/schemas/BlueprintImportResult'
description: ''
'400':
content:
application/json:
schema:
$ref: '#/components/schemas/BlueprintImportResult'
description: ''
$ref: '#/components/responses/ValidationErrorResponse'
'403':
$ref: '#/components/responses/GenericErrorResponse'
/oauth2/access_tokens/:

View File

@@ -8,8 +8,6 @@ use eyre::{Result, eyre};
use tracing::{error, info, trace};
mod metrics;
#[cfg(feature = "proxy")]
mod outpost;
#[cfg(feature = "core")]
mod server;
#[cfg(feature = "core")]
@@ -31,8 +29,6 @@ enum Command {
Server(server::Cli),
#[cfg(feature = "core")]
Worker(worker::Cli),
#[cfg(feature = "proxy")]
Proxy(outpost::proxy::Cli),
}
#[derive(Debug, FromArgs, PartialEq)]
@@ -57,8 +53,6 @@ fn main() -> Result<()> {
Command::Server(_) => Mode::set(Mode::Server)?,
#[cfg(feature = "core")]
Command::Worker(_) => Mode::set(Mode::Worker)?,
#[cfg(feature = "proxy")]
Command::Proxy(_) => Mode::set(Mode::Proxy)?,
}
trace!("installing error formatting");
@@ -114,10 +108,6 @@ fn main() -> Result<()> {
let workers = worker::start(args, &mut tasks)?;
metrics.workers.store(Some(workers));
}
#[cfg(feature = "proxy")]
Command::Proxy(args) => {
outpost::start::<outpost::proxy::ProxyOutpost>(args, &mut tasks).await?;
}
}
let errors = tasks.run().await;

View File

@@ -1,312 +0,0 @@
use std::{fmt::Display, sync::Arc};
use ak_common::{Arbiter, Tasks, VERSION, api, arbiter, authentik_build_hash};
use axum::http::{HeaderValue, header::AUTHORIZATION};
use eyre::{Result, eyre};
use futures::{SinkExt as _, StreamExt as _};
use nix::unistd::gethostname;
use serde::{Deserialize, Serialize};
use serde_repr::{Deserialize_repr, Serialize_repr};
use time::UtcDateTime;
use tokio::{
signal::unix::SignalKind,
time::{Duration, interval, sleep},
};
use tokio_tungstenite::tungstenite::{Message, client::IntoClientRequest as _};
use tracing::{debug, info, instrument, trace, warn};
use url::Url;
use crate::outpost::{Outpost, OutpostController};
#[derive(Serialize_repr, Deserialize_repr, PartialEq, Debug, Clone, Copy, Eq)]
#[repr(u8)]
enum EventKind {
/// Code used to acknowledge a previous message.
Ack = 0,
/// Code used to send a healthcheck keepalive.
Hello = 1,
/// Code received to trigger a config update.
TriggerUpdate = 2,
/// Code received to trigger some provider specific function.
ProviderSpecific = 3,
/// Code received to identify the end of a session.
SessionEnd = 4,
}
impl Display for EventKind {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Ack => write!(f, "Ack"),
Self::Hello => write!(f, "Hello"),
Self::TriggerUpdate => write!(f, "TriggerUpdate"),
Self::ProviderSpecific => write!(f, "ProviderSpecific"),
Self::SessionEnd => write!(f, "SessionEnd"),
}
}
}
#[derive(Serialize, Deserialize)]
struct Event {
instruction: EventKind,
args: serde_json::Value,
}
#[derive(Debug, Deserialize)]
pub(crate) struct EventSessionEnd {
session_id: String,
}
fn build_ws_url(mut url: Url, outpost_pk: &str, instance_uuid: &str, attempt: u32) -> Result<Url> {
let ws_scheme = match url.scheme() {
"https" => "wss",
"http" => "ws",
other => return Err(eyre!("Unsupported scheme for WebSocket URL: {other}")),
};
url.set_scheme(ws_scheme)
.map_err(|()| eyre!("Failed to set URL scheme to {ws_scheme}"))?;
url.set_path(&format!("{}ws/outpost/{outpost_pk}/", url.path()));
url.query_pairs_mut()
.append_pair("instance_uuid", instance_uuid)
.append_pair("attempt", &attempt.to_string());
Ok(url)
}
fn hello_args(instance_uuid: &str) -> serde_json::Value {
let raw_hostname = gethostname().unwrap_or_default();
let hostname = raw_hostname.to_string_lossy();
serde_json::json!({
"version": VERSION,
"buildHash": authentik_build_hash(None),
"uuid": instance_uuid,
// TODO: rust version and AWS-LC versions
"hostname": hostname,
})
}
#[instrument(skip_all)]
async fn handle_event<O: Outpost>(
controller: Arc<OutpostController>,
outpost: Arc<O>,
event: Event,
) -> Result<()> {
match event.instruction {
EventKind::Ack | EventKind::Hello => {}
EventKind::TriggerUpdate => {
info!("received update trigger, refreshing outpost");
sleep(controller.reload_offset).await;
controller.refresh().await?;
debug!("outpost controller has been refreshed");
outpost.refresh().await?;
debug!("outpost has been refreshed");
#[expect(
clippy::as_conversions,
clippy::cast_precision_loss,
reason = "This is fine, we'll never get big values here."
)]
controller
.m_last_update
.set(UtcDateTime::now().unix_timestamp() as f64);
}
EventKind::SessionEnd => {
let event: EventSessionEnd = serde_json::from_value(event.args)?;
outpost.end_session(event).await?;
}
#[expect(
clippy::unimplemented,
reason = "this is only relevant for the RAC provider"
)]
EventKind::ProviderSpecific => unimplemented!(),
}
Ok(())
}
async fn watch_events_inner<O: Outpost>(
arbiter: Arbiter,
controller: Arc<OutpostController>,
outpost: Arc<O>,
attempt: u32,
) -> Result<()> {
let server_config = api::ServerConfig::new()?;
let ws_url = build_ws_url(
server_config.host,
&controller.outpost.load().pk.to_string(),
&controller.instance_uuid.to_string(),
attempt,
)?;
debug!(url = %ws_url, "connecting to websocket");
let mut request = ws_url.into_client_request()?;
let token = controller
.api_config
.bearer_access_token
.as_deref()
.unwrap_or("");
request.headers_mut().insert(
AUTHORIZATION,
HeaderValue::from_str(&format!("Bearer {token}"))?,
);
let (ws_stream, _response) = tokio_tungstenite::connect_async(request).await?;
let (mut ws_write, mut ws_read) = ws_stream.split();
info!(
outpost = %controller.outpost.load().pk,
"connected to websocket"
);
controller.m_connection.set(1_u8);
let get_refresh_interval = || {
let mut interval = controller.outpost.load().refresh_interval_s;
// Ensure timer interval is not negative or 0.
// If it is, we default to 5 minutes.
if interval <= 0_i32 {
interval = 60_i32 * 5_i32;
}
// Clamp interval to be at least 30 seconds.
if interval < 30_i32 {
interval = 30_i32;
}
// infallible because we bound it to be positive above
Duration::from_secs(interval.try_into().expect("infallible"))
};
let mut refresh_interval = interval(get_refresh_interval());
let mut heartbeat_interval = interval(Duration::from_secs(10));
let mut events_rx = arbiter.events_subscribe();
loop {
tokio::select! {
_ = refresh_interval.tick() => {
info!("refreshing outpost on interval");
if let Err(err) = handle_event(
Arc::clone(&controller),
Arc::clone(&outpost),
Event {
instruction: EventKind::TriggerUpdate,
args: serde_json::Value::Null
}
).await {
warn!(?err, "failed to refresh");
}
refresh_interval = interval(get_refresh_interval());
// Since we re-create the interval, we need to make it tick instantly to avoid
// ending up in a never-ending tick-loop.
refresh_interval.tick().await;
},
_ = heartbeat_interval.tick() => {
let ping = Event {
instruction: EventKind::Hello,
args: hello_args(&controller.instance_uuid.to_string()),
};
ws_write.send(Message::text(serde_json::to_string(&ping)?)).await?;
trace!("sent websocket hello (heartbeat)");
},
Ok(arbiter::Event::Signal(signal)) = events_rx.recv() => {
if signal == SignalKind::user_defined1() {
info!("refreshing outpost on signal");
if let Err(err) = handle_event(
Arc::clone(&controller),
Arc::clone(&outpost),
Event {
instruction: EventKind::TriggerUpdate,
args: serde_json::Value::Null
}
).await {
warn!(?err, "failed to refresh");
}
}
},
msg = ws_read.next() => {
let Some(msg) = msg else {
break;
};
let msg = msg?;
match msg {
Message::Text(text) => {
let Ok(event): Result<Event, _> = serde_json::from_str(&text) else {
warn!(data = text.as_str(), "failed to parse event");
continue;
};
trace!(event = %event.instruction, "received websocket event");
if let Err(err) = handle_event(
Arc::clone(&controller),
Arc::clone(&outpost),
event,
).await {
warn!(?err, "failed to handle event");
}
},
Message::Ping(data) => {
ws_write.send(Message::Pong(data)).await?;
},
Message::Close(_) => {
break;
},
_ => {},
}
},
() = arbiter.shutdown() => break,
}
}
Ok(())
}
async fn watch_events<O: Outpost>(
arbiter: Arbiter,
controller: Arc<OutpostController>,
outpost: Arc<O>,
) -> Result<()> {
const MAX_BACKOFF: Duration = Duration::from_mins(5);
let mut backoff = Duration::from_secs(1);
let mut attempt: u32 = 0;
loop {
tokio::select! {
() = arbiter.shutdown() => break,
res = watch_events_inner(
arbiter.clone(),
Arc::clone(&controller),
Arc::clone(&outpost),
attempt
) => {
controller.m_connection.set(0_u8);
match res {
Ok(()) => debug!("websocket disconnected cleanly"),
Err(err) => warn!(?err, attempt, "websocket error"),
}
info!(attempt, delay = backoff.as_secs(), "reconnecting websocket in {}s...", backoff.as_secs());
tokio::select! {
() = arbiter.shutdown() => break,
() = sleep(backoff) => {}
}
backoff = (backoff * 2).min(MAX_BACKOFF);
attempt += 1;
}
}
}
info!("stopping event watcher");
Ok(())
}
pub(crate) fn start<O: Outpost + 'static>(
tasks: &mut Tasks,
controller: Arc<OutpostController>,
outpost: Arc<O>,
) -> Result<()> {
let arbiter = tasks.arbiter();
tasks
.build_task()
.name(&format!("{}::watch_events", module_path!()))
.spawn(watch_events(arbiter, controller, outpost))?;
Ok(())
}

View File

@@ -1,123 +0,0 @@
use std::{sync::Arc, time::Duration};
use ak_client::{
apis::{configuration::Configuration, outposts_api::outposts_instances_list},
models::Outpost as OutpostModel,
};
use ak_common::{Tasks, VERSION, api, authentik_build_hash};
use arc_swap::ArcSwap;
use eyre::{Result, eyre};
use tracing::{debug, info, instrument};
use uuid::Uuid;
pub(crate) mod event;
#[cfg(feature = "proxy")]
pub(crate) mod proxy;
pub(crate) trait Outpost: Send + Sync + Sized {
const OUTPOST_TYPE: &'static str;
type Cli: Send + Sync;
async fn new(controller: Arc<OutpostController>) -> Result<Self>;
fn start(self: Arc<Self>, tasks: &mut Tasks) -> Result<()>;
fn refresh(&self) -> impl Future<Output = Result<()>> + Send;
fn end_session(&self, event: event::EventSessionEnd)
-> impl Future<Output = Result<()>> + Send;
}
#[derive(Debug)]
pub(crate) struct OutpostController {
api_config: Configuration,
outpost: ArcSwap<OutpostModel>,
instance_uuid: Uuid,
reload_offset: Duration,
m_info: metrics::Gauge,
m_last_update: metrics::Gauge,
m_connection: metrics::Gauge,
}
impl OutpostController {
#[instrument(skip_all)]
async fn get_outpost(api_config: &Configuration) -> Result<OutpostModel> {
let outposts = outposts_instances_list(
api_config, None, None, None, None, None, None, None, None, None, None, None, None,
)
.await?;
let Some(outpost) = outposts.results.into_iter().next() else {
return Err(eyre!(
"No outposts found with given token, ensure the given token corresponds to an \
authentik Outpost"
));
};
debug!(name = outpost.name, "fetched outpost configuration");
Ok(outpost)
}
#[instrument(skip_all)]
async fn new<O: Outpost>() -> Result<Self> {
let api_config = api::make_config()?;
let outpost = Self::get_outpost(&api_config).await?;
let instance_uuid = Uuid::new_v4();
let m_labels = [
("outpost_name", outpost.name.clone()),
("outpost_type", O::OUTPOST_TYPE.to_owned()),
("uuid", instance_uuid.to_string()),
("version", VERSION.to_owned()),
("build", authentik_build_hash(None)),
];
metrics::describe_gauge!("authentik_outpost_info", "Outpost info");
let m_info = metrics::gauge!("authentik_outpost_info", &m_labels);
metrics::describe_gauge!("authentik_outpost_last_update", "Time of last update");
let m_last_update = metrics::gauge!("authentik_outpost_last_update", &m_labels);
metrics::describe_gauge!("authentik_outpost_connection", "Connection status");
let m_connection = metrics::gauge!("authentik_outpost_connection", &m_labels);
let reload_offset = Duration::from_secs(rand::random_range(0..10));
let controller = Self {
api_config,
outpost: ArcSwap::from_pointee(outpost),
instance_uuid,
reload_offset,
m_info,
m_last_update,
m_connection,
};
info!(embedded = controller.is_embedded(), "outpost mode");
debug!(?reload_offset, "HA Reload offset");
Ok(controller)
}
fn is_embedded(&self) -> bool {
self.outpost
.load()
.managed
.as_ref()
.and_then(|m| m.as_deref())
.is_some_and(|m| m == "goauthentik.io/outposts/embedded")
}
async fn refresh(&self) -> Result<()> {
let outpost = Self::get_outpost(&self.api_config).await?;
self.outpost.swap(Arc::new(outpost));
Ok(())
}
}
#[instrument(skip_all)]
pub(crate) async fn start<O: Outpost + 'static>(_cli: O::Cli, tasks: &mut Tasks) -> Result<()> {
let controller = Arc::new(OutpostController::new::<O>().await?);
let outpost = Arc::new(O::new(Arc::clone(&controller)).await?);
event::start(tasks, Arc::clone(&controller), Arc::clone(&outpost))?;
outpost.start(tasks)?;
controller.m_info.set(1_u8);
Ok(())
}

View File

@@ -1,82 +0,0 @@
use std::sync::Arc;
use ak_client::{
apis::crypto_api::{
crypto_certificatekeypairs_view_certificate_retrieve,
crypto_certificatekeypairs_view_private_key_retrieve,
},
models::ProxyOutpostConfig,
};
use axum::Router;
use eyre::{Result, eyre};
use rustls::{
crypto::CryptoProvider,
pki_types::{CertificateDer, PrivateKeyDer, pem::PemObject as _},
sign::CertifiedKey,
};
use tracing::instrument;
use url::Url;
use crate::outpost::proxy::ProxyOutpost;
const REDIRECT_PARAM: &str = "rd";
const CALLBACK_SIGNATURE: &str = "X-authentik-auth-callback";
const LOGOUT_SIGNATURE: &str = "X-authentik-logout";
#[derive(Debug)]
pub(super) struct Application {
pub(super) host: String,
pub(super) provider: ProxyOutpostConfig,
pub(super) router: Router,
pub(super) cert: Option<Arc<CertifiedKey>>,
}
impl Application {
#[instrument(skip_all)]
pub(super) async fn new(outpost: &ProxyOutpost, provider: ProxyOutpostConfig) -> Result<Self> {
let external_url = Url::parse(&provider.external_host)?;
if !external_url.has_authority() {
return Err(eyre!("no host in external host"));
}
let external_host = external_url.authority();
// TODO: extract this to a certificate store to avoid re-fetching the certificate every time
let cert = if let Some(Some(kp_uuid)) = provider.certificate {
let cert = crypto_certificatekeypairs_view_certificate_retrieve(
&outpost.controller.api_config,
&kp_uuid.to_string(),
None,
)
.await?;
let key = crypto_certificatekeypairs_view_private_key_retrieve(
&outpost.controller.api_config,
&kp_uuid.to_string(),
None,
)
.await?;
let cert_chain = CertificateDer::pem_reader_iter(cert.data.as_bytes())
.collect::<Result<Vec<_>, _>>()?;
let key_der = PrivateKeyDer::from_pem_reader(key.data.as_bytes())?;
let provider = CryptoProvider::get_default().expect("no rustls provider installed");
Some(Arc::new(CertifiedKey::new(
cert_chain,
provider.key_provider.load_private_key(key_der)?,
)))
} else {
None
};
let _redirect_url = {
let mut redirect_url = external_url.join("outpost.goauthentik.io/callback")?;
redirect_url.set_query(Some(&format!("{CALLBACK_SIGNATURE}=true")));
redirect_url
};
Ok(Self {
host: external_host.to_owned(),
provider,
router: Router::new(),
cert,
})
}
}

View File

@@ -1,76 +0,0 @@
use std::sync::Arc;
use ak_axum::{error::Result, extract::host::Host};
use axum::{
extract::{Request, State},
http::{Method, StatusCode, header::CONTENT_TYPE},
response::{IntoResponse as _, Response},
};
use metrics::histogram;
use serde_json::json;
use tokio::time::Instant;
use tower::util::ServiceExt as _;
use tracing::{Instrument as _, field, info_span, instrument, trace, warn};
use crate::outpost::proxy::ProxyOutpost;
#[instrument(skip_all)]
pub(super) async fn handle_ping(
method: Method,
Host(host): Host,
State(outpost): State<Arc<ProxyOutpost>>,
) -> Response {
let start = Instant::now();
histogram!(
"authentik_outpost_proxy_request_duration_seconds",
"outpost_name" => outpost.controller.outpost.load().name.clone(),
"method" => method.to_string(),
"host" => host,
"type" => "ping",
)
.record(start.elapsed().as_secs_f64());
StatusCode::NO_CONTENT.into_response()
}
#[instrument(skip_all)]
pub(super) async fn default(
method: Method,
Host(host): Host,
State(outpost): State<Arc<ProxyOutpost>>,
request: Request,
) -> Result<Response> {
let span = info_span!("proxy outpost request", user = field::Empty);
let start = Instant::now();
let Some(app) = outpost.lookup_app(&host) else {
trace!(headers = ?request.headers(), "tracing headers for no hostname match");
warn!("no app for hostname");
return Ok(Response::builder()
.status(StatusCode::BAD_REQUEST)
.header(CONTENT_TYPE, "application/json")
.body(
json!({
"message": "no app for hostname",
"host": host,
"detail": format!("check the outpost settings and make sure '{host}' is included."),
})
.to_string()
.into(),
)
.expect("infallible"));
};
trace!("passing to application");
let response = app.router.clone().oneshot(request).instrument(span).await?;
histogram!(
"authentik_outpost_proxy_request_duration_seconds",
"outpost_name" => outpost.controller.outpost.load().name.clone(),
"method" => method.to_string(),
"host" => host,
"type" => "app",
)
.record(start.elapsed().as_secs_f64());
Ok(response)
}

View File

@@ -1,228 +0,0 @@
use std::{collections::HashMap, sync::Arc};
use ak_axum::router::wrap_router;
use ak_client::{apis::outposts_api::outposts_proxy_list, models::ProxyMode};
use ak_common::{Tasks, api::fetch_all, config, tls};
use arc_swap::ArcSwap;
use argh::FromArgs;
use axum::Router;
use axum_server::tls_rustls::RustlsConfig;
use eyre::Result;
use rustls::{
ServerConfig,
server::{ClientHello, ResolvesServerCert},
sign::CertifiedKey,
};
use tracing::{debug, error, info, instrument, warn};
use crate::outpost::{Outpost, OutpostController, proxy::application::Application};
mod application;
mod handlers;
#[derive(Debug, Default, FromArgs, PartialEq, Eq)]
/// Run the authentik proxy outpost.
#[argh(subcommand, name = "proxy")]
#[expect(
clippy::empty_structs_with_brackets,
reason = "argh doesn't support unit structs"
)]
pub(crate) struct Cli {}
#[derive(Debug)]
pub(crate) struct ProxyOutpost {
controller: Arc<OutpostController>,
apps: ArcSwap<HashMap<String, Arc<Application>>>,
default_cert: Arc<CertifiedKey>,
}
impl Outpost for ProxyOutpost {
type Cli = Cli;
const OUTPOST_TYPE: &'static str = "proxy";
#[instrument(skip_all)]
async fn new(controller: Arc<OutpostController>) -> Result<Self> {
Ok(Self {
controller,
apps: ArcSwap::from_pointee(HashMap::with_capacity(0)),
default_cert: Arc::new(tls::self_signed::generate_certifiedkey()?),
})
}
fn start(self: Arc<Self>, tasks: &mut Tasks) -> Result<()> {
let router = build_router(Arc::clone(&self));
for addr in config::get().listen.http.iter().copied() {
ak_axum::server::start_plain(tasks, "proxy-outpost", router.clone(), addr)?;
}
for addr in config::get().listen.https.iter().copied() {
let resolver = Arc::clone(&self);
let server_config = ServerConfig::builder()
.with_no_client_auth()
.with_cert_resolver(resolver);
let rustls_config = RustlsConfig::from_config(Arc::new(server_config));
ak_axum::server::start_tls(
tasks,
"proxy-outpost",
router.clone(),
addr,
rustls_config,
)?;
}
Ok(())
}
#[instrument(skip_all)]
async fn refresh(&self) -> Result<()> {
debug!(
outpost_pk = %self.controller.outpost.load().pk,
"requesting providers for outpost"
);
let providers = fetch_all(
|page| {
outposts_proxy_list(
&self.controller.api_config,
None,
None,
Some(page),
Some(100_i32),
None,
)
},
|r| &r.pagination,
|r| r.results,
)
.await
.inspect_err(|err| error!(?err, "failed to fetch providers"))?;
debug!(count = providers.len(), "fetched providers");
if providers.is_empty() && !self.controller.is_embedded() {
warn!(
"no providers assigned to this outpost, check outpost configuration in authentik"
);
}
for (i, provider) in providers.iter().enumerate() {
debug!(
index = i,
name = provider.name,
external_host = provider.external_host,
assigned_to_app = provider.assigned_application_name,
"provider details"
);
}
let mut apps = HashMap::with_capacity(providers.len());
for provider in providers {
let name = provider.name.clone();
let Ok(application) = Application::new(self, provider)
.await
.inspect_err(|err| warn!(?err, "failed to setup application, skipping provider"))
else {
continue;
};
info!(name, host = application.host, "loaded application");
apps.insert(application.host.clone(), Arc::new(application));
}
self.apps.store(Arc::new(apps));
Ok(())
}
async fn end_session(&self, _event: super::event::EventSessionEnd) -> Result<()> {
// todo!()
warn!(?_event, "removing session");
Ok(())
}
}
impl ResolvesServerCert for ProxyOutpost {
fn resolve(&self, client_hello: ClientHello<'_>) -> Option<Arc<CertifiedKey>> {
if let Some(server_name) = client_hello.server_name()
&& let Some(app) = self.apps.load().get(server_name)
&& let Some(cert) = &app.cert
{
return Some(Arc::clone(cert));
}
Some(Arc::clone(&self.default_cert))
}
fn only_raw_public_keys(&self) -> bool {
false
}
}
impl ProxyOutpost {
#[instrument(skip(self))]
fn lookup_app(&self, host: &str) -> Option<Arc<Application>> {
let apps = self.apps.load();
// If we only have a single app, host name switching doesn't matter.
if apps.len() == 1
&& let Some(app) = apps.values().next()
{
debug!(app = app.provider.name, "found a single app, using it");
return Some(Arc::clone(app));
}
if let Some(app) = apps.get(host) {
debug!(app = app.provider.name, "found app based direct host match");
return Some(Arc::clone(app));
}
// For forward_auth_domain, we don't have a direct app to domain relationship.
// Check through all apps, and check how much of their cookie domain matches the host.
// Return the application that has the longest match.
let mut longest_match = None;
let mut longest_len = 0_usize;
for app in apps.values() {
if app.provider.mode != Some(ProxyMode::ForwardDomain) {
continue;
}
if let Some(cookie_domain) = app.provider.cookie_domain.as_deref() {
// Check if the cookie domain has a leading period for a wildcard.
// This will decrease the weight of a wildcard domain, but a request to example.com
// with the cookie domain set to example.com will still be routed correctly.
let domain = cookie_domain.trim_start_matches('.');
if host.ends_with(domain) && domain.len() > longest_len {
longest_len = domain.len();
longest_match = Some(Arc::clone(app));
}
// For forward_auth_domain, we need to response on the external domain too.
if app.provider.external_host == host {
debug!(app = app.provider.name, "found app based on external_host");
return Some(Arc::clone(app));
}
}
}
if let Some(app) = &longest_match {
debug!(app = app.provider.name, "found app based on cookie domain");
}
longest_match
}
}
fn build_router(outpost: Arc<ProxyOutpost>) -> Router {
wrap_router(
Router::new()
.nest(
"/outpost.goauthentik.io/ping",
Router::new().fallback(handlers::handle_ping),
)
.fallback(handlers::default)
.with_state(outpost),
true,
)
}

11
uv.lock generated
View File

@@ -1143,7 +1143,7 @@ requires-dist = [
{ name = "django", specifier = ">=4.2,<6.0" },
{ name = "django-pglock", specifier = ">=1.7,<2" },
{ name = "django-pgtrigger", specifier = ">=4,<5" },
{ name = "dramatiq", specifier = ">=1.17,<1.18" },
{ name = "dramatiq", specifier = ">=2,<3" },
{ name = "structlog", specifier = ">=25,<26" },
{ name = "tenacity", specifier = ">=9,<10" },
]
@@ -1381,14 +1381,11 @@ wheels = [
[[package]]
name = "dramatiq"
version = "1.17.1"
version = "2.1.0"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "prometheus-client" },
]
sdist = { url = "https://files.pythonhosted.org/packages/c6/7a/6792ddc64a77d22bfd97261b751a7a76cf2f9d62edc59aafb679ac48b77d/dramatiq-1.17.1.tar.gz", hash = "sha256:2675d2f57e0d82db3a7d2a60f1f9c536365349db78c7f8d80a63e4c54697647a", size = 99071, upload-time = "2024-10-26T05:09:28.283Z" }
sdist = { url = "https://files.pythonhosted.org/packages/22/69/02b54e3fc4fe75721b322bc578054b4f03cec258ba614fa98a1a5bbe1efe/dramatiq-2.1.0.tar.gz", hash = "sha256:cf81550729de6cf64234b05bd63970645654aaf38967faa7a2b6e401384bb090", size = 105444, upload-time = "2026-03-03T11:22:10.067Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/ee/36/925c7afd5db4f1a3f00676b9c3c58f31ff7ae29a347282d86c8d429280a5/dramatiq-1.17.1-py3-none-any.whl", hash = "sha256:951cdc334478dff8e5150bb02a6f7a947d215ee24b5aedaf738eff20e17913df", size = 120382, upload-time = "2024-10-26T05:09:26.436Z" },
{ url = "https://files.pythonhosted.org/packages/c2/91/422960c8c415fd31ca1519d71d6f7e4bcabb2cdcc5872f784467e9fe7237/dramatiq-2.1.0-py3-none-any.whl", hash = "sha256:3ef940c2815722d3679aed79ef96c805f02fd33d4361529b2de30f01511ca44d", size = 125543, upload-time = "2026-03-03T11:22:08.664Z" },
]
[[package]]

View File

@@ -7,152 +7,186 @@
*/
/**
* @typedef {(this: Lexer, chr: string) => any} DefunctFunction
* A token produced by a {@link LexerAction}. The lexer is agnostic to the
* concrete token shape; consumers pick whatever representation suits them.
*
* @typedef {unknown} Token
*/
/**
* @typedef {(this: Lexer, ...args: RegExpExecArray) => string | string[] | undefined} RuleAction
* A rule action. Invoked with the regex match (full match followed by capture
* groups) bound to the owning {@link Lexer} so it can read or set `state`,
* `index`, and `reject`.
*
* Return values:
* - `null` (or `undefined` from an implicit return) — discard the match and continue scanning.
* - a single token — yield it from {@link Lexer.lex}.
* - an array of tokens — yield the first; queue the rest for subsequent calls.
*
* @callback LexerAction
* @this {Lexer}
* @param {...string} match
* @returns {Token | Token[] | null | void}
*/
/**
* @typedef {Object} Rule
* @property {RegExp} pattern
* @property {boolean} global
* @property {RuleAction} action
* @property {number[]} start
* @typedef {object} LexerRule
* @property {RegExp} pattern Sticky-compiled pattern used to probe the input.
* @property {boolean} global Whether the user-supplied pattern was global.
* @property {LexerAction} action
* @property {number[]} start States in which the rule is active. `[0]` is the default state; an empty array means "any state".
*/
/**
* @typedef {Object} Match
* @typedef {object} LexerMatch
* @property {RegExpExecArray} result
* @property {RuleAction} action
* @property {LexerAction} action
* @property {number} length
* @property {boolean} global Whether the producing rule was declared with the `g` flag.
*/
/**
* Handler invoked when no rule matches at the current position.
*
* @callback DefunctHandler
* @this {Lexer}
* @param {string} chr The unexpected character.
* @returns {Token | Token[] | null | void}
*/
/**
* @type {DefunctHandler}
*/
function defaultDefunct(chr) {
throw new Error(`Unexpected character at index ${this.index - 1}: ${chr}`);
}
/**
* Lexer class for tokenizing input strings.
*/
export class Lexer {
/**
* @type {string[]}
*/
tokens = [];
/**
* @type {Rule[]}
*/
rules = [];
/**
* @type {number}
*/
remove = 0;
/**
* Current lexer state. Rules whose `start` array contains this value (or
* is empty) are eligible to match. Odd-numbered states are also matched
* by rules declared with `start: [0]`, mirroring flex's inclusive states.
*
* @type {number}
*/
state = 0;
/**
* @type {number}
*/
/** @type {number} */
index = 0;
/**
* @type {string}
*/
/** @type {string} */
input = "";
/**
* @param {DefunctFunction} [defunct]
* When set to `true` from inside an action, the current match is rolled
* back and the next-best match is tried instead.
*
* @type {boolean}
*/
reject = false;
/** @type {LexerRule[]} */
#rules = [];
/** @type {Token[]} */
#tokens = [];
/** @type {number} */
#remove = 0;
/** @type {DefunctHandler} */
#defunct;
/**
* @param {DefunctHandler} [defunct] Optional handler for unexpected characters.
*/
constructor(defunct) {
defunct ||= function (chr) {
throw new Error("Unexpected character at index " + (this.index - 1) + ": " + chr);
};
this.defunct = defunct;
this.#defunct = typeof defunct === "function" ? defunct : defaultDefunct;
}
/**
* Add a lexing rule.
* Register a tokenization rule.
*
* @param {RegExp} pattern
* @param {RuleAction} action
* @param {number[]} [start]
* @returns {Lexer}
* @param {LexerAction} action
* @param {number[]} [start] States in which the rule is active. Defaults to `[0]`.
* @returns {this}
*/
addRule = (pattern, action, start) => {
addRule(pattern, action, start) {
const global = pattern.global;
if (!global || !pattern.sticky) {
let flags = "gy";
if (pattern.multiline) flags += "m";
if (pattern.ignoreCase) flags += "i";
if (pattern.unicode) flags += "u";
pattern = new RegExp(pattern.source, flags);
}
if (!Array.isArray(start)) start = [0];
this.rules.push({
pattern: pattern,
global: global,
action: action,
start: start,
this.#rules.push({
pattern,
global,
action,
start: Array.isArray(start) ? start : [0],
});
return this;
};
}
/**
* Set the input string for lexing.
* Reset the lexer and load a new input string.
*
* @param {string} input
* @returns {Lexer}
* @returns {this}
*/
setInput = (input) => {
this.remove = 0;
setInput(input) {
this.#remove = 0;
this.state = 0;
this.index = 0;
this.tokens.length = 0;
this.#tokens.length = 0;
this.input = input;
return this;
};
}
/**
* Lex the next token from the input.
* Produce the next token from the input, or `null` once exhausted.
*
* @returns {string | string[] | undefined}
* @returns {Token | null}
*/
lex = () => {
if (this.tokens.length) return this.tokens.shift();
lex() {
if (this.#tokens.length) return /** @type {Token} */ (this.#tokens.shift());
this.reject = true;
while (this.index <= this.input.length) {
const matches = this.scan().splice(this.remove);
const matches = this.#scan().splice(this.#remove);
const index = this.index;
while (matches.length) {
if (!this.reject) {
break;
}
const match = matches.shift();
if (!this.reject) break;
if (!match) break;
const result = match.result;
const length = match.length;
const match = /** @type {LexerMatch} */ (matches.shift());
const { result, length } = match;
this.index += length;
this.reject = false;
this.remove++;
this.#remove++;
let token = match.action.apply(this, result);
let token = match.action.apply(
this,
/** @type {string[]} */ (/** @type {unknown} */ (result)),
);
if (this.reject) {
this.index = result.index;
} else if (Array.isArray(token)) {
this.tokens = token.slice(1);
token = token[0];
} else {
if (length) this.remove = 0;
} else if (token !== null && token !== undefined) {
if (Array.isArray(token)) {
this.#tokens = token.slice(1);
token = token[0];
}
if (length) this.#remove = 0;
return token;
}
}
@@ -161,79 +195,82 @@ export class Lexer {
if (index < input.length) {
if (this.reject) {
this.remove = 0;
const token = this.defunct(input.charAt(this.index++));
if (typeof token !== "undefined") {
this.#remove = 0;
const token = this.#defunct(input.charAt(this.index++));
if (token !== null && token !== undefined) {
if (Array.isArray(token)) {
this.tokens = token.slice(1);
this.#tokens = token.slice(1);
return token[0];
}
return token;
}
} else {
if (this.index !== index) this.remove = 0;
if (this.index !== index) this.#remove = 0;
this.reject = true;
}
} else if (matches.length) this.reject = true;
else break;
} else if (matches.length) {
this.reject = true;
} else {
break;
}
}
};
return null;
}
/**
* Scan the input for matches.
* Probe every state-eligible rule at the current position, returning the
* matches sorted by length (longest first), with global rules pinned
* after non-global ones to preserve flex's "longest non-global wins"
* tie-breaking.
*
* @returns {Match[]}
* @returns {LexerMatch[]}
*/
scan = () => {
/**
* @type {Match[]}
*/
#scan() {
/** @type {LexerMatch[]} */
const matches = [];
let index = 0;
const state = this.state;
const lastIndex = this.index;
const input = this.input;
for (let i = 0, length = this.rules.length; i < length; i++) {
const rule = this.rules[i];
for (const rule of this.#rules) {
const start = rule.start;
const states = start.length;
const eligible =
!states || start.indexOf(state) >= 0 || (state % 2 && states === 1 && !start[0]);
if (!states || start.indexOf(state) >= 0 || (state % 2 && states === 1 && !start[0])) {
const pattern = rule.pattern;
pattern.lastIndex = lastIndex;
const result = pattern.exec(input);
if (!eligible) continue;
if (!result || result.index !== lastIndex) {
continue;
}
const pattern = rule.pattern;
pattern.lastIndex = lastIndex;
const result = pattern.exec(input);
let j = matches.push({
result: result,
action: rule.action,
length: result[0].length,
});
if (!result || result.index !== lastIndex) continue;
if (rule.global) {
index = j;
}
let j = matches.push({
result,
action: rule.action,
length: result[0].length,
global: rule.global,
});
while (--j > index) {
const k = j - 1;
while (--j > 0) {
const k = j - 1;
const cur = matches[j];
const prev = matches[k];
const longer = cur.length > prev.length;
const tieFavorsCur = cur.length === prev.length && prev.global && !cur.global;
if (matches[j].length > matches[k].length) {
const temple = matches[j];
matches[j] = matches[k];
matches[k] = temple;
}
}
if (!longer && !tieFavorsCur) break;
matches[j] = prev;
matches[k] = cur;
}
}
return matches;
};
}
}
export default Lexer;

View File

@@ -23,6 +23,7 @@ import { certificateProvider, certificateSelector } from "#admin/brands/Certific
import {
Application,
AuthenticationEnum,
Brand,
CoreApi,
CoreApplicationsListRequest,
@@ -31,7 +32,6 @@ import {
FlowsApi,
UsageEnum,
} from "@goauthentik/api";
import { AuthenticationEnum } from "@goauthentik/api/dist/models/AuthenticationEnum.js";
import YAML from "yaml";

View File

@@ -17,6 +17,7 @@ import { DesignationToLabel, LayoutToLabel } from "#admin/flows/utils";
import { policyEngineModes } from "#admin/policies/PolicyEngineModes";
import {
AuthenticationEnum,
DeniedActionEnum,
Flow,
FlowDesignationEnum,
@@ -24,7 +25,6 @@ import {
FlowsApi,
UsageEnum,
} from "@goauthentik/api";
import { AuthenticationEnum } from "@goauthentik/api/dist/models/AuthenticationEnum.js";
import { msg } from "@lit/localize";
import { html, TemplateResult } from "lit";

View File

@@ -1,6 +1,6 @@
import { ModelForm } from "#elements/forms/ModelForm";
import type { Stage } from "@goauthentik/api/dist/models/Stage";
import type { Stage } from "@goauthentik/api";
import { msg } from "@lit/localize";

View File

@@ -19,10 +19,10 @@ import {
CoreGroupsListRequest,
Group,
StagesApi,
UserCreationModeEnum,
UserTypeEnum,
UserWriteStage,
} from "@goauthentik/api";
import { UserCreationModeEnum } from "@goauthentik/api/dist/models/UserCreationModeEnum.js";
import { msg } from "@lit/localize";
import { html, TemplateResult } from "lit";

317
web/test/unit/lexer.test.ts Normal file
View File

@@ -0,0 +1,317 @@
/* eslint-disable func-names */
import { Lexer } from "lex";
import { describe, expect, it, vi } from "vitest";
const drain = (lexer: Lexer): unknown[] => {
const out: unknown[] = [];
let token: unknown;
while ((token = lexer.lex()) !== null) {
out.push(token);
}
return out;
};
describe("Lexer", () => {
describe("addRule", () => {
it("returns the lexer for chaining", () => {
const lexer = new Lexer();
expect(lexer.addRule(/a/, () => "a")).toBe(lexer);
});
it("preserves multiline, ignoreCase, and unicode flags when re-compiling", () => {
const lexer = new Lexer(() => null);
const seen: string[] = [];
lexer.addRule(/^a/im, (m) => {
seen.push(m);
});
lexer.setInput("A\nA");
drain(lexer);
expect(seen).toEqual(["A", "A"]);
});
it("matches unicode patterns", () => {
const lexer = new Lexer();
lexer.addRule(/\p{Letter}+/u, (m) => m);
lexer.setInput("café");
expect(lexer.lex()).toBe("café");
});
});
describe("setInput", () => {
it("resets state, index, and pending tokens", () => {
const lexer = new Lexer();
lexer.addRule(/./, (c) => c);
lexer.setInput("ab");
expect(lexer.lex()).toBe("a");
lexer.state = 7;
lexer.setInput("xy");
expect(lexer.state).toBe(0);
expect(lexer.index).toBe(0);
expect(lexer.lex()).toBe("x");
expect(lexer.lex()).toBe("y");
});
it("returns the lexer for chaining", () => {
const lexer = new Lexer();
expect(lexer.setInput("")).toBe(lexer);
});
});
describe("tokenization", () => {
it("tokenizes a simple expression", () => {
const lexer = new Lexer();
lexer
.addRule(/\s+/, () => null)
.addRule(/[a-zA-Z]+/, (m) => ({ type: "ident", value: m }))
.addRule(/\d+/, (m) => ({ type: "num", value: Number(m) }))
.addRule(/[+\-*/]/, (m) => ({ type: "op", value: m }));
lexer.setInput("foo + 12 * bar");
expect(drain(lexer)).toEqual([
{ type: "ident", value: "foo" },
{ type: "op", value: "+" },
{ type: "num", value: 12 },
{ type: "op", value: "*" },
{ type: "ident", value: "bar" },
]);
});
it("skips matches whose action returns null", () => {
const lexer = new Lexer();
lexer.addRule(/\s+/, () => null).addRule(/\S+/, (m) => m);
lexer.setInput(" foo bar ");
expect(drain(lexer)).toEqual(["foo", "bar"]);
});
it("returns null once the input is exhausted", () => {
const lexer = new Lexer();
lexer.addRule(/./, (c) => c);
lexer.setInput("a");
expect(lexer.lex()).toBe("a");
expect(lexer.lex()).toBeNull();
expect(lexer.lex()).toBeNull();
});
it("passes capture groups to the action", () => {
const lexer = new Lexer();
const calls: string[][] = [];
lexer.addRule(/(\w+)=(\w+)/, (...args) => {
calls.push(args);
return args[0];
});
lexer.setInput("foo=bar");
lexer.lex();
expect(calls).toEqual([["foo=bar", "foo", "bar"]]);
});
it("binds `this` to the lexer inside the action", () => {
const lexer = new Lexer();
let captured: Lexer | undefined;
lexer.addRule(/a/, function () {
// eslint-disable-next-line consistent-this, @typescript-eslint/no-this-alias
captured = this;
return "a";
});
lexer.setInput("a");
lexer.lex();
expect(captured).toBe(lexer);
});
});
describe("longest-match tie-breaking", () => {
it("prefers the longer non-global match", () => {
const lexer = new Lexer();
lexer.addRule(/if/, () => "KW_IF").addRule(/iffy/, () => "IDENT_IFFY");
lexer.setInput("iffy");
expect(lexer.lex()).toBe("IDENT_IFFY");
});
it("treats global rules as fallbacks behind non-global rules of the same length", () => {
const lexer = new Lexer();
lexer.addRule(/[a-z]+/g, (m) => `g:${m}`).addRule(/foo/, (m) => `s:${m}`);
lexer.setInput("foo");
expect(lexer.lex()).toBe("s:foo");
});
});
describe("multi-token return", () => {
it("yields the first token immediately and queues the rest", () => {
const lexer = new Lexer();
lexer.addRule(/a/, () => ["A1", "A2", "A3"]);
lexer.setInput("a");
expect(lexer.lex()).toBe("A1");
expect(lexer.lex()).toBe("A2");
expect(lexer.lex()).toBe("A3");
expect(lexer.lex()).toBeNull();
});
it("drains the queue before scanning further input", () => {
const lexer = new Lexer();
lexer.addRule(/a/, () => ["A1", "A2"]).addRule(/b/, () => "B");
lexer.setInput("ab");
expect(drain(lexer)).toEqual(["A1", "A2", "B"]);
});
});
describe("reject", () => {
it("falls through to the next-best match when an action sets reject", () => {
const lexer = new Lexer();
const order: string[] = [];
lexer
.addRule(/foo/, function () {
order.push("first");
this.reject = true;
})
.addRule(/foo/, () => {
order.push("second");
return "FOO";
});
lexer.setInput("foo");
expect(lexer.lex()).toBe("FOO");
expect(order).toEqual(["first", "second"]);
});
it("rolls back the lexer index when an action rejects", () => {
const lexer = new Lexer();
lexer
.addRule(/abc/, function () {
this.reject = true;
})
.addRule(/a/, (m) => m);
lexer.setInput("abc");
expect(lexer.lex()).toBe("a");
expect(lexer.index).toBe(1);
});
});
describe("defunct handling", () => {
it("throws by default on unexpected characters", () => {
const lexer = new Lexer();
lexer.addRule(/a/, (m) => m);
lexer.setInput("a@");
expect(lexer.lex()).toBe("a");
expect(() => lexer.lex()).toThrow(/Unexpected character at index 1: @/);
});
it("invokes a custom defunct handler with the offending character", () => {
const defunct = vi.fn((chr: string) => `?${chr}`);
const lexer = new Lexer(defunct);
lexer.addRule(/a/, (m) => m);
lexer.setInput("a@b");
expect(drain(lexer)).toEqual(["a", "?@", "?b"]);
expect(defunct).toHaveBeenCalledTimes(2);
expect(defunct.mock.calls[0]?.[0]).toBe("@");
});
it("ignores defunct return values that are null", () => {
const lexer = new Lexer((_chr) => null);
lexer.addRule(/a/, (m) => m);
lexer.setInput("@@a");
expect(lexer.lex()).toBe("a");
expect(lexer.lex()).toBeNull();
});
it("supports array returns from the defunct handler", () => {
const lexer = new Lexer((chr) => [`bad:${chr}`, "extra"]);
lexer.addRule(/a/, (m) => m);
lexer.setInput("@");
expect(lexer.lex()).toBe("bad:@");
expect(lexer.lex()).toBe("extra");
});
it("falls back to the default handler when given a non-function", () => {
// @ts-expect-error — exercising the runtime guard
const lexer = new Lexer("not a function");
lexer.setInput("@");
expect(() => lexer.lex()).toThrow(/Unexpected character/);
});
});
describe("states", () => {
it("only fires rules whose start array includes the current state", () => {
const lexer = new Lexer();
lexer
.addRule(/"/, function () {
this.state = 2;
})
.addRule(
/"/,
function () {
this.state = 0;
},
[2],
)
.addRule(/[^"]+/, (m) => `STR:${m}`, [2])
.addRule(/[a-z]+/, (m) => `ID:${m}`);
lexer.setInput('foo"hello"bar');
expect(drain(lexer)).toEqual(["ID:foo", "STR:hello", "ID:bar"]);
});
it("treats an empty start array as 'active in any state'", () => {
const lexer = new Lexer();
lexer
.addRule(/!/, function () {
this.state = 5;
return "BANG";
})
.addRule(/./, (m) => m, []);
lexer.setInput("a!b");
expect(drain(lexer)).toEqual(["a", "BANG", "b"]);
});
it("matches inclusive `[0]` rules from odd-numbered states", () => {
const lexer = new Lexer();
lexer
.addRule(/#/, function () {
this.state = 1;
})
.addRule(/[a-z]+/, (m) => m);
lexer.setInput("ab#cd");
expect(drain(lexer)).toEqual(["ab", "cd"]);
});
it("does not match `[0]` rules from even non-zero states", () => {
const lexer = new Lexer();
lexer
.addRule(/#/, function () {
this.state = 2;
})
.addRule(/[a-z]+/, (m) => m);
lexer.setInput("ab#cd");
expect(lexer.lex()).toBe("ab");
expect(() => lexer.lex()).toThrow(/Unexpected character/);
});
});
});

View File

@@ -0,0 +1,35 @@
// @file TSConfig used by the web package during development.
{
"extends": "@goauthentik/tsconfig",
"compilerOptions": {
"types": ["node"],
"checkJs": true,
"allowJs": true,
"composite": true,
"resolveJsonModule": true,
"allowSyntheticDefaultImports": true,
"emitDeclarationOnly": true,
"target": "esnext",
"module": "preserve",
"moduleResolution": "bundler",
"lib": ["DOM", "DOM.Iterable", "ESNext"],
"noUncheckedIndexedAccess": true
},
"include": ["./**/*", "../**/*"],
"exclude": [
// ---
"**/out/**/*",
"**/dist/**/*",
"storybook-static",
// TODO: @lit/localize-tools v0.8.0 has a nullish coalescing typing error.
// Remove when we upgrade past that.
"scripts/pseudolocalize.mjs",
"scripts/build-locales.mjs"
],
"references": [
{
"path": "../.."
}
]
}

View File

@@ -41,9 +41,12 @@ export default defineConfig({
projects: [
{
test: {
include: ["./unit/**/*.{test,spec}.ts", "**/*.unit.{test,spec}.ts"],
name: "unit",
include: ["./test/unit/**/*.{test,spec}.ts", "**/*.unit.{test,spec}.ts"],
name: "Unit Tests",
environment: "node",
typecheck: {
tsconfig: "./tsconfig.unit.json",
},
},
},
{
@@ -51,7 +54,7 @@ export default defineConfig({
setupFiles: ["./test/lit/setup.js"],
include: ["./browser/**/*.{test,spec}.ts", "**/*.browser.{test,spec}.ts"],
name: "browser",
name: "Browser Tests",
browser: {
enabled: true,
provider: playwright(),