mirror of
https://github.com/goauthentik/authentik
synced 2026-05-07 07:32:23 +02:00
Compare commits
4 Commits
rust-proxy
...
blueprint_
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b8f952558f | ||
|
|
ebd18b466d | ||
|
|
b32df17513 | ||
|
|
1db6c3af8b |
2
.github/actions/setup/action.yml
vendored
2
.github/actions/setup/action.yml
vendored
@@ -49,7 +49,7 @@ runs:
|
||||
if: ${{ contains(inputs.dependencies, 'python') }}
|
||||
shell: bash
|
||||
working-directory: ${{ inputs.working-directory }}
|
||||
run: uv sync --all-extras --dev --frozen
|
||||
run: uv sync --all-extras --dev --locked
|
||||
- name: Setup rust (stable)
|
||||
if: ${{ contains(inputs.dependencies, 'rust') && !contains(inputs.dependencies, 'rust-nightly') }}
|
||||
uses: actions-rust-lang/setup-rust-toolchain@2b1f5e9b395427c92ee4e3331786ca3c37afe2d7 # v1
|
||||
|
||||
184
Cargo.lock
generated
184
Cargo.lock
generated
@@ -143,45 +143,6 @@ version = "0.5.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7d902e3d592a523def97af8f317b08ce16b7ab854c1985a0c671e6f15cebc236"
|
||||
|
||||
[[package]]
|
||||
name = "asn1-rs"
|
||||
version = "0.7.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "56624a96882bb8c26d61312ae18cb45868e5a9992ea73c58e45c3101e56a1e60"
|
||||
dependencies = [
|
||||
"asn1-rs-derive",
|
||||
"asn1-rs-impl",
|
||||
"displaydoc",
|
||||
"nom",
|
||||
"num-traits",
|
||||
"rusticata-macros",
|
||||
"thiserror 2.0.18",
|
||||
"time",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "asn1-rs-derive"
|
||||
version = "0.6.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "3109e49b1e4909e9db6515a30c633684d68cdeaa252f215214cb4fa1a5bfee2c"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn",
|
||||
"synstructure",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "asn1-rs-impl"
|
||||
version = "0.2.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7b18050c2cd6fe86c3a76584ef5e0baf286d038cda203eb6223df2cc413565f7"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "async-trait"
|
||||
version = "0.1.89"
|
||||
@@ -215,13 +176,10 @@ dependencies = [
|
||||
"arc-swap",
|
||||
"argh",
|
||||
"authentik-axum",
|
||||
"authentik-client",
|
||||
"authentik-common",
|
||||
"axum",
|
||||
"axum-server",
|
||||
"color-eyre",
|
||||
"eyre",
|
||||
"futures",
|
||||
"hyper-unix-socket",
|
||||
"hyper-util",
|
||||
"metrics",
|
||||
@@ -229,19 +187,9 @@ dependencies = [
|
||||
"nix 0.31.2",
|
||||
"pyo3",
|
||||
"pyo3-build-config",
|
||||
"rand 0.10.1",
|
||||
"rustls",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"serde_repr",
|
||||
"sqlx",
|
||||
"time",
|
||||
"tokio",
|
||||
"tokio-retry2",
|
||||
"tokio-tungstenite",
|
||||
"tower",
|
||||
"tracing",
|
||||
"url",
|
||||
"uuid",
|
||||
"which",
|
||||
]
|
||||
@@ -299,7 +247,6 @@ dependencies = [
|
||||
"nix 0.31.2",
|
||||
"notify",
|
||||
"pin-project-lite",
|
||||
"rcgen",
|
||||
"reqwest",
|
||||
"reqwest-middleware",
|
||||
"rustls",
|
||||
@@ -595,17 +542,6 @@ version = "0.2.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "613afe47fcd5fac7ccf1db93babcb082c5994d996f20b8b159f2ad1658eb5724"
|
||||
|
||||
[[package]]
|
||||
name = "chacha20"
|
||||
version = "0.10.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "6f8d983286843e49675a4b7a2d174efe136dc93a18d69130dd18198a6c167601"
|
||||
dependencies = [
|
||||
"cfg-if",
|
||||
"cpufeatures 0.3.0",
|
||||
"rand_core 0.10.1",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "chrono"
|
||||
version = "0.4.44"
|
||||
@@ -843,15 +779,6 @@ dependencies = [
|
||||
"libc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "cpufeatures"
|
||||
version = "0.3.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8b2a41393f66f16b0823bb79094d54ac5fbd34ab292ddafb9a0456ac9f87d201"
|
||||
dependencies = [
|
||||
"libc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "crc"
|
||||
version = "3.4.0"
|
||||
@@ -946,20 +873,6 @@ dependencies = [
|
||||
"zeroize",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "der-parser"
|
||||
version = "10.0.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "07da5016415d5a3c4dd39b11ed26f915f52fc4e0dc197d87908bc916e51bc1a6"
|
||||
dependencies = [
|
||||
"asn1-rs",
|
||||
"displaydoc",
|
||||
"nom",
|
||||
"num-bigint",
|
||||
"num-traits",
|
||||
"rusticata-macros",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "deranged"
|
||||
version = "0.5.8"
|
||||
@@ -1378,7 +1291,6 @@ dependencies = [
|
||||
"cfg-if",
|
||||
"libc",
|
||||
"r-efi 6.0.0",
|
||||
"rand_core 0.10.1",
|
||||
"wasip2",
|
||||
"wasip3",
|
||||
]
|
||||
@@ -2270,16 +2182,6 @@ dependencies = [
|
||||
"windows-sys 0.61.2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "num-bigint"
|
||||
version = "0.4.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a5e44f723f1133c9deac646763579fdb3ac745e418f2a7af9cd0c431da1f20b9"
|
||||
dependencies = [
|
||||
"num-integer",
|
||||
"num-traits",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "num-bigint-dig"
|
||||
version = "0.8.6"
|
||||
@@ -2509,15 +2411,6 @@ dependencies = [
|
||||
"memchr",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "oid-registry"
|
||||
version = "0.8.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "12f40cff3dde1b6087cc5d5f5d4d65712f34016a03ed60e9c08dcc392736b5b7"
|
||||
dependencies = [
|
||||
"asn1-rs",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "once_cell"
|
||||
version = "1.21.4"
|
||||
@@ -2919,17 +2812,6 @@ dependencies = [
|
||||
"rand_core 0.9.5",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rand"
|
||||
version = "0.10.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d2e8e8bcc7961af1fdac401278c6a831614941f6164ee3bf4ce61b7edb162207"
|
||||
dependencies = [
|
||||
"chacha20",
|
||||
"getrandom 0.4.2",
|
||||
"rand_core 0.10.1",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rand_chacha"
|
||||
version = "0.3.1"
|
||||
@@ -2968,12 +2850,6 @@ dependencies = [
|
||||
"getrandom 0.3.4",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rand_core"
|
||||
version = "0.10.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "63b8176103e19a2643978565ca18b50549f6101881c443590420e4dc998a3c69"
|
||||
|
||||
[[package]]
|
||||
name = "rand_xoshiro"
|
||||
version = "0.7.0"
|
||||
@@ -3001,19 +2877,6 @@ dependencies = [
|
||||
"bitflags 2.11.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rcgen"
|
||||
version = "0.14.7"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "10b99e0098aa4082912d4c649628623db6aba77335e4f4569ff5083a6448b32e"
|
||||
dependencies = [
|
||||
"aws-lc-rs",
|
||||
"rustls-pki-types",
|
||||
"time",
|
||||
"x509-parser",
|
||||
"yasna",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "redox_syscall"
|
||||
version = "0.5.18"
|
||||
@@ -3177,15 +3040,6 @@ dependencies = [
|
||||
"semver",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rusticata-macros"
|
||||
version = "4.1.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "faf0c4a6ece9950b9abdb62b1cfcf2a68b3b67a10ba445b3bb85be2a293d0632"
|
||||
dependencies = [
|
||||
"nom",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rustix"
|
||||
version = "1.1.4"
|
||||
@@ -3565,7 +3419,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e3bf829a2d51ab4a5ddf1352d8470c140cadc8301b2ae1789db023f01cedd6ba"
|
||||
dependencies = [
|
||||
"cfg-if",
|
||||
"cpufeatures 0.2.17",
|
||||
"cpufeatures",
|
||||
"digest",
|
||||
]
|
||||
|
||||
@@ -3576,7 +3430,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a7507d819769d01a365ab707794a4084392c824f54a7a6a7862f8c3d0892b283"
|
||||
dependencies = [
|
||||
"cfg-if",
|
||||
"cpufeatures 0.2.17",
|
||||
"cpufeatures",
|
||||
"digest",
|
||||
]
|
||||
|
||||
@@ -4146,12 +4000,8 @@ checksum = "8f72a05e828585856dacd553fba484c242c46e391fb0e58917c942ee9202915c"
|
||||
dependencies = [
|
||||
"futures-util",
|
||||
"log",
|
||||
"rustls",
|
||||
"rustls-pki-types",
|
||||
"tokio",
|
||||
"tokio-rustls",
|
||||
"tungstenite",
|
||||
"webpki-roots 0.26.11",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -4365,11 +4215,8 @@ dependencies = [
|
||||
"httparse",
|
||||
"log",
|
||||
"rand 0.9.4",
|
||||
"rustls",
|
||||
"rustls-pki-types",
|
||||
"sha1",
|
||||
"thiserror 2.0.18",
|
||||
"url",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -5240,24 +5087,6 @@ version = "0.6.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9edde0db4769d2dc68579893f2306b26c6ecfbe0ef499b013d731b7b9247e0b9"
|
||||
|
||||
[[package]]
|
||||
name = "x509-parser"
|
||||
version = "0.18.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d43b0f71ce057da06bc0851b23ee24f3f86190b07203dd8f567d0b706a185202"
|
||||
dependencies = [
|
||||
"asn1-rs",
|
||||
"aws-lc-rs",
|
||||
"data-encoding",
|
||||
"der-parser",
|
||||
"lazy_static",
|
||||
"nom",
|
||||
"oid-registry",
|
||||
"rusticata-macros",
|
||||
"thiserror 2.0.18",
|
||||
"time",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "yaml-rust2"
|
||||
version = "0.10.4"
|
||||
@@ -5269,15 +5098,6 @@ dependencies = [
|
||||
"hashlink",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "yasna"
|
||||
version = "0.5.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e17bb3549cc1321ae1296b9cdc2698e2b6cb1992adfa19a8c72e5b7a738f44cd"
|
||||
dependencies = [
|
||||
"time",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "yoke"
|
||||
version = "0.8.1"
|
||||
|
||||
26
Cargo.toml
26
Cargo.toml
@@ -50,11 +50,6 @@ notify = "= 8.2.0"
|
||||
pin-project-lite = "= 0.2.17"
|
||||
pyo3 = "= 0.28.3"
|
||||
pyo3-build-config = "= 0.28.3"
|
||||
rand = "= 0.10.1"
|
||||
rcgen = { version = "= 0.14.7", default-features = false, features = [
|
||||
"aws_lc_rs",
|
||||
"fips",
|
||||
] }
|
||||
regex = "= 1.12.3"
|
||||
reqwest = { version = "= 0.13.3", features = [
|
||||
"form",
|
||||
@@ -105,10 +100,6 @@ time = { version = "= 0.3.47", features = ["macros"] }
|
||||
tokio = { version = "= 1.52.1", features = ["full", "tracing"] }
|
||||
tokio-retry2 = "= 0.9.1"
|
||||
tokio-rustls = "= 0.26.4"
|
||||
tokio-tungstenite = { version = "= 0.29.0", features = [
|
||||
"rustls-tls-webpki-roots",
|
||||
"url",
|
||||
] }
|
||||
tokio-util = { version = "= 0.7.18", features = ["full"] }
|
||||
tower = "= 0.5.3"
|
||||
tower-http = { version = "= 0.6.8", features = ["timeout"] }
|
||||
@@ -269,41 +260,28 @@ publish.workspace = true
|
||||
[features]
|
||||
default = ["core", "proxy"]
|
||||
core = ["ak-common/core", "dep:pyo3", "dep:sqlx"]
|
||||
proxy = ["ak-common/proxy", "dep:ak-client"]
|
||||
proxy = ["ak-common/proxy"]
|
||||
|
||||
[build-dependencies]
|
||||
pyo3-build-config.workspace = true
|
||||
|
||||
[dependencies]
|
||||
ak-axum.workspace = true
|
||||
ak-client = { workspace = true, optional = true }
|
||||
ak-common.workspace = true
|
||||
arc-swap.workspace = true
|
||||
argh.workspace = true
|
||||
axum-server.workspace = true
|
||||
axum.workspace = true
|
||||
color-eyre.workspace = true
|
||||
eyre.workspace = true
|
||||
futures.workspace = true
|
||||
hyper-unix-socket.workspace = true
|
||||
hyper-util.workspace = true
|
||||
metrics-exporter-prometheus.workspace = true
|
||||
metrics.workspace = true
|
||||
metrics-exporter-prometheus.workspace = true
|
||||
nix.workspace = true
|
||||
pyo3 = { workspace = true, optional = true }
|
||||
rand.workspace = true
|
||||
rustls.workspace = true
|
||||
serde.workspace = true
|
||||
serde_json.workspace = true
|
||||
serde_repr.workspace = true
|
||||
sqlx = { workspace = true, optional = true }
|
||||
time.workspace = true
|
||||
tokio-retry2.workspace = true
|
||||
tokio-tungstenite.workspace = true
|
||||
tokio.workspace = true
|
||||
tower.workspace = true
|
||||
tracing.workspace = true
|
||||
url.workspace = true
|
||||
uuid.workspace = true
|
||||
which.workspace = true
|
||||
|
||||
|
||||
@@ -217,10 +217,7 @@ class BlueprintInstanceViewSet(UsedByMixin, ModelViewSet):
|
||||
|
||||
@extend_schema(
|
||||
request={"multipart/form-data": BlueprintUploadSerializer},
|
||||
responses={
|
||||
204: BlueprintImportResultSerializer,
|
||||
400: BlueprintImportResultSerializer,
|
||||
},
|
||||
responses={200: BlueprintImportResultSerializer},
|
||||
)
|
||||
@action(url_path="import", detail=False, methods=["POST"], parser_classes=(MultiPartParser,))
|
||||
@validate(
|
||||
@@ -247,21 +244,13 @@ class BlueprintInstanceViewSet(UsedByMixin, ModelViewSet):
|
||||
|
||||
import_response = self.BlueprintImportResultSerializer(
|
||||
data={
|
||||
"logs": [],
|
||||
"success": False,
|
||||
"logs": [LogEventSerializer(log).data for log in logs],
|
||||
"success": valid,
|
||||
}
|
||||
)
|
||||
import_response.is_valid(raise_exception=True)
|
||||
|
||||
import_response.initial_data["logs"] = [LogEventSerializer(log).data for log in logs]
|
||||
import_response.initial_data["success"] = valid
|
||||
import_response.is_valid()
|
||||
if not valid:
|
||||
return Response(data=import_response.initial_data, status=200)
|
||||
|
||||
successful = importer.apply()
|
||||
import_response.initial_data["success"] = successful
|
||||
import_response.is_valid()
|
||||
if not successful:
|
||||
return Response(data=import_response.initial_data, status=200)
|
||||
if valid:
|
||||
import_response.initial_data["success"] = importer.apply()
|
||||
import_response.is_valid()
|
||||
return Response(data=import_response.initial_data, status=200)
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
from json import dumps, loads
|
||||
from tempfile import NamedTemporaryFile, mkdtemp
|
||||
|
||||
from django.core.files.uploadedfile import SimpleUploadedFile
|
||||
from django.urls import reverse
|
||||
from rest_framework.test import APITestCase
|
||||
from yaml import dump
|
||||
@@ -141,6 +142,20 @@ class TestBlueprintsV1API(APITestCase):
|
||||
)
|
||||
self.assertEqual(res.status_code, 200)
|
||||
|
||||
def test_api_import_invalid_blueprint_returns_result_payload(self):
|
||||
"""Invalid blueprint content returns a result payload instead of a 400 response."""
|
||||
file = SimpleUploadedFile("invalid-blueprint.yaml", b'{"version": 3}')
|
||||
|
||||
res = self.client.post(
|
||||
reverse("authentik_api:blueprintinstance-import-"),
|
||||
data={"file": file},
|
||||
format="multipart",
|
||||
)
|
||||
|
||||
self.assertEqual(res.status_code, 200)
|
||||
self.assertFalse(res.json()["success"])
|
||||
self.assertGreater(len(res.json()["logs"]), 0)
|
||||
|
||||
def test_api_import_unknown_path(self):
|
||||
"""Path not in available blueprints is rejected (covers api.py:56)."""
|
||||
res = self.client.post(
|
||||
|
||||
@@ -7,7 +7,7 @@ from dramatiq.broker import Broker, MessageProxy, get_broker
|
||||
from dramatiq.middleware.middleware import Middleware
|
||||
from dramatiq.middleware.retries import Retries
|
||||
from dramatiq.results.middleware import Results
|
||||
from dramatiq.worker import Worker, _ConsumerThread, _WorkerThread
|
||||
from dramatiq.worker import ConsumerThread, Worker, WorkerThread
|
||||
|
||||
from authentik.tasks.broker import PostgresBroker
|
||||
|
||||
@@ -20,7 +20,7 @@ class TestWorker(Worker):
|
||||
self.worker_id = 1000
|
||||
self.work_queue = PriorityQueue()
|
||||
self.consumers = {
|
||||
TESTING_QUEUE: _ConsumerThread(
|
||||
TESTING_QUEUE: ConsumerThread(
|
||||
broker=self.broker,
|
||||
queue_name=TESTING_QUEUE,
|
||||
prefetch=2,
|
||||
@@ -33,7 +33,7 @@ class TestWorker(Worker):
|
||||
prefetch=2,
|
||||
timeout=1,
|
||||
)
|
||||
self._worker = _WorkerThread(
|
||||
self._worker = WorkerThread(
|
||||
broker=self.broker,
|
||||
consumers=self.consumers,
|
||||
work_queue=self.work_queue,
|
||||
@@ -78,17 +78,18 @@ def use_test_broker():
|
||||
actor.broker = broker
|
||||
actor.broker.declare_actor(actor)
|
||||
|
||||
for middleware_class, middleware_kwargs in Conf().middlewares:
|
||||
middleware: Middleware = import_string(middleware_class)(
|
||||
for middleware_class_path, middleware_kwargs in Conf().middlewares:
|
||||
middleware_class = import_string(middleware_class_path)
|
||||
if issubclass(middleware_class, Results):
|
||||
middleware_kwargs["backend"] = import_string(Conf().result_backend)(
|
||||
*Conf().result_backend_args,
|
||||
**Conf().result_backend_kwargs,
|
||||
)
|
||||
middleware: Middleware = middleware_class(
|
||||
**middleware_kwargs,
|
||||
)
|
||||
if isinstance(middleware, Retries):
|
||||
middleware.max_retries = 0
|
||||
if isinstance(middleware, Results):
|
||||
middleware.backend = import_string(Conf().result_backend)(
|
||||
*Conf().result_backend_args,
|
||||
**Conf().result_backend_kwargs,
|
||||
)
|
||||
broker.add_middleware(middleware)
|
||||
|
||||
broker.start()
|
||||
|
||||
@@ -79,7 +79,7 @@ function prepare_debug {
|
||||
apt-get update
|
||||
apt-get install -y --no-install-recommends krb5-kdc krb5-user krb5-admin-server libkrb5-dev gcc
|
||||
source "${VENV_PATH}/bin/activate"
|
||||
uv sync --active --frozen
|
||||
uv sync --active --locked
|
||||
touch /unittest.xml
|
||||
chown authentik:authentik /unittest.xml
|
||||
}
|
||||
|
||||
@@ -101,6 +101,8 @@ RUN --mount=type=bind,target=rust-toolchain.toml,src=rust-toolchain.toml \
|
||||
rustc --version && \
|
||||
cargo --version
|
||||
|
||||
RUN cat /root/.rustup/settings.toml
|
||||
|
||||
# Stage: Download uv
|
||||
FROM ghcr.io/astral-sh/uv:0.11.5@sha256:555ac94f9a22e656fc5f2ce5dfee13b04e94d099e46bb8dd3a73ec7263f2e484 AS uv
|
||||
# Stage: Base python image
|
||||
@@ -198,7 +200,7 @@ RUN --mount=type=bind,target=pyproject.toml,src=pyproject.toml \
|
||||
--mount=type=bind,target=packages/django-postgres-cache,src=packages/django-postgres-cache \
|
||||
--mount=type=bind,target=rust-toolchain.toml,src=rust-toolchain.toml \
|
||||
--mount=type=cache,id=uv-python-deps-$TARGETARCH$TARGETVARIANT,target=/root/.cache/uv \
|
||||
uv sync --frozen --no-install-project --no-dev
|
||||
uv sync --locked --no-install-project --no-dev
|
||||
|
||||
# Stage: Run
|
||||
FROM python-base AS final-image
|
||||
|
||||
@@ -21,45 +21,33 @@ COPY web .
|
||||
RUN npm run build-proxy
|
||||
|
||||
# Stage 2: Build
|
||||
FROM ghcr.io/goauthentik/fips-debian:trixie-slim-fips@sha256:7726387c78b5787d2146868c2ccc8948a3591d0a5a6436f7780c8c28acc76341 AS builder
|
||||
FROM --platform=${BUILDPLATFORM} docker.io/library/golang:1.26.2-trixie@sha256:4a7137ea573f79c86ae451ff05817ed762ef5597fcf732259e97abeb3108d873 AS builder
|
||||
|
||||
ARG TARGETOS
|
||||
ARG TARGETARCH
|
||||
ARG TARGETVARIANT
|
||||
|
||||
ENV PATH="/root/.cargo/bin:$PATH"
|
||||
SHELL ["/bin/sh", "-o", "pipefail", "-c"]
|
||||
RUN rm -f /etc/apt/apt.conf.d/docker-clean; echo 'Binary::apt::APT::Keep-Downloaded-Packages "true";' > /etc/apt/apt.conf.d/keep-cache
|
||||
RUN --mount=type=cache,id=apt-$TARGETARCH$TARGETVARIANT,sharing=locked,target=/var/cache/apt \
|
||||
--mount=type=bind,target=rust-toolchain.toml,src=rust-toolchain.toml \
|
||||
apt-get update && \
|
||||
# Required for installing pip packages
|
||||
apt-get install -y --no-install-recommends \
|
||||
# Build essentials
|
||||
build-essential \
|
||||
# aws-lc deps
|
||||
cmake clang golang && \
|
||||
curl https://sh.rustup.rs -sSf | sh -s -- -y --profile minimal --default-toolchain none && \
|
||||
rustup install && \
|
||||
rustup default "$(sed -n 's/channel = "\(.*\)"/\1/p' rust-toolchain.toml)" && \
|
||||
rustc --version && \
|
||||
cargo --version
|
||||
# See https://github.com/aws/aws-lc-rs/issues/569
|
||||
ENV AWS_LC_FIPS_SYS_CC=clang
|
||||
ARG GOOS=$TARGETOS
|
||||
ARG GOARCH=$TARGETARCH
|
||||
|
||||
RUN --mount=type=bind,target=rust-toolchain.toml,src=rust-toolchain.toml \
|
||||
--mount=type=bind,target=Cargo.toml,src=Cargo.toml \
|
||||
--mount=type=bind,target=Cargo.lock,src=Cargo.lock \
|
||||
--mount=type=bind,target=.cargo/,src=.cargo/ \
|
||||
--mount=type=bind,target=src/,src=src/ \
|
||||
--mount=type=bind,target=packages/,src=packages/ \
|
||||
--mount=type=bind,target=authentik/lib/default.yml,src=authentik/lib/default.yml \
|
||||
# Required otherwise workspace discovery fails
|
||||
--mount=type=bind,target=website/scripts/docsmg/,src=website/scripts/docsmg/ \
|
||||
--mount=type=cache,id=cargo-git-db-$TARGETARCH$TARGETVARIANT,target=/root/.cargo/git/db/ \
|
||||
--mount=type=cache,id=cargo-registry-$TARGETARCH$TARGETVARIANT,target=/root/.cargo/registry/ \
|
||||
--mount=type=cache,id=rust-target-$TARGETARCH$TARGETVARIANT,target=/build/target/ \
|
||||
cargo build --package authentik --no-default-features --features proxy --locked --release && \
|
||||
cp ./target/release/authentik /bin/authentik
|
||||
WORKDIR /go/src/goauthentik.io
|
||||
|
||||
RUN --mount=type=cache,id=apt-$TARGETARCH$TARGETVARIANT,sharing=locked,target=/var/cache/apt \
|
||||
dpkg --add-architecture arm64 && \
|
||||
apt-get update && \
|
||||
apt-get install -y --no-install-recommends crossbuild-essential-arm64 gcc-aarch64-linux-gnu
|
||||
|
||||
RUN --mount=type=bind,target=/go/src/goauthentik.io/go.mod,src=./go.mod \
|
||||
--mount=type=bind,target=/go/src/goauthentik.io/go.sum,src=./go.sum \
|
||||
--mount=type=cache,target=/go/pkg/mod \
|
||||
go mod download
|
||||
|
||||
COPY . .
|
||||
RUN --mount=type=cache,sharing=locked,target=/go/pkg/mod \
|
||||
--mount=type=cache,id=go-build-$TARGETARCH$TARGETVARIANT,sharing=locked,target=/root/.cache/go-build \
|
||||
if [ "$TARGETARCH" = "arm64" ]; then export CC=aarch64-linux-gnu-gcc && export CC_FOR_TARGET=gcc-aarch64-linux-gnu; fi && \
|
||||
CGO_ENABLED=1 GOFIPS140=latest GOARM="${TARGETVARIANT#v}" \
|
||||
go build -o /go/proxy ./cmd/proxy
|
||||
|
||||
# Stage 3: Run
|
||||
FROM ghcr.io/goauthentik/fips-debian:trixie-slim-fips@sha256:7726387c78b5787d2146868c2ccc8948a3591d0a5a6436f7780c8c28acc76341
|
||||
@@ -84,13 +72,13 @@ RUN apt-get update && \
|
||||
apt-get clean && \
|
||||
rm -rf /tmp/* /var/lib/apt/lists/*
|
||||
|
||||
COPY --from=builder /bin/authentik /
|
||||
COPY --from=builder /go/proxy /
|
||||
COPY --from=web-builder /static/robots.txt /web/robots.txt
|
||||
COPY --from=web-builder /static/security.txt /web/security.txt
|
||||
COPY --from=web-builder /static/dist/ /web/dist/
|
||||
COPY --from=web-builder /static/authentik/ /web/authentik/
|
||||
|
||||
HEALTHCHECK --interval=5s --retries=20 --start-period=3s CMD [ "/authentik", "healthcheck" ]
|
||||
HEALTHCHECK --interval=5s --retries=20 --start-period=3s CMD [ "/proxy", "healthcheck" ]
|
||||
|
||||
EXPOSE 9000 9300 9443
|
||||
|
||||
@@ -99,4 +87,4 @@ USER 1000
|
||||
ENV TMPDIR=/dev/shm/ \
|
||||
GOFIPS=1
|
||||
|
||||
ENTRYPOINT ["/authentik", "proxy"]
|
||||
ENTRYPOINT ["/proxy"]
|
||||
|
||||
@@ -28,12 +28,7 @@ class HttpHandler(BaseHTTPRequestHandler):
|
||||
_ = db_conn.cursor()
|
||||
|
||||
def do_GET(self):
|
||||
from django.db import (
|
||||
DatabaseError,
|
||||
InterfaceError,
|
||||
OperationalError,
|
||||
connections,
|
||||
)
|
||||
from django.db import DatabaseError, InterfaceError, OperationalError, connections
|
||||
from psycopg.errors import AdminShutdown
|
||||
|
||||
from authentik.root.monitoring import monitoring_set
|
||||
@@ -42,7 +37,6 @@ class HttpHandler(BaseHTTPRequestHandler):
|
||||
AdminShutdown,
|
||||
InterfaceError,
|
||||
DatabaseError,
|
||||
ConnectionError,
|
||||
OperationalError,
|
||||
)
|
||||
|
||||
|
||||
@@ -27,7 +27,6 @@ ipnet.workspace = true
|
||||
json-subscriber.workspace = true
|
||||
notify.workspace = true
|
||||
pin-project-lite.workspace = true
|
||||
rcgen.workspace = true
|
||||
reqwest.workspace = true
|
||||
reqwest-middleware.workspace = true
|
||||
rustls.workspace = true
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
//! Utilities for working with the authentik API client.
|
||||
|
||||
use ak_client::{apis::configuration::Configuration, models::Pagination};
|
||||
use ak_client::apis::configuration::Configuration;
|
||||
use eyre::{Result, eyre};
|
||||
use url::Url;
|
||||
|
||||
@@ -60,42 +60,6 @@ pub fn make_config() -> Result<Configuration> {
|
||||
})
|
||||
}
|
||||
|
||||
/// Fetch all pages from a paginated API endpoint, returning all results combined.
|
||||
///
|
||||
/// - `fetch`: a function that takes a page number and returns a future resolving to a paginated
|
||||
/// response.
|
||||
/// - `get_pagination`: a function that extracts the [`Pagination`] metadata from a response.
|
||||
/// - `get_results`: a function that extracts the result items from a response.
|
||||
pub async fn fetch_all<T, R, E, F, Fut>(
|
||||
fetch: F,
|
||||
get_pagination: impl Fn(&R) -> &Pagination,
|
||||
get_results: impl Fn(R) -> Vec<T>,
|
||||
) -> std::result::Result<Vec<T>, E>
|
||||
where
|
||||
F: Fn(i32) -> Fut,
|
||||
Fut: Future<Output = std::result::Result<R, E>>,
|
||||
{
|
||||
let mut page = 1;
|
||||
let mut results = Vec::with_capacity(0);
|
||||
|
||||
loop {
|
||||
let response = fetch(page).await?;
|
||||
let next = get_pagination(&response).next;
|
||||
if page == 1 {
|
||||
let count = get_pagination(&response).count as usize;
|
||||
results.reserve(count);
|
||||
}
|
||||
results.extend(get_results(response));
|
||||
if next > 0.0 {
|
||||
page += 1;
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(results)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use serde_json::json;
|
||||
|
||||
@@ -3,9 +3,8 @@ use std::{collections::HashMap, net::SocketAddr, num::NonZeroUsize};
|
||||
use ipnet::IpNet;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
pub(super) const KEYS_TO_PARSE_AS_LIST: [&str; 5] = [
|
||||
pub(super) const KEYS_TO_PARSE_AS_LIST: [&str; 4] = [
|
||||
"listen.http",
|
||||
"listen.https",
|
||||
"listen.metrics",
|
||||
"listen.trusted_proxy_cidrs",
|
||||
"log.http_headers",
|
||||
@@ -60,7 +59,6 @@ pub struct PostgreSQLConfig {
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ListenConfig {
|
||||
pub http: Vec<SocketAddr>,
|
||||
pub https: Vec<SocketAddr>,
|
||||
pub metrics: Vec<SocketAddr>,
|
||||
pub debug_tokio: SocketAddr,
|
||||
pub trusted_proxy_cidrs: Vec<IpNet>,
|
||||
|
||||
@@ -7,8 +7,6 @@ use tracing::trace;
|
||||
|
||||
use crate::config;
|
||||
|
||||
pub mod self_signed;
|
||||
|
||||
/// Dummy resolver for FIPS compliance check.
|
||||
#[derive(Debug)]
|
||||
struct EmptyCertResolver;
|
||||
@@ -1,52 +0,0 @@
|
||||
use eyre::Result;
|
||||
use rcgen::{
|
||||
Certificate, CertificateParams, DistinguishedName, DnType, ExtendedKeyUsagePurpose, KeyPair,
|
||||
KeyUsagePurpose, PKCS_RSA_SHA256, SanType,
|
||||
};
|
||||
use rustls::{
|
||||
crypto::aws_lc_rs::sign::any_supported_type,
|
||||
pki_types::{CertificateDer, PrivateKeyDer},
|
||||
sign::CertifiedKey,
|
||||
};
|
||||
use time::{Duration, OffsetDateTime};
|
||||
|
||||
pub fn generate() -> Result<(Certificate, KeyPair)> {
|
||||
let signing_key = KeyPair::generate_for(&PKCS_RSA_SHA256)?;
|
||||
|
||||
let mut params = CertificateParams::default();
|
||||
params.not_before = OffsetDateTime::now_utc();
|
||||
params.not_after = OffsetDateTime::now_utc() + Duration::days(365);
|
||||
params.distinguished_name = {
|
||||
let mut dn = DistinguishedName::new();
|
||||
dn.push(DnType::OrganizationName, "authentik");
|
||||
dn.push(DnType::CommonName, "authentik default certificate");
|
||||
dn
|
||||
};
|
||||
params.subject_alt_names = vec![SanType::DnsName("*".try_into()?)];
|
||||
params.key_usages = vec![
|
||||
KeyUsagePurpose::DigitalSignature,
|
||||
KeyUsagePurpose::KeyEncipherment,
|
||||
];
|
||||
params.extended_key_usages = vec![ExtendedKeyUsagePurpose::ServerAuth];
|
||||
|
||||
let cert = params.self_signed(&signing_key)?;
|
||||
|
||||
Ok((cert, signing_key))
|
||||
}
|
||||
|
||||
pub fn generate_certifiedkey() -> Result<CertifiedKey> {
|
||||
let (cert, keypair) = generate()?;
|
||||
|
||||
let cert_der = cert.der().to_vec();
|
||||
let key_der = keypair.serialize_der();
|
||||
|
||||
let private_key =
|
||||
PrivateKeyDer::try_from(key_der).map_err(|_| rcgen::Error::CouldNotParseKeyPair)?;
|
||||
let signing_key =
|
||||
any_supported_type(&private_key).map_err(|_| rcgen::Error::CouldNotParseKeyPair)?;
|
||||
|
||||
Ok(CertifiedKey::new(
|
||||
vec![CertificateDer::from(cert_der)],
|
||||
signing_key,
|
||||
))
|
||||
}
|
||||
@@ -30,12 +30,12 @@ pub fn install() -> Result<()> {
|
||||
}
|
||||
|
||||
if config.debug {
|
||||
// let console_layer = console_subscriber::ConsoleLayer::builder()
|
||||
// .server_addr(config.listen.debug_tokio)
|
||||
// .spawn();
|
||||
let console_layer = console_subscriber::ConsoleLayer::builder()
|
||||
.server_addr(config.listen.debug_tokio)
|
||||
.spawn();
|
||||
tracing_subscriber::registry()
|
||||
.with(ErrorLayer::default())
|
||||
// .with(console_layer)
|
||||
.with(console_layer)
|
||||
.with(
|
||||
fmt::layer()
|
||||
.compact()
|
||||
@@ -187,10 +187,13 @@ pub mod sentry {
|
||||
environment: config.environment,
|
||||
send_pii: config.send_pii,
|
||||
#[expect(
|
||||
clippy::as_conversions,
|
||||
clippy::cast_possible_truncation,
|
||||
reason = "This is fine, we'll never get big values here."
|
||||
)]
|
||||
#[expect(
|
||||
clippy::as_conversions,
|
||||
reason = "This is fine, we'll never get big values here."
|
||||
)]
|
||||
sample_rate: config.traces_sample_rate as f32,
|
||||
})
|
||||
}
|
||||
|
||||
2
packages/client-ts/package.json
generated
2
packages/client-ts/package.json
generated
@@ -8,8 +8,8 @@
|
||||
"url": "https://github.com/goauthentik/authentik.git"
|
||||
},
|
||||
"scripts": {
|
||||
"clean": "tsc -b --clean tsconfig.json tsconfig.esm.json",
|
||||
"build": "npm run clean && tsc -b tsconfig.json tsconfig.esm.json",
|
||||
"clean": "tsc -b --clean tsconfig.json tsconfig.esm.json",
|
||||
"prepare": "npm run build"
|
||||
},
|
||||
"main": "./dist/index.js",
|
||||
|
||||
@@ -32,16 +32,17 @@ class DjangoDramatiqPostgres(AppConfig):
|
||||
middleware=[],
|
||||
)
|
||||
|
||||
for middleware_class, middleware_kwargs in Conf().middlewares:
|
||||
middleware: dramatiq.middleware.middleware.Middleware = import_string(middleware_class)(
|
||||
**middleware_kwargs,
|
||||
)
|
||||
if isinstance(middleware, Results):
|
||||
middleware.backend = import_string(Conf().result_backend)(
|
||||
for middleware_class_path, middleware_kwargs in Conf().middlewares:
|
||||
middleware_class = import_string(middleware_class_path)
|
||||
if issubclass(middleware_class, Results):
|
||||
middleware_kwargs["backend"] = import_string(Conf().result_backend)(
|
||||
*Conf().result_backend_args,
|
||||
**Conf().result_backend_kwargs,
|
||||
)
|
||||
broker.add_middleware(middleware) # type: ignore[no-untyped-call]
|
||||
middleware: dramatiq.middleware.middleware.Middleware = middleware_class(
|
||||
**middleware_kwargs,
|
||||
)
|
||||
broker.add_middleware(middleware)
|
||||
|
||||
dramatiq.set_broker(broker)
|
||||
|
||||
|
||||
@@ -23,11 +23,9 @@ from django.utils.functional import cached_property
|
||||
from django.utils.module_loading import import_string
|
||||
from dramatiq.broker import Broker, Consumer, MessageProxy
|
||||
from dramatiq.common import compute_backoff, current_millis, dq_name, q_name, xq_name
|
||||
from dramatiq.errors import ConnectionError, QueueJoinTimeout
|
||||
from dramatiq.errors import BrokerConnectionError, QueueJoinTimeout
|
||||
from dramatiq.message import Message
|
||||
from dramatiq.middleware import (
|
||||
Middleware,
|
||||
)
|
||||
from dramatiq.middleware import Middleware
|
||||
from pglock.core import _cast_lock_id
|
||||
from psycopg import sql
|
||||
from psycopg.errors import AdminShutdown
|
||||
@@ -46,7 +44,6 @@ DATABASE_ERRORS = (
|
||||
AdminShutdown,
|
||||
InterfaceError,
|
||||
DatabaseError,
|
||||
ConnectionError,
|
||||
OperationalError,
|
||||
)
|
||||
|
||||
@@ -55,7 +52,7 @@ def channel_name(queue_name: str, identifier: ChannelIdentifier) -> str:
|
||||
return f"{CHANNEL_PREFIX}.{queue_name}.{identifier.value}"
|
||||
|
||||
|
||||
def raise_connection_error(func: Callable[P, R]) -> Callable[P, R]: # noqa: UP047
|
||||
def raise_broker_connection_error(func: Callable[P, R]) -> Callable[P, R]: # noqa: UP047
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||
try:
|
||||
@@ -66,13 +63,13 @@ def raise_connection_error(func: Callable[P, R]) -> Callable[P, R]: # noqa: UP0
|
||||
connections.close_all()
|
||||
except DATABASE_ERRORS:
|
||||
pass
|
||||
raise ConnectionError(str(exc)) from exc # type: ignore[no-untyped-call]
|
||||
raise BrokerConnectionError(str(exc)) from exc # type: ignore[no-untyped-call]
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
class PostgresBroker(Broker):
|
||||
queues: set[str] # type: ignore[assignment]
|
||||
queues: set[str]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -81,7 +78,7 @@ class PostgresBroker(Broker):
|
||||
db_alias: str = DEFAULT_DB_ALIAS,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().__init__(*args, middleware=[], **kwargs) # type: ignore[no-untyped-call,misc]
|
||||
super().__init__(*args, middleware=[], **kwargs) # type: ignore[misc]
|
||||
self.logger = get_logger(__name__, type(self))
|
||||
|
||||
self.queues = set()
|
||||
@@ -122,10 +119,10 @@ class PostgresBroker(Broker):
|
||||
|
||||
def declare_queue(self, queue_name: str) -> None:
|
||||
if queue_name not in self.queues:
|
||||
self.emit_before("declare_queue", queue_name) # type: ignore[no-untyped-call]
|
||||
self.emit_before("declare_queue", queue_name)
|
||||
self.queues.add(queue_name)
|
||||
# Nothing more to do, all queues are in the same table
|
||||
self.emit_after("declare_queue", queue_name) # type: ignore[no-untyped-call]
|
||||
self.emit_after("declare_queue", queue_name)
|
||||
|
||||
def model_defaults(self, message: Message[Any]) -> dict[str, Any]:
|
||||
eta = None
|
||||
@@ -141,7 +138,7 @@ class PostgresBroker(Broker):
|
||||
}
|
||||
|
||||
@tenacity.retry(
|
||||
retry=tenacity.retry_if_exception_type(ConnectionError),
|
||||
retry=tenacity.retry_if_exception_type(BrokerConnectionError),
|
||||
reraise=True,
|
||||
wait=tenacity.wait_random_exponential(multiplier=1, max=5),
|
||||
stop=tenacity.stop_after_attempt(3),
|
||||
@@ -149,11 +146,11 @@ class PostgresBroker(Broker):
|
||||
cast(logging.Logger, logger), logging.INFO, exc_info=True
|
||||
),
|
||||
)
|
||||
@raise_connection_error
|
||||
@raise_broker_connection_error
|
||||
def enqueue(self, message: Message[Any], *, delay: int | None = None) -> Message[Any]:
|
||||
queue_name = q_name(message.queue_name) # type: ignore[no-untyped-call]
|
||||
queue_name = q_name(message.queue_name)
|
||||
if delay:
|
||||
message_eta = current_millis() + delay # type: ignore[no-untyped-call]
|
||||
message_eta = current_millis() + delay
|
||||
message.options["eta"] = message_eta
|
||||
|
||||
self.declare_queue(queue_name)
|
||||
@@ -163,7 +160,7 @@ class PostgresBroker(Broker):
|
||||
|
||||
message.options["model_defaults"] = self.model_defaults(message)
|
||||
message.options["model_create_defaults"] = {}
|
||||
self.emit_before("enqueue", message, delay) # type: ignore[no-untyped-call]
|
||||
self.emit_before("enqueue", message, delay)
|
||||
|
||||
with transaction.atomic(using=self.db_alias):
|
||||
query = {
|
||||
@@ -185,7 +182,7 @@ class PostgresBroker(Broker):
|
||||
message.options["task"] = task
|
||||
message.options["task_created"] = created
|
||||
|
||||
self.emit_after("enqueue", message, delay) # type: ignore[no-untyped-call]
|
||||
self.emit_after("enqueue", message, delay)
|
||||
return message
|
||||
|
||||
def get_declared_queues(self) -> set[str]:
|
||||
@@ -193,7 +190,7 @@ class PostgresBroker(Broker):
|
||||
|
||||
def flush(self, queue_name: str) -> None:
|
||||
self.query_set.filter(
|
||||
queue_name__in=(queue_name, dq_name(queue_name), xq_name(queue_name)) # type: ignore[no-untyped-call]
|
||||
queue_name__in=(queue_name, dq_name(queue_name), xq_name(queue_name))
|
||||
).delete()
|
||||
|
||||
def flush_all(self) -> None:
|
||||
@@ -375,7 +372,7 @@ class _PostgresConsumer(Consumer):
|
||||
self.in_processing.add(str(message_id))
|
||||
return message
|
||||
|
||||
@raise_connection_error
|
||||
@raise_broker_connection_error
|
||||
def __next__(self) -> MessageProxy | None:
|
||||
# This method is called every second
|
||||
|
||||
@@ -395,7 +392,7 @@ class _PostgresConsumer(Consumer):
|
||||
if processing >= self.prefetch:
|
||||
# If we have too many messages already processing, wait and don't consume a message
|
||||
# straight away, other workers will be faster.
|
||||
self.misses, backoff_ms = compute_backoff(self.misses, max_backoff=1000) # type: ignore[no-untyped-call]
|
||||
self.misses, backoff_ms = compute_backoff(self.misses, max_backoff=1000)
|
||||
self.logger.debug(
|
||||
"Too many messages in processing, Sleeping",
|
||||
processing=processing,
|
||||
@@ -420,7 +417,7 @@ class _PostgresConsumer(Consumer):
|
||||
break
|
||||
message = self._consume_one(str(message_id))
|
||||
if message is not None:
|
||||
return MessageProxy(message) # type: ignore[no-untyped-call]
|
||||
return MessageProxy(message)
|
||||
else:
|
||||
self.logger.debug("Message already consumed. Skipping.", message_id=message_id)
|
||||
continue
|
||||
@@ -444,7 +441,7 @@ class _PostgresConsumer(Consumer):
|
||||
self.to_unlock.add(str(message_id))
|
||||
return False
|
||||
|
||||
def _post_process_message(self, message: Message[Any], state: TaskState) -> None:
|
||||
def _post_process_message(self, message: MessageProxy, state: TaskState) -> None:
|
||||
self.logger.debug("Post-processing message", message=message.message_id, state=state)
|
||||
try:
|
||||
self.in_processing.remove(str(message.message_id))
|
||||
@@ -466,16 +463,16 @@ class _PostgresConsumer(Consumer):
|
||||
)
|
||||
message.options["task"] = task
|
||||
|
||||
@raise_connection_error
|
||||
def ack(self, message: Message[Any]) -> None:
|
||||
@raise_broker_connection_error
|
||||
def ack(self, message: MessageProxy) -> None:
|
||||
self._post_process_message(message, TaskState.DONE)
|
||||
|
||||
@raise_connection_error
|
||||
def nack(self, message: Message[Any]) -> None:
|
||||
@raise_broker_connection_error
|
||||
def nack(self, message: MessageProxy) -> None:
|
||||
self._post_process_message(message, TaskState.REJECTED)
|
||||
|
||||
@raise_connection_error
|
||||
def requeue(self, messages: Iterable[Message[Any]]) -> None:
|
||||
@raise_broker_connection_error
|
||||
def requeue(self, messages: Iterable[MessageProxy]) -> None:
|
||||
self.query_set.filter(
|
||||
message_id__in=[message.message_id for message in messages],
|
||||
).update(
|
||||
@@ -514,7 +511,7 @@ class _PostgresConsumer(Consumer):
|
||||
self.logger.info("Purged messages in all queues", count=count)
|
||||
self.task_purge_last_run = timezone.now()
|
||||
|
||||
@raise_connection_error
|
||||
@raise_broker_connection_error
|
||||
def close(self) -> None:
|
||||
try:
|
||||
self._purge_locks()
|
||||
|
||||
@@ -5,7 +5,7 @@ from signal import pause
|
||||
from django_dramatiq_postgres.conf import Conf
|
||||
|
||||
|
||||
def worker_metrics() -> None:
|
||||
def worker_metrics() -> int:
|
||||
import_module(Conf().autodiscovery["setup_module"])
|
||||
|
||||
from django_dramatiq_postgres.middleware import MetricsMiddleware
|
||||
@@ -15,3 +15,4 @@ def worker_metrics() -> None:
|
||||
int(os.getenv("dramatiq_prom_port", "9191")),
|
||||
)
|
||||
pause()
|
||||
return 0
|
||||
|
||||
@@ -10,7 +10,7 @@ from typing import TYPE_CHECKING, Any, cast
|
||||
|
||||
from django.db import DatabaseError, close_old_connections, connections
|
||||
from dramatiq.actor import Actor
|
||||
from dramatiq.broker import Broker
|
||||
from dramatiq.broker import Broker, MessageProxy
|
||||
from dramatiq.common import current_millis
|
||||
from dramatiq.message import Message
|
||||
from dramatiq.middleware.middleware import Middleware
|
||||
@@ -79,7 +79,7 @@ class DbConnectionMiddleware(Middleware):
|
||||
|
||||
|
||||
class TaskStateBeforeMiddleware(Middleware):
|
||||
def before_process_message(self, broker: PostgresBroker, message: Message[Any]) -> None:
|
||||
def before_process_message(self, broker: PostgresBroker, message: Message[Any]) -> None: # type: ignore[override]
|
||||
broker.query_set.filter(
|
||||
message_id=message.message_id,
|
||||
queue_name=message.queue_name,
|
||||
@@ -90,7 +90,7 @@ class TaskStateBeforeMiddleware(Middleware):
|
||||
|
||||
|
||||
class TaskStateAfterMiddleware(Middleware):
|
||||
def before_process_message(self, broker: PostgresBroker, message: Message[Any]) -> None:
|
||||
def before_process_message(self, broker: PostgresBroker, message: MessageProxy) -> None: # type: ignore[override]
|
||||
broker.query_set.filter(
|
||||
message_id=message.message_id,
|
||||
queue_name=message.queue_name,
|
||||
@@ -99,7 +99,7 @@ class TaskStateAfterMiddleware(Middleware):
|
||||
state=TaskState.RUNNING,
|
||||
)
|
||||
|
||||
def after_skip_message(self, broker: PostgresBroker, message: Message[Any]) -> None:
|
||||
def after_skip_message(self, broker: PostgresBroker, message: MessageProxy) -> None: # type: ignore[override]
|
||||
broker.query_set.filter(
|
||||
message_id=message.message_id,
|
||||
queue_name=message.queue_name,
|
||||
@@ -110,11 +110,11 @@ class TaskStateAfterMiddleware(Middleware):
|
||||
|
||||
def after_process_message(
|
||||
self,
|
||||
broker: PostgresBroker,
|
||||
message: Message[Any],
|
||||
broker: PostgresBroker, # type: ignore[override]
|
||||
message: MessageProxy,
|
||||
*,
|
||||
result: Any | None = None,
|
||||
exception: Exception | None = None,
|
||||
exception: BaseException | None = None,
|
||||
) -> None:
|
||||
self.after_skip_message(broker, message)
|
||||
|
||||
@@ -147,7 +147,7 @@ class CurrentTask(Middleware):
|
||||
raise CurrentTaskNotFound()
|
||||
return task[-1]
|
||||
|
||||
def before_process_message(self, broker: Broker, message: Message[Any]) -> None:
|
||||
def before_process_message(self, broker: Broker, message: MessageProxy) -> None:
|
||||
tasks = self._TASKS.get()
|
||||
if tasks is None:
|
||||
tasks = []
|
||||
@@ -157,10 +157,10 @@ class CurrentTask(Middleware):
|
||||
def after_process_message(
|
||||
self,
|
||||
broker: Broker,
|
||||
message: Message[Any],
|
||||
message: MessageProxy,
|
||||
*,
|
||||
result: Any | None = None,
|
||||
exception: Exception | None = None,
|
||||
exception: BaseException | None = None,
|
||||
) -> None:
|
||||
tasks: list[TaskBase] | None = self._TASKS.get()
|
||||
if tasks is None or len(tasks) == 0:
|
||||
@@ -194,7 +194,7 @@ class CurrentTask(Middleware):
|
||||
pass
|
||||
self._TASKS.set(tasks[:-1])
|
||||
|
||||
def after_skip_message(self, broker: Broker, message: Message[Any]) -> None:
|
||||
def after_skip_message(self, broker: Broker, message: MessageProxy) -> None:
|
||||
self.after_process_message(broker, message)
|
||||
|
||||
|
||||
@@ -236,7 +236,7 @@ class MetricsMiddleware(Middleware):
|
||||
self.message_start_times: dict[str, int] = {}
|
||||
|
||||
@property
|
||||
def forks(self) -> list[Callable[[], None]]:
|
||||
def forks(self) -> list[Callable[[], int]]:
|
||||
from django_dramatiq_postgres.forks import worker_metrics
|
||||
|
||||
return [worker_metrics]
|
||||
@@ -310,41 +310,41 @@ class MetricsMiddleware(Middleware):
|
||||
# TODO: worker_id
|
||||
multiprocess.mark_process_dead(os.getpid()) # type: ignore[no-untyped-call]
|
||||
|
||||
def _make_labels(self, message: Message[Any]) -> list[str]:
|
||||
def _make_labels(self, message: MessageProxy | Message[Any]) -> list[str]:
|
||||
return [message.queue_name, message.actor_name]
|
||||
|
||||
def after_nack(self, broker: Broker, message: Message[Any]) -> None:
|
||||
def after_nack(self, broker: Broker, message: MessageProxy) -> None:
|
||||
self.total_rejected_messages.labels(*self._make_labels(message)).inc()
|
||||
|
||||
def after_enqueue(self, broker: Broker, message: Message[Any], delay: int) -> None:
|
||||
if "retries" in message.options:
|
||||
self.total_retried_messages.labels(*self._make_labels(message)).inc()
|
||||
|
||||
def before_delay_message(self, broker: Broker, message: Message[Any]) -> None:
|
||||
def before_delay_message(self, broker: Broker, message: MessageProxy) -> None:
|
||||
self.delayed_messages.add(message.message_id)
|
||||
self.in_progress_delayed_messages.labels(*self._make_labels(message)).inc()
|
||||
|
||||
def before_process_message(self, broker: Broker, message: Message[Any]) -> None:
|
||||
def before_process_message(self, broker: Broker, message: MessageProxy) -> None:
|
||||
labels = self._make_labels(message)
|
||||
if message.message_id in self.delayed_messages:
|
||||
self.delayed_messages.remove(message.message_id)
|
||||
self.in_progress_delayed_messages.labels(*labels).dec()
|
||||
|
||||
self.in_progress_messages.labels(*labels).inc()
|
||||
self.message_start_times[message.message_id] = current_millis() # type: ignore[no-untyped-call]
|
||||
self.message_start_times[message.message_id] = current_millis()
|
||||
|
||||
def after_process_message(
|
||||
self,
|
||||
broker: Broker,
|
||||
message: Message[Any],
|
||||
message: MessageProxy,
|
||||
*,
|
||||
result: Any | None = None,
|
||||
exception: Exception | None = None,
|
||||
exception: BaseException | None = None,
|
||||
) -> None:
|
||||
labels = self._make_labels(message)
|
||||
|
||||
message_start_time = self.message_start_times.pop(message.message_id, current_millis()) # type: ignore[no-untyped-call]
|
||||
message_duration = current_millis() - message_start_time # type: ignore[no-untyped-call]
|
||||
message_start_time = self.message_start_times.pop(message.message_id, current_millis())
|
||||
message_duration = current_millis() - message_start_time
|
||||
self.messages_durations.labels(*labels).observe(message_duration)
|
||||
|
||||
self.in_progress_messages.labels(*labels).dec()
|
||||
|
||||
@@ -159,7 +159,7 @@ class ScheduleBase(models.Model):
|
||||
|
||||
def send(self, broker: Broker | None = None) -> Message[Any]:
|
||||
broker = broker or get_broker()
|
||||
actor: Actor[Any, Any] = broker.get_actor(self.actor_name) # type: ignore[no-untyped-call]
|
||||
actor: Actor[Any, Any] = broker.get_actor(self.actor_name)
|
||||
return actor.send_with_options(
|
||||
args=pickle.loads(self.args), # nosec
|
||||
kwargs=pickle.loads(self.kwargs), # nosec
|
||||
|
||||
@@ -36,7 +36,7 @@ dependencies = [
|
||||
"django >=4.2,<6.0",
|
||||
"django-pglock >=1.7,<2",
|
||||
"django-pgtrigger >=4,<5",
|
||||
"dramatiq >=1.17,<1.18",
|
||||
"dramatiq >=2,<3",
|
||||
"tenacity >=9,<10",
|
||||
"structlog >=25,<26",
|
||||
]
|
||||
|
||||
@@ -9678,18 +9678,14 @@ paths:
|
||||
security:
|
||||
- authentik: []
|
||||
responses:
|
||||
'204':
|
||||
'200':
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/BlueprintImportResult'
|
||||
description: ''
|
||||
'400':
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/BlueprintImportResult'
|
||||
description: ''
|
||||
$ref: '#/components/responses/ValidationErrorResponse'
|
||||
'403':
|
||||
$ref: '#/components/responses/GenericErrorResponse'
|
||||
/oauth2/access_tokens/:
|
||||
|
||||
10
src/main.rs
10
src/main.rs
@@ -8,8 +8,6 @@ use eyre::{Result, eyre};
|
||||
use tracing::{error, info, trace};
|
||||
|
||||
mod metrics;
|
||||
#[cfg(feature = "proxy")]
|
||||
mod outpost;
|
||||
#[cfg(feature = "core")]
|
||||
mod server;
|
||||
#[cfg(feature = "core")]
|
||||
@@ -31,8 +29,6 @@ enum Command {
|
||||
Server(server::Cli),
|
||||
#[cfg(feature = "core")]
|
||||
Worker(worker::Cli),
|
||||
#[cfg(feature = "proxy")]
|
||||
Proxy(outpost::proxy::Cli),
|
||||
}
|
||||
|
||||
#[derive(Debug, FromArgs, PartialEq)]
|
||||
@@ -57,8 +53,6 @@ fn main() -> Result<()> {
|
||||
Command::Server(_) => Mode::set(Mode::Server)?,
|
||||
#[cfg(feature = "core")]
|
||||
Command::Worker(_) => Mode::set(Mode::Worker)?,
|
||||
#[cfg(feature = "proxy")]
|
||||
Command::Proxy(_) => Mode::set(Mode::Proxy)?,
|
||||
}
|
||||
|
||||
trace!("installing error formatting");
|
||||
@@ -114,10 +108,6 @@ fn main() -> Result<()> {
|
||||
let workers = worker::start(args, &mut tasks)?;
|
||||
metrics.workers.store(Some(workers));
|
||||
}
|
||||
#[cfg(feature = "proxy")]
|
||||
Command::Proxy(args) => {
|
||||
outpost::start::<outpost::proxy::ProxyOutpost>(args, &mut tasks).await?;
|
||||
}
|
||||
}
|
||||
|
||||
let errors = tasks.run().await;
|
||||
|
||||
@@ -1,312 +0,0 @@
|
||||
use std::{fmt::Display, sync::Arc};
|
||||
|
||||
use ak_common::{Arbiter, Tasks, VERSION, api, arbiter, authentik_build_hash};
|
||||
use axum::http::{HeaderValue, header::AUTHORIZATION};
|
||||
use eyre::{Result, eyre};
|
||||
use futures::{SinkExt as _, StreamExt as _};
|
||||
use nix::unistd::gethostname;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_repr::{Deserialize_repr, Serialize_repr};
|
||||
use time::UtcDateTime;
|
||||
use tokio::{
|
||||
signal::unix::SignalKind,
|
||||
time::{Duration, interval, sleep},
|
||||
};
|
||||
use tokio_tungstenite::tungstenite::{Message, client::IntoClientRequest as _};
|
||||
use tracing::{debug, info, instrument, trace, warn};
|
||||
use url::Url;
|
||||
|
||||
use crate::outpost::{Outpost, OutpostController};
|
||||
|
||||
#[derive(Serialize_repr, Deserialize_repr, PartialEq, Debug, Clone, Copy, Eq)]
|
||||
#[repr(u8)]
|
||||
enum EventKind {
|
||||
/// Code used to acknowledge a previous message.
|
||||
Ack = 0,
|
||||
/// Code used to send a healthcheck keepalive.
|
||||
Hello = 1,
|
||||
/// Code received to trigger a config update.
|
||||
TriggerUpdate = 2,
|
||||
/// Code received to trigger some provider specific function.
|
||||
ProviderSpecific = 3,
|
||||
/// Code received to identify the end of a session.
|
||||
SessionEnd = 4,
|
||||
}
|
||||
|
||||
impl Display for EventKind {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
Self::Ack => write!(f, "Ack"),
|
||||
Self::Hello => write!(f, "Hello"),
|
||||
Self::TriggerUpdate => write!(f, "TriggerUpdate"),
|
||||
Self::ProviderSpecific => write!(f, "ProviderSpecific"),
|
||||
Self::SessionEnd => write!(f, "SessionEnd"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
struct Event {
|
||||
instruction: EventKind,
|
||||
args: serde_json::Value,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub(crate) struct EventSessionEnd {
|
||||
session_id: String,
|
||||
}
|
||||
|
||||
fn build_ws_url(mut url: Url, outpost_pk: &str, instance_uuid: &str, attempt: u32) -> Result<Url> {
|
||||
let ws_scheme = match url.scheme() {
|
||||
"https" => "wss",
|
||||
"http" => "ws",
|
||||
other => return Err(eyre!("Unsupported scheme for WebSocket URL: {other}")),
|
||||
};
|
||||
|
||||
url.set_scheme(ws_scheme)
|
||||
.map_err(|()| eyre!("Failed to set URL scheme to {ws_scheme}"))?;
|
||||
url.set_path(&format!("{}ws/outpost/{outpost_pk}/", url.path()));
|
||||
url.query_pairs_mut()
|
||||
.append_pair("instance_uuid", instance_uuid)
|
||||
.append_pair("attempt", &attempt.to_string());
|
||||
|
||||
Ok(url)
|
||||
}
|
||||
|
||||
fn hello_args(instance_uuid: &str) -> serde_json::Value {
|
||||
let raw_hostname = gethostname().unwrap_or_default();
|
||||
let hostname = raw_hostname.to_string_lossy();
|
||||
|
||||
serde_json::json!({
|
||||
"version": VERSION,
|
||||
"buildHash": authentik_build_hash(None),
|
||||
"uuid": instance_uuid,
|
||||
// TODO: rust version and AWS-LC versions
|
||||
"hostname": hostname,
|
||||
})
|
||||
}
|
||||
|
||||
#[instrument(skip_all)]
|
||||
async fn handle_event<O: Outpost>(
|
||||
controller: Arc<OutpostController>,
|
||||
outpost: Arc<O>,
|
||||
event: Event,
|
||||
) -> Result<()> {
|
||||
match event.instruction {
|
||||
EventKind::Ack | EventKind::Hello => {}
|
||||
EventKind::TriggerUpdate => {
|
||||
info!("received update trigger, refreshing outpost");
|
||||
sleep(controller.reload_offset).await;
|
||||
controller.refresh().await?;
|
||||
debug!("outpost controller has been refreshed");
|
||||
outpost.refresh().await?;
|
||||
debug!("outpost has been refreshed");
|
||||
#[expect(
|
||||
clippy::as_conversions,
|
||||
clippy::cast_precision_loss,
|
||||
reason = "This is fine, we'll never get big values here."
|
||||
)]
|
||||
controller
|
||||
.m_last_update
|
||||
.set(UtcDateTime::now().unix_timestamp() as f64);
|
||||
}
|
||||
EventKind::SessionEnd => {
|
||||
let event: EventSessionEnd = serde_json::from_value(event.args)?;
|
||||
outpost.end_session(event).await?;
|
||||
}
|
||||
#[expect(
|
||||
clippy::unimplemented,
|
||||
reason = "this is only relevant for the RAC provider"
|
||||
)]
|
||||
EventKind::ProviderSpecific => unimplemented!(),
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn watch_events_inner<O: Outpost>(
|
||||
arbiter: Arbiter,
|
||||
controller: Arc<OutpostController>,
|
||||
outpost: Arc<O>,
|
||||
attempt: u32,
|
||||
) -> Result<()> {
|
||||
let server_config = api::ServerConfig::new()?;
|
||||
let ws_url = build_ws_url(
|
||||
server_config.host,
|
||||
&controller.outpost.load().pk.to_string(),
|
||||
&controller.instance_uuid.to_string(),
|
||||
attempt,
|
||||
)?;
|
||||
|
||||
debug!(url = %ws_url, "connecting to websocket");
|
||||
let mut request = ws_url.into_client_request()?;
|
||||
let token = controller
|
||||
.api_config
|
||||
.bearer_access_token
|
||||
.as_deref()
|
||||
.unwrap_or("");
|
||||
request.headers_mut().insert(
|
||||
AUTHORIZATION,
|
||||
HeaderValue::from_str(&format!("Bearer {token}"))?,
|
||||
);
|
||||
|
||||
let (ws_stream, _response) = tokio_tungstenite::connect_async(request).await?;
|
||||
let (mut ws_write, mut ws_read) = ws_stream.split();
|
||||
|
||||
info!(
|
||||
outpost = %controller.outpost.load().pk,
|
||||
"connected to websocket"
|
||||
);
|
||||
controller.m_connection.set(1_u8);
|
||||
|
||||
let get_refresh_interval = || {
|
||||
let mut interval = controller.outpost.load().refresh_interval_s;
|
||||
// Ensure timer interval is not negative or 0.
|
||||
// If it is, we default to 5 minutes.
|
||||
if interval <= 0_i32 {
|
||||
interval = 60_i32 * 5_i32;
|
||||
}
|
||||
// Clamp interval to be at least 30 seconds.
|
||||
if interval < 30_i32 {
|
||||
interval = 30_i32;
|
||||
}
|
||||
// infallible because we bound it to be positive above
|
||||
Duration::from_secs(interval.try_into().expect("infallible"))
|
||||
};
|
||||
let mut refresh_interval = interval(get_refresh_interval());
|
||||
let mut heartbeat_interval = interval(Duration::from_secs(10));
|
||||
|
||||
let mut events_rx = arbiter.events_subscribe();
|
||||
|
||||
loop {
|
||||
tokio::select! {
|
||||
_ = refresh_interval.tick() => {
|
||||
info!("refreshing outpost on interval");
|
||||
if let Err(err) = handle_event(
|
||||
Arc::clone(&controller),
|
||||
Arc::clone(&outpost),
|
||||
Event {
|
||||
instruction: EventKind::TriggerUpdate,
|
||||
args: serde_json::Value::Null
|
||||
}
|
||||
).await {
|
||||
warn!(?err, "failed to refresh");
|
||||
}
|
||||
refresh_interval = interval(get_refresh_interval());
|
||||
// Since we re-create the interval, we need to make it tick instantly to avoid
|
||||
// ending up in a never-ending tick-loop.
|
||||
refresh_interval.tick().await;
|
||||
},
|
||||
_ = heartbeat_interval.tick() => {
|
||||
let ping = Event {
|
||||
instruction: EventKind::Hello,
|
||||
args: hello_args(&controller.instance_uuid.to_string()),
|
||||
};
|
||||
ws_write.send(Message::text(serde_json::to_string(&ping)?)).await?;
|
||||
trace!("sent websocket hello (heartbeat)");
|
||||
},
|
||||
Ok(arbiter::Event::Signal(signal)) = events_rx.recv() => {
|
||||
if signal == SignalKind::user_defined1() {
|
||||
info!("refreshing outpost on signal");
|
||||
if let Err(err) = handle_event(
|
||||
Arc::clone(&controller),
|
||||
Arc::clone(&outpost),
|
||||
Event {
|
||||
instruction: EventKind::TriggerUpdate,
|
||||
args: serde_json::Value::Null
|
||||
}
|
||||
).await {
|
||||
warn!(?err, "failed to refresh");
|
||||
}
|
||||
}
|
||||
},
|
||||
msg = ws_read.next() => {
|
||||
let Some(msg) = msg else {
|
||||
break;
|
||||
};
|
||||
let msg = msg?;
|
||||
match msg {
|
||||
Message::Text(text) => {
|
||||
let Ok(event): Result<Event, _> = serde_json::from_str(&text) else {
|
||||
warn!(data = text.as_str(), "failed to parse event");
|
||||
continue;
|
||||
};
|
||||
trace!(event = %event.instruction, "received websocket event");
|
||||
if let Err(err) = handle_event(
|
||||
Arc::clone(&controller),
|
||||
Arc::clone(&outpost),
|
||||
event,
|
||||
).await {
|
||||
warn!(?err, "failed to handle event");
|
||||
}
|
||||
},
|
||||
Message::Ping(data) => {
|
||||
ws_write.send(Message::Pong(data)).await?;
|
||||
},
|
||||
Message::Close(_) => {
|
||||
break;
|
||||
},
|
||||
_ => {},
|
||||
}
|
||||
},
|
||||
() = arbiter.shutdown() => break,
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn watch_events<O: Outpost>(
|
||||
arbiter: Arbiter,
|
||||
controller: Arc<OutpostController>,
|
||||
outpost: Arc<O>,
|
||||
) -> Result<()> {
|
||||
const MAX_BACKOFF: Duration = Duration::from_mins(5);
|
||||
let mut backoff = Duration::from_secs(1);
|
||||
let mut attempt: u32 = 0;
|
||||
|
||||
loop {
|
||||
tokio::select! {
|
||||
() = arbiter.shutdown() => break,
|
||||
res = watch_events_inner(
|
||||
arbiter.clone(),
|
||||
Arc::clone(&controller),
|
||||
Arc::clone(&outpost),
|
||||
attempt
|
||||
) => {
|
||||
controller.m_connection.set(0_u8);
|
||||
match res {
|
||||
Ok(()) => debug!("websocket disconnected cleanly"),
|
||||
Err(err) => warn!(?err, attempt, "websocket error"),
|
||||
}
|
||||
|
||||
info!(attempt, delay = backoff.as_secs(), "reconnecting websocket in {}s...", backoff.as_secs());
|
||||
|
||||
tokio::select! {
|
||||
() = arbiter.shutdown() => break,
|
||||
() = sleep(backoff) => {}
|
||||
}
|
||||
|
||||
backoff = (backoff * 2).min(MAX_BACKOFF);
|
||||
attempt += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
info!("stopping event watcher");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) fn start<O: Outpost + 'static>(
|
||||
tasks: &mut Tasks,
|
||||
controller: Arc<OutpostController>,
|
||||
outpost: Arc<O>,
|
||||
) -> Result<()> {
|
||||
let arbiter = tasks.arbiter();
|
||||
tasks
|
||||
.build_task()
|
||||
.name(&format!("{}::watch_events", module_path!()))
|
||||
.spawn(watch_events(arbiter, controller, outpost))?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -1,123 +0,0 @@
|
||||
use std::{sync::Arc, time::Duration};
|
||||
|
||||
use ak_client::{
|
||||
apis::{configuration::Configuration, outposts_api::outposts_instances_list},
|
||||
models::Outpost as OutpostModel,
|
||||
};
|
||||
use ak_common::{Tasks, VERSION, api, authentik_build_hash};
|
||||
use arc_swap::ArcSwap;
|
||||
use eyre::{Result, eyre};
|
||||
use tracing::{debug, info, instrument};
|
||||
use uuid::Uuid;
|
||||
|
||||
pub(crate) mod event;
|
||||
#[cfg(feature = "proxy")]
|
||||
pub(crate) mod proxy;
|
||||
|
||||
pub(crate) trait Outpost: Send + Sync + Sized {
|
||||
const OUTPOST_TYPE: &'static str;
|
||||
type Cli: Send + Sync;
|
||||
|
||||
async fn new(controller: Arc<OutpostController>) -> Result<Self>;
|
||||
|
||||
fn start(self: Arc<Self>, tasks: &mut Tasks) -> Result<()>;
|
||||
fn refresh(&self) -> impl Future<Output = Result<()>> + Send;
|
||||
|
||||
fn end_session(&self, event: event::EventSessionEnd)
|
||||
-> impl Future<Output = Result<()>> + Send;
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub(crate) struct OutpostController {
|
||||
api_config: Configuration,
|
||||
outpost: ArcSwap<OutpostModel>,
|
||||
instance_uuid: Uuid,
|
||||
reload_offset: Duration,
|
||||
m_info: metrics::Gauge,
|
||||
m_last_update: metrics::Gauge,
|
||||
m_connection: metrics::Gauge,
|
||||
}
|
||||
|
||||
impl OutpostController {
|
||||
#[instrument(skip_all)]
|
||||
async fn get_outpost(api_config: &Configuration) -> Result<OutpostModel> {
|
||||
let outposts = outposts_instances_list(
|
||||
api_config, None, None, None, None, None, None, None, None, None, None, None, None,
|
||||
)
|
||||
.await?;
|
||||
|
||||
let Some(outpost) = outposts.results.into_iter().next() else {
|
||||
return Err(eyre!(
|
||||
"No outposts found with given token, ensure the given token corresponds to an \
|
||||
authentik Outpost"
|
||||
));
|
||||
};
|
||||
debug!(name = outpost.name, "fetched outpost configuration");
|
||||
|
||||
Ok(outpost)
|
||||
}
|
||||
|
||||
#[instrument(skip_all)]
|
||||
async fn new<O: Outpost>() -> Result<Self> {
|
||||
let api_config = api::make_config()?;
|
||||
let outpost = Self::get_outpost(&api_config).await?;
|
||||
let instance_uuid = Uuid::new_v4();
|
||||
|
||||
let m_labels = [
|
||||
("outpost_name", outpost.name.clone()),
|
||||
("outpost_type", O::OUTPOST_TYPE.to_owned()),
|
||||
("uuid", instance_uuid.to_string()),
|
||||
("version", VERSION.to_owned()),
|
||||
("build", authentik_build_hash(None)),
|
||||
];
|
||||
metrics::describe_gauge!("authentik_outpost_info", "Outpost info");
|
||||
let m_info = metrics::gauge!("authentik_outpost_info", &m_labels);
|
||||
metrics::describe_gauge!("authentik_outpost_last_update", "Time of last update");
|
||||
let m_last_update = metrics::gauge!("authentik_outpost_last_update", &m_labels);
|
||||
metrics::describe_gauge!("authentik_outpost_connection", "Connection status");
|
||||
let m_connection = metrics::gauge!("authentik_outpost_connection", &m_labels);
|
||||
|
||||
let reload_offset = Duration::from_secs(rand::random_range(0..10));
|
||||
let controller = Self {
|
||||
api_config,
|
||||
outpost: ArcSwap::from_pointee(outpost),
|
||||
instance_uuid,
|
||||
reload_offset,
|
||||
m_info,
|
||||
m_last_update,
|
||||
m_connection,
|
||||
};
|
||||
|
||||
info!(embedded = controller.is_embedded(), "outpost mode");
|
||||
debug!(?reload_offset, "HA Reload offset");
|
||||
|
||||
Ok(controller)
|
||||
}
|
||||
|
||||
fn is_embedded(&self) -> bool {
|
||||
self.outpost
|
||||
.load()
|
||||
.managed
|
||||
.as_ref()
|
||||
.and_then(|m| m.as_deref())
|
||||
.is_some_and(|m| m == "goauthentik.io/outposts/embedded")
|
||||
}
|
||||
|
||||
async fn refresh(&self) -> Result<()> {
|
||||
let outpost = Self::get_outpost(&self.api_config).await?;
|
||||
self.outpost.swap(Arc::new(outpost));
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[instrument(skip_all)]
|
||||
pub(crate) async fn start<O: Outpost + 'static>(_cli: O::Cli, tasks: &mut Tasks) -> Result<()> {
|
||||
let controller = Arc::new(OutpostController::new::<O>().await?);
|
||||
let outpost = Arc::new(O::new(Arc::clone(&controller)).await?);
|
||||
|
||||
event::start(tasks, Arc::clone(&controller), Arc::clone(&outpost))?;
|
||||
outpost.start(tasks)?;
|
||||
controller.m_info.set(1_u8);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -1,82 +0,0 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use ak_client::{
|
||||
apis::crypto_api::{
|
||||
crypto_certificatekeypairs_view_certificate_retrieve,
|
||||
crypto_certificatekeypairs_view_private_key_retrieve,
|
||||
},
|
||||
models::ProxyOutpostConfig,
|
||||
};
|
||||
use axum::Router;
|
||||
use eyre::{Result, eyre};
|
||||
use rustls::{
|
||||
crypto::CryptoProvider,
|
||||
pki_types::{CertificateDer, PrivateKeyDer, pem::PemObject as _},
|
||||
sign::CertifiedKey,
|
||||
};
|
||||
use tracing::instrument;
|
||||
use url::Url;
|
||||
|
||||
use crate::outpost::proxy::ProxyOutpost;
|
||||
|
||||
const REDIRECT_PARAM: &str = "rd";
|
||||
const CALLBACK_SIGNATURE: &str = "X-authentik-auth-callback";
|
||||
const LOGOUT_SIGNATURE: &str = "X-authentik-logout";
|
||||
|
||||
#[derive(Debug)]
|
||||
pub(super) struct Application {
|
||||
pub(super) host: String,
|
||||
pub(super) provider: ProxyOutpostConfig,
|
||||
pub(super) router: Router,
|
||||
pub(super) cert: Option<Arc<CertifiedKey>>,
|
||||
}
|
||||
|
||||
impl Application {
|
||||
#[instrument(skip_all)]
|
||||
pub(super) async fn new(outpost: &ProxyOutpost, provider: ProxyOutpostConfig) -> Result<Self> {
|
||||
let external_url = Url::parse(&provider.external_host)?;
|
||||
if !external_url.has_authority() {
|
||||
return Err(eyre!("no host in external host"));
|
||||
}
|
||||
let external_host = external_url.authority();
|
||||
|
||||
// TODO: extract this to a certificate store to avoid re-fetching the certificate every time
|
||||
let cert = if let Some(Some(kp_uuid)) = provider.certificate {
|
||||
let cert = crypto_certificatekeypairs_view_certificate_retrieve(
|
||||
&outpost.controller.api_config,
|
||||
&kp_uuid.to_string(),
|
||||
None,
|
||||
)
|
||||
.await?;
|
||||
let key = crypto_certificatekeypairs_view_private_key_retrieve(
|
||||
&outpost.controller.api_config,
|
||||
&kp_uuid.to_string(),
|
||||
None,
|
||||
)
|
||||
.await?;
|
||||
let cert_chain = CertificateDer::pem_reader_iter(cert.data.as_bytes())
|
||||
.collect::<Result<Vec<_>, _>>()?;
|
||||
let key_der = PrivateKeyDer::from_pem_reader(key.data.as_bytes())?;
|
||||
let provider = CryptoProvider::get_default().expect("no rustls provider installed");
|
||||
Some(Arc::new(CertifiedKey::new(
|
||||
cert_chain,
|
||||
provider.key_provider.load_private_key(key_der)?,
|
||||
)))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let _redirect_url = {
|
||||
let mut redirect_url = external_url.join("outpost.goauthentik.io/callback")?;
|
||||
redirect_url.set_query(Some(&format!("{CALLBACK_SIGNATURE}=true")));
|
||||
redirect_url
|
||||
};
|
||||
|
||||
Ok(Self {
|
||||
host: external_host.to_owned(),
|
||||
provider,
|
||||
router: Router::new(),
|
||||
cert,
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -1,76 +0,0 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use ak_axum::{error::Result, extract::host::Host};
|
||||
use axum::{
|
||||
extract::{Request, State},
|
||||
http::{Method, StatusCode, header::CONTENT_TYPE},
|
||||
response::{IntoResponse as _, Response},
|
||||
};
|
||||
use metrics::histogram;
|
||||
use serde_json::json;
|
||||
use tokio::time::Instant;
|
||||
use tower::util::ServiceExt as _;
|
||||
use tracing::{Instrument as _, field, info_span, instrument, trace, warn};
|
||||
|
||||
use crate::outpost::proxy::ProxyOutpost;
|
||||
|
||||
#[instrument(skip_all)]
|
||||
pub(super) async fn handle_ping(
|
||||
method: Method,
|
||||
Host(host): Host,
|
||||
State(outpost): State<Arc<ProxyOutpost>>,
|
||||
) -> Response {
|
||||
let start = Instant::now();
|
||||
histogram!(
|
||||
"authentik_outpost_proxy_request_duration_seconds",
|
||||
"outpost_name" => outpost.controller.outpost.load().name.clone(),
|
||||
"method" => method.to_string(),
|
||||
"host" => host,
|
||||
"type" => "ping",
|
||||
)
|
||||
.record(start.elapsed().as_secs_f64());
|
||||
StatusCode::NO_CONTENT.into_response()
|
||||
}
|
||||
|
||||
#[instrument(skip_all)]
|
||||
pub(super) async fn default(
|
||||
method: Method,
|
||||
Host(host): Host,
|
||||
State(outpost): State<Arc<ProxyOutpost>>,
|
||||
request: Request,
|
||||
) -> Result<Response> {
|
||||
let span = info_span!("proxy outpost request", user = field::Empty);
|
||||
let start = Instant::now();
|
||||
|
||||
let Some(app) = outpost.lookup_app(&host) else {
|
||||
trace!(headers = ?request.headers(), "tracing headers for no hostname match");
|
||||
warn!("no app for hostname");
|
||||
return Ok(Response::builder()
|
||||
.status(StatusCode::BAD_REQUEST)
|
||||
.header(CONTENT_TYPE, "application/json")
|
||||
.body(
|
||||
json!({
|
||||
"message": "no app for hostname",
|
||||
"host": host,
|
||||
"detail": format!("check the outpost settings and make sure '{host}' is included."),
|
||||
})
|
||||
.to_string()
|
||||
.into(),
|
||||
)
|
||||
.expect("infallible"));
|
||||
};
|
||||
|
||||
trace!("passing to application");
|
||||
let response = app.router.clone().oneshot(request).instrument(span).await?;
|
||||
|
||||
histogram!(
|
||||
"authentik_outpost_proxy_request_duration_seconds",
|
||||
"outpost_name" => outpost.controller.outpost.load().name.clone(),
|
||||
"method" => method.to_string(),
|
||||
"host" => host,
|
||||
"type" => "app",
|
||||
)
|
||||
.record(start.elapsed().as_secs_f64());
|
||||
|
||||
Ok(response)
|
||||
}
|
||||
@@ -1,228 +0,0 @@
|
||||
use std::{collections::HashMap, sync::Arc};
|
||||
|
||||
use ak_axum::router::wrap_router;
|
||||
use ak_client::{apis::outposts_api::outposts_proxy_list, models::ProxyMode};
|
||||
use ak_common::{Tasks, api::fetch_all, config, tls};
|
||||
use arc_swap::ArcSwap;
|
||||
use argh::FromArgs;
|
||||
use axum::Router;
|
||||
use axum_server::tls_rustls::RustlsConfig;
|
||||
use eyre::Result;
|
||||
use rustls::{
|
||||
ServerConfig,
|
||||
server::{ClientHello, ResolvesServerCert},
|
||||
sign::CertifiedKey,
|
||||
};
|
||||
use tracing::{debug, error, info, instrument, warn};
|
||||
|
||||
use crate::outpost::{Outpost, OutpostController, proxy::application::Application};
|
||||
|
||||
mod application;
|
||||
mod handlers;
|
||||
|
||||
#[derive(Debug, Default, FromArgs, PartialEq, Eq)]
|
||||
/// Run the authentik proxy outpost.
|
||||
#[argh(subcommand, name = "proxy")]
|
||||
#[expect(
|
||||
clippy::empty_structs_with_brackets,
|
||||
reason = "argh doesn't support unit structs"
|
||||
)]
|
||||
pub(crate) struct Cli {}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub(crate) struct ProxyOutpost {
|
||||
controller: Arc<OutpostController>,
|
||||
apps: ArcSwap<HashMap<String, Arc<Application>>>,
|
||||
default_cert: Arc<CertifiedKey>,
|
||||
}
|
||||
|
||||
impl Outpost for ProxyOutpost {
|
||||
type Cli = Cli;
|
||||
|
||||
const OUTPOST_TYPE: &'static str = "proxy";
|
||||
|
||||
#[instrument(skip_all)]
|
||||
async fn new(controller: Arc<OutpostController>) -> Result<Self> {
|
||||
Ok(Self {
|
||||
controller,
|
||||
apps: ArcSwap::from_pointee(HashMap::with_capacity(0)),
|
||||
default_cert: Arc::new(tls::self_signed::generate_certifiedkey()?),
|
||||
})
|
||||
}
|
||||
|
||||
fn start(self: Arc<Self>, tasks: &mut Tasks) -> Result<()> {
|
||||
let router = build_router(Arc::clone(&self));
|
||||
|
||||
for addr in config::get().listen.http.iter().copied() {
|
||||
ak_axum::server::start_plain(tasks, "proxy-outpost", router.clone(), addr)?;
|
||||
}
|
||||
|
||||
for addr in config::get().listen.https.iter().copied() {
|
||||
let resolver = Arc::clone(&self);
|
||||
let server_config = ServerConfig::builder()
|
||||
.with_no_client_auth()
|
||||
.with_cert_resolver(resolver);
|
||||
let rustls_config = RustlsConfig::from_config(Arc::new(server_config));
|
||||
ak_axum::server::start_tls(
|
||||
tasks,
|
||||
"proxy-outpost",
|
||||
router.clone(),
|
||||
addr,
|
||||
rustls_config,
|
||||
)?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[instrument(skip_all)]
|
||||
async fn refresh(&self) -> Result<()> {
|
||||
debug!(
|
||||
outpost_pk = %self.controller.outpost.load().pk,
|
||||
"requesting providers for outpost"
|
||||
);
|
||||
|
||||
let providers = fetch_all(
|
||||
|page| {
|
||||
outposts_proxy_list(
|
||||
&self.controller.api_config,
|
||||
None,
|
||||
None,
|
||||
Some(page),
|
||||
Some(100_i32),
|
||||
None,
|
||||
)
|
||||
},
|
||||
|r| &r.pagination,
|
||||
|r| r.results,
|
||||
)
|
||||
.await
|
||||
.inspect_err(|err| error!(?err, "failed to fetch providers"))?;
|
||||
debug!(count = providers.len(), "fetched providers");
|
||||
|
||||
if providers.is_empty() && !self.controller.is_embedded() {
|
||||
warn!(
|
||||
"no providers assigned to this outpost, check outpost configuration in authentik"
|
||||
);
|
||||
}
|
||||
|
||||
for (i, provider) in providers.iter().enumerate() {
|
||||
debug!(
|
||||
index = i,
|
||||
name = provider.name,
|
||||
external_host = provider.external_host,
|
||||
assigned_to_app = provider.assigned_application_name,
|
||||
"provider details"
|
||||
);
|
||||
}
|
||||
|
||||
let mut apps = HashMap::with_capacity(providers.len());
|
||||
|
||||
for provider in providers {
|
||||
let name = provider.name.clone();
|
||||
let Ok(application) = Application::new(self, provider)
|
||||
.await
|
||||
.inspect_err(|err| warn!(?err, "failed to setup application, skipping provider"))
|
||||
else {
|
||||
continue;
|
||||
};
|
||||
info!(name, host = application.host, "loaded application");
|
||||
|
||||
apps.insert(application.host.clone(), Arc::new(application));
|
||||
}
|
||||
|
||||
self.apps.store(Arc::new(apps));
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn end_session(&self, _event: super::event::EventSessionEnd) -> Result<()> {
|
||||
// todo!()
|
||||
warn!(?_event, "removing session");
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl ResolvesServerCert for ProxyOutpost {
|
||||
fn resolve(&self, client_hello: ClientHello<'_>) -> Option<Arc<CertifiedKey>> {
|
||||
if let Some(server_name) = client_hello.server_name()
|
||||
&& let Some(app) = self.apps.load().get(server_name)
|
||||
&& let Some(cert) = &app.cert
|
||||
{
|
||||
return Some(Arc::clone(cert));
|
||||
}
|
||||
Some(Arc::clone(&self.default_cert))
|
||||
}
|
||||
|
||||
fn only_raw_public_keys(&self) -> bool {
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
impl ProxyOutpost {
|
||||
#[instrument(skip(self))]
|
||||
fn lookup_app(&self, host: &str) -> Option<Arc<Application>> {
|
||||
let apps = self.apps.load();
|
||||
|
||||
// If we only have a single app, host name switching doesn't matter.
|
||||
if apps.len() == 1
|
||||
&& let Some(app) = apps.values().next()
|
||||
{
|
||||
debug!(app = app.provider.name, "found a single app, using it");
|
||||
return Some(Arc::clone(app));
|
||||
}
|
||||
|
||||
if let Some(app) = apps.get(host) {
|
||||
debug!(app = app.provider.name, "found app based direct host match");
|
||||
return Some(Arc::clone(app));
|
||||
}
|
||||
|
||||
// For forward_auth_domain, we don't have a direct app to domain relationship.
|
||||
// Check through all apps, and check how much of their cookie domain matches the host.
|
||||
// Return the application that has the longest match.
|
||||
let mut longest_match = None;
|
||||
let mut longest_len = 0_usize;
|
||||
|
||||
for app in apps.values() {
|
||||
if app.provider.mode != Some(ProxyMode::ForwardDomain) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if let Some(cookie_domain) = app.provider.cookie_domain.as_deref() {
|
||||
// Check if the cookie domain has a leading period for a wildcard.
|
||||
// This will decrease the weight of a wildcard domain, but a request to example.com
|
||||
// with the cookie domain set to example.com will still be routed correctly.
|
||||
let domain = cookie_domain.trim_start_matches('.');
|
||||
|
||||
if host.ends_with(domain) && domain.len() > longest_len {
|
||||
longest_len = domain.len();
|
||||
longest_match = Some(Arc::clone(app));
|
||||
}
|
||||
// For forward_auth_domain, we need to response on the external domain too.
|
||||
if app.provider.external_host == host {
|
||||
debug!(app = app.provider.name, "found app based on external_host");
|
||||
return Some(Arc::clone(app));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(app) = &longest_match {
|
||||
debug!(app = app.provider.name, "found app based on cookie domain");
|
||||
}
|
||||
|
||||
longest_match
|
||||
}
|
||||
}
|
||||
|
||||
fn build_router(outpost: Arc<ProxyOutpost>) -> Router {
|
||||
wrap_router(
|
||||
Router::new()
|
||||
.nest(
|
||||
"/outpost.goauthentik.io/ping",
|
||||
Router::new().fallback(handlers::handle_ping),
|
||||
)
|
||||
.fallback(handlers::default)
|
||||
.with_state(outpost),
|
||||
true,
|
||||
)
|
||||
}
|
||||
11
uv.lock
generated
11
uv.lock
generated
@@ -1143,7 +1143,7 @@ requires-dist = [
|
||||
{ name = "django", specifier = ">=4.2,<6.0" },
|
||||
{ name = "django-pglock", specifier = ">=1.7,<2" },
|
||||
{ name = "django-pgtrigger", specifier = ">=4,<5" },
|
||||
{ name = "dramatiq", specifier = ">=1.17,<1.18" },
|
||||
{ name = "dramatiq", specifier = ">=2,<3" },
|
||||
{ name = "structlog", specifier = ">=25,<26" },
|
||||
{ name = "tenacity", specifier = ">=9,<10" },
|
||||
]
|
||||
@@ -1381,14 +1381,11 @@ wheels = [
|
||||
|
||||
[[package]]
|
||||
name = "dramatiq"
|
||||
version = "1.17.1"
|
||||
version = "2.1.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "prometheus-client" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/c6/7a/6792ddc64a77d22bfd97261b751a7a76cf2f9d62edc59aafb679ac48b77d/dramatiq-1.17.1.tar.gz", hash = "sha256:2675d2f57e0d82db3a7d2a60f1f9c536365349db78c7f8d80a63e4c54697647a", size = 99071, upload-time = "2024-10-26T05:09:28.283Z" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/22/69/02b54e3fc4fe75721b322bc578054b4f03cec258ba614fa98a1a5bbe1efe/dramatiq-2.1.0.tar.gz", hash = "sha256:cf81550729de6cf64234b05bd63970645654aaf38967faa7a2b6e401384bb090", size = 105444, upload-time = "2026-03-03T11:22:10.067Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/ee/36/925c7afd5db4f1a3f00676b9c3c58f31ff7ae29a347282d86c8d429280a5/dramatiq-1.17.1-py3-none-any.whl", hash = "sha256:951cdc334478dff8e5150bb02a6f7a947d215ee24b5aedaf738eff20e17913df", size = 120382, upload-time = "2024-10-26T05:09:26.436Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/c2/91/422960c8c415fd31ca1519d71d6f7e4bcabb2cdcc5872f784467e9fe7237/dramatiq-2.1.0-py3-none-any.whl", hash = "sha256:3ef940c2815722d3679aed79ef96c805f02fd33d4361529b2de30f01511ca44d", size = 125543, upload-time = "2026-03-03T11:22:08.664Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
||||
267
web/packages/lex/index.js
vendored
267
web/packages/lex/index.js
vendored
@@ -7,152 +7,186 @@
|
||||
*/
|
||||
|
||||
/**
|
||||
* @typedef {(this: Lexer, chr: string) => any} DefunctFunction
|
||||
* A token produced by a {@link LexerAction}. The lexer is agnostic to the
|
||||
* concrete token shape; consumers pick whatever representation suits them.
|
||||
*
|
||||
* @typedef {unknown} Token
|
||||
*/
|
||||
|
||||
/**
|
||||
* @typedef {(this: Lexer, ...args: RegExpExecArray) => string | string[] | undefined} RuleAction
|
||||
* A rule action. Invoked with the regex match (full match followed by capture
|
||||
* groups) bound to the owning {@link Lexer} so it can read or set `state`,
|
||||
* `index`, and `reject`.
|
||||
*
|
||||
* Return values:
|
||||
* - `null` (or `undefined` from an implicit return) — discard the match and continue scanning.
|
||||
* - a single token — yield it from {@link Lexer.lex}.
|
||||
* - an array of tokens — yield the first; queue the rest for subsequent calls.
|
||||
*
|
||||
* @callback LexerAction
|
||||
* @this {Lexer}
|
||||
* @param {...string} match
|
||||
* @returns {Token | Token[] | null | void}
|
||||
*/
|
||||
|
||||
/**
|
||||
* @typedef {Object} Rule
|
||||
* @property {RegExp} pattern
|
||||
* @property {boolean} global
|
||||
* @property {RuleAction} action
|
||||
* @property {number[]} start
|
||||
* @typedef {object} LexerRule
|
||||
* @property {RegExp} pattern Sticky-compiled pattern used to probe the input.
|
||||
* @property {boolean} global Whether the user-supplied pattern was global.
|
||||
* @property {LexerAction} action
|
||||
* @property {number[]} start States in which the rule is active. `[0]` is the default state; an empty array means "any state".
|
||||
*/
|
||||
|
||||
/**
|
||||
* @typedef {Object} Match
|
||||
* @typedef {object} LexerMatch
|
||||
* @property {RegExpExecArray} result
|
||||
* @property {RuleAction} action
|
||||
* @property {LexerAction} action
|
||||
* @property {number} length
|
||||
* @property {boolean} global Whether the producing rule was declared with the `g` flag.
|
||||
*/
|
||||
|
||||
/**
|
||||
* Handler invoked when no rule matches at the current position.
|
||||
*
|
||||
* @callback DefunctHandler
|
||||
* @this {Lexer}
|
||||
* @param {string} chr The unexpected character.
|
||||
* @returns {Token | Token[] | null | void}
|
||||
*/
|
||||
|
||||
/**
|
||||
* @type {DefunctHandler}
|
||||
*/
|
||||
function defaultDefunct(chr) {
|
||||
throw new Error(`Unexpected character at index ${this.index - 1}: ${chr}`);
|
||||
}
|
||||
|
||||
/**
|
||||
* Lexer class for tokenizing input strings.
|
||||
*/
|
||||
export class Lexer {
|
||||
/**
|
||||
* @type {string[]}
|
||||
*/
|
||||
tokens = [];
|
||||
/**
|
||||
* @type {Rule[]}
|
||||
*/
|
||||
rules = [];
|
||||
/**
|
||||
* @type {number}
|
||||
*/
|
||||
remove = 0;
|
||||
/**
|
||||
* Current lexer state. Rules whose `start` array contains this value (or
|
||||
* is empty) are eligible to match. Odd-numbered states are also matched
|
||||
* by rules declared with `start: [0]`, mirroring flex's inclusive states.
|
||||
*
|
||||
* @type {number}
|
||||
*/
|
||||
state = 0;
|
||||
/**
|
||||
* @type {number}
|
||||
*/
|
||||
|
||||
/** @type {number} */
|
||||
index = 0;
|
||||
/**
|
||||
* @type {string}
|
||||
*/
|
||||
|
||||
/** @type {string} */
|
||||
input = "";
|
||||
|
||||
/**
|
||||
* @param {DefunctFunction} [defunct]
|
||||
* When set to `true` from inside an action, the current match is rolled
|
||||
* back and the next-best match is tried instead.
|
||||
*
|
||||
* @type {boolean}
|
||||
*/
|
||||
reject = false;
|
||||
|
||||
/** @type {LexerRule[]} */
|
||||
#rules = [];
|
||||
|
||||
/** @type {Token[]} */
|
||||
#tokens = [];
|
||||
|
||||
/** @type {number} */
|
||||
#remove = 0;
|
||||
|
||||
/** @type {DefunctHandler} */
|
||||
#defunct;
|
||||
|
||||
/**
|
||||
* @param {DefunctHandler} [defunct] Optional handler for unexpected characters.
|
||||
*/
|
||||
constructor(defunct) {
|
||||
defunct ||= function (chr) {
|
||||
throw new Error("Unexpected character at index " + (this.index - 1) + ": " + chr);
|
||||
};
|
||||
|
||||
this.defunct = defunct;
|
||||
this.#defunct = typeof defunct === "function" ? defunct : defaultDefunct;
|
||||
}
|
||||
|
||||
/**
|
||||
* Add a lexing rule.
|
||||
* Register a tokenization rule.
|
||||
*
|
||||
* @param {RegExp} pattern
|
||||
* @param {RuleAction} action
|
||||
* @param {number[]} [start]
|
||||
* @returns {Lexer}
|
||||
* @param {LexerAction} action
|
||||
* @param {number[]} [start] States in which the rule is active. Defaults to `[0]`.
|
||||
* @returns {this}
|
||||
*/
|
||||
addRule = (pattern, action, start) => {
|
||||
addRule(pattern, action, start) {
|
||||
const global = pattern.global;
|
||||
|
||||
if (!global || !pattern.sticky) {
|
||||
let flags = "gy";
|
||||
|
||||
if (pattern.multiline) flags += "m";
|
||||
if (pattern.ignoreCase) flags += "i";
|
||||
if (pattern.unicode) flags += "u";
|
||||
pattern = new RegExp(pattern.source, flags);
|
||||
}
|
||||
|
||||
if (!Array.isArray(start)) start = [0];
|
||||
|
||||
this.rules.push({
|
||||
pattern: pattern,
|
||||
global: global,
|
||||
action: action,
|
||||
start: start,
|
||||
this.#rules.push({
|
||||
pattern,
|
||||
global,
|
||||
action,
|
||||
start: Array.isArray(start) ? start : [0],
|
||||
});
|
||||
|
||||
return this;
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Set the input string for lexing.
|
||||
* Reset the lexer and load a new input string.
|
||||
*
|
||||
* @param {string} input
|
||||
* @returns {Lexer}
|
||||
* @returns {this}
|
||||
*/
|
||||
setInput = (input) => {
|
||||
this.remove = 0;
|
||||
setInput(input) {
|
||||
this.#remove = 0;
|
||||
this.state = 0;
|
||||
this.index = 0;
|
||||
this.tokens.length = 0;
|
||||
this.#tokens.length = 0;
|
||||
this.input = input;
|
||||
return this;
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Lex the next token from the input.
|
||||
* Produce the next token from the input, or `null` once exhausted.
|
||||
*
|
||||
* @returns {string | string[] | undefined}
|
||||
* @returns {Token | null}
|
||||
*/
|
||||
lex = () => {
|
||||
if (this.tokens.length) return this.tokens.shift();
|
||||
lex() {
|
||||
if (this.#tokens.length) return /** @type {Token} */ (this.#tokens.shift());
|
||||
|
||||
this.reject = true;
|
||||
|
||||
while (this.index <= this.input.length) {
|
||||
const matches = this.scan().splice(this.remove);
|
||||
const matches = this.#scan().splice(this.#remove);
|
||||
const index = this.index;
|
||||
|
||||
while (matches.length) {
|
||||
if (!this.reject) {
|
||||
break;
|
||||
}
|
||||
const match = matches.shift();
|
||||
if (!this.reject) break;
|
||||
|
||||
if (!match) break;
|
||||
|
||||
const result = match.result;
|
||||
const length = match.length;
|
||||
const match = /** @type {LexerMatch} */ (matches.shift());
|
||||
const { result, length } = match;
|
||||
this.index += length;
|
||||
this.reject = false;
|
||||
this.remove++;
|
||||
this.#remove++;
|
||||
|
||||
let token = match.action.apply(this, result);
|
||||
let token = match.action.apply(
|
||||
this,
|
||||
/** @type {string[]} */ (/** @type {unknown} */ (result)),
|
||||
);
|
||||
|
||||
if (this.reject) {
|
||||
this.index = result.index;
|
||||
} else if (Array.isArray(token)) {
|
||||
this.tokens = token.slice(1);
|
||||
token = token[0];
|
||||
} else {
|
||||
if (length) this.remove = 0;
|
||||
} else if (token !== null && token !== undefined) {
|
||||
if (Array.isArray(token)) {
|
||||
this.#tokens = token.slice(1);
|
||||
token = token[0];
|
||||
}
|
||||
if (length) this.#remove = 0;
|
||||
return token;
|
||||
}
|
||||
}
|
||||
@@ -161,79 +195,82 @@ export class Lexer {
|
||||
|
||||
if (index < input.length) {
|
||||
if (this.reject) {
|
||||
this.remove = 0;
|
||||
const token = this.defunct(input.charAt(this.index++));
|
||||
if (typeof token !== "undefined") {
|
||||
this.#remove = 0;
|
||||
const token = this.#defunct(input.charAt(this.index++));
|
||||
if (token !== null && token !== undefined) {
|
||||
if (Array.isArray(token)) {
|
||||
this.tokens = token.slice(1);
|
||||
this.#tokens = token.slice(1);
|
||||
return token[0];
|
||||
}
|
||||
|
||||
return token;
|
||||
}
|
||||
} else {
|
||||
if (this.index !== index) this.remove = 0;
|
||||
if (this.index !== index) this.#remove = 0;
|
||||
this.reject = true;
|
||||
}
|
||||
} else if (matches.length) this.reject = true;
|
||||
else break;
|
||||
} else if (matches.length) {
|
||||
this.reject = true;
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
return null;
|
||||
}
|
||||
|
||||
/**
|
||||
* Scan the input for matches.
|
||||
* Probe every state-eligible rule at the current position, returning the
|
||||
* matches sorted by length (longest first), with global rules pinned
|
||||
* after non-global ones to preserve flex's "longest non-global wins"
|
||||
* tie-breaking.
|
||||
*
|
||||
* @returns {Match[]}
|
||||
* @returns {LexerMatch[]}
|
||||
*/
|
||||
scan = () => {
|
||||
/**
|
||||
* @type {Match[]}
|
||||
*/
|
||||
#scan() {
|
||||
/** @type {LexerMatch[]} */
|
||||
const matches = [];
|
||||
let index = 0;
|
||||
|
||||
const state = this.state;
|
||||
const lastIndex = this.index;
|
||||
const input = this.input;
|
||||
|
||||
for (let i = 0, length = this.rules.length; i < length; i++) {
|
||||
const rule = this.rules[i];
|
||||
for (const rule of this.#rules) {
|
||||
const start = rule.start;
|
||||
const states = start.length;
|
||||
const eligible =
|
||||
!states || start.indexOf(state) >= 0 || (state % 2 && states === 1 && !start[0]);
|
||||
|
||||
if (!states || start.indexOf(state) >= 0 || (state % 2 && states === 1 && !start[0])) {
|
||||
const pattern = rule.pattern;
|
||||
pattern.lastIndex = lastIndex;
|
||||
const result = pattern.exec(input);
|
||||
if (!eligible) continue;
|
||||
|
||||
if (!result || result.index !== lastIndex) {
|
||||
continue;
|
||||
}
|
||||
const pattern = rule.pattern;
|
||||
pattern.lastIndex = lastIndex;
|
||||
const result = pattern.exec(input);
|
||||
|
||||
let j = matches.push({
|
||||
result: result,
|
||||
action: rule.action,
|
||||
length: result[0].length,
|
||||
});
|
||||
if (!result || result.index !== lastIndex) continue;
|
||||
|
||||
if (rule.global) {
|
||||
index = j;
|
||||
}
|
||||
let j = matches.push({
|
||||
result,
|
||||
action: rule.action,
|
||||
length: result[0].length,
|
||||
global: rule.global,
|
||||
});
|
||||
|
||||
while (--j > index) {
|
||||
const k = j - 1;
|
||||
while (--j > 0) {
|
||||
const k = j - 1;
|
||||
const cur = matches[j];
|
||||
const prev = matches[k];
|
||||
const longer = cur.length > prev.length;
|
||||
const tieFavorsCur = cur.length === prev.length && prev.global && !cur.global;
|
||||
|
||||
if (matches[j].length > matches[k].length) {
|
||||
const temple = matches[j];
|
||||
matches[j] = matches[k];
|
||||
matches[k] = temple;
|
||||
}
|
||||
}
|
||||
if (!longer && !tieFavorsCur) break;
|
||||
|
||||
matches[j] = prev;
|
||||
matches[k] = cur;
|
||||
}
|
||||
}
|
||||
|
||||
return matches;
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
export default Lexer;
|
||||
|
||||
@@ -23,6 +23,7 @@ import { certificateProvider, certificateSelector } from "#admin/brands/Certific
|
||||
|
||||
import {
|
||||
Application,
|
||||
AuthenticationEnum,
|
||||
Brand,
|
||||
CoreApi,
|
||||
CoreApplicationsListRequest,
|
||||
@@ -31,7 +32,6 @@ import {
|
||||
FlowsApi,
|
||||
UsageEnum,
|
||||
} from "@goauthentik/api";
|
||||
import { AuthenticationEnum } from "@goauthentik/api/dist/models/AuthenticationEnum.js";
|
||||
|
||||
import YAML from "yaml";
|
||||
|
||||
|
||||
@@ -17,6 +17,7 @@ import { DesignationToLabel, LayoutToLabel } from "#admin/flows/utils";
|
||||
import { policyEngineModes } from "#admin/policies/PolicyEngineModes";
|
||||
|
||||
import {
|
||||
AuthenticationEnum,
|
||||
DeniedActionEnum,
|
||||
Flow,
|
||||
FlowDesignationEnum,
|
||||
@@ -24,7 +25,6 @@ import {
|
||||
FlowsApi,
|
||||
UsageEnum,
|
||||
} from "@goauthentik/api";
|
||||
import { AuthenticationEnum } from "@goauthentik/api/dist/models/AuthenticationEnum.js";
|
||||
|
||||
import { msg } from "@lit/localize";
|
||||
import { html, TemplateResult } from "lit";
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import { ModelForm } from "#elements/forms/ModelForm";
|
||||
|
||||
import type { Stage } from "@goauthentik/api/dist/models/Stage";
|
||||
import type { Stage } from "@goauthentik/api";
|
||||
|
||||
import { msg } from "@lit/localize";
|
||||
|
||||
|
||||
@@ -19,10 +19,10 @@ import {
|
||||
CoreGroupsListRequest,
|
||||
Group,
|
||||
StagesApi,
|
||||
UserCreationModeEnum,
|
||||
UserTypeEnum,
|
||||
UserWriteStage,
|
||||
} from "@goauthentik/api";
|
||||
import { UserCreationModeEnum } from "@goauthentik/api/dist/models/UserCreationModeEnum.js";
|
||||
|
||||
import { msg } from "@lit/localize";
|
||||
import { html, TemplateResult } from "lit";
|
||||
|
||||
317
web/test/unit/lexer.test.ts
Normal file
317
web/test/unit/lexer.test.ts
Normal file
@@ -0,0 +1,317 @@
|
||||
/* eslint-disable func-names */
|
||||
import { Lexer } from "lex";
|
||||
import { describe, expect, it, vi } from "vitest";
|
||||
|
||||
const drain = (lexer: Lexer): unknown[] => {
|
||||
const out: unknown[] = [];
|
||||
let token: unknown;
|
||||
|
||||
while ((token = lexer.lex()) !== null) {
|
||||
out.push(token);
|
||||
}
|
||||
return out;
|
||||
};
|
||||
|
||||
describe("Lexer", () => {
|
||||
describe("addRule", () => {
|
||||
it("returns the lexer for chaining", () => {
|
||||
const lexer = new Lexer();
|
||||
expect(lexer.addRule(/a/, () => "a")).toBe(lexer);
|
||||
});
|
||||
|
||||
it("preserves multiline, ignoreCase, and unicode flags when re-compiling", () => {
|
||||
const lexer = new Lexer(() => null);
|
||||
const seen: string[] = [];
|
||||
|
||||
lexer.addRule(/^a/im, (m) => {
|
||||
seen.push(m);
|
||||
});
|
||||
lexer.setInput("A\nA");
|
||||
|
||||
drain(lexer);
|
||||
expect(seen).toEqual(["A", "A"]);
|
||||
});
|
||||
|
||||
it("matches unicode patterns", () => {
|
||||
const lexer = new Lexer();
|
||||
lexer.addRule(/\p{Letter}+/u, (m) => m);
|
||||
lexer.setInput("café");
|
||||
|
||||
expect(lexer.lex()).toBe("café");
|
||||
});
|
||||
});
|
||||
|
||||
describe("setInput", () => {
|
||||
it("resets state, index, and pending tokens", () => {
|
||||
const lexer = new Lexer();
|
||||
lexer.addRule(/./, (c) => c);
|
||||
|
||||
lexer.setInput("ab");
|
||||
expect(lexer.lex()).toBe("a");
|
||||
lexer.state = 7;
|
||||
|
||||
lexer.setInput("xy");
|
||||
expect(lexer.state).toBe(0);
|
||||
expect(lexer.index).toBe(0);
|
||||
expect(lexer.lex()).toBe("x");
|
||||
expect(lexer.lex()).toBe("y");
|
||||
});
|
||||
|
||||
it("returns the lexer for chaining", () => {
|
||||
const lexer = new Lexer();
|
||||
expect(lexer.setInput("")).toBe(lexer);
|
||||
});
|
||||
});
|
||||
|
||||
describe("tokenization", () => {
|
||||
it("tokenizes a simple expression", () => {
|
||||
const lexer = new Lexer();
|
||||
lexer
|
||||
.addRule(/\s+/, () => null)
|
||||
.addRule(/[a-zA-Z]+/, (m) => ({ type: "ident", value: m }))
|
||||
.addRule(/\d+/, (m) => ({ type: "num", value: Number(m) }))
|
||||
.addRule(/[+\-*/]/, (m) => ({ type: "op", value: m }));
|
||||
|
||||
lexer.setInput("foo + 12 * bar");
|
||||
expect(drain(lexer)).toEqual([
|
||||
{ type: "ident", value: "foo" },
|
||||
{ type: "op", value: "+" },
|
||||
{ type: "num", value: 12 },
|
||||
{ type: "op", value: "*" },
|
||||
{ type: "ident", value: "bar" },
|
||||
]);
|
||||
});
|
||||
|
||||
it("skips matches whose action returns null", () => {
|
||||
const lexer = new Lexer();
|
||||
lexer.addRule(/\s+/, () => null).addRule(/\S+/, (m) => m);
|
||||
|
||||
lexer.setInput(" foo bar ");
|
||||
expect(drain(lexer)).toEqual(["foo", "bar"]);
|
||||
});
|
||||
|
||||
it("returns null once the input is exhausted", () => {
|
||||
const lexer = new Lexer();
|
||||
lexer.addRule(/./, (c) => c);
|
||||
lexer.setInput("a");
|
||||
|
||||
expect(lexer.lex()).toBe("a");
|
||||
expect(lexer.lex()).toBeNull();
|
||||
expect(lexer.lex()).toBeNull();
|
||||
});
|
||||
|
||||
it("passes capture groups to the action", () => {
|
||||
const lexer = new Lexer();
|
||||
const calls: string[][] = [];
|
||||
|
||||
lexer.addRule(/(\w+)=(\w+)/, (...args) => {
|
||||
calls.push(args);
|
||||
return args[0];
|
||||
});
|
||||
|
||||
lexer.setInput("foo=bar");
|
||||
lexer.lex();
|
||||
expect(calls).toEqual([["foo=bar", "foo", "bar"]]);
|
||||
});
|
||||
|
||||
it("binds `this` to the lexer inside the action", () => {
|
||||
const lexer = new Lexer();
|
||||
let captured: Lexer | undefined;
|
||||
|
||||
lexer.addRule(/a/, function () {
|
||||
// eslint-disable-next-line consistent-this, @typescript-eslint/no-this-alias
|
||||
captured = this;
|
||||
return "a";
|
||||
});
|
||||
|
||||
lexer.setInput("a");
|
||||
lexer.lex();
|
||||
expect(captured).toBe(lexer);
|
||||
});
|
||||
});
|
||||
|
||||
describe("longest-match tie-breaking", () => {
|
||||
it("prefers the longer non-global match", () => {
|
||||
const lexer = new Lexer();
|
||||
lexer.addRule(/if/, () => "KW_IF").addRule(/iffy/, () => "IDENT_IFFY");
|
||||
|
||||
lexer.setInput("iffy");
|
||||
expect(lexer.lex()).toBe("IDENT_IFFY");
|
||||
});
|
||||
|
||||
it("treats global rules as fallbacks behind non-global rules of the same length", () => {
|
||||
const lexer = new Lexer();
|
||||
lexer.addRule(/[a-z]+/g, (m) => `g:${m}`).addRule(/foo/, (m) => `s:${m}`);
|
||||
|
||||
lexer.setInput("foo");
|
||||
expect(lexer.lex()).toBe("s:foo");
|
||||
});
|
||||
});
|
||||
|
||||
describe("multi-token return", () => {
|
||||
it("yields the first token immediately and queues the rest", () => {
|
||||
const lexer = new Lexer();
|
||||
lexer.addRule(/a/, () => ["A1", "A2", "A3"]);
|
||||
|
||||
lexer.setInput("a");
|
||||
expect(lexer.lex()).toBe("A1");
|
||||
expect(lexer.lex()).toBe("A2");
|
||||
expect(lexer.lex()).toBe("A3");
|
||||
expect(lexer.lex()).toBeNull();
|
||||
});
|
||||
|
||||
it("drains the queue before scanning further input", () => {
|
||||
const lexer = new Lexer();
|
||||
lexer.addRule(/a/, () => ["A1", "A2"]).addRule(/b/, () => "B");
|
||||
|
||||
lexer.setInput("ab");
|
||||
expect(drain(lexer)).toEqual(["A1", "A2", "B"]);
|
||||
});
|
||||
});
|
||||
|
||||
describe("reject", () => {
|
||||
it("falls through to the next-best match when an action sets reject", () => {
|
||||
const lexer = new Lexer();
|
||||
const order: string[] = [];
|
||||
|
||||
lexer
|
||||
.addRule(/foo/, function () {
|
||||
order.push("first");
|
||||
this.reject = true;
|
||||
})
|
||||
.addRule(/foo/, () => {
|
||||
order.push("second");
|
||||
return "FOO";
|
||||
});
|
||||
|
||||
lexer.setInput("foo");
|
||||
expect(lexer.lex()).toBe("FOO");
|
||||
expect(order).toEqual(["first", "second"]);
|
||||
});
|
||||
|
||||
it("rolls back the lexer index when an action rejects", () => {
|
||||
const lexer = new Lexer();
|
||||
|
||||
lexer
|
||||
.addRule(/abc/, function () {
|
||||
this.reject = true;
|
||||
})
|
||||
.addRule(/a/, (m) => m);
|
||||
|
||||
lexer.setInput("abc");
|
||||
expect(lexer.lex()).toBe("a");
|
||||
expect(lexer.index).toBe(1);
|
||||
});
|
||||
});
|
||||
|
||||
describe("defunct handling", () => {
|
||||
it("throws by default on unexpected characters", () => {
|
||||
const lexer = new Lexer();
|
||||
lexer.addRule(/a/, (m) => m);
|
||||
|
||||
lexer.setInput("a@");
|
||||
expect(lexer.lex()).toBe("a");
|
||||
expect(() => lexer.lex()).toThrow(/Unexpected character at index 1: @/);
|
||||
});
|
||||
|
||||
it("invokes a custom defunct handler with the offending character", () => {
|
||||
const defunct = vi.fn((chr: string) => `?${chr}`);
|
||||
const lexer = new Lexer(defunct);
|
||||
lexer.addRule(/a/, (m) => m);
|
||||
|
||||
lexer.setInput("a@b");
|
||||
expect(drain(lexer)).toEqual(["a", "?@", "?b"]);
|
||||
expect(defunct).toHaveBeenCalledTimes(2);
|
||||
expect(defunct.mock.calls[0]?.[0]).toBe("@");
|
||||
});
|
||||
|
||||
it("ignores defunct return values that are null", () => {
|
||||
const lexer = new Lexer((_chr) => null);
|
||||
lexer.addRule(/a/, (m) => m);
|
||||
|
||||
lexer.setInput("@@a");
|
||||
expect(lexer.lex()).toBe("a");
|
||||
expect(lexer.lex()).toBeNull();
|
||||
});
|
||||
|
||||
it("supports array returns from the defunct handler", () => {
|
||||
const lexer = new Lexer((chr) => [`bad:${chr}`, "extra"]);
|
||||
lexer.addRule(/a/, (m) => m);
|
||||
|
||||
lexer.setInput("@");
|
||||
expect(lexer.lex()).toBe("bad:@");
|
||||
expect(lexer.lex()).toBe("extra");
|
||||
});
|
||||
|
||||
it("falls back to the default handler when given a non-function", () => {
|
||||
// @ts-expect-error — exercising the runtime guard
|
||||
const lexer = new Lexer("not a function");
|
||||
lexer.setInput("@");
|
||||
expect(() => lexer.lex()).toThrow(/Unexpected character/);
|
||||
});
|
||||
});
|
||||
|
||||
describe("states", () => {
|
||||
it("only fires rules whose start array includes the current state", () => {
|
||||
const lexer = new Lexer();
|
||||
|
||||
lexer
|
||||
.addRule(/"/, function () {
|
||||
this.state = 2;
|
||||
})
|
||||
.addRule(
|
||||
/"/,
|
||||
function () {
|
||||
this.state = 0;
|
||||
},
|
||||
[2],
|
||||
)
|
||||
.addRule(/[^"]+/, (m) => `STR:${m}`, [2])
|
||||
.addRule(/[a-z]+/, (m) => `ID:${m}`);
|
||||
|
||||
lexer.setInput('foo"hello"bar');
|
||||
expect(drain(lexer)).toEqual(["ID:foo", "STR:hello", "ID:bar"]);
|
||||
});
|
||||
|
||||
it("treats an empty start array as 'active in any state'", () => {
|
||||
const lexer = new Lexer();
|
||||
|
||||
lexer
|
||||
.addRule(/!/, function () {
|
||||
this.state = 5;
|
||||
return "BANG";
|
||||
})
|
||||
.addRule(/./, (m) => m, []);
|
||||
|
||||
lexer.setInput("a!b");
|
||||
expect(drain(lexer)).toEqual(["a", "BANG", "b"]);
|
||||
});
|
||||
|
||||
it("matches inclusive `[0]` rules from odd-numbered states", () => {
|
||||
const lexer = new Lexer();
|
||||
|
||||
lexer
|
||||
.addRule(/#/, function () {
|
||||
this.state = 1;
|
||||
})
|
||||
.addRule(/[a-z]+/, (m) => m);
|
||||
|
||||
lexer.setInput("ab#cd");
|
||||
expect(drain(lexer)).toEqual(["ab", "cd"]);
|
||||
});
|
||||
|
||||
it("does not match `[0]` rules from even non-zero states", () => {
|
||||
const lexer = new Lexer();
|
||||
|
||||
lexer
|
||||
.addRule(/#/, function () {
|
||||
this.state = 2;
|
||||
})
|
||||
.addRule(/[a-z]+/, (m) => m);
|
||||
|
||||
lexer.setInput("ab#cd");
|
||||
expect(lexer.lex()).toBe("ab");
|
||||
expect(() => lexer.lex()).toThrow(/Unexpected character/);
|
||||
});
|
||||
});
|
||||
});
|
||||
35
web/test/unit/tsconfig.json
Normal file
35
web/test/unit/tsconfig.json
Normal file
@@ -0,0 +1,35 @@
|
||||
// @file TSConfig used by the web package during development.
|
||||
{
|
||||
"extends": "@goauthentik/tsconfig",
|
||||
"compilerOptions": {
|
||||
"types": ["node"],
|
||||
"checkJs": true,
|
||||
"allowJs": true,
|
||||
"composite": true,
|
||||
"resolveJsonModule": true,
|
||||
"allowSyntheticDefaultImports": true,
|
||||
"emitDeclarationOnly": true,
|
||||
"target": "esnext",
|
||||
"module": "preserve",
|
||||
"moduleResolution": "bundler",
|
||||
"lib": ["DOM", "DOM.Iterable", "ESNext"],
|
||||
|
||||
"noUncheckedIndexedAccess": true
|
||||
},
|
||||
"include": ["./**/*", "../**/*"],
|
||||
"exclude": [
|
||||
// ---
|
||||
"**/out/**/*",
|
||||
"**/dist/**/*",
|
||||
"storybook-static",
|
||||
// TODO: @lit/localize-tools v0.8.0 has a nullish coalescing typing error.
|
||||
// Remove when we upgrade past that.
|
||||
"scripts/pseudolocalize.mjs",
|
||||
"scripts/build-locales.mjs"
|
||||
],
|
||||
"references": [
|
||||
{
|
||||
"path": "../.."
|
||||
}
|
||||
]
|
||||
}
|
||||
@@ -41,9 +41,12 @@ export default defineConfig({
|
||||
projects: [
|
||||
{
|
||||
test: {
|
||||
include: ["./unit/**/*.{test,spec}.ts", "**/*.unit.{test,spec}.ts"],
|
||||
name: "unit",
|
||||
include: ["./test/unit/**/*.{test,spec}.ts", "**/*.unit.{test,spec}.ts"],
|
||||
name: "Unit Tests",
|
||||
environment: "node",
|
||||
typecheck: {
|
||||
tsconfig: "./tsconfig.unit.json",
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -51,7 +54,7 @@ export default defineConfig({
|
||||
setupFiles: ["./test/lit/setup.js"],
|
||||
|
||||
include: ["./browser/**/*.{test,spec}.ts", "**/*.browser.{test,spec}.ts"],
|
||||
name: "browser",
|
||||
name: "Browser Tests",
|
||||
browser: {
|
||||
enabled: true,
|
||||
provider: playwright(),
|
||||
|
||||
Reference in New Issue
Block a user